1
2
3
4
5 package mlkem768
6
7 import (
8 "bytes"
9 "crypto/rand"
10 _ "embed"
11 "encoding/hex"
12 "errors"
13 "flag"
14 "math/big"
15 "strconv"
16 "testing"
17
18 "golang.org/x/crypto/sha3"
19 )
20
21 func TestFieldReduce(t *testing.T) {
22 for a := uint32(0); a < 2*q*q; a++ {
23 got := fieldReduce(a)
24 exp := fieldElement(a % q)
25 if got != exp {
26 t.Fatalf("reduce(%d) = %d, expected %d", a, got, exp)
27 }
28 }
29 }
30
31 func TestFieldAdd(t *testing.T) {
32 for a := fieldElement(0); a < q; a++ {
33 for b := fieldElement(0); b < q; b++ {
34 got := fieldAdd(a, b)
35 exp := (a + b) % q
36 if got != exp {
37 t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
38 }
39 }
40 }
41 }
42
43 func TestFieldSub(t *testing.T) {
44 for a := fieldElement(0); a < q; a++ {
45 for b := fieldElement(0); b < q; b++ {
46 got := fieldSub(a, b)
47 exp := (a - b + q) % q
48 if got != exp {
49 t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
50 }
51 }
52 }
53 }
54
55 func TestFieldMul(t *testing.T) {
56 for a := fieldElement(0); a < q; a++ {
57 for b := fieldElement(0); b < q; b++ {
58 got := fieldMul(a, b)
59 exp := fieldElement((uint32(a) * uint32(b)) % q)
60 if got != exp {
61 t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
62 }
63 }
64 }
65 }
66
67 func TestDecompressCompress(t *testing.T) {
68 for _, bits := range []uint8{1, 4, 10} {
69 for a := uint16(0); a < 1<<bits; a++ {
70 f := decompress(a, bits)
71 if f >= q {
72 t.Fatalf("decompress(%d, %d) = %d >= q", a, bits, f)
73 }
74 got := compress(f, bits)
75 if got != a {
76 t.Fatalf("compress(decompress(%d, %d), %d) = %d", a, bits, bits, got)
77 }
78 }
79
80 for a := fieldElement(0); a < q; a++ {
81 c := compress(a, bits)
82 if c >= 1<<bits {
83 t.Fatalf("compress(%d, %d) = %d >= 2^bits", a, bits, c)
84 }
85 got := decompress(c, bits)
86 diff := min(a-got, got-a, a-got+q, got-a+q)
87 ceil := q / (1 << bits)
88 if diff > fieldElement(ceil) {
89 t.Fatalf("decompress(compress(%d, %d), %d) = %d (diff %d, max diff %d)",
90 a, bits, bits, got, diff, ceil)
91 }
92 }
93 }
94 }
95
96 func CompressRat(x fieldElement, d uint8) uint16 {
97 if x >= q {
98 panic("x out of range")
99 }
100 if d <= 0 || d >= 12 {
101 panic("d out of range")
102 }
103
104 precise := big.NewRat((1<<d)*int64(x), q)
105
106
107
108 rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
109 if err != nil {
110 panic(err)
111 }
112
113
114 return uint16(rounded % (1 << d))
115 }
116
117 func TestCompress(t *testing.T) {
118 for d := 1; d < 12; d++ {
119 for n := 0; n < q; n++ {
120 expected := CompressRat(fieldElement(n), uint8(d))
121 result := compress(fieldElement(n), uint8(d))
122 if result != expected {
123 t.Errorf("compress(%d, %d): got %d, expected %d", n, d, result, expected)
124 }
125 }
126 }
127 }
128
129 func DecompressRat(y uint16, d uint8) fieldElement {
130 if y >= 1<<d {
131 panic("y out of range")
132 }
133 if d <= 0 || d >= 12 {
134 panic("d out of range")
135 }
136
137 precise := big.NewRat(q*int64(y), 1<<d)
138
139
140
141 rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
142 if err != nil {
143 panic(err)
144 }
145
146
147 return fieldElement(rounded % q)
148 }
149
150 func TestDecompress(t *testing.T) {
151 for d := 1; d < 12; d++ {
152 for n := 0; n < (1 << d); n++ {
153 expected := DecompressRat(uint16(n), uint8(d))
154 result := decompress(uint16(n), uint8(d))
155 if result != expected {
156 t.Errorf("decompress(%d, %d): got %d, expected %d", n, d, result, expected)
157 }
158 }
159 }
160 }
161
162 func BitRev7(n uint8) uint8 {
163 if n>>7 != 0 {
164 panic("not 7 bits")
165 }
166 var r uint8
167 r |= n >> 6 & 0b0000_0001
168 r |= n >> 4 & 0b0000_0010
169 r |= n >> 2 & 0b0000_0100
170 r |= n & 0b0000_1000
171 r |= n << 2 & 0b0001_0000
172 r |= n << 4 & 0b0010_0000
173 r |= n << 6 & 0b0100_0000
174 return r
175 }
176
177 func TestZetas(t *testing.T) {
178 ζ := big.NewInt(17)
179 q := big.NewInt(q)
180 for k, zeta := range zetas {
181
182 exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))), q)
183 if big.NewInt(int64(zeta)).Cmp(exp) != 0 {
184 t.Errorf("zetas[%d] = %v, expected %v", k, zeta, exp)
185 }
186 }
187 }
188
189 func TestGammas(t *testing.T) {
190 ζ := big.NewInt(17)
191 q := big.NewInt(q)
192 for k, gamma := range gammas {
193
194 exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))*2+1), q)
195 if big.NewInt(int64(gamma)).Cmp(exp) != 0 {
196 t.Errorf("gammas[%d] = %v, expected %v", k, gamma, exp)
197 }
198 }
199 }
200
201 func TestRoundTrip(t *testing.T) {
202 dk, err := GenerateKey()
203 if err != nil {
204 t.Fatal(err)
205 }
206 c, Ke, err := Encapsulate(dk.EncapsulationKey())
207 if err != nil {
208 t.Fatal(err)
209 }
210 Kd, err := Decapsulate(dk, c)
211 if err != nil {
212 t.Fatal(err)
213 }
214 if !bytes.Equal(Ke, Kd) {
215 t.Fail()
216 }
217
218 dk1, err := GenerateKey()
219 if err != nil {
220 t.Fatal(err)
221 }
222 if bytes.Equal(dk.EncapsulationKey(), dk1.EncapsulationKey()) {
223 t.Fail()
224 }
225 if bytes.Equal(dk.Bytes(), dk1.Bytes()) {
226 t.Fail()
227 }
228 if bytes.Equal(dk.Bytes()[EncapsulationKeySize-32:], dk1.Bytes()[EncapsulationKeySize-32:]) {
229 t.Fail()
230 }
231
232 c1, Ke1, err := Encapsulate(dk.EncapsulationKey())
233 if err != nil {
234 t.Fatal(err)
235 }
236 if bytes.Equal(c, c1) {
237 t.Fail()
238 }
239 if bytes.Equal(Ke, Ke1) {
240 t.Fail()
241 }
242 }
243
244 func TestBadLengths(t *testing.T) {
245 dk, err := GenerateKey()
246 if err != nil {
247 t.Fatal(err)
248 }
249 ek := dk.EncapsulationKey()
250
251 for i := 0; i < len(ek)-1; i++ {
252 if _, _, err := Encapsulate(ek[:i]); err == nil {
253 t.Errorf("expected error for ek length %d", i)
254 }
255 }
256 ekLong := ek
257 for i := 0; i < 100; i++ {
258 ekLong = append(ekLong, 0)
259 if _, _, err := Encapsulate(ekLong); err == nil {
260 t.Errorf("expected error for ek length %d", len(ekLong))
261 }
262 }
263
264 c, _, err := Encapsulate(ek)
265 if err != nil {
266 t.Fatal(err)
267 }
268
269 for i := 0; i < len(dk.Bytes())-1; i++ {
270 if _, err := NewKeyFromExtendedEncoding(dk.Bytes()[:i]); err == nil {
271 t.Errorf("expected error for dk length %d", i)
272 }
273 }
274 dkLong := dk.Bytes()
275 for i := 0; i < 100; i++ {
276 dkLong = append(dkLong, 0)
277 if _, err := NewKeyFromExtendedEncoding(dkLong); err == nil {
278 t.Errorf("expected error for dk length %d", len(dkLong))
279 }
280 }
281
282 for i := 0; i < len(c)-1; i++ {
283 if _, err := Decapsulate(dk, c[:i]); err == nil {
284 t.Errorf("expected error for c length %d", i)
285 }
286 }
287 cLong := c
288 for i := 0; i < 100; i++ {
289 cLong = append(cLong, 0)
290 if _, err := Decapsulate(dk, cLong); err == nil {
291 t.Errorf("expected error for c length %d", len(cLong))
292 }
293 }
294 }
295
296 func EncapsulateDerand(ek, m []byte) (c, K []byte, err error) {
297 if len(m) != messageSize {
298 return nil, nil, errors.New("bad message length")
299 }
300 return kemEncaps(nil, ek, (*[messageSize]byte)(m))
301 }
302
303 func DecapsulateFromBytes(dkBytes []byte, c []byte) ([]byte, error) {
304 dk, err := NewKeyFromExtendedEncoding(dkBytes)
305 if err != nil {
306 return nil, err
307 }
308 return Decapsulate(dk, c)
309 }
310
311 func GenerateKeyDerand(t testing.TB, d, z []byte) ([]byte, *DecapsulationKey) {
312 if len(d) != 32 || len(z) != 32 {
313 t.Fatal("bad length")
314 }
315 dk := kemKeyGen(nil, (*[32]byte)(d), (*[32]byte)(z))
316 return dk.EncapsulationKey(), dk
317 }
318
319 var millionFlag = flag.Bool("million", false, "run the million vector test")
320
321
322
323
324 func TestPQCrystalsAccumulated(t *testing.T) {
325 n := 10000
326 expected := "f7db260e1137a742e05fe0db9525012812b004d29040a5b606aad3d134b548d3"
327 if testing.Short() {
328 n = 100
329 expected = "8d0c478ead6037897a0da6be21e5399545babf5fc6dd10c061c99b7dee2bf0dc"
330 }
331 if *millionFlag {
332 n = 1000000
333 expected = "70090cc5842aad0ec43d5042c783fae9bc320c047b5dafcb6e134821db02384d"
334 }
335
336 s := sha3.NewShake128()
337 o := sha3.NewShake128()
338 d := make([]byte, 32)
339 z := make([]byte, 32)
340 msg := make([]byte, 32)
341 ct1 := make([]byte, CiphertextSize)
342
343 for i := 0; i < n; i++ {
344 s.Read(d)
345 s.Read(z)
346 ek, dk := GenerateKeyDerand(t, d, z)
347 o.Write(ek)
348 o.Write(dk.Bytes())
349
350 s.Read(msg)
351 ct, k, err := EncapsulateDerand(ek, msg)
352 if err != nil {
353 t.Fatal(err)
354 }
355 o.Write(ct)
356 o.Write(k)
357
358 kk, err := Decapsulate(dk, ct)
359 if err != nil {
360 t.Fatal(err)
361 }
362 if !bytes.Equal(kk, k) {
363 t.Errorf("k: got %x, expected %x", kk, k)
364 }
365
366 s.Read(ct1)
367 k1, err := Decapsulate(dk, ct1)
368 if err != nil {
369 t.Fatal(err)
370 }
371 o.Write(k1)
372 }
373
374 got := hex.EncodeToString(o.Sum(nil))
375 if got != expected {
376 t.Errorf("got %s, expected %s", got, expected)
377 }
378 }
379
380 var sink byte
381
382 func BenchmarkKeyGen(b *testing.B) {
383 var dk DecapsulationKey
384 var d, z [32]byte
385 rand.Read(d[:])
386 rand.Read(z[:])
387 b.ResetTimer()
388 for i := 0; i < b.N; i++ {
389 dk := kemKeyGen(&dk, &d, &z)
390 sink ^= dk.EncapsulationKey()[0]
391 }
392 }
393
394 func BenchmarkEncaps(b *testing.B) {
395 d := make([]byte, 32)
396 rand.Read(d)
397 z := make([]byte, 32)
398 rand.Read(z)
399 var m [messageSize]byte
400 rand.Read(m[:])
401 ek, _ := GenerateKeyDerand(b, d, z)
402 var c [CiphertextSize]byte
403 b.ResetTimer()
404 for i := 0; i < b.N; i++ {
405 c, K, err := kemEncaps(&c, ek, &m)
406 if err != nil {
407 b.Fatal(err)
408 }
409 sink ^= c[0] ^ K[0]
410 }
411 }
412
413 func BenchmarkDecaps(b *testing.B) {
414 d := make([]byte, 32)
415 rand.Read(d)
416 z := make([]byte, 32)
417 rand.Read(z)
418 m := make([]byte, 32)
419 rand.Read(m)
420 ek, dk := GenerateKeyDerand(b, d, z)
421 c, _, err := EncapsulateDerand(ek, m)
422 if err != nil {
423 b.Fatal(err)
424 }
425 b.ResetTimer()
426 for i := 0; i < b.N; i++ {
427 K := kemDecaps(dk, (*[CiphertextSize]byte)(c))
428 sink ^= K[0]
429 }
430 }
431
432 func BenchmarkRoundTrip(b *testing.B) {
433 dk, err := GenerateKey()
434 if err != nil {
435 b.Fatal(err)
436 }
437 ek := dk.EncapsulationKey()
438 c, _, err := Encapsulate(ek)
439 if err != nil {
440 b.Fatal(err)
441 }
442 b.Run("Alice", func(b *testing.B) {
443 for i := 0; i < b.N; i++ {
444 dkS, err := GenerateKey()
445 if err != nil {
446 b.Fatal(err)
447 }
448 ekS := dkS.EncapsulationKey()
449 sink ^= ekS[0]
450
451 Ks, err := Decapsulate(dk, c)
452 if err != nil {
453 b.Fatal(err)
454 }
455 sink ^= Ks[0]
456 }
457 })
458 b.Run("Bob", func(b *testing.B) {
459 for i := 0; i < b.N; i++ {
460 cS, Ks, err := Encapsulate(ek)
461 if err != nil {
462 b.Fatal(err)
463 }
464 sink ^= cS[0] ^ Ks[0]
465 }
466 })
467 }
468
View as plain text