Source file src/crypto/mlkem/mlkem_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mlkem_test
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/internal/fips140/mlkem"
    10  	"crypto/internal/fips140/sha3"
    11  	. "crypto/mlkem"
    12  	"crypto/mlkem/mlkemtest"
    13  	"crypto/rand"
    14  	"encoding/hex"
    15  	"flag"
    16  	"testing"
    17  )
    18  
    19  type encapsulationKey interface {
    20  	Bytes() []byte
    21  	Encapsulate() ([]byte, []byte)
    22  }
    23  
    24  type decapsulationKey[E encapsulationKey] interface {
    25  	Bytes() []byte
    26  	Decapsulate([]byte) ([]byte, error)
    27  	EncapsulationKey() E
    28  }
    29  
    30  func TestRoundTrip(t *testing.T) {
    31  	t.Run("768", func(t *testing.T) {
    32  		testRoundTrip(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
    33  	})
    34  	t.Run("1024", func(t *testing.T) {
    35  		testRoundTrip(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
    36  	})
    37  }
    38  
    39  func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
    40  	t *testing.T, generateKey func() (D, error),
    41  	newEncapsulationKey func([]byte) (E, error),
    42  	newDecapsulationKey func([]byte) (D, error)) {
    43  	dk, err := generateKey()
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  	ek := dk.EncapsulationKey()
    48  	Ke, c := ek.Encapsulate()
    49  	Kd, err := dk.Decapsulate(c)
    50  	if err != nil {
    51  		t.Fatal(err)
    52  	}
    53  	if !bytes.Equal(Ke, Kd) {
    54  		t.Fail()
    55  	}
    56  
    57  	ek1, err := newEncapsulationKey(ek.Bytes())
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  	if !bytes.Equal(ek.Bytes(), ek1.Bytes()) {
    62  		t.Fail()
    63  	}
    64  	dk1, err := newDecapsulationKey(dk.Bytes())
    65  	if err != nil {
    66  		t.Fatal(err)
    67  	}
    68  	if !bytes.Equal(dk.Bytes(), dk1.Bytes()) {
    69  		t.Fail()
    70  	}
    71  	Ke1, c1 := ek1.Encapsulate()
    72  	Kd1, err := dk1.Decapsulate(c1)
    73  	if err != nil {
    74  		t.Fatal(err)
    75  	}
    76  	if !bytes.Equal(Ke1, Kd1) {
    77  		t.Fail()
    78  	}
    79  
    80  	dk2, err := generateKey()
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	if bytes.Equal(dk.EncapsulationKey().Bytes(), dk2.EncapsulationKey().Bytes()) {
    85  		t.Fail()
    86  	}
    87  	if bytes.Equal(dk.Bytes(), dk2.Bytes()) {
    88  		t.Fail()
    89  	}
    90  
    91  	Ke2, c2 := dk.EncapsulationKey().Encapsulate()
    92  	if bytes.Equal(c, c2) {
    93  		t.Fail()
    94  	}
    95  	if bytes.Equal(Ke, Ke2) {
    96  		t.Fail()
    97  	}
    98  }
    99  
   100  func TestBadLengths(t *testing.T) {
   101  	t.Run("768", func(t *testing.T) {
   102  		testBadLengths(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
   103  	})
   104  	t.Run("1024", func(t *testing.T) {
   105  		testBadLengths(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
   106  	})
   107  }
   108  
   109  func testBadLengths[E encapsulationKey, D decapsulationKey[E]](
   110  	t *testing.T, generateKey func() (D, error),
   111  	newEncapsulationKey func([]byte) (E, error),
   112  	newDecapsulationKey func([]byte) (D, error)) {
   113  	dk, err := generateKey()
   114  	dkBytes := dk.Bytes()
   115  	if err != nil {
   116  		t.Fatal(err)
   117  	}
   118  	ek := dk.EncapsulationKey()
   119  	ekBytes := dk.EncapsulationKey().Bytes()
   120  	_, c := ek.Encapsulate()
   121  
   122  	for i := 0; i < len(dkBytes)-1; i++ {
   123  		if _, err := newDecapsulationKey(dkBytes[:i]); err == nil {
   124  			t.Errorf("expected error for dk length %d", i)
   125  		}
   126  	}
   127  	dkLong := dkBytes
   128  	for i := 0; i < 100; i++ {
   129  		dkLong = append(dkLong, 0)
   130  		if _, err := newDecapsulationKey(dkLong); err == nil {
   131  			t.Errorf("expected error for dk length %d", len(dkLong))
   132  		}
   133  	}
   134  
   135  	for i := 0; i < len(ekBytes)-1; i++ {
   136  		if _, err := newEncapsulationKey(ekBytes[:i]); err == nil {
   137  			t.Errorf("expected error for ek length %d", i)
   138  		}
   139  	}
   140  	ekLong := ekBytes
   141  	for i := 0; i < 100; i++ {
   142  		ekLong = append(ekLong, 0)
   143  		if _, err := newEncapsulationKey(ekLong); err == nil {
   144  			t.Errorf("expected error for ek length %d", len(ekLong))
   145  		}
   146  	}
   147  
   148  	for i := 0; i < len(c)-1; i++ {
   149  		if _, err := dk.Decapsulate(c[:i]); err == nil {
   150  			t.Errorf("expected error for c length %d", i)
   151  		}
   152  	}
   153  	cLong := c
   154  	for i := 0; i < 100; i++ {
   155  		cLong = append(cLong, 0)
   156  		if _, err := dk.Decapsulate(cLong); err == nil {
   157  			t.Errorf("expected error for c length %d", len(cLong))
   158  		}
   159  	}
   160  }
   161  
   162  var millionFlag = flag.Bool("million", false, "run the million vector test")
   163  
   164  // TestAccumulated accumulates 10k (or 100, or 1M) random vectors and checks the
   165  // hash of the result, to avoid checking in 150MB of test vectors.
   166  func TestAccumulated(t *testing.T) {
   167  	n := 10000
   168  	expected := "8a518cc63da366322a8e7a818c7a0d63483cb3528d34a4cf42f35d5ad73f22fc"
   169  	if testing.Short() {
   170  		n = 100
   171  		expected = "1114b1b6699ed191734fa339376afa7e285c9e6acf6ff0177d346696ce564415"
   172  	}
   173  	if *millionFlag {
   174  		n = 1000000
   175  		expected = "424bf8f0e8ae99b78d788a6e2e8e9cdaf9773fc0c08a6f433507cb559edfd0f0"
   176  	}
   177  
   178  	s := sha3.NewShake128()
   179  	o := sha3.NewShake128()
   180  	seed := make([]byte, SeedSize)
   181  	msg := make([]byte, 32)
   182  	ct1 := make([]byte, CiphertextSize768)
   183  
   184  	for i := 0; i < n; i++ {
   185  		s.Read(seed)
   186  		dk, err := NewDecapsulationKey768(seed)
   187  		if err != nil {
   188  			t.Fatal(err)
   189  		}
   190  		ek := dk.EncapsulationKey()
   191  		o.Write(ek.Bytes())
   192  
   193  		s.Read(msg)
   194  		k, ct, err := mlkemtest.Encapsulate768(ek, msg)
   195  		if err != nil {
   196  			t.Fatal(err)
   197  		}
   198  		o.Write(ct)
   199  		o.Write(k)
   200  
   201  		kk, err := dk.Decapsulate(ct)
   202  		if err != nil {
   203  			t.Fatal(err)
   204  		}
   205  		if !bytes.Equal(kk, k) {
   206  			t.Errorf("k: got %x, expected %x", kk, k)
   207  		}
   208  
   209  		s.Read(ct1)
   210  		k1, err := dk.Decapsulate(ct1)
   211  		if err != nil {
   212  			t.Fatal(err)
   213  		}
   214  		o.Write(k1)
   215  	}
   216  
   217  	got := hex.EncodeToString(o.Sum(nil))
   218  	if got != expected {
   219  		t.Errorf("got %s, expected %s", got, expected)
   220  	}
   221  }
   222  
   223  var sink byte
   224  
   225  func BenchmarkKeyGen(b *testing.B) {
   226  	var d, z [32]byte
   227  	rand.Read(d[:])
   228  	rand.Read(z[:])
   229  	b.ResetTimer()
   230  	for i := 0; i < b.N; i++ {
   231  		dk := mlkem.GenerateKeyInternal768(&d, &z)
   232  		sink ^= dk.EncapsulationKey().Bytes()[0]
   233  	}
   234  }
   235  
   236  func BenchmarkEncaps(b *testing.B) {
   237  	seed := make([]byte, SeedSize)
   238  	rand.Read(seed)
   239  	dk, err := NewDecapsulationKey768(seed)
   240  	if err != nil {
   241  		b.Fatal(err)
   242  	}
   243  	ekBytes := dk.EncapsulationKey().Bytes()
   244  	b.ResetTimer()
   245  	for i := 0; i < b.N; i++ {
   246  		ek, err := NewEncapsulationKey768(ekBytes)
   247  		if err != nil {
   248  			b.Fatal(err)
   249  		}
   250  		K, c := ek.Encapsulate()
   251  		sink ^= c[0] ^ K[0]
   252  	}
   253  }
   254  
   255  func BenchmarkDecaps(b *testing.B) {
   256  	dk, err := GenerateKey768()
   257  	if err != nil {
   258  		b.Fatal(err)
   259  	}
   260  	ek := dk.EncapsulationKey()
   261  	_, c := ek.Encapsulate()
   262  	b.ResetTimer()
   263  	for i := 0; i < b.N; i++ {
   264  		K, _ := dk.Decapsulate(c)
   265  		sink ^= K[0]
   266  	}
   267  }
   268  
   269  func BenchmarkRoundTrip(b *testing.B) {
   270  	dk, err := GenerateKey768()
   271  	if err != nil {
   272  		b.Fatal(err)
   273  	}
   274  	ek := dk.EncapsulationKey()
   275  	ekBytes := ek.Bytes()
   276  	_, c := ek.Encapsulate()
   277  	if err != nil {
   278  		b.Fatal(err)
   279  	}
   280  	b.Run("Alice", func(b *testing.B) {
   281  		for i := 0; i < b.N; i++ {
   282  			dkS, err := GenerateKey768()
   283  			if err != nil {
   284  				b.Fatal(err)
   285  			}
   286  			ekS := dkS.EncapsulationKey().Bytes()
   287  			sink ^= ekS[0]
   288  
   289  			Ks, err := dk.Decapsulate(c)
   290  			if err != nil {
   291  				b.Fatal(err)
   292  			}
   293  			sink ^= Ks[0]
   294  		}
   295  	})
   296  	b.Run("Bob", func(b *testing.B) {
   297  		for i := 0; i < b.N; i++ {
   298  			ek, err := NewEncapsulationKey768(ekBytes)
   299  			if err != nil {
   300  				b.Fatal(err)
   301  			}
   302  			Ks, cS := ek.Encapsulate()
   303  			if err != nil {
   304  				b.Fatal(err)
   305  			}
   306  			sink ^= cS[0] ^ Ks[0]
   307  		}
   308  	})
   309  }
   310  
   311  // Test that the constants from the public API match the corresponding values from the internal API.
   312  func TestConstantSizes(t *testing.T) {
   313  	if SharedKeySize != mlkem.SharedKeySize {
   314  		t.Errorf("SharedKeySize mismatch: got %d, want %d", SharedKeySize, mlkem.SharedKeySize)
   315  	}
   316  
   317  	if SeedSize != mlkem.SeedSize {
   318  		t.Errorf("SeedSize mismatch: got %d, want %d", SeedSize, mlkem.SeedSize)
   319  	}
   320  
   321  	if CiphertextSize768 != mlkem.CiphertextSize768 {
   322  		t.Errorf("CiphertextSize768 mismatch: got %d, want %d", CiphertextSize768, mlkem.CiphertextSize768)
   323  	}
   324  
   325  	if EncapsulationKeySize768 != mlkem.EncapsulationKeySize768 {
   326  		t.Errorf("EncapsulationKeySize768 mismatch: got %d, want %d", EncapsulationKeySize768, mlkem.EncapsulationKeySize768)
   327  	}
   328  
   329  	if CiphertextSize1024 != mlkem.CiphertextSize1024 {
   330  		t.Errorf("CiphertextSize1024 mismatch: got %d, want %d", CiphertextSize1024, mlkem.CiphertextSize1024)
   331  	}
   332  
   333  	if EncapsulationKeySize1024 != mlkem.EncapsulationKeySize1024 {
   334  		t.Errorf("EncapsulationKeySize1024 mismatch: got %d, want %d", EncapsulationKeySize1024, mlkem.EncapsulationKeySize1024)
   335  	}
   336  }
   337  

View as plain text