Source file src/crypto/internal/nistec/nistec_test.go

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package nistec_test
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/elliptic"
    10  	"crypto/internal/nistec"
    11  	"fmt"
    12  	"internal/testenv"
    13  	"math/big"
    14  	"math/rand"
    15  	"testing"
    16  )
    17  
    18  func TestAllocations(t *testing.T) {
    19  	testenv.SkipIfOptimizationOff(t)
    20  
    21  	t.Run("P224", func(t *testing.T) {
    22  		if allocs := testing.AllocsPerRun(10, func() {
    23  			p := nistec.NewP224Point().SetGenerator()
    24  			scalar := make([]byte, 28)
    25  			rand.Read(scalar)
    26  			p.ScalarBaseMult(scalar)
    27  			p.ScalarMult(p, scalar)
    28  			out := p.Bytes()
    29  			if _, err := nistec.NewP224Point().SetBytes(out); err != nil {
    30  				t.Fatal(err)
    31  			}
    32  			out = p.BytesCompressed()
    33  			if _, err := p.SetBytes(out); err != nil {
    34  				t.Fatal(err)
    35  			}
    36  		}); allocs > 0 {
    37  			t.Errorf("expected zero allocations, got %0.1f", allocs)
    38  		}
    39  	})
    40  	t.Run("P256", func(t *testing.T) {
    41  		if allocs := testing.AllocsPerRun(10, func() {
    42  			p := nistec.NewP256Point().SetGenerator()
    43  			scalar := make([]byte, 32)
    44  			rand.Read(scalar)
    45  			p.ScalarBaseMult(scalar)
    46  			p.ScalarMult(p, scalar)
    47  			out := p.Bytes()
    48  			if _, err := nistec.NewP256Point().SetBytes(out); err != nil {
    49  				t.Fatal(err)
    50  			}
    51  			out = p.BytesCompressed()
    52  			if _, err := p.SetBytes(out); err != nil {
    53  				t.Fatal(err)
    54  			}
    55  		}); allocs > 0 {
    56  			t.Errorf("expected zero allocations, got %0.1f", allocs)
    57  		}
    58  	})
    59  	t.Run("P384", func(t *testing.T) {
    60  		if allocs := testing.AllocsPerRun(10, func() {
    61  			p := nistec.NewP384Point().SetGenerator()
    62  			scalar := make([]byte, 48)
    63  			rand.Read(scalar)
    64  			p.ScalarBaseMult(scalar)
    65  			p.ScalarMult(p, scalar)
    66  			out := p.Bytes()
    67  			if _, err := nistec.NewP384Point().SetBytes(out); err != nil {
    68  				t.Fatal(err)
    69  			}
    70  			out = p.BytesCompressed()
    71  			if _, err := p.SetBytes(out); err != nil {
    72  				t.Fatal(err)
    73  			}
    74  		}); allocs > 0 {
    75  			t.Errorf("expected zero allocations, got %0.1f", allocs)
    76  		}
    77  	})
    78  	t.Run("P521", func(t *testing.T) {
    79  		if allocs := testing.AllocsPerRun(10, func() {
    80  			p := nistec.NewP521Point().SetGenerator()
    81  			scalar := make([]byte, 66)
    82  			rand.Read(scalar)
    83  			p.ScalarBaseMult(scalar)
    84  			p.ScalarMult(p, scalar)
    85  			out := p.Bytes()
    86  			if _, err := nistec.NewP521Point().SetBytes(out); err != nil {
    87  				t.Fatal(err)
    88  			}
    89  			out = p.BytesCompressed()
    90  			if _, err := p.SetBytes(out); err != nil {
    91  				t.Fatal(err)
    92  			}
    93  		}); allocs > 0 {
    94  			t.Errorf("expected zero allocations, got %0.1f", allocs)
    95  		}
    96  	})
    97  }
    98  
    99  type nistPoint[T any] interface {
   100  	Bytes() []byte
   101  	SetGenerator() T
   102  	SetBytes([]byte) (T, error)
   103  	Add(T, T) T
   104  	Double(T) T
   105  	ScalarMult(T, []byte) (T, error)
   106  	ScalarBaseMult([]byte) (T, error)
   107  }
   108  
   109  func TestEquivalents(t *testing.T) {
   110  	t.Run("P224", func(t *testing.T) {
   111  		testEquivalents(t, nistec.NewP224Point, elliptic.P224())
   112  	})
   113  	t.Run("P256", func(t *testing.T) {
   114  		testEquivalents(t, nistec.NewP256Point, elliptic.P256())
   115  	})
   116  	t.Run("P384", func(t *testing.T) {
   117  		testEquivalents(t, nistec.NewP384Point, elliptic.P384())
   118  	})
   119  	t.Run("P521", func(t *testing.T) {
   120  		testEquivalents(t, nistec.NewP521Point, elliptic.P521())
   121  	})
   122  }
   123  
   124  func testEquivalents[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic.Curve) {
   125  	p := newPoint().SetGenerator()
   126  
   127  	elementSize := (c.Params().BitSize + 7) / 8
   128  	two := make([]byte, elementSize)
   129  	two[len(two)-1] = 2
   130  	nPlusTwo := make([]byte, elementSize)
   131  	new(big.Int).Add(c.Params().N, big.NewInt(2)).FillBytes(nPlusTwo)
   132  
   133  	p1 := newPoint().Double(p)
   134  	p2 := newPoint().Add(p, p)
   135  	p3, err := newPoint().ScalarMult(p, two)
   136  	fatalIfErr(t, err)
   137  	p4, err := newPoint().ScalarBaseMult(two)
   138  	fatalIfErr(t, err)
   139  	p5, err := newPoint().ScalarMult(p, nPlusTwo)
   140  	fatalIfErr(t, err)
   141  	p6, err := newPoint().ScalarBaseMult(nPlusTwo)
   142  	fatalIfErr(t, err)
   143  
   144  	if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
   145  		t.Error("P+P != 2*P")
   146  	}
   147  	if !bytes.Equal(p1.Bytes(), p3.Bytes()) {
   148  		t.Error("P+P != [2]P")
   149  	}
   150  	if !bytes.Equal(p1.Bytes(), p4.Bytes()) {
   151  		t.Error("G+G != [2]G")
   152  	}
   153  	if !bytes.Equal(p1.Bytes(), p5.Bytes()) {
   154  		t.Error("P+P != [N+2]P")
   155  	}
   156  	if !bytes.Equal(p1.Bytes(), p6.Bytes()) {
   157  		t.Error("G+G != [N+2]G")
   158  	}
   159  }
   160  
   161  func TestScalarMult(t *testing.T) {
   162  	t.Run("P224", func(t *testing.T) {
   163  		testScalarMult(t, nistec.NewP224Point, elliptic.P224())
   164  	})
   165  	t.Run("P256", func(t *testing.T) {
   166  		testScalarMult(t, nistec.NewP256Point, elliptic.P256())
   167  	})
   168  	t.Run("P384", func(t *testing.T) {
   169  		testScalarMult(t, nistec.NewP384Point, elliptic.P384())
   170  	})
   171  	t.Run("P521", func(t *testing.T) {
   172  		testScalarMult(t, nistec.NewP521Point, elliptic.P521())
   173  	})
   174  }
   175  
   176  func testScalarMult[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic.Curve) {
   177  	G := newPoint().SetGenerator()
   178  	checkScalar := func(t *testing.T, scalar []byte) {
   179  		p1, err := newPoint().ScalarBaseMult(scalar)
   180  		fatalIfErr(t, err)
   181  		p2, err := newPoint().ScalarMult(G, scalar)
   182  		fatalIfErr(t, err)
   183  		if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
   184  			t.Error("[k]G != ScalarBaseMult(k)")
   185  		}
   186  
   187  		expectInfinity := new(big.Int).Mod(new(big.Int).SetBytes(scalar), c.Params().N).Sign() == 0
   188  		if expectInfinity {
   189  			if !bytes.Equal(p1.Bytes(), newPoint().Bytes()) {
   190  				t.Error("ScalarBaseMult(k) != ∞")
   191  			}
   192  			if !bytes.Equal(p2.Bytes(), newPoint().Bytes()) {
   193  				t.Error("[k]G != ∞")
   194  			}
   195  		} else {
   196  			if bytes.Equal(p1.Bytes(), newPoint().Bytes()) {
   197  				t.Error("ScalarBaseMult(k) == ∞")
   198  			}
   199  			if bytes.Equal(p2.Bytes(), newPoint().Bytes()) {
   200  				t.Error("[k]G == ∞")
   201  			}
   202  		}
   203  
   204  		d := new(big.Int).SetBytes(scalar)
   205  		d.Sub(c.Params().N, d)
   206  		d.Mod(d, c.Params().N)
   207  		g1, err := newPoint().ScalarBaseMult(d.FillBytes(make([]byte, len(scalar))))
   208  		fatalIfErr(t, err)
   209  		g1.Add(g1, p1)
   210  		if !bytes.Equal(g1.Bytes(), newPoint().Bytes()) {
   211  			t.Error("[N - k]G + [k]G != ∞")
   212  		}
   213  	}
   214  
   215  	byteLen := len(c.Params().N.Bytes())
   216  	bitLen := c.Params().N.BitLen()
   217  	t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) })
   218  	t.Run("1", func(t *testing.T) {
   219  		checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen)))
   220  	})
   221  	t.Run("N-1", func(t *testing.T) {
   222  		checkScalar(t, new(big.Int).Sub(c.Params().N, big.NewInt(1)).Bytes())
   223  	})
   224  	t.Run("N", func(t *testing.T) { checkScalar(t, c.Params().N.Bytes()) })
   225  	t.Run("N+1", func(t *testing.T) {
   226  		checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(1)).Bytes())
   227  	})
   228  	t.Run("all1s", func(t *testing.T) {
   229  		s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen))
   230  		s.Sub(s, big.NewInt(1))
   231  		checkScalar(t, s.Bytes())
   232  	})
   233  	if testing.Short() {
   234  		return
   235  	}
   236  	for i := 0; i < bitLen; i++ {
   237  		t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) {
   238  			s := new(big.Int).Lsh(big.NewInt(1), uint(i))
   239  			checkScalar(t, s.FillBytes(make([]byte, byteLen)))
   240  		})
   241  	}
   242  	for i := 0; i <= 64; i++ {
   243  		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   244  			checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen)))
   245  		})
   246  	}
   247  	// Test N-64...N+64 since they risk overlapping with precomputed table values
   248  	// in the final additions.
   249  	for i := int64(-64); i <= 64; i++ {
   250  		t.Run(fmt.Sprintf("N%+d", i), func(t *testing.T) {
   251  			checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(i)).Bytes())
   252  		})
   253  	}
   254  }
   255  
   256  func fatalIfErr(t *testing.T, err error) {
   257  	t.Helper()
   258  	if err != nil {
   259  		t.Fatal(err)
   260  	}
   261  }
   262  
   263  func BenchmarkScalarMult(b *testing.B) {
   264  	b.Run("P224", func(b *testing.B) {
   265  		benchmarkScalarMult(b, nistec.NewP224Point().SetGenerator(), 28)
   266  	})
   267  	b.Run("P256", func(b *testing.B) {
   268  		benchmarkScalarMult(b, nistec.NewP256Point().SetGenerator(), 32)
   269  	})
   270  	b.Run("P384", func(b *testing.B) {
   271  		benchmarkScalarMult(b, nistec.NewP384Point().SetGenerator(), 48)
   272  	})
   273  	b.Run("P521", func(b *testing.B) {
   274  		benchmarkScalarMult(b, nistec.NewP521Point().SetGenerator(), 66)
   275  	})
   276  }
   277  
   278  func benchmarkScalarMult[P nistPoint[P]](b *testing.B, p P, scalarSize int) {
   279  	scalar := make([]byte, scalarSize)
   280  	rand.Read(scalar)
   281  	b.ReportAllocs()
   282  	b.ResetTimer()
   283  	for i := 0; i < b.N; i++ {
   284  		p.ScalarMult(p, scalar)
   285  	}
   286  }
   287  
   288  func BenchmarkScalarBaseMult(b *testing.B) {
   289  	b.Run("P224", func(b *testing.B) {
   290  		benchmarkScalarBaseMult(b, nistec.NewP224Point().SetGenerator(), 28)
   291  	})
   292  	b.Run("P256", func(b *testing.B) {
   293  		benchmarkScalarBaseMult(b, nistec.NewP256Point().SetGenerator(), 32)
   294  	})
   295  	b.Run("P384", func(b *testing.B) {
   296  		benchmarkScalarBaseMult(b, nistec.NewP384Point().SetGenerator(), 48)
   297  	})
   298  	b.Run("P521", func(b *testing.B) {
   299  		benchmarkScalarBaseMult(b, nistec.NewP521Point().SetGenerator(), 66)
   300  	})
   301  }
   302  
   303  func benchmarkScalarBaseMult[P nistPoint[P]](b *testing.B, p P, scalarSize int) {
   304  	scalar := make([]byte, scalarSize)
   305  	rand.Read(scalar)
   306  	b.ReportAllocs()
   307  	b.ResetTimer()
   308  	for i := 0; i < b.N; i++ {
   309  		p.ScalarBaseMult(scalar)
   310  	}
   311  }
   312  

View as plain text