1
2
3
4
5
6
7
8
9 package mlkem
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 import (
27 "bytes"
28 "crypto/internal/fips140"
29 "crypto/internal/fips140/drbg"
30 "crypto/internal/fips140/sha3"
31 "crypto/internal/fips140/subtle"
32 "errors"
33 )
34
35 const (
36
37 n = 256
38 q = 3329
39
40
41
42 encodingSize12 = n * 12 / 8
43 encodingSize11 = n * 11 / 8
44 encodingSize10 = n * 10 / 8
45 encodingSize5 = n * 5 / 8
46 encodingSize4 = n * 4 / 8
47 encodingSize1 = n * 1 / 8
48
49 messageSize = encodingSize1
50
51 SharedKeySize = 32
52 SeedSize = 32 + 32
53 )
54
55
56 const (
57 k = 3
58
59 CiphertextSize768 = k*encodingSize10 + encodingSize4
60 EncapsulationKeySize768 = k*encodingSize12 + 32
61 decapsulationKeySize768 = k*encodingSize12 + EncapsulationKeySize768 + 32 + 32
62 )
63
64
65 const (
66 k1024 = 4
67
68 CiphertextSize1024 = k1024*encodingSize11 + encodingSize5
69 EncapsulationKeySize1024 = k1024*encodingSize12 + 32
70 decapsulationKeySize1024 = k1024*encodingSize12 + EncapsulationKeySize1024 + 32 + 32
71 )
72
73
74
75 type DecapsulationKey768 struct {
76 d [32]byte
77 z [32]byte
78
79 ρ [32]byte
80 h [32]byte
81
82 encryptionKey
83 decryptionKey
84 }
85
86
87
88
89 func (dk *DecapsulationKey768) Bytes() []byte {
90 var b [SeedSize]byte
91 copy(b[:], dk.d[:])
92 copy(b[32:], dk.z[:])
93 return b[:]
94 }
95
96
97
98
99
100
101 func TestingOnlyExpandedBytes768(dk *DecapsulationKey768) []byte {
102 b := make([]byte, 0, decapsulationKeySize768)
103
104
105 for i := range dk.s {
106 b = polyByteEncode(b, dk.s[i])
107 }
108
109
110 for i := range dk.t {
111 b = polyByteEncode(b, dk.t[i])
112 }
113 b = append(b, dk.ρ[:]...)
114
115
116 b = append(b, dk.h[:]...)
117 b = append(b, dk.z[:]...)
118
119 return b
120 }
121
122
123
124 func (dk *DecapsulationKey768) EncapsulationKey() *EncapsulationKey768 {
125 return &EncapsulationKey768{
126 ρ: dk.ρ,
127 h: dk.h,
128 encryptionKey: dk.encryptionKey,
129 }
130 }
131
132
133
134 type EncapsulationKey768 struct {
135 ρ [32]byte
136 h [32]byte
137 encryptionKey
138 }
139
140
141 func (ek *EncapsulationKey768) Bytes() []byte {
142
143 b := make([]byte, 0, EncapsulationKeySize768)
144 return ek.bytes(b)
145 }
146
147 func (ek *EncapsulationKey768) bytes(b []byte) []byte {
148 for i := range ek.t {
149 b = polyByteEncode(b, ek.t[i])
150 }
151 b = append(b, ek.ρ[:]...)
152 return b
153 }
154
155
156 type encryptionKey struct {
157 t [k]nttElement
158 a [k * k]nttElement
159 }
160
161
162 type decryptionKey struct {
163 s [k]nttElement
164 }
165
166
167
168 func GenerateKey768() (*DecapsulationKey768, error) {
169
170 dk := &DecapsulationKey768{}
171 return generateKey(dk)
172 }
173
174 func generateKey(dk *DecapsulationKey768) (*DecapsulationKey768, error) {
175 var d [32]byte
176 drbg.Read(d[:])
177 var z [32]byte
178 drbg.Read(z[:])
179 kemKeyGen(dk, &d, &z)
180 if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) }); err != nil {
181
182 panic(err)
183 }
184 fips140.RecordApproved()
185 return dk, nil
186 }
187
188
189
190 func GenerateKeyInternal768(d, z *[32]byte) *DecapsulationKey768 {
191 dk := &DecapsulationKey768{}
192 kemKeyGen(dk, d, z)
193 return dk
194 }
195
196
197
198 func NewDecapsulationKey768(seed []byte) (*DecapsulationKey768, error) {
199
200 dk := &DecapsulationKey768{}
201 return newKeyFromSeed(dk, seed)
202 }
203
204 func newKeyFromSeed(dk *DecapsulationKey768, seed []byte) (*DecapsulationKey768, error) {
205 if len(seed) != SeedSize {
206 return nil, errors.New("mlkem: invalid seed length")
207 }
208 d := (*[32]byte)(seed[:32])
209 z := (*[32]byte)(seed[32:])
210 kemKeyGen(dk, d, z)
211 if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) }); err != nil {
212
213 panic(err)
214 }
215 fips140.RecordApproved()
216 return dk, nil
217 }
218
219
220
221
222
223
224
225
226 func TestingOnlyNewDecapsulationKey768(b []byte) (*DecapsulationKey768, error) {
227 if len(b) != decapsulationKeySize768 {
228 return nil, errors.New("mlkem: invalid NIST decapsulation key length")
229 }
230
231 dk := &DecapsulationKey768{}
232 for i := range dk.s {
233 var err error
234 dk.s[i], err = polyByteDecode[nttElement](b[:encodingSize12])
235 if err != nil {
236 return nil, errors.New("mlkem: invalid secret key encoding")
237 }
238 b = b[encodingSize12:]
239 }
240
241 ek, err := NewEncapsulationKey768(b[:EncapsulationKeySize768])
242 if err != nil {
243 return nil, err
244 }
245 dk.ρ = ek.ρ
246 dk.h = ek.h
247 dk.encryptionKey = ek.encryptionKey
248 b = b[EncapsulationKeySize768:]
249
250 if !bytes.Equal(dk.h[:], b[:32]) {
251 return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
252 }
253 b = b[32:]
254
255 copy(dk.z[:], b)
256
257
258
259
260
261 drbg.Read(dk.d[:])
262
263 return dk, nil
264 }
265
266
267
268
269
270
271 func kemKeyGen(dk *DecapsulationKey768, d, z *[32]byte) {
272 dk.d = *d
273 dk.z = *z
274
275 g := sha3.New512()
276 g.Write(d[:])
277 g.Write([]byte{k})
278 G := g.Sum(make([]byte, 0, 64))
279 ρ, σ := G[:32], G[32:]
280 dk.ρ = [32]byte(ρ)
281
282 A := &dk.a
283 for i := byte(0); i < k; i++ {
284 for j := byte(0); j < k; j++ {
285 A[i*k+j] = sampleNTT(ρ, j, i)
286 }
287 }
288
289 var N byte
290 s := &dk.s
291 for i := range s {
292 s[i] = ntt(samplePolyCBD(σ, N))
293 N++
294 }
295 e := make([]nttElement, k)
296 for i := range e {
297 e[i] = ntt(samplePolyCBD(σ, N))
298 N++
299 }
300
301 t := &dk.t
302 for i := range t {
303 t[i] = e[i]
304 for j := range s {
305 t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j]))
306 }
307 }
308
309 H := sha3.New256()
310 ek := dk.EncapsulationKey().Bytes()
311 H.Write(ek)
312 H.Sum(dk.h[:0])
313 }
314
315
316
317
318
319
320
321
322
323 func kemPCT(dk *DecapsulationKey768) error {
324 ek := dk.EncapsulationKey()
325 K, c := ek.Encapsulate()
326 K1, err := dk.Decapsulate(c)
327 if err != nil {
328 return err
329 }
330 if subtle.ConstantTimeCompare(K, K1) != 1 {
331 return errors.New("mlkem: PCT failed")
332 }
333 return nil
334 }
335
336
337
338
339
340 func (ek *EncapsulationKey768) Encapsulate() (sharedKey, ciphertext []byte) {
341
342 var cc [CiphertextSize768]byte
343 return ek.encapsulate(&cc)
344 }
345
346 func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (sharedKey, ciphertext []byte) {
347 var m [messageSize]byte
348 drbg.Read(m[:])
349
350
351 fips140.RecordApproved()
352 return kemEncaps(cc, ek, &m)
353 }
354
355
356
357 func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
358 cc := &[CiphertextSize768]byte{}
359 return kemEncaps(cc, ek, m)
360 }
361
362
363
364
365 func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (K, c []byte) {
366 g := sha3.New512()
367 g.Write(m[:])
368 g.Write(ek.h[:])
369 G := g.Sum(nil)
370 K, r := G[:SharedKeySize], G[SharedKeySize:]
371 c = pkeEncrypt(cc, &ek.encryptionKey, m, r)
372 return K, c
373 }
374
375
376
377 func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, error) {
378
379 ek := &EncapsulationKey768{}
380 return parseEK(ek, encapsulationKey)
381 }
382
383
384
385
386
387 func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) {
388 if len(ekPKE) != EncapsulationKeySize768 {
389 return nil, errors.New("mlkem: invalid encapsulation key length")
390 }
391
392 h := sha3.New256()
393 h.Write(ekPKE)
394 h.Sum(ek.h[:0])
395
396 for i := range ek.t {
397 var err error
398 ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
399 if err != nil {
400 return nil, err
401 }
402 ekPKE = ekPKE[encodingSize12:]
403 }
404 copy(ek.ρ[:], ekPKE)
405
406 for i := byte(0); i < k; i++ {
407 for j := byte(0); j < k; j++ {
408 ek.a[i*k+j] = sampleNTT(ek.ρ[:], j, i)
409 }
410 }
411
412 return ek, nil
413 }
414
415
416
417
418
419 func pkeEncrypt(cc *[CiphertextSize768]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte {
420 var N byte
421 r, e1 := make([]nttElement, k), make([]ringElement, k)
422 for i := range r {
423 r[i] = ntt(samplePolyCBD(rnd, N))
424 N++
425 }
426 for i := range e1 {
427 e1[i] = samplePolyCBD(rnd, N)
428 N++
429 }
430 e2 := samplePolyCBD(rnd, N)
431
432 u := make([]ringElement, k)
433 for i := range u {
434 u[i] = e1[i]
435 for j := range r {
436
437 u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k+i], r[j])))
438 }
439 }
440
441 μ := ringDecodeAndDecompress1(m)
442
443 var vNTT nttElement
444 for i := range ex.t {
445 vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
446 }
447 v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
448
449 c := cc[:0]
450 for _, f := range u {
451 c = ringCompressAndEncode10(c, f)
452 }
453 c = ringCompressAndEncode4(c, v)
454
455 return c
456 }
457
458
459
460
461
462 func (dk *DecapsulationKey768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
463 if len(ciphertext) != CiphertextSize768 {
464 return nil, errors.New("mlkem: invalid ciphertext length")
465 }
466 c := (*[CiphertextSize768]byte)(ciphertext)
467
468
469
470 return kemDecaps(dk, c), nil
471 }
472
473
474
475
476 func kemDecaps(dk *DecapsulationKey768, c *[CiphertextSize768]byte) (K []byte) {
477 fips140.RecordApproved()
478 m := pkeDecrypt(&dk.decryptionKey, c)
479 g := sha3.New512()
480 g.Write(m[:])
481 g.Write(dk.h[:])
482 G := g.Sum(make([]byte, 0, 64))
483 Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
484 J := sha3.NewShake256()
485 J.Write(dk.z[:])
486 J.Write(c[:])
487 Kout := make([]byte, SharedKeySize)
488 J.Read(Kout)
489 var cc [CiphertextSize768]byte
490 c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r)
491
492 subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
493 return Kout
494 }
495
496
497
498
499
500 func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize768]byte) []byte {
501 u := make([]ringElement, k)
502 for i := range u {
503 b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)])
504 u[i] = ringDecodeAndDecompress10(b)
505 }
506
507 b := (*[encodingSize4]byte)(c[encodingSize10*k:])
508 v := ringDecodeAndDecompress4(b)
509
510 var mask nttElement
511 for i := range dx.s {
512 mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
513 }
514 w := polySub(v, inverseNTT(mask))
515
516 return ringCompressAndEncode1(nil, w)
517 }
518
View as plain text