1
2
3
4
5 package mlkem_test
6
7 import (
8 "bytes"
9 "crypto/internal/fips140/mlkem"
10 "crypto/internal/fips140/sha3"
11 . "crypto/mlkem"
12 "crypto/mlkem/mlkemtest"
13 "crypto/rand"
14 "encoding/hex"
15 "flag"
16 "testing"
17 )
18
19 type encapsulationKey interface {
20 Bytes() []byte
21 Encapsulate() ([]byte, []byte)
22 }
23
24 type decapsulationKey[E encapsulationKey] interface {
25 Bytes() []byte
26 Decapsulate([]byte) ([]byte, error)
27 EncapsulationKey() E
28 }
29
30 func TestRoundTrip(t *testing.T) {
31 t.Run("768", func(t *testing.T) {
32 testRoundTrip(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
33 })
34 t.Run("1024", func(t *testing.T) {
35 testRoundTrip(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
36 })
37 }
38
39 func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
40 t *testing.T, generateKey func() (D, error),
41 newEncapsulationKey func([]byte) (E, error),
42 newDecapsulationKey func([]byte) (D, error)) {
43 dk, err := generateKey()
44 if err != nil {
45 t.Fatal(err)
46 }
47 ek := dk.EncapsulationKey()
48 Ke, c := ek.Encapsulate()
49 Kd, err := dk.Decapsulate(c)
50 if err != nil {
51 t.Fatal(err)
52 }
53 if !bytes.Equal(Ke, Kd) {
54 t.Fail()
55 }
56
57 ek1, err := newEncapsulationKey(ek.Bytes())
58 if err != nil {
59 t.Fatal(err)
60 }
61 if !bytes.Equal(ek.Bytes(), ek1.Bytes()) {
62 t.Fail()
63 }
64 dk1, err := newDecapsulationKey(dk.Bytes())
65 if err != nil {
66 t.Fatal(err)
67 }
68 if !bytes.Equal(dk.Bytes(), dk1.Bytes()) {
69 t.Fail()
70 }
71 Ke1, c1 := ek1.Encapsulate()
72 Kd1, err := dk1.Decapsulate(c1)
73 if err != nil {
74 t.Fatal(err)
75 }
76 if !bytes.Equal(Ke1, Kd1) {
77 t.Fail()
78 }
79
80 dk2, err := generateKey()
81 if err != nil {
82 t.Fatal(err)
83 }
84 if bytes.Equal(dk.EncapsulationKey().Bytes(), dk2.EncapsulationKey().Bytes()) {
85 t.Fail()
86 }
87 if bytes.Equal(dk.Bytes(), dk2.Bytes()) {
88 t.Fail()
89 }
90
91 Ke2, c2 := dk.EncapsulationKey().Encapsulate()
92 if bytes.Equal(c, c2) {
93 t.Fail()
94 }
95 if bytes.Equal(Ke, Ke2) {
96 t.Fail()
97 }
98 }
99
100 func TestBadLengths(t *testing.T) {
101 t.Run("768", func(t *testing.T) {
102 testBadLengths(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
103 })
104 t.Run("1024", func(t *testing.T) {
105 testBadLengths(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
106 })
107 }
108
109 func testBadLengths[E encapsulationKey, D decapsulationKey[E]](
110 t *testing.T, generateKey func() (D, error),
111 newEncapsulationKey func([]byte) (E, error),
112 newDecapsulationKey func([]byte) (D, error)) {
113 dk, err := generateKey()
114 dkBytes := dk.Bytes()
115 if err != nil {
116 t.Fatal(err)
117 }
118 ek := dk.EncapsulationKey()
119 ekBytes := dk.EncapsulationKey().Bytes()
120 _, c := ek.Encapsulate()
121
122 for i := 0; i < len(dkBytes)-1; i++ {
123 if _, err := newDecapsulationKey(dkBytes[:i]); err == nil {
124 t.Errorf("expected error for dk length %d", i)
125 }
126 }
127 dkLong := dkBytes
128 for i := 0; i < 100; i++ {
129 dkLong = append(dkLong, 0)
130 if _, err := newDecapsulationKey(dkLong); err == nil {
131 t.Errorf("expected error for dk length %d", len(dkLong))
132 }
133 }
134
135 for i := 0; i < len(ekBytes)-1; i++ {
136 if _, err := newEncapsulationKey(ekBytes[:i]); err == nil {
137 t.Errorf("expected error for ek length %d", i)
138 }
139 }
140 ekLong := ekBytes
141 for i := 0; i < 100; i++ {
142 ekLong = append(ekLong, 0)
143 if _, err := newEncapsulationKey(ekLong); err == nil {
144 t.Errorf("expected error for ek length %d", len(ekLong))
145 }
146 }
147
148 for i := 0; i < len(c)-1; i++ {
149 if _, err := dk.Decapsulate(c[:i]); err == nil {
150 t.Errorf("expected error for c length %d", i)
151 }
152 }
153 cLong := c
154 for i := 0; i < 100; i++ {
155 cLong = append(cLong, 0)
156 if _, err := dk.Decapsulate(cLong); err == nil {
157 t.Errorf("expected error for c length %d", len(cLong))
158 }
159 }
160 }
161
162 var millionFlag = flag.Bool("million", false, "run the million vector test")
163
164
165
166 func TestAccumulated(t *testing.T) {
167 n := 10000
168 expected := "8a518cc63da366322a8e7a818c7a0d63483cb3528d34a4cf42f35d5ad73f22fc"
169 if testing.Short() {
170 n = 100
171 expected = "1114b1b6699ed191734fa339376afa7e285c9e6acf6ff0177d346696ce564415"
172 }
173 if *millionFlag {
174 n = 1000000
175 expected = "424bf8f0e8ae99b78d788a6e2e8e9cdaf9773fc0c08a6f433507cb559edfd0f0"
176 }
177
178 s := sha3.NewShake128()
179 o := sha3.NewShake128()
180 seed := make([]byte, SeedSize)
181 msg := make([]byte, 32)
182 ct1 := make([]byte, CiphertextSize768)
183
184 for i := 0; i < n; i++ {
185 s.Read(seed)
186 dk, err := NewDecapsulationKey768(seed)
187 if err != nil {
188 t.Fatal(err)
189 }
190 ek := dk.EncapsulationKey()
191 o.Write(ek.Bytes())
192
193 s.Read(msg)
194 k, ct, err := mlkemtest.Encapsulate768(ek, msg)
195 if err != nil {
196 t.Fatal(err)
197 }
198 o.Write(ct)
199 o.Write(k)
200
201 kk, err := dk.Decapsulate(ct)
202 if err != nil {
203 t.Fatal(err)
204 }
205 if !bytes.Equal(kk, k) {
206 t.Errorf("k: got %x, expected %x", kk, k)
207 }
208
209 s.Read(ct1)
210 k1, err := dk.Decapsulate(ct1)
211 if err != nil {
212 t.Fatal(err)
213 }
214 o.Write(k1)
215 }
216
217 got := hex.EncodeToString(o.Sum(nil))
218 if got != expected {
219 t.Errorf("got %s, expected %s", got, expected)
220 }
221 }
222
223 var sink byte
224
225 func BenchmarkKeyGen(b *testing.B) {
226 var d, z [32]byte
227 rand.Read(d[:])
228 rand.Read(z[:])
229 b.ResetTimer()
230 for i := 0; i < b.N; i++ {
231 dk := mlkem.GenerateKeyInternal768(&d, &z)
232 sink ^= dk.EncapsulationKey().Bytes()[0]
233 }
234 }
235
236 func BenchmarkEncaps(b *testing.B) {
237 seed := make([]byte, SeedSize)
238 rand.Read(seed)
239 dk, err := NewDecapsulationKey768(seed)
240 if err != nil {
241 b.Fatal(err)
242 }
243 ekBytes := dk.EncapsulationKey().Bytes()
244 b.ResetTimer()
245 for i := 0; i < b.N; i++ {
246 ek, err := NewEncapsulationKey768(ekBytes)
247 if err != nil {
248 b.Fatal(err)
249 }
250 K, c := ek.Encapsulate()
251 sink ^= c[0] ^ K[0]
252 }
253 }
254
255 func BenchmarkDecaps(b *testing.B) {
256 dk, err := GenerateKey768()
257 if err != nil {
258 b.Fatal(err)
259 }
260 ek := dk.EncapsulationKey()
261 _, c := ek.Encapsulate()
262 b.ResetTimer()
263 for i := 0; i < b.N; i++ {
264 K, _ := dk.Decapsulate(c)
265 sink ^= K[0]
266 }
267 }
268
269 func BenchmarkRoundTrip(b *testing.B) {
270 dk, err := GenerateKey768()
271 if err != nil {
272 b.Fatal(err)
273 }
274 ek := dk.EncapsulationKey()
275 ekBytes := ek.Bytes()
276 _, c := ek.Encapsulate()
277 if err != nil {
278 b.Fatal(err)
279 }
280 b.Run("Alice", func(b *testing.B) {
281 for i := 0; i < b.N; i++ {
282 dkS, err := GenerateKey768()
283 if err != nil {
284 b.Fatal(err)
285 }
286 ekS := dkS.EncapsulationKey().Bytes()
287 sink ^= ekS[0]
288
289 Ks, err := dk.Decapsulate(c)
290 if err != nil {
291 b.Fatal(err)
292 }
293 sink ^= Ks[0]
294 }
295 })
296 b.Run("Bob", func(b *testing.B) {
297 for i := 0; i < b.N; i++ {
298 ek, err := NewEncapsulationKey768(ekBytes)
299 if err != nil {
300 b.Fatal(err)
301 }
302 Ks, cS := ek.Encapsulate()
303 if err != nil {
304 b.Fatal(err)
305 }
306 sink ^= cS[0] ^ Ks[0]
307 }
308 })
309 }
310
311
312 func TestConstantSizes(t *testing.T) {
313 if SharedKeySize != mlkem.SharedKeySize {
314 t.Errorf("SharedKeySize mismatch: got %d, want %d", SharedKeySize, mlkem.SharedKeySize)
315 }
316
317 if SeedSize != mlkem.SeedSize {
318 t.Errorf("SeedSize mismatch: got %d, want %d", SeedSize, mlkem.SeedSize)
319 }
320
321 if CiphertextSize768 != mlkem.CiphertextSize768 {
322 t.Errorf("CiphertextSize768 mismatch: got %d, want %d", CiphertextSize768, mlkem.CiphertextSize768)
323 }
324
325 if EncapsulationKeySize768 != mlkem.EncapsulationKeySize768 {
326 t.Errorf("EncapsulationKeySize768 mismatch: got %d, want %d", EncapsulationKeySize768, mlkem.EncapsulationKeySize768)
327 }
328
329 if CiphertextSize1024 != mlkem.CiphertextSize1024 {
330 t.Errorf("CiphertextSize1024 mismatch: got %d, want %d", CiphertextSize1024, mlkem.CiphertextSize1024)
331 }
332
333 if EncapsulationKeySize1024 != mlkem.EncapsulationKeySize1024 {
334 t.Errorf("EncapsulationKeySize1024 mismatch: got %d, want %d", EncapsulationKeySize1024, mlkem.EncapsulationKeySize1024)
335 }
336 }
337
View as plain text