1
2
3
4
5 package elliptic
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "encoding/hex"
11 "math/big"
12 "testing"
13 )
14
15
16
17
18
19 func genericParamsForCurve(c Curve) *CurveParams {
20 d := *(c.Params())
21 return &d
22 }
23
24 func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
25 tests := []struct {
26 name string
27 curve Curve
28 }{
29 {"P256", P256()},
30 {"P256/Params", genericParamsForCurve(P256())},
31 {"P224", P224()},
32 {"P224/Params", genericParamsForCurve(P224())},
33 {"P384", P384()},
34 {"P384/Params", genericParamsForCurve(P384())},
35 {"P521", P521()},
36 {"P521/Params", genericParamsForCurve(P521())},
37 }
38 if testing.Short() {
39 tests = tests[:1]
40 }
41 for _, test := range tests {
42 curve := test.curve
43 t.Run(test.name, func(t *testing.T) {
44 t.Parallel()
45 f(t, curve)
46 })
47 }
48 }
49
50 func TestOnCurve(t *testing.T) {
51 t.Parallel()
52 testAllCurves(t, func(t *testing.T, curve Curve) {
53 if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) {
54 t.Error("basepoint is not on the curve")
55 }
56 })
57 }
58
59 func TestOffCurve(t *testing.T) {
60 t.Parallel()
61 testAllCurves(t, func(t *testing.T, curve Curve) {
62 x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
63 if curve.IsOnCurve(x, y) {
64 t.Errorf("point off curve is claimed to be on the curve")
65 }
66
67 byteLen := (curve.Params().BitSize + 7) / 8
68 b := make([]byte, 1+2*byteLen)
69 b[0] = 4
70 x.FillBytes(b[1 : 1+byteLen])
71 y.FillBytes(b[1+byteLen : 1+2*byteLen])
72
73 x1, y1 := Unmarshal(curve, b)
74 if x1 != nil || y1 != nil {
75 t.Errorf("unmarshaling a point not on the curve succeeded")
76 }
77 })
78 }
79
80 func TestInfinity(t *testing.T) {
81 t.Parallel()
82 testAllCurves(t, testInfinity)
83 }
84
85 func isInfinity(x, y *big.Int) bool {
86 return x.Sign() == 0 && y.Sign() == 0
87 }
88
89 func testInfinity(t *testing.T, curve Curve) {
90 x0, y0 := new(big.Int), new(big.Int)
91 xG, yG := curve.Params().Gx, curve.Params().Gy
92
93 if !isInfinity(curve.ScalarMult(xG, yG, curve.Params().N.Bytes())) {
94 t.Errorf("x^q != ∞")
95 }
96 if !isInfinity(curve.ScalarMult(xG, yG, []byte{0})) {
97 t.Errorf("x^0 != ∞")
98 }
99
100 if !isInfinity(curve.ScalarMult(x0, y0, []byte{1, 2, 3})) {
101 t.Errorf("∞^k != ∞")
102 }
103 if !isInfinity(curve.ScalarMult(x0, y0, []byte{0})) {
104 t.Errorf("∞^0 != ∞")
105 }
106
107 if !isInfinity(curve.ScalarBaseMult(curve.Params().N.Bytes())) {
108 t.Errorf("b^q != ∞")
109 }
110 if !isInfinity(curve.ScalarBaseMult([]byte{0})) {
111 t.Errorf("b^0 != ∞")
112 }
113
114 if !isInfinity(curve.Double(x0, y0)) {
115 t.Errorf("2∞ != ∞")
116 }
117
118
119
120 nMinusOne := new(big.Int).Sub(curve.Params().N, big.NewInt(1))
121 x, y := curve.ScalarMult(xG, yG, nMinusOne.Bytes())
122 x, y = curve.Add(x, y, xG, yG)
123 if !isInfinity(x, y) {
124 t.Errorf("x^(q-1) + x != ∞")
125 }
126 x, y = curve.Add(xG, yG, x0, y0)
127 if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
128 t.Errorf("x+∞ != x")
129 }
130 x, y = curve.Add(x0, y0, xG, yG)
131 if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
132 t.Errorf("∞+x != x")
133 }
134
135 if curve.IsOnCurve(x0, y0) {
136 t.Errorf("IsOnCurve(∞) == true")
137 }
138
139 if xx, yy := Unmarshal(curve, Marshal(curve, x0, y0)); xx != nil || yy != nil {
140 t.Errorf("Unmarshal(Marshal(∞)) did not return an error")
141 }
142
143
144 if xx, yy := Unmarshal(curve, []byte{0x00}); xx != nil || yy != nil {
145 t.Errorf("Unmarshal(∞) did not return an error")
146 }
147 byteLen := (curve.Params().BitSize + 7) / 8
148 buf := make([]byte, byteLen*2+1)
149 buf[0] = 4
150 if xx, yy := Unmarshal(curve, buf); xx != nil || yy != nil {
151 t.Errorf("Unmarshal((0,0)) did not return an error")
152 }
153 }
154
155 func TestMarshal(t *testing.T) {
156 t.Parallel()
157 testAllCurves(t, func(t *testing.T, curve Curve) {
158 _, x, y, err := GenerateKey(curve, rand.Reader)
159 if err != nil {
160 t.Fatal(err)
161 }
162 serialized := Marshal(curve, x, y)
163 xx, yy := Unmarshal(curve, serialized)
164 if xx == nil {
165 t.Fatal("failed to unmarshal")
166 }
167 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
168 t.Fatal("unmarshal returned different values")
169 }
170 })
171 }
172
173 func TestUnmarshalToLargeCoordinates(t *testing.T) {
174 t.Parallel()
175
176 testAllCurves(t, testUnmarshalToLargeCoordinates)
177 }
178
179 func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) {
180 p := curve.Params().P
181 byteLen := (p.BitLen() + 7) / 8
182
183
184
185
186 x := new(big.Int).Add(p, big.NewInt(5))
187 y := curve.Params().polynomial(x)
188 y.ModSqrt(y, p)
189
190 invalid := make([]byte, byteLen*2+1)
191 invalid[0] = 4
192 x.FillBytes(invalid[1 : 1+byteLen])
193 y.FillBytes(invalid[1+byteLen:])
194
195 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
196 t.Errorf("Unmarshal accepts invalid X coordinate")
197 }
198
199 if curve == p256 {
200
201
202 x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10)
203 y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10)
204 y.Add(y, p)
205
206 if p.Cmp(y) > 0 || y.BitLen() != 256 {
207 t.Fatal("y not within expected range")
208 }
209
210
211 x.FillBytes(invalid[1 : 1+byteLen])
212 y.FillBytes(invalid[1+byteLen:])
213
214 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
215 t.Errorf("Unmarshal accepts invalid Y coordinate")
216 }
217 }
218 }
219
220
221
222
223 func TestInvalidCoordinates(t *testing.T) {
224 t.Parallel()
225 testAllCurves(t, testInvalidCoordinates)
226 }
227
228 func testInvalidCoordinates(t *testing.T, curve Curve) {
229 checkIsOnCurveFalse := func(name string, x, y *big.Int) {
230 if curve.IsOnCurve(x, y) {
231 t.Errorf("IsOnCurve(%s) unexpectedly returned true", name)
232 }
233 }
234
235 p := curve.Params().P
236 _, x, y, _ := GenerateKey(curve, rand.Reader)
237 xx, yy := new(big.Int), new(big.Int)
238
239
240 xx.Neg(x)
241 checkIsOnCurveFalse("-x, y", xx, y)
242 yy.Neg(y)
243 checkIsOnCurveFalse("x, -y", x, yy)
244
245
246 xx.Sub(x, p)
247 checkIsOnCurveFalse("x-P, y", xx, y)
248 yy.Sub(y, p)
249 checkIsOnCurveFalse("x, y-P", x, yy)
250
251
252 xx.Add(x, p)
253 checkIsOnCurveFalse("x+P, y", xx, y)
254 yy.Add(y, p)
255 checkIsOnCurveFalse("x, y+P", x, yy)
256
257
258 xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535))
259 checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y)
260 yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535))
261 checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy)
262
263
264
265
266
267
268 if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil {
269 if !curve.IsOnCurve(big.NewInt(0), yy) {
270 t.Fatal("(0, mod_sqrt(B)) is not on the curve?")
271 }
272 checkIsOnCurveFalse("P, y", p, yy)
273 }
274 }
275
276 func TestMarshalCompressed(t *testing.T) {
277 t.Parallel()
278 t.Run("P-256/03", func(t *testing.T) {
279 data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
280 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
281 y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10)
282 testMarshalCompressed(t, P256(), x, y, data)
283 })
284 t.Run("P-256/02", func(t *testing.T) {
285 data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
286 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
287 y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10)
288 testMarshalCompressed(t, P256(), x, y, data)
289 })
290
291 t.Run("Invalid", func(t *testing.T) {
292 data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535")
293 X, Y := UnmarshalCompressed(P256(), data)
294 if X != nil || Y != nil {
295 t.Error("expected an error for invalid encoding")
296 }
297 })
298
299 if testing.Short() {
300 t.Skip("skipping other curves on short test")
301 }
302
303 testAllCurves(t, func(t *testing.T, curve Curve) {
304 _, x, y, err := GenerateKey(curve, rand.Reader)
305 if err != nil {
306 t.Fatal(err)
307 }
308 testMarshalCompressed(t, curve, x, y, nil)
309 })
310
311 }
312
313 func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
314 if !curve.IsOnCurve(x, y) {
315 t.Fatal("invalid test point")
316 }
317 got := MarshalCompressed(curve, x, y)
318 if want != nil && !bytes.Equal(got, want) {
319 t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
320 }
321
322 X, Y := UnmarshalCompressed(curve, got)
323 if X == nil || Y == nil {
324 t.Fatalf("UnmarshalCompressed failed unexpectedly")
325 }
326
327 if !curve.IsOnCurve(X, Y) {
328 t.Error("UnmarshalCompressed returned a point not on the curve")
329 }
330 if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
331 t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
332 }
333 }
334
335 func TestLargeIsOnCurve(t *testing.T) {
336 t.Parallel()
337 testAllCurves(t, func(t *testing.T, curve Curve) {
338 large := big.NewInt(1)
339 large.Lsh(large, 1000)
340 if curve.IsOnCurve(large, large) {
341 t.Errorf("(2^1000, 2^1000) is reported on the curve")
342 }
343 })
344 }
345
346 func benchmarkAllCurves(b *testing.B, f func(*testing.B, Curve)) {
347 tests := []struct {
348 name string
349 curve Curve
350 }{
351 {"P256", P256()},
352 {"P224", P224()},
353 {"P384", P384()},
354 {"P521", P521()},
355 }
356 for _, test := range tests {
357 curve := test.curve
358 b.Run(test.name, func(b *testing.B) {
359 f(b, curve)
360 })
361 }
362 }
363
364 func BenchmarkScalarBaseMult(b *testing.B) {
365 benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
366 priv, _, _, _ := GenerateKey(curve, rand.Reader)
367 b.ReportAllocs()
368 b.ResetTimer()
369 for i := 0; i < b.N; i++ {
370 x, _ := curve.ScalarBaseMult(priv)
371
372 priv[0] ^= byte(x.Bits()[0])
373 }
374 })
375 }
376
377 func BenchmarkScalarMult(b *testing.B) {
378 benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
379 _, x, y, _ := GenerateKey(curve, rand.Reader)
380 priv, _, _, _ := GenerateKey(curve, rand.Reader)
381 b.ReportAllocs()
382 b.ResetTimer()
383 for i := 0; i < b.N; i++ {
384 x, y = curve.ScalarMult(x, y, priv)
385 }
386 })
387 }
388
389 func BenchmarkMarshalUnmarshal(b *testing.B) {
390 benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
391 _, x, y, _ := GenerateKey(curve, rand.Reader)
392 b.Run("Uncompressed", func(b *testing.B) {
393 b.ReportAllocs()
394 for i := 0; i < b.N; i++ {
395 buf := Marshal(curve, x, y)
396 xx, yy := Unmarshal(curve, buf)
397 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
398 b.Error("Unmarshal output different from Marshal input")
399 }
400 }
401 })
402 b.Run("Compressed", func(b *testing.B) {
403 b.ReportAllocs()
404 for i := 0; i < b.N; i++ {
405 buf := MarshalCompressed(curve, x, y)
406 xx, yy := UnmarshalCompressed(curve, buf)
407 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
408 b.Error("Unmarshal output different from Marshal input")
409 }
410 }
411 })
412 })
413 }
414
View as plain text