Source file src/crypto/mlkem/mlkem_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mlkem
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/internal/fips140/mlkem"
    10  	"crypto/internal/fips140/sha3"
    11  	"crypto/rand"
    12  	"encoding/hex"
    13  	"flag"
    14  	"testing"
    15  )
    16  
    17  type encapsulationKey interface {
    18  	Bytes() []byte
    19  	Encapsulate() ([]byte, []byte)
    20  }
    21  
    22  type decapsulationKey[E encapsulationKey] interface {
    23  	Bytes() []byte
    24  	Decapsulate([]byte) ([]byte, error)
    25  	EncapsulationKey() E
    26  }
    27  
    28  func TestRoundTrip(t *testing.T) {
    29  	t.Run("768", func(t *testing.T) {
    30  		testRoundTrip(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
    31  	})
    32  	t.Run("1024", func(t *testing.T) {
    33  		testRoundTrip(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
    34  	})
    35  }
    36  
    37  func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
    38  	t *testing.T, generateKey func() (D, error),
    39  	newEncapsulationKey func([]byte) (E, error),
    40  	newDecapsulationKey func([]byte) (D, error)) {
    41  	dk, err := generateKey()
    42  	if err != nil {
    43  		t.Fatal(err)
    44  	}
    45  	ek := dk.EncapsulationKey()
    46  	Ke, c := ek.Encapsulate()
    47  	Kd, err := dk.Decapsulate(c)
    48  	if err != nil {
    49  		t.Fatal(err)
    50  	}
    51  	if !bytes.Equal(Ke, Kd) {
    52  		t.Fail()
    53  	}
    54  
    55  	ek1, err := newEncapsulationKey(ek.Bytes())
    56  	if err != nil {
    57  		t.Fatal(err)
    58  	}
    59  	if !bytes.Equal(ek.Bytes(), ek1.Bytes()) {
    60  		t.Fail()
    61  	}
    62  	dk1, err := newDecapsulationKey(dk.Bytes())
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  	if !bytes.Equal(dk.Bytes(), dk1.Bytes()) {
    67  		t.Fail()
    68  	}
    69  	Ke1, c1 := ek1.Encapsulate()
    70  	Kd1, err := dk1.Decapsulate(c1)
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  	if !bytes.Equal(Ke1, Kd1) {
    75  		t.Fail()
    76  	}
    77  
    78  	dk2, err := generateKey()
    79  	if err != nil {
    80  		t.Fatal(err)
    81  	}
    82  	if bytes.Equal(dk.EncapsulationKey().Bytes(), dk2.EncapsulationKey().Bytes()) {
    83  		t.Fail()
    84  	}
    85  	if bytes.Equal(dk.Bytes(), dk2.Bytes()) {
    86  		t.Fail()
    87  	}
    88  
    89  	Ke2, c2 := dk.EncapsulationKey().Encapsulate()
    90  	if bytes.Equal(c, c2) {
    91  		t.Fail()
    92  	}
    93  	if bytes.Equal(Ke, Ke2) {
    94  		t.Fail()
    95  	}
    96  }
    97  
    98  func TestBadLengths(t *testing.T) {
    99  	t.Run("768", func(t *testing.T) {
   100  		testBadLengths(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
   101  	})
   102  	t.Run("1024", func(t *testing.T) {
   103  		testBadLengths(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
   104  	})
   105  }
   106  
   107  func testBadLengths[E encapsulationKey, D decapsulationKey[E]](
   108  	t *testing.T, generateKey func() (D, error),
   109  	newEncapsulationKey func([]byte) (E, error),
   110  	newDecapsulationKey func([]byte) (D, error)) {
   111  	dk, err := generateKey()
   112  	dkBytes := dk.Bytes()
   113  	if err != nil {
   114  		t.Fatal(err)
   115  	}
   116  	ek := dk.EncapsulationKey()
   117  	ekBytes := dk.EncapsulationKey().Bytes()
   118  	_, c := ek.Encapsulate()
   119  
   120  	for i := 0; i < len(dkBytes)-1; i++ {
   121  		if _, err := newDecapsulationKey(dkBytes[:i]); err == nil {
   122  			t.Errorf("expected error for dk length %d", i)
   123  		}
   124  	}
   125  	dkLong := dkBytes
   126  	for i := 0; i < 100; i++ {
   127  		dkLong = append(dkLong, 0)
   128  		if _, err := newDecapsulationKey(dkLong); err == nil {
   129  			t.Errorf("expected error for dk length %d", len(dkLong))
   130  		}
   131  	}
   132  
   133  	for i := 0; i < len(ekBytes)-1; i++ {
   134  		if _, err := newEncapsulationKey(ekBytes[:i]); err == nil {
   135  			t.Errorf("expected error for ek length %d", i)
   136  		}
   137  	}
   138  	ekLong := ekBytes
   139  	for i := 0; i < 100; i++ {
   140  		ekLong = append(ekLong, 0)
   141  		if _, err := newEncapsulationKey(ekLong); err == nil {
   142  			t.Errorf("expected error for ek length %d", len(ekLong))
   143  		}
   144  	}
   145  
   146  	for i := 0; i < len(c)-1; i++ {
   147  		if _, err := dk.Decapsulate(c[:i]); err == nil {
   148  			t.Errorf("expected error for c length %d", i)
   149  		}
   150  	}
   151  	cLong := c
   152  	for i := 0; i < 100; i++ {
   153  		cLong = append(cLong, 0)
   154  		if _, err := dk.Decapsulate(cLong); err == nil {
   155  			t.Errorf("expected error for c length %d", len(cLong))
   156  		}
   157  	}
   158  }
   159  
   160  var millionFlag = flag.Bool("million", false, "run the million vector test")
   161  
   162  // TestAccumulated accumulates 10k (or 100, or 1M) random vectors and checks the
   163  // hash of the result, to avoid checking in 150MB of test vectors.
   164  func TestAccumulated(t *testing.T) {
   165  	n := 10000
   166  	expected := "8a518cc63da366322a8e7a818c7a0d63483cb3528d34a4cf42f35d5ad73f22fc"
   167  	if testing.Short() {
   168  		n = 100
   169  		expected = "1114b1b6699ed191734fa339376afa7e285c9e6acf6ff0177d346696ce564415"
   170  	}
   171  	if *millionFlag {
   172  		n = 1000000
   173  		expected = "424bf8f0e8ae99b78d788a6e2e8e9cdaf9773fc0c08a6f433507cb559edfd0f0"
   174  	}
   175  
   176  	s := sha3.NewShake128()
   177  	o := sha3.NewShake128()
   178  	seed := make([]byte, SeedSize)
   179  	var msg [32]byte
   180  	ct1 := make([]byte, CiphertextSize768)
   181  
   182  	for i := 0; i < n; i++ {
   183  		s.Read(seed)
   184  		dk, err := NewDecapsulationKey768(seed)
   185  		if err != nil {
   186  			t.Fatal(err)
   187  		}
   188  		ek := dk.EncapsulationKey()
   189  		o.Write(ek.Bytes())
   190  
   191  		s.Read(msg[:])
   192  		k, ct := ek.key.EncapsulateInternal(&msg)
   193  		o.Write(ct)
   194  		o.Write(k)
   195  
   196  		kk, err := dk.Decapsulate(ct)
   197  		if err != nil {
   198  			t.Fatal(err)
   199  		}
   200  		if !bytes.Equal(kk, k) {
   201  			t.Errorf("k: got %x, expected %x", kk, k)
   202  		}
   203  
   204  		s.Read(ct1)
   205  		k1, err := dk.Decapsulate(ct1)
   206  		if err != nil {
   207  			t.Fatal(err)
   208  		}
   209  		o.Write(k1)
   210  	}
   211  
   212  	got := hex.EncodeToString(o.Sum(nil))
   213  	if got != expected {
   214  		t.Errorf("got %s, expected %s", got, expected)
   215  	}
   216  }
   217  
   218  var sink byte
   219  
   220  func BenchmarkKeyGen(b *testing.B) {
   221  	var d, z [32]byte
   222  	rand.Read(d[:])
   223  	rand.Read(z[:])
   224  	b.ResetTimer()
   225  	for i := 0; i < b.N; i++ {
   226  		dk := mlkem.GenerateKeyInternal768(&d, &z)
   227  		sink ^= dk.EncapsulationKey().Bytes()[0]
   228  	}
   229  }
   230  
   231  func BenchmarkEncaps(b *testing.B) {
   232  	seed := make([]byte, SeedSize)
   233  	rand.Read(seed)
   234  	var m [32]byte
   235  	rand.Read(m[:])
   236  	dk, err := NewDecapsulationKey768(seed)
   237  	if err != nil {
   238  		b.Fatal(err)
   239  	}
   240  	ekBytes := dk.EncapsulationKey().Bytes()
   241  	b.ResetTimer()
   242  	for i := 0; i < b.N; i++ {
   243  		ek, err := NewEncapsulationKey768(ekBytes)
   244  		if err != nil {
   245  			b.Fatal(err)
   246  		}
   247  		K, c := ek.key.EncapsulateInternal(&m)
   248  		sink ^= c[0] ^ K[0]
   249  	}
   250  }
   251  
   252  func BenchmarkDecaps(b *testing.B) {
   253  	dk, err := GenerateKey768()
   254  	if err != nil {
   255  		b.Fatal(err)
   256  	}
   257  	ek := dk.EncapsulationKey()
   258  	_, c := ek.Encapsulate()
   259  	b.ResetTimer()
   260  	for i := 0; i < b.N; i++ {
   261  		K, _ := dk.Decapsulate(c)
   262  		sink ^= K[0]
   263  	}
   264  }
   265  
   266  func BenchmarkRoundTrip(b *testing.B) {
   267  	dk, err := GenerateKey768()
   268  	if err != nil {
   269  		b.Fatal(err)
   270  	}
   271  	ek := dk.EncapsulationKey()
   272  	ekBytes := ek.Bytes()
   273  	_, c := ek.Encapsulate()
   274  	if err != nil {
   275  		b.Fatal(err)
   276  	}
   277  	b.Run("Alice", func(b *testing.B) {
   278  		for i := 0; i < b.N; i++ {
   279  			dkS, err := GenerateKey768()
   280  			if err != nil {
   281  				b.Fatal(err)
   282  			}
   283  			ekS := dkS.EncapsulationKey().Bytes()
   284  			sink ^= ekS[0]
   285  
   286  			Ks, err := dk.Decapsulate(c)
   287  			if err != nil {
   288  				b.Fatal(err)
   289  			}
   290  			sink ^= Ks[0]
   291  		}
   292  	})
   293  	b.Run("Bob", func(b *testing.B) {
   294  		for i := 0; i < b.N; i++ {
   295  			ek, err := NewEncapsulationKey768(ekBytes)
   296  			if err != nil {
   297  				b.Fatal(err)
   298  			}
   299  			Ks, cS := ek.Encapsulate()
   300  			if err != nil {
   301  				b.Fatal(err)
   302  			}
   303  			sink ^= cS[0] ^ Ks[0]
   304  		}
   305  	})
   306  }
   307  
   308  // Test that the constants from the public API match the corresponding values from the internal API.
   309  func TestConstantSizes(t *testing.T) {
   310  	if SharedKeySize != mlkem.SharedKeySize {
   311  		t.Errorf("SharedKeySize mismatch: got %d, want %d", SharedKeySize, mlkem.SharedKeySize)
   312  	}
   313  
   314  	if SeedSize != mlkem.SeedSize {
   315  		t.Errorf("SeedSize mismatch: got %d, want %d", SeedSize, mlkem.SeedSize)
   316  	}
   317  
   318  	if CiphertextSize768 != mlkem.CiphertextSize768 {
   319  		t.Errorf("CiphertextSize768 mismatch: got %d, want %d", CiphertextSize768, mlkem.CiphertextSize768)
   320  	}
   321  
   322  	if EncapsulationKeySize768 != mlkem.EncapsulationKeySize768 {
   323  		t.Errorf("EncapsulationKeySize768 mismatch: got %d, want %d", EncapsulationKeySize768, mlkem.EncapsulationKeySize768)
   324  	}
   325  
   326  	if CiphertextSize1024 != mlkem.CiphertextSize1024 {
   327  		t.Errorf("CiphertextSize1024 mismatch: got %d, want %d", CiphertextSize1024, mlkem.CiphertextSize1024)
   328  	}
   329  
   330  	if EncapsulationKeySize1024 != mlkem.EncapsulationKeySize1024 {
   331  		t.Errorf("EncapsulationKeySize1024 mismatch: got %d, want %d", EncapsulationKeySize1024, mlkem.EncapsulationKeySize1024)
   332  	}
   333  }
   334  

View as plain text