Source file src/crypto/mlkem/mlkem_wycheproof_test.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mlkem_test
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/internal/cryptotest/wycheproof"
    10  	"crypto/internal/fips140/mlkem"
    11  	. "crypto/mlkem"
    12  	"crypto/mlkem/mlkemtest"
    13  	"testing"
    14  )
    15  
    16  func TestKeyGenWycheproof(t *testing.T) {
    17  	for _, file := range []string{
    18  		// mlkem_512_keygen_seed_test omitted - no ML-KEM 512 support.
    19  		"mlkem_768_keygen_seed_test.json",
    20  		"mlkem_1024_keygen_seed_test.json",
    21  	} {
    22  		var testdata wycheproof.MlkemKeygenSeedTestSchemaJson
    23  		wycheproof.LoadVectorFile(t, file, &testdata)
    24  
    25  		for _, tg := range testdata.TestGroups {
    26  			for _, tv := range tg.Tests {
    27  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
    28  					t.Parallel()
    29  					runKeyGenTest(t, tg.ParameterSet, tv)
    30  				})
    31  			}
    32  		}
    33  	}
    34  }
    35  
    36  func runKeyGenTest(t *testing.T, paramSet wycheproof.MLKEMKeyGenTestGroupParameterSet, tv wycheproof.MLKEMKeyGenTestGroupTestsElem) {
    37  	t.Helper()
    38  
    39  	seed := wycheproof.MustDecodeHex(tv.Seed)
    40  	expectedEk := wycheproof.MustDecodeHex(tv.Ek)
    41  	expectedDk := wycheproof.MustDecodeHex(tv.Dk)
    42  
    43  	switch paramSet {
    44  	case wycheproof.MLKEMKeyGenTestGroupParameterSetMLKEM768:
    45  		dk, err := mlkem.NewDecapsulationKey768(seed)
    46  		if err != nil {
    47  			if tv.Result == "valid" {
    48  				t.Fatalf("NewDecapsulationKey768: %v", err)
    49  			}
    50  			return
    51  		}
    52  		if !bytes.Equal(dk.Bytes(), seed) {
    53  			t.Errorf("decapsulation key seed roundtrip mismatch")
    54  		}
    55  		ek := dk.EncapsulationKey()
    56  		if !bytes.Equal(ek.Bytes(), expectedEk) {
    57  			t.Errorf("encapsulation key mismatch")
    58  		}
    59  		if !bytes.Equal(mlkem.TestingOnlyExpandedBytes768(dk), expectedDk) {
    60  			t.Errorf("expanded decapsulation key mismatch")
    61  		}
    62  		ek2, err := mlkem.NewEncapsulationKey768(expectedEk)
    63  		if err != nil {
    64  			t.Fatalf("NewEncapsulationKey768: %v", err)
    65  		}
    66  		if !bytes.Equal(ek2.Bytes(), expectedEk) {
    67  			t.Errorf("encapsulation key roundtrip mismatch")
    68  		}
    69  		k, c := ek.Encapsulate()
    70  		k2, err := dk.Decapsulate(c)
    71  		if err != nil {
    72  			t.Fatalf("Decapsulate: %v", err)
    73  		}
    74  		if !bytes.Equal(k, k2) {
    75  			t.Errorf("encaps/decaps roundtrip key mismatch")
    76  		}
    77  
    78  	case wycheproof.MLKEMKeyGenTestGroupParameterSetMLKEM1024:
    79  		dk, err := mlkem.NewDecapsulationKey1024(seed)
    80  		if err != nil {
    81  			if tv.Result == "valid" {
    82  				t.Fatalf("NewDecapsulationKey1024: %v", err)
    83  			}
    84  			return
    85  		}
    86  		if !bytes.Equal(dk.Bytes(), seed) {
    87  			t.Errorf("decapsulation key seed roundtrip mismatch")
    88  		}
    89  		ek := dk.EncapsulationKey()
    90  		if !bytes.Equal(ek.Bytes(), expectedEk) {
    91  			t.Errorf("encapsulation key mismatch")
    92  		}
    93  		if !bytes.Equal(mlkem.TestingOnlyExpandedBytes1024(dk), expectedDk) {
    94  			t.Errorf("expanded decapsulation key mismatch")
    95  		}
    96  		ek2, err := mlkem.NewEncapsulationKey1024(expectedEk)
    97  		if err != nil {
    98  			t.Fatalf("NewEncapsulationKey1024: %v", err)
    99  		}
   100  		if !bytes.Equal(ek2.Bytes(), expectedEk) {
   101  			t.Errorf("encapsulation key roundtrip mismatch")
   102  		}
   103  		k, c := ek.Encapsulate()
   104  		k2, err := dk.Decapsulate(c)
   105  		if err != nil {
   106  			t.Fatalf("Decapsulate: %v", err)
   107  		}
   108  		if !bytes.Equal(k, k2) {
   109  			t.Errorf("encaps/decaps roundtrip key mismatch")
   110  		}
   111  
   112  	default:
   113  		t.Fatalf("parameter set %s unsupported", paramSet)
   114  	}
   115  }
   116  
   117  func TestMLKEMEncapsWycheproof(t *testing.T) {
   118  	for _, file := range []string{
   119  		// mlkem_512_encaps_test omitted - no ML-KEM 512 support.
   120  		"mlkem_768_encaps_test.json",
   121  		"mlkem_1024_encaps_test.json",
   122  	} {
   123  		var testdata wycheproof.MlkemEncapsTestSchemaJson
   124  		wycheproof.LoadVectorFile(t, file, &testdata)
   125  
   126  		for _, tg := range testdata.TestGroups {
   127  			for _, tv := range tg.Tests {
   128  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
   129  					t.Parallel()
   130  					runEncapsTest(t, tg.ParameterSet, tv)
   131  				})
   132  			}
   133  		}
   134  	}
   135  }
   136  
   137  func runEncapsTest(t *testing.T, paramSet wycheproof.MLKEMEncapsTestGroupParameterSet, tv wycheproof.MLKEMEncapsTestGroupTestsElem) {
   138  	t.Helper()
   139  
   140  	shouldPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
   141  	ekBytes := wycheproof.MustDecodeHex(tv.Ek)
   142  	m := wycheproof.MustDecodeHex(tv.M)
   143  	expectedC := wycheproof.MustDecodeHex(tv.C)
   144  	expectedK := wycheproof.MustDecodeHex(tv.K)
   145  
   146  	switch paramSet {
   147  	case wycheproof.MLKEMEncapsTestGroupParameterSetMLKEM768:
   148  		ek, err := NewEncapsulationKey768(ekBytes)
   149  		if err != nil {
   150  			if shouldPass {
   151  				t.Fatalf("NewEncapsulationKey768: %v", err)
   152  			}
   153  			return
   154  		}
   155  		if !bytes.Equal(ek.Bytes(), ekBytes) {
   156  			t.Errorf("encapsulation key roundtrip mismatch")
   157  		}
   158  		k, c, err := mlkemtest.Encapsulate768(ek, m)
   159  		if err != nil {
   160  			if shouldPass {
   161  				t.Fatalf("Encapsulate768: %v", err)
   162  			}
   163  			return
   164  		}
   165  		if !shouldPass {
   166  			t.Errorf("Encapsulate768 unexpectedly succeeded")
   167  			return
   168  		}
   169  		if !bytes.Equal(c, expectedC) {
   170  			t.Errorf("ciphertext mismatch")
   171  		}
   172  		if !bytes.Equal(k, expectedK) {
   173  			t.Errorf("shared key mismatch")
   174  		}
   175  
   176  	case wycheproof.MLKEMEncapsTestGroupParameterSetMLKEM1024:
   177  		ek, err := NewEncapsulationKey1024(ekBytes)
   178  		if err != nil {
   179  			if shouldPass {
   180  				t.Fatalf("NewEncapsulationKey1024: %v", err)
   181  			}
   182  			return
   183  		}
   184  		if !bytes.Equal(ek.Bytes(), ekBytes) {
   185  			t.Errorf("encapsulation key roundtrip mismatch")
   186  		}
   187  		k, c, err := mlkemtest.Encapsulate1024(ek, m)
   188  		if err != nil {
   189  			if shouldPass {
   190  				t.Fatalf("Encapsulate1024: %v", err)
   191  			}
   192  			return
   193  		}
   194  		if !shouldPass {
   195  			t.Errorf("Encapsulate1024 unexpectedly succeeded")
   196  			return
   197  		}
   198  		if !bytes.Equal(c, expectedC) {
   199  			t.Errorf("ciphertext mismatch")
   200  		}
   201  		if !bytes.Equal(k, expectedK) {
   202  			t.Errorf("shared key mismatch")
   203  		}
   204  
   205  	default:
   206  		t.Fatalf("parameter set %s unsupported", paramSet)
   207  	}
   208  }
   209  
   210  func TestMLKEMDecapsWycheproof(t *testing.T) {
   211  	for _, file := range []string{
   212  		// mlkem_512_test omitted - no ML-KEM 512 support.
   213  		"mlkem_768_test.json",
   214  		"mlkem_1024_test.json",
   215  	} {
   216  		var testdata wycheproof.MlkemTestSchemaJson
   217  		wycheproof.LoadVectorFile(t, file, &testdata)
   218  
   219  		for _, tg := range testdata.TestGroups {
   220  			for _, tv := range tg.Tests {
   221  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
   222  					t.Parallel()
   223  					runDecapsTest(t, tg.ParameterSet, tv)
   224  				})
   225  			}
   226  		}
   227  	}
   228  }
   229  
   230  func runDecapsTest(t *testing.T, paramSet wycheproof.MLKEMTestGroupParameterSet, tv wycheproof.MLKEMTestGroupTestsElem) {
   231  	t.Helper()
   232  
   233  	shouldPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
   234  	seed := wycheproof.MustDecodeHex(tv.Seed)
   235  	ciphertext := wycheproof.MustDecodeHex(tv.C)
   236  	expectedK := wycheproof.MustDecodeHex(tv.K)
   237  
   238  	switch paramSet {
   239  	case wycheproof.MLKEMTestGroupParameterSetMLKEM768:
   240  		dk, err := NewDecapsulationKey768(seed)
   241  		if err != nil {
   242  			if shouldPass {
   243  				t.Fatalf("NewDecapsulationKey768: %v", err)
   244  			}
   245  			return
   246  		}
   247  		if !bytes.Equal(dk.Bytes(), seed) {
   248  			t.Errorf("decapsulation key seed roundtrip mismatch")
   249  		}
   250  		if tv.Ek != nil {
   251  			expectedEk := wycheproof.MustDecodeHex(*tv.Ek)
   252  			if !bytes.Equal(dk.EncapsulationKey().Bytes(), expectedEk) {
   253  				t.Errorf("encapsulation key mismatch")
   254  			}
   255  		}
   256  		k, err := dk.Decapsulate(ciphertext)
   257  		if err != nil {
   258  			if shouldPass {
   259  				t.Fatalf("Decapsulate: %v", err)
   260  			}
   261  			return
   262  		}
   263  		if shouldPass {
   264  			if !bytes.Equal(k, expectedK) {
   265  				t.Errorf("shared key mismatch: got %x, want %x", k, expectedK)
   266  			}
   267  			kFresh, cFresh := dk.EncapsulationKey().Encapsulate()
   268  			kRT, err := dk.Decapsulate(cFresh)
   269  			if err != nil {
   270  				t.Fatalf("Decapsulate of fresh ciphertext: %v", err)
   271  			}
   272  			if !bytes.Equal(kFresh, kRT) {
   273  				t.Errorf("encaps/decaps roundtrip key mismatch")
   274  			}
   275  		}
   276  
   277  	case wycheproof.MLKEMTestGroupParameterSetMLKEM1024:
   278  		dk, err := NewDecapsulationKey1024(seed)
   279  		if err != nil {
   280  			if shouldPass {
   281  				t.Fatalf("NewDecapsulationKey1024: %v", err)
   282  			}
   283  			return
   284  		}
   285  		if !bytes.Equal(dk.Bytes(), seed) {
   286  			t.Errorf("decapsulation key seed roundtrip mismatch")
   287  		}
   288  		if tv.Ek != nil {
   289  			expectedEk := wycheproof.MustDecodeHex(*tv.Ek)
   290  			if !bytes.Equal(dk.EncapsulationKey().Bytes(), expectedEk) {
   291  				t.Errorf("encapsulation key mismatch")
   292  			}
   293  		}
   294  		k, err := dk.Decapsulate(ciphertext)
   295  		if err != nil {
   296  			if shouldPass {
   297  				t.Fatalf("Decapsulate: %v", err)
   298  			}
   299  			return
   300  		}
   301  		if shouldPass {
   302  			if !bytes.Equal(k, expectedK) {
   303  				t.Errorf("shared key mismatch: got %x, want %x", k, expectedK)
   304  			}
   305  			kFresh, cFresh := dk.EncapsulationKey().Encapsulate()
   306  			kRT, err := dk.Decapsulate(cFresh)
   307  			if err != nil {
   308  				t.Fatalf("Decapsulate of fresh ciphertext: %v", err)
   309  			}
   310  			if !bytes.Equal(kFresh, kRT) {
   311  				t.Errorf("encaps/decaps roundtrip key mismatch")
   312  			}
   313  		}
   314  
   315  	default:
   316  		t.Fatalf("parameter set %s unsupported", paramSet)
   317  	}
   318  }
   319  
   320  func TestMLKEMSemiExpandedDecapsWycheproof(t *testing.T) {
   321  	for _, file := range []string{
   322  		// mlkem_512_semi_expanded_decaps_test omitted - no ML-KEM 512 support.
   323  		"mlkem_768_semi_expanded_decaps_test.json",
   324  		"mlkem_1024_semi_expanded_decaps_test.json",
   325  	} {
   326  		var testdata wycheproof.MlkemSemiExpandedDecapsTestSchemaJson
   327  		wycheproof.LoadVectorFile(t, file, &testdata)
   328  
   329  		for _, tg := range testdata.TestGroups {
   330  			for _, tv := range tg.Tests {
   331  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
   332  					t.Parallel()
   333  					runSemiExpandedDecapsTest(t, tg.ParameterSet, tv)
   334  				})
   335  			}
   336  		}
   337  	}
   338  }
   339  
   340  func runSemiExpandedDecapsTest(t *testing.T, paramSet wycheproof.MLKEMDecapsTestGroupParameterSet, tv wycheproof.MLKEMDecapsTestGroupTestsElem) {
   341  	t.Helper()
   342  
   343  	shouldPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
   344  	dkBytes := wycheproof.MustDecodeHex(tv.Dk)
   345  	ciphertext := wycheproof.MustDecodeHex(tv.C)
   346  
   347  	switch paramSet {
   348  	case wycheproof.MLKEMDecapsTestGroupParameterSetMLKEM768:
   349  		dk, err := mlkem.TestingOnlyNewDecapsulationKey768(dkBytes)
   350  		if err != nil {
   351  			if shouldPass {
   352  				t.Fatalf("TestingOnlyNewDecapsulationKey768: %v", err)
   353  			}
   354  			return
   355  		}
   356  		if !bytes.Equal(mlkem.TestingOnlyExpandedBytes768(dk), dkBytes) {
   357  			t.Errorf("expanded decapsulation key roundtrip mismatch")
   358  		}
   359  		k, err := dk.Decapsulate(ciphertext)
   360  		if err != nil {
   361  			if shouldPass {
   362  				t.Fatalf("Decapsulate: %v", err)
   363  			}
   364  			return
   365  		}
   366  		if !shouldPass {
   367  			t.Errorf("Decapsulate unexpectedly succeeded")
   368  			return
   369  		}
   370  		if len(k) != SharedKeySize {
   371  			t.Errorf("shared key has wrong length: got %d, want %d", len(k), SharedKeySize)
   372  		}
   373  		kFresh, cFresh := dk.EncapsulationKey().Encapsulate()
   374  		kRT, err := dk.Decapsulate(cFresh)
   375  		if err != nil {
   376  			t.Fatalf("Decapsulate of fresh ciphertext: %v", err)
   377  		}
   378  		if !bytes.Equal(kFresh, kRT) {
   379  			t.Errorf("encaps/decaps roundtrip key mismatch")
   380  		}
   381  
   382  	case wycheproof.MLKEMDecapsTestGroupParameterSetMLKEM1024:
   383  		dk, err := mlkem.TestingOnlyNewDecapsulationKey1024(dkBytes)
   384  		if err != nil {
   385  			if shouldPass {
   386  				t.Fatalf("TestingOnlyNewDecapsulationKey1024: %v", err)
   387  			}
   388  			return
   389  		}
   390  		if !bytes.Equal(mlkem.TestingOnlyExpandedBytes1024(dk), dkBytes) {
   391  			t.Errorf("expanded decapsulation key roundtrip mismatch")
   392  		}
   393  		k, err := dk.Decapsulate(ciphertext)
   394  		if err != nil {
   395  			if shouldPass {
   396  				t.Fatalf("Decapsulate: %v", err)
   397  			}
   398  			return
   399  		}
   400  		if !shouldPass {
   401  			t.Errorf("Decapsulate unexpectedly succeeded")
   402  			return
   403  		}
   404  		if len(k) != SharedKeySize {
   405  			t.Errorf("shared key has wrong length: got %d, want %d", len(k), SharedKeySize)
   406  		}
   407  		kFresh, cFresh := dk.EncapsulationKey().Encapsulate()
   408  		kRT, err := dk.Decapsulate(cFresh)
   409  		if err != nil {
   410  			t.Fatalf("Decapsulate of fresh ciphertext: %v", err)
   411  		}
   412  		if !bytes.Equal(kFresh, kRT) {
   413  			t.Errorf("encaps/decaps roundtrip key mismatch")
   414  		}
   415  
   416  	default:
   417  		t.Fatalf("parameter set %s unsupported", paramSet)
   418  	}
   419  }
   420  

View as plain text