Source file
src/crypto/mlkem/mlkem_wycheproof_test.go
1
2
3
4
5 package mlkem_test
6
7 import (
8 "bytes"
9 "crypto/internal/cryptotest/wycheproof"
10 "crypto/internal/fips140/mlkem"
11 . "crypto/mlkem"
12 "crypto/mlkem/mlkemtest"
13 "testing"
14 )
15
16 func TestKeyGenWycheproof(t *testing.T) {
17 for _, file := range []string{
18
19 "mlkem_768_keygen_seed_test.json",
20 "mlkem_1024_keygen_seed_test.json",
21 } {
22 var testdata wycheproof.MlkemKeygenSeedTestSchemaJson
23 wycheproof.LoadVectorFile(t, file, &testdata)
24
25 for _, tg := range testdata.TestGroups {
26 for _, tv := range tg.Tests {
27 t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
28 t.Parallel()
29 runKeyGenTest(t, tg.ParameterSet, tv)
30 })
31 }
32 }
33 }
34 }
35
36 func runKeyGenTest(t *testing.T, paramSet wycheproof.MLKEMKeyGenTestGroupParameterSet, tv wycheproof.MLKEMKeyGenTestGroupTestsElem) {
37 t.Helper()
38
39 seed := wycheproof.MustDecodeHex(tv.Seed)
40 expectedEk := wycheproof.MustDecodeHex(tv.Ek)
41 expectedDk := wycheproof.MustDecodeHex(tv.Dk)
42
43 switch paramSet {
44 case wycheproof.MLKEMKeyGenTestGroupParameterSetMLKEM768:
45 dk, err := mlkem.NewDecapsulationKey768(seed)
46 if err != nil {
47 if tv.Result == "valid" {
48 t.Fatalf("NewDecapsulationKey768: %v", err)
49 }
50 return
51 }
52 if !bytes.Equal(dk.Bytes(), seed) {
53 t.Errorf("decapsulation key seed roundtrip mismatch")
54 }
55 ek := dk.EncapsulationKey()
56 if !bytes.Equal(ek.Bytes(), expectedEk) {
57 t.Errorf("encapsulation key mismatch")
58 }
59 if !bytes.Equal(mlkem.TestingOnlyExpandedBytes768(dk), expectedDk) {
60 t.Errorf("expanded decapsulation key mismatch")
61 }
62 ek2, err := mlkem.NewEncapsulationKey768(expectedEk)
63 if err != nil {
64 t.Fatalf("NewEncapsulationKey768: %v", err)
65 }
66 if !bytes.Equal(ek2.Bytes(), expectedEk) {
67 t.Errorf("encapsulation key roundtrip mismatch")
68 }
69 k, c := ek.Encapsulate()
70 k2, err := dk.Decapsulate(c)
71 if err != nil {
72 t.Fatalf("Decapsulate: %v", err)
73 }
74 if !bytes.Equal(k, k2) {
75 t.Errorf("encaps/decaps roundtrip key mismatch")
76 }
77
78 case wycheproof.MLKEMKeyGenTestGroupParameterSetMLKEM1024:
79 dk, err := mlkem.NewDecapsulationKey1024(seed)
80 if err != nil {
81 if tv.Result == "valid" {
82 t.Fatalf("NewDecapsulationKey1024: %v", err)
83 }
84 return
85 }
86 if !bytes.Equal(dk.Bytes(), seed) {
87 t.Errorf("decapsulation key seed roundtrip mismatch")
88 }
89 ek := dk.EncapsulationKey()
90 if !bytes.Equal(ek.Bytes(), expectedEk) {
91 t.Errorf("encapsulation key mismatch")
92 }
93 if !bytes.Equal(mlkem.TestingOnlyExpandedBytes1024(dk), expectedDk) {
94 t.Errorf("expanded decapsulation key mismatch")
95 }
96 ek2, err := mlkem.NewEncapsulationKey1024(expectedEk)
97 if err != nil {
98 t.Fatalf("NewEncapsulationKey1024: %v", err)
99 }
100 if !bytes.Equal(ek2.Bytes(), expectedEk) {
101 t.Errorf("encapsulation key roundtrip mismatch")
102 }
103 k, c := ek.Encapsulate()
104 k2, err := dk.Decapsulate(c)
105 if err != nil {
106 t.Fatalf("Decapsulate: %v", err)
107 }
108 if !bytes.Equal(k, k2) {
109 t.Errorf("encaps/decaps roundtrip key mismatch")
110 }
111
112 default:
113 t.Fatalf("parameter set %s unsupported", paramSet)
114 }
115 }
116
117 func TestMLKEMEncapsWycheproof(t *testing.T) {
118 for _, file := range []string{
119
120 "mlkem_768_encaps_test.json",
121 "mlkem_1024_encaps_test.json",
122 } {
123 var testdata wycheproof.MlkemEncapsTestSchemaJson
124 wycheproof.LoadVectorFile(t, file, &testdata)
125
126 for _, tg := range testdata.TestGroups {
127 for _, tv := range tg.Tests {
128 t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
129 t.Parallel()
130 runEncapsTest(t, tg.ParameterSet, tv)
131 })
132 }
133 }
134 }
135 }
136
137 func runEncapsTest(t *testing.T, paramSet wycheproof.MLKEMEncapsTestGroupParameterSet, tv wycheproof.MLKEMEncapsTestGroupTestsElem) {
138 t.Helper()
139
140 shouldPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
141 ekBytes := wycheproof.MustDecodeHex(tv.Ek)
142 m := wycheproof.MustDecodeHex(tv.M)
143 expectedC := wycheproof.MustDecodeHex(tv.C)
144 expectedK := wycheproof.MustDecodeHex(tv.K)
145
146 switch paramSet {
147 case wycheproof.MLKEMEncapsTestGroupParameterSetMLKEM768:
148 ek, err := NewEncapsulationKey768(ekBytes)
149 if err != nil {
150 if shouldPass {
151 t.Fatalf("NewEncapsulationKey768: %v", err)
152 }
153 return
154 }
155 if !bytes.Equal(ek.Bytes(), ekBytes) {
156 t.Errorf("encapsulation key roundtrip mismatch")
157 }
158 k, c, err := mlkemtest.Encapsulate768(ek, m)
159 if err != nil {
160 if shouldPass {
161 t.Fatalf("Encapsulate768: %v", err)
162 }
163 return
164 }
165 if !shouldPass {
166 t.Errorf("Encapsulate768 unexpectedly succeeded")
167 return
168 }
169 if !bytes.Equal(c, expectedC) {
170 t.Errorf("ciphertext mismatch")
171 }
172 if !bytes.Equal(k, expectedK) {
173 t.Errorf("shared key mismatch")
174 }
175
176 case wycheproof.MLKEMEncapsTestGroupParameterSetMLKEM1024:
177 ek, err := NewEncapsulationKey1024(ekBytes)
178 if err != nil {
179 if shouldPass {
180 t.Fatalf("NewEncapsulationKey1024: %v", err)
181 }
182 return
183 }
184 if !bytes.Equal(ek.Bytes(), ekBytes) {
185 t.Errorf("encapsulation key roundtrip mismatch")
186 }
187 k, c, err := mlkemtest.Encapsulate1024(ek, m)
188 if err != nil {
189 if shouldPass {
190 t.Fatalf("Encapsulate1024: %v", err)
191 }
192 return
193 }
194 if !shouldPass {
195 t.Errorf("Encapsulate1024 unexpectedly succeeded")
196 return
197 }
198 if !bytes.Equal(c, expectedC) {
199 t.Errorf("ciphertext mismatch")
200 }
201 if !bytes.Equal(k, expectedK) {
202 t.Errorf("shared key mismatch")
203 }
204
205 default:
206 t.Fatalf("parameter set %s unsupported", paramSet)
207 }
208 }
209
210 func TestMLKEMDecapsWycheproof(t *testing.T) {
211 for _, file := range []string{
212
213 "mlkem_768_test.json",
214 "mlkem_1024_test.json",
215 } {
216 var testdata wycheproof.MlkemTestSchemaJson
217 wycheproof.LoadVectorFile(t, file, &testdata)
218
219 for _, tg := range testdata.TestGroups {
220 for _, tv := range tg.Tests {
221 t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
222 t.Parallel()
223 runDecapsTest(t, tg.ParameterSet, tv)
224 })
225 }
226 }
227 }
228 }
229
230 func runDecapsTest(t *testing.T, paramSet wycheproof.MLKEMTestGroupParameterSet, tv wycheproof.MLKEMTestGroupTestsElem) {
231 t.Helper()
232
233 shouldPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
234 seed := wycheproof.MustDecodeHex(tv.Seed)
235 ciphertext := wycheproof.MustDecodeHex(tv.C)
236 expectedK := wycheproof.MustDecodeHex(tv.K)
237
238 switch paramSet {
239 case wycheproof.MLKEMTestGroupParameterSetMLKEM768:
240 dk, err := NewDecapsulationKey768(seed)
241 if err != nil {
242 if shouldPass {
243 t.Fatalf("NewDecapsulationKey768: %v", err)
244 }
245 return
246 }
247 if !bytes.Equal(dk.Bytes(), seed) {
248 t.Errorf("decapsulation key seed roundtrip mismatch")
249 }
250 if tv.Ek != nil {
251 expectedEk := wycheproof.MustDecodeHex(*tv.Ek)
252 if !bytes.Equal(dk.EncapsulationKey().Bytes(), expectedEk) {
253 t.Errorf("encapsulation key mismatch")
254 }
255 }
256 k, err := dk.Decapsulate(ciphertext)
257 if err != nil {
258 if shouldPass {
259 t.Fatalf("Decapsulate: %v", err)
260 }
261 return
262 }
263 if shouldPass {
264 if !bytes.Equal(k, expectedK) {
265 t.Errorf("shared key mismatch: got %x, want %x", k, expectedK)
266 }
267 kFresh, cFresh := dk.EncapsulationKey().Encapsulate()
268 kRT, err := dk.Decapsulate(cFresh)
269 if err != nil {
270 t.Fatalf("Decapsulate of fresh ciphertext: %v", err)
271 }
272 if !bytes.Equal(kFresh, kRT) {
273 t.Errorf("encaps/decaps roundtrip key mismatch")
274 }
275 }
276
277 case wycheproof.MLKEMTestGroupParameterSetMLKEM1024:
278 dk, err := NewDecapsulationKey1024(seed)
279 if err != nil {
280 if shouldPass {
281 t.Fatalf("NewDecapsulationKey1024: %v", err)
282 }
283 return
284 }
285 if !bytes.Equal(dk.Bytes(), seed) {
286 t.Errorf("decapsulation key seed roundtrip mismatch")
287 }
288 if tv.Ek != nil {
289 expectedEk := wycheproof.MustDecodeHex(*tv.Ek)
290 if !bytes.Equal(dk.EncapsulationKey().Bytes(), expectedEk) {
291 t.Errorf("encapsulation key mismatch")
292 }
293 }
294 k, err := dk.Decapsulate(ciphertext)
295 if err != nil {
296 if shouldPass {
297 t.Fatalf("Decapsulate: %v", err)
298 }
299 return
300 }
301 if shouldPass {
302 if !bytes.Equal(k, expectedK) {
303 t.Errorf("shared key mismatch: got %x, want %x", k, expectedK)
304 }
305 kFresh, cFresh := dk.EncapsulationKey().Encapsulate()
306 kRT, err := dk.Decapsulate(cFresh)
307 if err != nil {
308 t.Fatalf("Decapsulate of fresh ciphertext: %v", err)
309 }
310 if !bytes.Equal(kFresh, kRT) {
311 t.Errorf("encaps/decaps roundtrip key mismatch")
312 }
313 }
314
315 default:
316 t.Fatalf("parameter set %s unsupported", paramSet)
317 }
318 }
319
320 func TestMLKEMSemiExpandedDecapsWycheproof(t *testing.T) {
321 for _, file := range []string{
322
323 "mlkem_768_semi_expanded_decaps_test.json",
324 "mlkem_1024_semi_expanded_decaps_test.json",
325 } {
326 var testdata wycheproof.MlkemSemiExpandedDecapsTestSchemaJson
327 wycheproof.LoadVectorFile(t, file, &testdata)
328
329 for _, tg := range testdata.TestGroups {
330 for _, tv := range tg.Tests {
331 t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
332 t.Parallel()
333 runSemiExpandedDecapsTest(t, tg.ParameterSet, tv)
334 })
335 }
336 }
337 }
338 }
339
340 func runSemiExpandedDecapsTest(t *testing.T, paramSet wycheproof.MLKEMDecapsTestGroupParameterSet, tv wycheproof.MLKEMDecapsTestGroupTestsElem) {
341 t.Helper()
342
343 shouldPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
344 dkBytes := wycheproof.MustDecodeHex(tv.Dk)
345 ciphertext := wycheproof.MustDecodeHex(tv.C)
346
347 switch paramSet {
348 case wycheproof.MLKEMDecapsTestGroupParameterSetMLKEM768:
349 dk, err := mlkem.TestingOnlyNewDecapsulationKey768(dkBytes)
350 if err != nil {
351 if shouldPass {
352 t.Fatalf("TestingOnlyNewDecapsulationKey768: %v", err)
353 }
354 return
355 }
356 if !bytes.Equal(mlkem.TestingOnlyExpandedBytes768(dk), dkBytes) {
357 t.Errorf("expanded decapsulation key roundtrip mismatch")
358 }
359 k, err := dk.Decapsulate(ciphertext)
360 if err != nil {
361 if shouldPass {
362 t.Fatalf("Decapsulate: %v", err)
363 }
364 return
365 }
366 if !shouldPass {
367 t.Errorf("Decapsulate unexpectedly succeeded")
368 return
369 }
370 if len(k) != SharedKeySize {
371 t.Errorf("shared key has wrong length: got %d, want %d", len(k), SharedKeySize)
372 }
373 kFresh, cFresh := dk.EncapsulationKey().Encapsulate()
374 kRT, err := dk.Decapsulate(cFresh)
375 if err != nil {
376 t.Fatalf("Decapsulate of fresh ciphertext: %v", err)
377 }
378 if !bytes.Equal(kFresh, kRT) {
379 t.Errorf("encaps/decaps roundtrip key mismatch")
380 }
381
382 case wycheproof.MLKEMDecapsTestGroupParameterSetMLKEM1024:
383 dk, err := mlkem.TestingOnlyNewDecapsulationKey1024(dkBytes)
384 if err != nil {
385 if shouldPass {
386 t.Fatalf("TestingOnlyNewDecapsulationKey1024: %v", err)
387 }
388 return
389 }
390 if !bytes.Equal(mlkem.TestingOnlyExpandedBytes1024(dk), dkBytes) {
391 t.Errorf("expanded decapsulation key roundtrip mismatch")
392 }
393 k, err := dk.Decapsulate(ciphertext)
394 if err != nil {
395 if shouldPass {
396 t.Fatalf("Decapsulate: %v", err)
397 }
398 return
399 }
400 if !shouldPass {
401 t.Errorf("Decapsulate unexpectedly succeeded")
402 return
403 }
404 if len(k) != SharedKeySize {
405 t.Errorf("shared key has wrong length: got %d, want %d", len(k), SharedKeySize)
406 }
407 kFresh, cFresh := dk.EncapsulationKey().Encapsulate()
408 kRT, err := dk.Decapsulate(cFresh)
409 if err != nil {
410 t.Fatalf("Decapsulate of fresh ciphertext: %v", err)
411 }
412 if !bytes.Equal(kFresh, kRT) {
413 t.Errorf("encaps/decaps roundtrip key mismatch")
414 }
415
416 default:
417 t.Fatalf("parameter set %s unsupported", paramSet)
418 }
419 }
420
View as plain text