1
2
3 package mlkem
4
5 import (
6 "bytes"
7 "crypto/internal/fips140"
8 "crypto/internal/fips140/drbg"
9 "crypto/internal/fips140/sha3"
10 "crypto/internal/fips140/subtle"
11 "errors"
12 )
13
14
15
16 type DecapsulationKey1024 struct {
17 d [32]byte
18 z [32]byte
19
20 ρ [32]byte
21 h [32]byte
22
23 encryptionKey1024
24 decryptionKey1024
25 }
26
27
28
29
30 func (dk *DecapsulationKey1024) Bytes() []byte {
31 var b [SeedSize]byte
32 copy(b[:], dk.d[:])
33 copy(b[32:], dk.z[:])
34 return b[:]
35 }
36
37
38
39
40
41
42 func TestingOnlyExpandedBytes1024(dk *DecapsulationKey1024) []byte {
43 b := make([]byte, 0, decapsulationKeySize1024)
44
45
46 for i := range dk.s {
47 b = polyByteEncode(b, dk.s[i])
48 }
49
50
51 for i := range dk.t {
52 b = polyByteEncode(b, dk.t[i])
53 }
54 b = append(b, dk.ρ[:]...)
55
56
57 b = append(b, dk.h[:]...)
58 b = append(b, dk.z[:]...)
59
60 return b
61 }
62
63
64
65 func (dk *DecapsulationKey1024) EncapsulationKey() *EncapsulationKey1024 {
66 return &EncapsulationKey1024{
67 ρ: dk.ρ,
68 h: dk.h,
69 encryptionKey1024: dk.encryptionKey1024,
70 }
71 }
72
73
74
75 type EncapsulationKey1024 struct {
76 ρ [32]byte
77 h [32]byte
78 encryptionKey1024
79 }
80
81
82 func (ek *EncapsulationKey1024) Bytes() []byte {
83
84 b := make([]byte, 0, EncapsulationKeySize1024)
85 return ek.bytes(b)
86 }
87
88 func (ek *EncapsulationKey1024) bytes(b []byte) []byte {
89 for i := range ek.t {
90 b = polyByteEncode(b, ek.t[i])
91 }
92 b = append(b, ek.ρ[:]...)
93 return b
94 }
95
96
97 type encryptionKey1024 struct {
98 t [k1024]nttElement
99 a [k1024 * k1024]nttElement
100 }
101
102
103 type decryptionKey1024 struct {
104 s [k1024]nttElement
105 }
106
107
108
109 func GenerateKey1024() (*DecapsulationKey1024, error) {
110
111 dk := &DecapsulationKey1024{}
112 return generateKey1024(dk)
113 }
114
115 func generateKey1024(dk *DecapsulationKey1024) (*DecapsulationKey1024, error) {
116 var d [32]byte
117 drbg.Read(d[:])
118 var z [32]byte
119 drbg.Read(z[:])
120 kemKeyGen1024(dk, &d, &z)
121 fips140.PCT("ML-KEM PCT", func() error { return kemPCT1024(dk) })
122 fips140.RecordApproved()
123 return dk, nil
124 }
125
126
127
128 func GenerateKeyInternal1024(d, z *[32]byte) *DecapsulationKey1024 {
129 dk := &DecapsulationKey1024{}
130 kemKeyGen1024(dk, d, z)
131 return dk
132 }
133
134
135
136 func NewDecapsulationKey1024(seed []byte) (*DecapsulationKey1024, error) {
137
138 dk := &DecapsulationKey1024{}
139 return newKeyFromSeed1024(dk, seed)
140 }
141
142 func newKeyFromSeed1024(dk *DecapsulationKey1024, seed []byte) (*DecapsulationKey1024, error) {
143 if len(seed) != SeedSize {
144 return nil, errors.New("mlkem: invalid seed length")
145 }
146 d := (*[32]byte)(seed[:32])
147 z := (*[32]byte)(seed[32:])
148 kemKeyGen1024(dk, d, z)
149 fips140.RecordApproved()
150 return dk, nil
151 }
152
153
154
155
156
157
158
159
160 func TestingOnlyNewDecapsulationKey1024(b []byte) (*DecapsulationKey1024, error) {
161 if len(b) != decapsulationKeySize1024 {
162 return nil, errors.New("mlkem: invalid NIST decapsulation key length")
163 }
164
165 dk := &DecapsulationKey1024{}
166 for i := range dk.s {
167 var err error
168 dk.s[i], err = polyByteDecode[nttElement](b[:encodingSize12])
169 if err != nil {
170 return nil, errors.New("mlkem: invalid secret key encoding")
171 }
172 b = b[encodingSize12:]
173 }
174
175 ek, err := NewEncapsulationKey1024(b[:EncapsulationKeySize1024])
176 if err != nil {
177 return nil, err
178 }
179 dk.ρ = ek.ρ
180 dk.h = ek.h
181 dk.encryptionKey1024 = ek.encryptionKey1024
182 b = b[EncapsulationKeySize1024:]
183
184 if !bytes.Equal(dk.h[:], b[:32]) {
185 return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
186 }
187 b = b[32:]
188
189 copy(dk.z[:], b)
190
191
192
193
194
195 drbg.Read(dk.d[:])
196
197 return dk, nil
198 }
199
200
201
202
203
204
205 func kemKeyGen1024(dk *DecapsulationKey1024, d, z *[32]byte) {
206 dk.d = *d
207 dk.z = *z
208
209 g := sha3.New512()
210 g.Write(d[:])
211 g.Write([]byte{k1024})
212 G := g.Sum(make([]byte, 0, 64))
213 ρ, σ := G[:32], G[32:]
214 dk.ρ = [32]byte(ρ)
215
216 A := &dk.a
217 for i := byte(0); i < k1024; i++ {
218 for j := byte(0); j < k1024; j++ {
219 A[i*k1024+j] = sampleNTT(ρ, j, i)
220 }
221 }
222
223 var N byte
224 s := &dk.s
225 for i := range s {
226 s[i] = ntt(samplePolyCBD(σ, N))
227 N++
228 }
229 e := make([]nttElement, k1024)
230 for i := range e {
231 e[i] = ntt(samplePolyCBD(σ, N))
232 N++
233 }
234
235 t := &dk.t
236 for i := range t {
237 t[i] = e[i]
238 for j := range s {
239 t[i] = polyAdd(t[i], nttMul(A[i*k1024+j], s[j]))
240 }
241 }
242
243 H := sha3.New256()
244 ek := dk.EncapsulationKey().Bytes()
245 H.Write(ek)
246 H.Sum(dk.h[:0])
247 }
248
249
250
251
252
253
254
255
256
257 func kemPCT1024(dk *DecapsulationKey1024) error {
258 ek := dk.EncapsulationKey()
259 K, c := ek.Encapsulate()
260 K1, err := dk.Decapsulate(c)
261 if err != nil {
262 return err
263 }
264 if subtle.ConstantTimeCompare(K, K1) != 1 {
265 return errors.New("mlkem: PCT failed")
266 }
267 return nil
268 }
269
270
271
272
273
274 func (ek *EncapsulationKey1024) Encapsulate() (sharedKey, ciphertext []byte) {
275
276 var cc [CiphertextSize1024]byte
277 return ek.encapsulate(&cc)
278 }
279
280 func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (sharedKey, ciphertext []byte) {
281 var m [messageSize]byte
282 drbg.Read(m[:])
283
284
285 fips140.RecordApproved()
286 return kemEncaps1024(cc, ek, &m)
287 }
288
289
290
291 func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
292 cc := &[CiphertextSize1024]byte{}
293 return kemEncaps1024(cc, ek, m)
294 }
295
296
297
298
299 func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (K, c []byte) {
300 g := sha3.New512()
301 g.Write(m[:])
302 g.Write(ek.h[:])
303 G := g.Sum(nil)
304 K, r := G[:SharedKeySize], G[SharedKeySize:]
305 c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r)
306 return K, c
307 }
308
309
310
311 func NewEncapsulationKey1024(encapsulationKey []byte) (*EncapsulationKey1024, error) {
312
313 ek := &EncapsulationKey1024{}
314 return parseEK1024(ek, encapsulationKey)
315 }
316
317
318
319
320
321 func parseEK1024(ek *EncapsulationKey1024, ekPKE []byte) (*EncapsulationKey1024, error) {
322 if len(ekPKE) != EncapsulationKeySize1024 {
323 return nil, errors.New("mlkem: invalid encapsulation key length")
324 }
325
326 h := sha3.New256()
327 h.Write(ekPKE)
328 h.Sum(ek.h[:0])
329
330 for i := range ek.t {
331 var err error
332 ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
333 if err != nil {
334 return nil, err
335 }
336 ekPKE = ekPKE[encodingSize12:]
337 }
338 copy(ek.ρ[:], ekPKE)
339
340 for i := byte(0); i < k1024; i++ {
341 for j := byte(0); j < k1024; j++ {
342 ek.a[i*k1024+j] = sampleNTT(ek.ρ[:], j, i)
343 }
344 }
345
346 return ek, nil
347 }
348
349
350
351
352
353 func pkeEncrypt1024(cc *[CiphertextSize1024]byte, ex *encryptionKey1024, m *[messageSize]byte, rnd []byte) []byte {
354 var N byte
355 r, e1 := make([]nttElement, k1024), make([]ringElement, k1024)
356 for i := range r {
357 r[i] = ntt(samplePolyCBD(rnd, N))
358 N++
359 }
360 for i := range e1 {
361 e1[i] = samplePolyCBD(rnd, N)
362 N++
363 }
364 e2 := samplePolyCBD(rnd, N)
365
366 u := make([]ringElement, k1024)
367 for i := range u {
368 u[i] = e1[i]
369 for j := range r {
370
371 u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k1024+i], r[j])))
372 }
373 }
374
375 μ := ringDecodeAndDecompress1(m)
376
377 var vNTT nttElement
378 for i := range ex.t {
379 vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
380 }
381 v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
382
383 c := cc[:0]
384 for _, f := range u {
385 c = ringCompressAndEncode11(c, f)
386 }
387 c = ringCompressAndEncode5(c, v)
388
389 return c
390 }
391
392
393
394
395
396 func (dk *DecapsulationKey1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
397 if len(ciphertext) != CiphertextSize1024 {
398 return nil, errors.New("mlkem: invalid ciphertext length")
399 }
400 c := (*[CiphertextSize1024]byte)(ciphertext)
401
402
403
404 return kemDecaps1024(dk, c), nil
405 }
406
407
408
409
410 func kemDecaps1024(dk *DecapsulationKey1024, c *[CiphertextSize1024]byte) (K []byte) {
411 fips140.RecordApproved()
412 m := pkeDecrypt1024(&dk.decryptionKey1024, c)
413 g := sha3.New512()
414 g.Write(m[:])
415 g.Write(dk.h[:])
416 G := g.Sum(make([]byte, 0, 64))
417 Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
418 J := sha3.NewShake256()
419 J.Write(dk.z[:])
420 J.Write(c[:])
421 Kout := make([]byte, SharedKeySize)
422 J.Read(Kout)
423 var cc [CiphertextSize1024]byte
424 c1 := pkeEncrypt1024(&cc, &dk.encryptionKey1024, (*[32]byte)(m), r)
425
426 subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
427 return Kout
428 }
429
430
431
432
433
434 func pkeDecrypt1024(dx *decryptionKey1024, c *[CiphertextSize1024]byte) []byte {
435 u := make([]ringElement, k1024)
436 for i := range u {
437 b := (*[encodingSize11]byte)(c[encodingSize11*i : encodingSize11*(i+1)])
438 u[i] = ringDecodeAndDecompress11(b)
439 }
440
441 b := (*[encodingSize5]byte)(c[encodingSize11*k1024:])
442 v := ringDecodeAndDecompress5(b)
443
444 var mask nttElement
445 for i := range dx.s {
446 mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
447 }
448 w := polySub(v, inverseNTT(mask))
449
450 return ringCompressAndEncode1(nil, w)
451 }
452
View as plain text