1
2
3
4
5 package bigmod
6
7 import (
8 "fmt"
9 "math/big"
10 "math/bits"
11 "math/rand"
12 "reflect"
13 "strings"
14 "testing"
15 "testing/quick"
16 )
17
18 func (n *Nat) String() string {
19 var limbs []string
20 for i := range n.limbs {
21 limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i]))
22 }
23 return "{" + strings.Join(limbs, " ") + "}"
24 }
25
26
27
28 func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
29 limbs := make([]uint, size)
30 for i := 0; i < size; i++ {
31 limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
32 }
33 return reflect.ValueOf(&Nat{limbs})
34 }
35
36 func testModAddCommutative(a *Nat, b *Nat) bool {
37 m := maxModulus(uint(len(a.limbs)))
38 aPlusB := new(Nat).set(a)
39 aPlusB.Add(b, m)
40 bPlusA := new(Nat).set(b)
41 bPlusA.Add(a, m)
42 return aPlusB.Equal(bPlusA) == 1
43 }
44
45 func TestModAddCommutative(t *testing.T) {
46 err := quick.Check(testModAddCommutative, &quick.Config{})
47 if err != nil {
48 t.Error(err)
49 }
50 }
51
52 func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
53 m := maxModulus(uint(len(a.limbs)))
54 original := new(Nat).set(a)
55 a.Sub(b, m)
56 a.Add(b, m)
57 return a.Equal(original) == 1
58 }
59
60 func TestModSubThenAddIdentity(t *testing.T) {
61 err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
62 if err != nil {
63 t.Error(err)
64 }
65 }
66
67 func TestMontgomeryRoundtrip(t *testing.T) {
68 err := quick.Check(func(a *Nat) bool {
69 one := &Nat{make([]uint, len(a.limbs))}
70 one.limbs[0] = 1
71 aPlusOne := new(big.Int).SetBytes(natBytes(a))
72 aPlusOne.Add(aPlusOne, big.NewInt(1))
73 m, _ := NewModulusFromBig(aPlusOne)
74 monty := new(Nat).set(a)
75 monty.montgomeryRepresentation(m)
76 aAgain := new(Nat).set(monty)
77 aAgain.montgomeryMul(monty, one, m)
78 if a.Equal(aAgain) != 1 {
79 t.Errorf("%v != %v", a, aAgain)
80 return false
81 }
82 return true
83 }, &quick.Config{})
84 if err != nil {
85 t.Error(err)
86 }
87 }
88
89 func TestShiftIn(t *testing.T) {
90 if bits.UintSize != 64 {
91 t.Skip("examples are only valid in 64 bit")
92 }
93 examples := []struct {
94 m, x, expected []byte
95 y uint64
96 }{{
97 m: []byte{13},
98 x: []byte{0},
99 y: 0xFFFF_FFFF_FFFF_FFFF,
100 expected: []byte{2},
101 }, {
102 m: []byte{13},
103 x: []byte{7},
104 y: 0xFFFF_FFFF_FFFF_FFFF,
105 expected: []byte{10},
106 }, {
107 m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
108 x: make([]byte, 9),
109 y: 0xFFFF_FFFF_FFFF_FFFF,
110 expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
111 }, {
112 m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
113 x: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
114 y: 0,
115 expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06},
116 }}
117
118 for i, tt := range examples {
119 m := modulusFromBytes(tt.m)
120 got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
121 if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 {
122 t.Errorf("%d: got %v, expected %v", i, got, exp)
123 }
124 }
125 }
126
127 func TestModulusAndNatSizes(t *testing.T) {
128
129
130
131
132 m := modulusFromBytes([]byte{
133 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
134 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
135 xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
136 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
137 natFromBytes(xb).ExpandFor(m)
138 NewNat().SetBytes(xb, m)
139 }
140
141 func TestSetBytes(t *testing.T) {
142 tests := []struct {
143 m, b []byte
144 fail bool
145 }{{
146 m: []byte{0xff, 0xff},
147 b: []byte{0x00, 0x01},
148 }, {
149 m: []byte{0xff, 0xff},
150 b: []byte{0xff, 0xff},
151 fail: true,
152 }, {
153 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
154 b: []byte{0x00, 0x01},
155 }, {
156 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
157 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
158 }, {
159 m: []byte{0xff, 0xff},
160 b: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
161 fail: true,
162 }, {
163 m: []byte{0xff, 0xff},
164 b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
165 fail: true,
166 }, {
167 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
168 b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
169 }, {
170 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
171 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
172 fail: true,
173 }, {
174 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
175 b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
176 fail: true,
177 }, {
178 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
179 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
180 fail: true,
181 }, {
182 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
183 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
184 fail: true,
185 }}
186
187 for i, tt := range tests {
188 m := modulusFromBytes(tt.m)
189 got, err := NewNat().SetBytes(tt.b, m)
190 if err != nil {
191 if !tt.fail {
192 t.Errorf("%d: unexpected error: %v", i, err)
193 }
194 continue
195 }
196 if tt.fail {
197 t.Errorf("%d: unexpected success", i)
198 continue
199 }
200 if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
201 t.Errorf("%d: got %v, expected %v", i, got, expected)
202 }
203 }
204
205 f := func(xBytes []byte) bool {
206 m := maxModulus(uint(len(xBytes)*8/_W + 1))
207 got, err := NewNat().SetBytes(xBytes, m)
208 if err != nil {
209 return false
210 }
211 return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
212 }
213
214 err := quick.Check(f, &quick.Config{})
215 if err != nil {
216 t.Error(err)
217 }
218 }
219
220 func TestExpand(t *testing.T) {
221 sliced := []uint{1, 2, 3, 4}
222 examples := []struct {
223 in []uint
224 n int
225 out []uint
226 }{{
227 []uint{1, 2},
228 4,
229 []uint{1, 2, 0, 0},
230 }, {
231 sliced[:2],
232 4,
233 []uint{1, 2, 0, 0},
234 }, {
235 []uint{1, 2},
236 2,
237 []uint{1, 2},
238 }}
239
240 for i, tt := range examples {
241 got := (&Nat{tt.in}).expand(tt.n)
242 if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
243 t.Errorf("%d: got %v, expected %v", i, got, tt.out)
244 }
245 }
246 }
247
248 func TestMod(t *testing.T) {
249 m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
250 x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
251 out := new(Nat)
252 out.Mod(x, m)
253 expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
254 if out.Equal(expected) != 1 {
255 t.Errorf("%+v != %+v", out, expected)
256 }
257 }
258
259 func TestModSub(t *testing.T) {
260 m := modulusFromBytes([]byte{13})
261 x := &Nat{[]uint{6}}
262 y := &Nat{[]uint{7}}
263 x.Sub(y, m)
264 expected := &Nat{[]uint{12}}
265 if x.Equal(expected) != 1 {
266 t.Errorf("%+v != %+v", x, expected)
267 }
268 x.Sub(y, m)
269 expected = &Nat{[]uint{5}}
270 if x.Equal(expected) != 1 {
271 t.Errorf("%+v != %+v", x, expected)
272 }
273 }
274
275 func TestModAdd(t *testing.T) {
276 m := modulusFromBytes([]byte{13})
277 x := &Nat{[]uint{6}}
278 y := &Nat{[]uint{7}}
279 x.Add(y, m)
280 expected := &Nat{[]uint{0}}
281 if x.Equal(expected) != 1 {
282 t.Errorf("%+v != %+v", x, expected)
283 }
284 x.Add(y, m)
285 expected = &Nat{[]uint{7}}
286 if x.Equal(expected) != 1 {
287 t.Errorf("%+v != %+v", x, expected)
288 }
289 }
290
291 func TestExp(t *testing.T) {
292 m := modulusFromBytes([]byte{13})
293 x := &Nat{[]uint{3}}
294 out := &Nat{[]uint{0}}
295 out.Exp(x, []byte{12}, m)
296 expected := &Nat{[]uint{1}}
297 if out.Equal(expected) != 1 {
298 t.Errorf("%+v != %+v", out, expected)
299 }
300 }
301
302 func TestExpShort(t *testing.T) {
303 m := modulusFromBytes([]byte{13})
304 x := &Nat{[]uint{3}}
305 out := &Nat{[]uint{0}}
306 out.ExpShortVarTime(x, 12, m)
307 expected := &Nat{[]uint{1}}
308 if out.Equal(expected) != 1 {
309 t.Errorf("%+v != %+v", out, expected)
310 }
311 }
312
313
314
315
316 func TestMulReductions(t *testing.T) {
317
318 a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10)
319 b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
320 n := new(big.Int).Mul(a, b)
321
322 N, _ := NewModulusFromBig(n)
323 A := NewNat().setBig(a).ExpandFor(N)
324 B := NewNat().setBig(b).ExpandFor(N)
325
326 if A.Mul(B, N).IsZero() != 1 {
327 t.Error("a * b mod (a * b) != 0")
328 }
329
330 i := new(big.Int).ModInverse(a, b)
331 N, _ = NewModulusFromBig(b)
332 A = NewNat().setBig(a).ExpandFor(N)
333 I := NewNat().setBig(i).ExpandFor(N)
334 one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
335
336 if A.Mul(I, N).Equal(one) != 1 {
337 t.Error("a * inv(a) mod b != 1")
338 }
339 }
340
341 func natBytes(n *Nat) []byte {
342 return n.Bytes(maxModulus(uint(len(n.limbs))))
343 }
344
345 func natFromBytes(b []byte) *Nat {
346
347 bb := new(big.Int).SetBytes(b)
348 return NewNat().setBig(bb)
349 }
350
351 func modulusFromBytes(b []byte) *Modulus {
352 bb := new(big.Int).SetBytes(b)
353 m, _ := NewModulusFromBig(bb)
354 return m
355 }
356
357
358 func maxModulus(n uint) *Modulus {
359 b := big.NewInt(1)
360 b.Lsh(b, n*_W)
361 b.Sub(b, big.NewInt(1))
362 m, _ := NewModulusFromBig(b)
363 return m
364 }
365
366 func makeBenchmarkModulus() *Modulus {
367 return maxModulus(32)
368 }
369
370 func makeBenchmarkValue() *Nat {
371 x := make([]uint, 32)
372 for i := 0; i < 32; i++ {
373 x[i]--
374 }
375 return &Nat{limbs: x}
376 }
377
378 func makeBenchmarkExponent() []byte {
379 e := make([]byte, 256)
380 for i := 0; i < 32; i++ {
381 e[i] = 0xFF
382 }
383 return e
384 }
385
386 func BenchmarkModAdd(b *testing.B) {
387 x := makeBenchmarkValue()
388 y := makeBenchmarkValue()
389 m := makeBenchmarkModulus()
390
391 b.ResetTimer()
392 for i := 0; i < b.N; i++ {
393 x.Add(y, m)
394 }
395 }
396
397 func BenchmarkModSub(b *testing.B) {
398 x := makeBenchmarkValue()
399 y := makeBenchmarkValue()
400 m := makeBenchmarkModulus()
401
402 b.ResetTimer()
403 for i := 0; i < b.N; i++ {
404 x.Sub(y, m)
405 }
406 }
407
408 func BenchmarkMontgomeryRepr(b *testing.B) {
409 x := makeBenchmarkValue()
410 m := makeBenchmarkModulus()
411
412 b.ResetTimer()
413 for i := 0; i < b.N; i++ {
414 x.montgomeryRepresentation(m)
415 }
416 }
417
418 func BenchmarkMontgomeryMul(b *testing.B) {
419 x := makeBenchmarkValue()
420 y := makeBenchmarkValue()
421 out := makeBenchmarkValue()
422 m := makeBenchmarkModulus()
423
424 b.ResetTimer()
425 for i := 0; i < b.N; i++ {
426 out.montgomeryMul(x, y, m)
427 }
428 }
429
430 func BenchmarkModMul(b *testing.B) {
431 x := makeBenchmarkValue()
432 y := makeBenchmarkValue()
433 m := makeBenchmarkModulus()
434
435 b.ResetTimer()
436 for i := 0; i < b.N; i++ {
437 x.Mul(y, m)
438 }
439 }
440
441 func BenchmarkExpBig(b *testing.B) {
442 out := new(big.Int)
443 exponentBytes := makeBenchmarkExponent()
444 x := new(big.Int).SetBytes(exponentBytes)
445 e := new(big.Int).SetBytes(exponentBytes)
446 n := new(big.Int).SetBytes(exponentBytes)
447 one := new(big.Int).SetUint64(1)
448 n.Add(n, one)
449
450 b.ResetTimer()
451 for i := 0; i < b.N; i++ {
452 out.Exp(x, e, n)
453 }
454 }
455
456 func BenchmarkExp(b *testing.B) {
457 x := makeBenchmarkValue()
458 e := makeBenchmarkExponent()
459 out := makeBenchmarkValue()
460 m := makeBenchmarkModulus()
461
462 b.ResetTimer()
463 for i := 0; i < b.N; i++ {
464 out.Exp(x, e, m)
465 }
466 }
467
468 func TestNewModFromBigZero(t *testing.T) {
469 expected := "modulus must be >= 0"
470 _, err := NewModulusFromBig(big.NewInt(0))
471 if err == nil || err.Error() != expected {
472 t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
473 }
474
475 expected = "modulus must be odd"
476 _, err = NewModulusFromBig(big.NewInt(2))
477 if err == nil || err.Error() != expected {
478 t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected)
479 }
480 }
481
View as plain text