Source file src/crypto/internal/bigmod/nat_test.go

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bigmod
     6  
     7  import (
     8  	"fmt"
     9  	"math/big"
    10  	"math/bits"
    11  	"math/rand"
    12  	"reflect"
    13  	"strings"
    14  	"testing"
    15  	"testing/quick"
    16  )
    17  
    18  func (n *Nat) String() string {
    19  	var limbs []string
    20  	for i := range n.limbs {
    21  		limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i]))
    22  	}
    23  	return "{" + strings.Join(limbs, " ") + "}"
    24  }
    25  
    26  // Generate generates an even nat. It's used by testing/quick to produce random
    27  // *nat values for quick.Check invocations.
    28  func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
    29  	limbs := make([]uint, size)
    30  	for i := 0; i < size; i++ {
    31  		limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
    32  	}
    33  	return reflect.ValueOf(&Nat{limbs})
    34  }
    35  
    36  func testModAddCommutative(a *Nat, b *Nat) bool {
    37  	m := maxModulus(uint(len(a.limbs)))
    38  	aPlusB := new(Nat).set(a)
    39  	aPlusB.Add(b, m)
    40  	bPlusA := new(Nat).set(b)
    41  	bPlusA.Add(a, m)
    42  	return aPlusB.Equal(bPlusA) == 1
    43  }
    44  
    45  func TestModAddCommutative(t *testing.T) {
    46  	err := quick.Check(testModAddCommutative, &quick.Config{})
    47  	if err != nil {
    48  		t.Error(err)
    49  	}
    50  }
    51  
    52  func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
    53  	m := maxModulus(uint(len(a.limbs)))
    54  	original := new(Nat).set(a)
    55  	a.Sub(b, m)
    56  	a.Add(b, m)
    57  	return a.Equal(original) == 1
    58  }
    59  
    60  func TestModSubThenAddIdentity(t *testing.T) {
    61  	err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
    62  	if err != nil {
    63  		t.Error(err)
    64  	}
    65  }
    66  
    67  func TestMontgomeryRoundtrip(t *testing.T) {
    68  	err := quick.Check(func(a *Nat) bool {
    69  		one := &Nat{make([]uint, len(a.limbs))}
    70  		one.limbs[0] = 1
    71  		aPlusOne := new(big.Int).SetBytes(natBytes(a))
    72  		aPlusOne.Add(aPlusOne, big.NewInt(1))
    73  		m, _ := NewModulusFromBig(aPlusOne)
    74  		monty := new(Nat).set(a)
    75  		monty.montgomeryRepresentation(m)
    76  		aAgain := new(Nat).set(monty)
    77  		aAgain.montgomeryMul(monty, one, m)
    78  		if a.Equal(aAgain) != 1 {
    79  			t.Errorf("%v != %v", a, aAgain)
    80  			return false
    81  		}
    82  		return true
    83  	}, &quick.Config{})
    84  	if err != nil {
    85  		t.Error(err)
    86  	}
    87  }
    88  
    89  func TestShiftIn(t *testing.T) {
    90  	if bits.UintSize != 64 {
    91  		t.Skip("examples are only valid in 64 bit")
    92  	}
    93  	examples := []struct {
    94  		m, x, expected []byte
    95  		y              uint64
    96  	}{{
    97  		m:        []byte{13},
    98  		x:        []byte{0},
    99  		y:        0xFFFF_FFFF_FFFF_FFFF,
   100  		expected: []byte{2},
   101  	}, {
   102  		m:        []byte{13},
   103  		x:        []byte{7},
   104  		y:        0xFFFF_FFFF_FFFF_FFFF,
   105  		expected: []byte{10},
   106  	}, {
   107  		m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
   108  		x:        make([]byte, 9),
   109  		y:        0xFFFF_FFFF_FFFF_FFFF,
   110  		expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   111  	}, {
   112  		m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
   113  		x:        []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   114  		y:        0,
   115  		expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06},
   116  	}}
   117  
   118  	for i, tt := range examples {
   119  		m := modulusFromBytes(tt.m)
   120  		got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
   121  		if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 {
   122  			t.Errorf("%d: got %v, expected %v", i, got, exp)
   123  		}
   124  	}
   125  }
   126  
   127  func TestModulusAndNatSizes(t *testing.T) {
   128  	// These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
   129  	// 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
   130  	// limbs, if they are not, they fit in three. This can be a problem because
   131  	// modulus strips leading zeroes and nat does not.
   132  	m := modulusFromBytes([]byte{
   133  		0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   134  		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
   135  	xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   136  		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
   137  	natFromBytes(xb).ExpandFor(m) // must not panic for shrinking
   138  	NewNat().SetBytes(xb, m)
   139  }
   140  
   141  func TestSetBytes(t *testing.T) {
   142  	tests := []struct {
   143  		m, b []byte
   144  		fail bool
   145  	}{{
   146  		m: []byte{0xff, 0xff},
   147  		b: []byte{0x00, 0x01},
   148  	}, {
   149  		m:    []byte{0xff, 0xff},
   150  		b:    []byte{0xff, 0xff},
   151  		fail: true,
   152  	}, {
   153  		m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   154  		b: []byte{0x00, 0x01},
   155  	}, {
   156  		m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   157  		b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   158  	}, {
   159  		m:    []byte{0xff, 0xff},
   160  		b:    []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
   161  		fail: true,
   162  	}, {
   163  		m:    []byte{0xff, 0xff},
   164  		b:    []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
   165  		fail: true,
   166  	}, {
   167  		m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   168  		b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   169  	}, {
   170  		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   171  		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   172  		fail: true,
   173  	}, {
   174  		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   175  		b:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   176  		fail: true,
   177  	}, {
   178  		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   179  		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   180  		fail: true,
   181  	}, {
   182  		m:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
   183  		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   184  		fail: true,
   185  	}}
   186  
   187  	for i, tt := range tests {
   188  		m := modulusFromBytes(tt.m)
   189  		got, err := NewNat().SetBytes(tt.b, m)
   190  		if err != nil {
   191  			if !tt.fail {
   192  				t.Errorf("%d: unexpected error: %v", i, err)
   193  			}
   194  			continue
   195  		}
   196  		if tt.fail {
   197  			t.Errorf("%d: unexpected success", i)
   198  			continue
   199  		}
   200  		if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
   201  			t.Errorf("%d: got %v, expected %v", i, got, expected)
   202  		}
   203  	}
   204  
   205  	f := func(xBytes []byte) bool {
   206  		m := maxModulus(uint(len(xBytes)*8/_W + 1))
   207  		got, err := NewNat().SetBytes(xBytes, m)
   208  		if err != nil {
   209  			return false
   210  		}
   211  		return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
   212  	}
   213  
   214  	err := quick.Check(f, &quick.Config{})
   215  	if err != nil {
   216  		t.Error(err)
   217  	}
   218  }
   219  
   220  func TestExpand(t *testing.T) {
   221  	sliced := []uint{1, 2, 3, 4}
   222  	examples := []struct {
   223  		in  []uint
   224  		n   int
   225  		out []uint
   226  	}{{
   227  		[]uint{1, 2},
   228  		4,
   229  		[]uint{1, 2, 0, 0},
   230  	}, {
   231  		sliced[:2],
   232  		4,
   233  		[]uint{1, 2, 0, 0},
   234  	}, {
   235  		[]uint{1, 2},
   236  		2,
   237  		[]uint{1, 2},
   238  	}}
   239  
   240  	for i, tt := range examples {
   241  		got := (&Nat{tt.in}).expand(tt.n)
   242  		if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
   243  			t.Errorf("%d: got %v, expected %v", i, got, tt.out)
   244  		}
   245  	}
   246  }
   247  
   248  func TestMod(t *testing.T) {
   249  	m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
   250  	x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
   251  	out := new(Nat)
   252  	out.Mod(x, m)
   253  	expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
   254  	if out.Equal(expected) != 1 {
   255  		t.Errorf("%+v != %+v", out, expected)
   256  	}
   257  }
   258  
   259  func TestModSub(t *testing.T) {
   260  	m := modulusFromBytes([]byte{13})
   261  	x := &Nat{[]uint{6}}
   262  	y := &Nat{[]uint{7}}
   263  	x.Sub(y, m)
   264  	expected := &Nat{[]uint{12}}
   265  	if x.Equal(expected) != 1 {
   266  		t.Errorf("%+v != %+v", x, expected)
   267  	}
   268  	x.Sub(y, m)
   269  	expected = &Nat{[]uint{5}}
   270  	if x.Equal(expected) != 1 {
   271  		t.Errorf("%+v != %+v", x, expected)
   272  	}
   273  }
   274  
   275  func TestModAdd(t *testing.T) {
   276  	m := modulusFromBytes([]byte{13})
   277  	x := &Nat{[]uint{6}}
   278  	y := &Nat{[]uint{7}}
   279  	x.Add(y, m)
   280  	expected := &Nat{[]uint{0}}
   281  	if x.Equal(expected) != 1 {
   282  		t.Errorf("%+v != %+v", x, expected)
   283  	}
   284  	x.Add(y, m)
   285  	expected = &Nat{[]uint{7}}
   286  	if x.Equal(expected) != 1 {
   287  		t.Errorf("%+v != %+v", x, expected)
   288  	}
   289  }
   290  
   291  func TestExp(t *testing.T) {
   292  	m := modulusFromBytes([]byte{13})
   293  	x := &Nat{[]uint{3}}
   294  	out := &Nat{[]uint{0}}
   295  	out.Exp(x, []byte{12}, m)
   296  	expected := &Nat{[]uint{1}}
   297  	if out.Equal(expected) != 1 {
   298  		t.Errorf("%+v != %+v", out, expected)
   299  	}
   300  }
   301  
   302  func TestExpShort(t *testing.T) {
   303  	m := modulusFromBytes([]byte{13})
   304  	x := &Nat{[]uint{3}}
   305  	out := &Nat{[]uint{0}}
   306  	out.ExpShortVarTime(x, 12, m)
   307  	expected := &Nat{[]uint{1}}
   308  	if out.Equal(expected) != 1 {
   309  		t.Errorf("%+v != %+v", out, expected)
   310  	}
   311  }
   312  
   313  // TestMulReductions tests that Mul reduces results equal or slightly greater
   314  // than the modulus. Some Montgomery algorithms don't and need extra care to
   315  // return correct results. See https://go.dev/issue/13907.
   316  func TestMulReductions(t *testing.T) {
   317  	// Two short but multi-limb primes.
   318  	a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10)
   319  	b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
   320  	n := new(big.Int).Mul(a, b)
   321  
   322  	N, _ := NewModulusFromBig(n)
   323  	A := NewNat().setBig(a).ExpandFor(N)
   324  	B := NewNat().setBig(b).ExpandFor(N)
   325  
   326  	if A.Mul(B, N).IsZero() != 1 {
   327  		t.Error("a * b mod (a * b) != 0")
   328  	}
   329  
   330  	i := new(big.Int).ModInverse(a, b)
   331  	N, _ = NewModulusFromBig(b)
   332  	A = NewNat().setBig(a).ExpandFor(N)
   333  	I := NewNat().setBig(i).ExpandFor(N)
   334  	one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
   335  
   336  	if A.Mul(I, N).Equal(one) != 1 {
   337  		t.Error("a * inv(a) mod b != 1")
   338  	}
   339  }
   340  
   341  func natBytes(n *Nat) []byte {
   342  	return n.Bytes(maxModulus(uint(len(n.limbs))))
   343  }
   344  
   345  func natFromBytes(b []byte) *Nat {
   346  	// Must not use Nat.SetBytes as it's used in TestSetBytes.
   347  	bb := new(big.Int).SetBytes(b)
   348  	return NewNat().setBig(bb)
   349  }
   350  
   351  func modulusFromBytes(b []byte) *Modulus {
   352  	bb := new(big.Int).SetBytes(b)
   353  	m, _ := NewModulusFromBig(bb)
   354  	return m
   355  }
   356  
   357  // maxModulus returns the biggest modulus that can fit in n limbs.
   358  func maxModulus(n uint) *Modulus {
   359  	b := big.NewInt(1)
   360  	b.Lsh(b, n*_W)
   361  	b.Sub(b, big.NewInt(1))
   362  	m, _ := NewModulusFromBig(b)
   363  	return m
   364  }
   365  
   366  func makeBenchmarkModulus() *Modulus {
   367  	return maxModulus(32)
   368  }
   369  
   370  func makeBenchmarkValue() *Nat {
   371  	x := make([]uint, 32)
   372  	for i := 0; i < 32; i++ {
   373  		x[i]--
   374  	}
   375  	return &Nat{limbs: x}
   376  }
   377  
   378  func makeBenchmarkExponent() []byte {
   379  	e := make([]byte, 256)
   380  	for i := 0; i < 32; i++ {
   381  		e[i] = 0xFF
   382  	}
   383  	return e
   384  }
   385  
   386  func BenchmarkModAdd(b *testing.B) {
   387  	x := makeBenchmarkValue()
   388  	y := makeBenchmarkValue()
   389  	m := makeBenchmarkModulus()
   390  
   391  	b.ResetTimer()
   392  	for i := 0; i < b.N; i++ {
   393  		x.Add(y, m)
   394  	}
   395  }
   396  
   397  func BenchmarkModSub(b *testing.B) {
   398  	x := makeBenchmarkValue()
   399  	y := makeBenchmarkValue()
   400  	m := makeBenchmarkModulus()
   401  
   402  	b.ResetTimer()
   403  	for i := 0; i < b.N; i++ {
   404  		x.Sub(y, m)
   405  	}
   406  }
   407  
   408  func BenchmarkMontgomeryRepr(b *testing.B) {
   409  	x := makeBenchmarkValue()
   410  	m := makeBenchmarkModulus()
   411  
   412  	b.ResetTimer()
   413  	for i := 0; i < b.N; i++ {
   414  		x.montgomeryRepresentation(m)
   415  	}
   416  }
   417  
   418  func BenchmarkMontgomeryMul(b *testing.B) {
   419  	x := makeBenchmarkValue()
   420  	y := makeBenchmarkValue()
   421  	out := makeBenchmarkValue()
   422  	m := makeBenchmarkModulus()
   423  
   424  	b.ResetTimer()
   425  	for i := 0; i < b.N; i++ {
   426  		out.montgomeryMul(x, y, m)
   427  	}
   428  }
   429  
   430  func BenchmarkModMul(b *testing.B) {
   431  	x := makeBenchmarkValue()
   432  	y := makeBenchmarkValue()
   433  	m := makeBenchmarkModulus()
   434  
   435  	b.ResetTimer()
   436  	for i := 0; i < b.N; i++ {
   437  		x.Mul(y, m)
   438  	}
   439  }
   440  
   441  func BenchmarkExpBig(b *testing.B) {
   442  	out := new(big.Int)
   443  	exponentBytes := makeBenchmarkExponent()
   444  	x := new(big.Int).SetBytes(exponentBytes)
   445  	e := new(big.Int).SetBytes(exponentBytes)
   446  	n := new(big.Int).SetBytes(exponentBytes)
   447  	one := new(big.Int).SetUint64(1)
   448  	n.Add(n, one)
   449  
   450  	b.ResetTimer()
   451  	for i := 0; i < b.N; i++ {
   452  		out.Exp(x, e, n)
   453  	}
   454  }
   455  
   456  func BenchmarkExp(b *testing.B) {
   457  	x := makeBenchmarkValue()
   458  	e := makeBenchmarkExponent()
   459  	out := makeBenchmarkValue()
   460  	m := makeBenchmarkModulus()
   461  
   462  	b.ResetTimer()
   463  	for i := 0; i < b.N; i++ {
   464  		out.Exp(x, e, m)
   465  	}
   466  }
   467  
   468  func TestNewModFromBigZero(t *testing.T) {
   469  	expected := "modulus must be >= 0"
   470  	_, err := NewModulusFromBig(big.NewInt(0))
   471  	if err == nil || err.Error() != expected {
   472  		t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
   473  	}
   474  
   475  	expected = "modulus must be odd"
   476  	_, err = NewModulusFromBig(big.NewInt(2))
   477  	if err == nil || err.Error() != expected {
   478  		t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected)
   479  	}
   480  }
   481  

View as plain text