Source file src/crypto/internal/fips140/mlkem/mlkem768.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package mlkem implements the quantum-resistant key encapsulation method
     6  // ML-KEM (formerly known as Kyber), as specified in [NIST FIPS 203].
     7  //
     8  // [NIST FIPS 203]: https://doi.org/10.6028/NIST.FIPS.203
     9  package mlkem
    10  
    11  // This package targets security, correctness, simplicity, readability, and
    12  // reviewability as its primary goals. All critical operations are performed in
    13  // constant time.
    14  //
    15  // Variable and function names, as well as code layout, are selected to
    16  // facilitate reviewing the implementation against the NIST FIPS 203 document.
    17  //
    18  // Reviewers unfamiliar with polynomials or linear algebra might find the
    19  // background at https://words.filippo.io/kyber-math/ useful.
    20  //
    21  // This file implements the recommended parameter set ML-KEM-768. The ML-KEM-1024
    22  // parameter set implementation is auto-generated from this file.
    23  //
    24  //go:generate go run generate1024.go -input mlkem768.go -output mlkem1024.go
    25  
    26  import (
    27  	"bytes"
    28  	"crypto/internal/fips140"
    29  	"crypto/internal/fips140/drbg"
    30  	"crypto/internal/fips140/sha3"
    31  	"crypto/internal/fips140/subtle"
    32  	"errors"
    33  )
    34  
    35  const (
    36  	// ML-KEM global constants.
    37  	n = 256
    38  	q = 3329
    39  
    40  	// encodingSizeX is the byte size of a ringElement or nttElement encoded
    41  	// by ByteEncode_X (FIPS 203, Algorithm 5).
    42  	encodingSize12 = n * 12 / 8
    43  	encodingSize11 = n * 11 / 8
    44  	encodingSize10 = n * 10 / 8
    45  	encodingSize5  = n * 5 / 8
    46  	encodingSize4  = n * 4 / 8
    47  	encodingSize1  = n * 1 / 8
    48  
    49  	messageSize = encodingSize1
    50  
    51  	SharedKeySize = 32
    52  	SeedSize      = 32 + 32
    53  )
    54  
    55  // ML-KEM-768 parameters.
    56  const (
    57  	k = 3
    58  
    59  	CiphertextSize768       = k*encodingSize10 + encodingSize4
    60  	EncapsulationKeySize768 = k*encodingSize12 + 32
    61  	decapsulationKeySize768 = k*encodingSize12 + EncapsulationKeySize768 + 32 + 32
    62  )
    63  
    64  // ML-KEM-1024 parameters.
    65  const (
    66  	k1024 = 4
    67  
    68  	CiphertextSize1024       = k1024*encodingSize11 + encodingSize5
    69  	EncapsulationKeySize1024 = k1024*encodingSize12 + 32
    70  	decapsulationKeySize1024 = k1024*encodingSize12 + EncapsulationKeySize1024 + 32 + 32
    71  )
    72  
    73  // A DecapsulationKey768 is the secret key used to decapsulate a shared key from a
    74  // ciphertext. It includes various precomputed values.
    75  type DecapsulationKey768 struct {
    76  	d [32]byte // decapsulation key seed
    77  	z [32]byte // implicit rejection sampling seed
    78  
    79  	ρ [32]byte // sampleNTT seed for A, stored for the encapsulation key
    80  	h [32]byte // H(ek), stored for ML-KEM.Decaps_internal
    81  
    82  	encryptionKey
    83  	decryptionKey
    84  }
    85  
    86  // Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
    87  //
    88  // The decapsulation key must be kept secret.
    89  func (dk *DecapsulationKey768) Bytes() []byte {
    90  	var b [SeedSize]byte
    91  	copy(b[:], dk.d[:])
    92  	copy(b[32:], dk.z[:])
    93  	return b[:]
    94  }
    95  
    96  // TestingOnlyExpandedBytes768 returns the decapsulation key as a byte slice
    97  // using the full expanded NIST encoding.
    98  //
    99  // This should only be used for ACVP testing. For all other purposes prefer
   100  // the Bytes method that returns the (much smaller) seed.
   101  func TestingOnlyExpandedBytes768(dk *DecapsulationKey768) []byte {
   102  	b := make([]byte, 0, decapsulationKeySize768)
   103  
   104  	// ByteEncode₁₂(s)
   105  	for i := range dk.s {
   106  		b = polyByteEncode(b, dk.s[i])
   107  	}
   108  
   109  	// ByteEncode₁₂(t) || ρ
   110  	for i := range dk.t {
   111  		b = polyByteEncode(b, dk.t[i])
   112  	}
   113  	b = append(b, dk.ρ[:]...)
   114  
   115  	// H(ek) || z
   116  	b = append(b, dk.h[:]...)
   117  	b = append(b, dk.z[:]...)
   118  
   119  	return b
   120  }
   121  
   122  // EncapsulationKey returns the public encapsulation key necessary to produce
   123  // ciphertexts.
   124  func (dk *DecapsulationKey768) EncapsulationKey() *EncapsulationKey768 {
   125  	return &EncapsulationKey768{
   126  		ρ:             dk.ρ,
   127  		h:             dk.h,
   128  		encryptionKey: dk.encryptionKey,
   129  	}
   130  }
   131  
   132  // An EncapsulationKey768 is the public key used to produce ciphertexts to be
   133  // decapsulated by the corresponding [DecapsulationKey768].
   134  type EncapsulationKey768 struct {
   135  	ρ [32]byte // sampleNTT seed for A
   136  	h [32]byte // H(ek)
   137  	encryptionKey
   138  }
   139  
   140  // Bytes returns the encapsulation key as a byte slice.
   141  func (ek *EncapsulationKey768) Bytes() []byte {
   142  	// The actual logic is in a separate function to outline this allocation.
   143  	b := make([]byte, 0, EncapsulationKeySize768)
   144  	return ek.bytes(b)
   145  }
   146  
   147  func (ek *EncapsulationKey768) bytes(b []byte) []byte {
   148  	for i := range ek.t {
   149  		b = polyByteEncode(b, ek.t[i])
   150  	}
   151  	b = append(b, ek.ρ[:]...)
   152  	return b
   153  }
   154  
   155  // encryptionKey is the parsed and expanded form of a PKE encryption key.
   156  type encryptionKey struct {
   157  	t [k]nttElement     // ByteDecode₁₂(ek[:384k])
   158  	a [k * k]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
   159  }
   160  
   161  // decryptionKey is the parsed and expanded form of a PKE decryption key.
   162  type decryptionKey struct {
   163  	s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize])
   164  }
   165  
   166  // GenerateKey768 generates a new decapsulation key, drawing random bytes from
   167  // a DRBG. The decapsulation key must be kept secret.
   168  func GenerateKey768() (*DecapsulationKey768, error) {
   169  	// The actual logic is in a separate function to outline this allocation.
   170  	dk := &DecapsulationKey768{}
   171  	return generateKey(dk)
   172  }
   173  
   174  func generateKey(dk *DecapsulationKey768) (*DecapsulationKey768, error) {
   175  	var d [32]byte
   176  	drbg.Read(d[:])
   177  	var z [32]byte
   178  	drbg.Read(z[:])
   179  	kemKeyGen(dk, &d, &z)
   180  	if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) }); err != nil {
   181  		// This clearly can't happen, but FIPS 140-3 requires us to check.
   182  		panic(err)
   183  	}
   184  	fips140.RecordApproved()
   185  	return dk, nil
   186  }
   187  
   188  // GenerateKeyInternal768 is a derandomized version of GenerateKey768,
   189  // exclusively for use in tests.
   190  func GenerateKeyInternal768(d, z *[32]byte) *DecapsulationKey768 {
   191  	dk := &DecapsulationKey768{}
   192  	kemKeyGen(dk, d, z)
   193  	return dk
   194  }
   195  
   196  // NewDecapsulationKey768 parses a decapsulation key from a 64-byte
   197  // seed in the "d || z" form. The seed must be uniformly random.
   198  func NewDecapsulationKey768(seed []byte) (*DecapsulationKey768, error) {
   199  	// The actual logic is in a separate function to outline this allocation.
   200  	dk := &DecapsulationKey768{}
   201  	return newKeyFromSeed(dk, seed)
   202  }
   203  
   204  func newKeyFromSeed(dk *DecapsulationKey768, seed []byte) (*DecapsulationKey768, error) {
   205  	if len(seed) != SeedSize {
   206  		return nil, errors.New("mlkem: invalid seed length")
   207  	}
   208  	d := (*[32]byte)(seed[:32])
   209  	z := (*[32]byte)(seed[32:])
   210  	kemKeyGen(dk, d, z)
   211  	if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) }); err != nil {
   212  		// This clearly can't happen, but FIPS 140-3 requires us to check.
   213  		panic(err)
   214  	}
   215  	fips140.RecordApproved()
   216  	return dk, nil
   217  }
   218  
   219  // TestingOnlyNewDecapsulationKey768 parses a decapsulation key from its expanded NIST format.
   220  //
   221  // Bytes() must not be called on the returned key, as it will not produce the
   222  // original seed.
   223  //
   224  // This function should only be used for ACVP testing. Prefer NewDecapsulationKey768 for all
   225  // other purposes.
   226  func TestingOnlyNewDecapsulationKey768(b []byte) (*DecapsulationKey768, error) {
   227  	if len(b) != decapsulationKeySize768 {
   228  		return nil, errors.New("mlkem: invalid NIST decapsulation key length")
   229  	}
   230  
   231  	dk := &DecapsulationKey768{}
   232  	for i := range dk.s {
   233  		var err error
   234  		dk.s[i], err = polyByteDecode[nttElement](b[:encodingSize12])
   235  		if err != nil {
   236  			return nil, errors.New("mlkem: invalid secret key encoding")
   237  		}
   238  		b = b[encodingSize12:]
   239  	}
   240  
   241  	ek, err := NewEncapsulationKey768(b[:EncapsulationKeySize768])
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  	dk.ρ = ek.ρ
   246  	dk.h = ek.h
   247  	dk.encryptionKey = ek.encryptionKey
   248  	b = b[EncapsulationKeySize768:]
   249  
   250  	if !bytes.Equal(dk.h[:], b[:32]) {
   251  		return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
   252  	}
   253  	b = b[32:]
   254  
   255  	copy(dk.z[:], b)
   256  
   257  	// Generate a random d value for use in Bytes(). This is a safety mechanism
   258  	// that avoids returning a broken key vs a random key if this function is
   259  	// called in contravention of the TestingOnlyNewDecapsulationKey768 function
   260  	// comment advising against it.
   261  	drbg.Read(dk.d[:])
   262  
   263  	return dk, nil
   264  }
   265  
   266  // kemKeyGen generates a decapsulation key.
   267  //
   268  // It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and
   269  // K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save
   270  // copies and allocations.
   271  func kemKeyGen(dk *DecapsulationKey768, d, z *[32]byte) {
   272  	dk.d = *d
   273  	dk.z = *z
   274  
   275  	g := sha3.New512()
   276  	g.Write(d[:])
   277  	g.Write([]byte{k}) // Module dimension as a domain separator.
   278  	G := g.Sum(make([]byte, 0, 64))
   279  	ρ, σ := G[:32], G[32:]
   280  	dk.ρ = [32]byte(ρ)
   281  
   282  	A := &dk.a
   283  	for i := byte(0); i < k; i++ {
   284  		for j := byte(0); j < k; j++ {
   285  			A[i*k+j] = sampleNTT(ρ, j, i)
   286  		}
   287  	}
   288  
   289  	var N byte
   290  	s := &dk.s
   291  	for i := range s {
   292  		s[i] = ntt(samplePolyCBD(σ, N))
   293  		N++
   294  	}
   295  	e := make([]nttElement, k)
   296  	for i := range e {
   297  		e[i] = ntt(samplePolyCBD(σ, N))
   298  		N++
   299  	}
   300  
   301  	t := &dk.t
   302  	for i := range t { // t = A ◦ s + e
   303  		t[i] = e[i]
   304  		for j := range s {
   305  			t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j]))
   306  		}
   307  	}
   308  
   309  	H := sha3.New256()
   310  	ek := dk.EncapsulationKey().Bytes()
   311  	H.Write(ek)
   312  	H.Sum(dk.h[:0])
   313  }
   314  
   315  // kemPCT performs a Pairwise Consistency Test per FIPS 140-3 IG 10.3.A
   316  // Additional Comment 1: "For key pairs generated for use with approved KEMs in
   317  // FIPS 203, the PCT shall consist of applying the encapsulation key ek to
   318  // encapsulate a shared secret K leading to ciphertext c, and then applying
   319  // decapsulation key dk to retrieve the same shared secret K. The PCT passes if
   320  // the two shared secret K values are equal. The PCT shall be performed either
   321  // when keys are generated/imported, prior to the first exportation, or prior to
   322  // the first operational use (if not exported before the first use)."
   323  func kemPCT(dk *DecapsulationKey768) error {
   324  	ek := dk.EncapsulationKey()
   325  	K, c := ek.Encapsulate()
   326  	K1, err := dk.Decapsulate(c)
   327  	if err != nil {
   328  		return err
   329  	}
   330  	if subtle.ConstantTimeCompare(K, K1) != 1 {
   331  		return errors.New("mlkem: PCT failed")
   332  	}
   333  	return nil
   334  }
   335  
   336  // Encapsulate generates a shared key and an associated ciphertext from an
   337  // encapsulation key, drawing random bytes from a DRBG.
   338  //
   339  // The shared key must be kept secret.
   340  func (ek *EncapsulationKey768) Encapsulate() (sharedKey, ciphertext []byte) {
   341  	// The actual logic is in a separate function to outline this allocation.
   342  	var cc [CiphertextSize768]byte
   343  	return ek.encapsulate(&cc)
   344  }
   345  
   346  func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (sharedKey, ciphertext []byte) {
   347  	var m [messageSize]byte
   348  	drbg.Read(m[:])
   349  	// Note that the modulus check (step 2 of the encapsulation key check from
   350  	// FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK.
   351  	fips140.RecordApproved()
   352  	return kemEncaps(cc, ek, &m)
   353  }
   354  
   355  // EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
   356  // use in tests.
   357  func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
   358  	cc := &[CiphertextSize768]byte{}
   359  	return kemEncaps(cc, ek, m)
   360  }
   361  
   362  // kemEncaps generates a shared key and an associated ciphertext.
   363  //
   364  // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
   365  func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (K, c []byte) {
   366  	g := sha3.New512()
   367  	g.Write(m[:])
   368  	g.Write(ek.h[:])
   369  	G := g.Sum(nil)
   370  	K, r := G[:SharedKeySize], G[SharedKeySize:]
   371  	c = pkeEncrypt(cc, &ek.encryptionKey, m, r)
   372  	return K, c
   373  }
   374  
   375  // NewEncapsulationKey768 parses an encapsulation key from its encoded form.
   376  // If the encapsulation key is not valid, NewEncapsulationKey768 returns an error.
   377  func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, error) {
   378  	// The actual logic is in a separate function to outline this allocation.
   379  	ek := &EncapsulationKey768{}
   380  	return parseEK(ek, encapsulationKey)
   381  }
   382  
   383  // parseEK parses an encryption key from its encoded form.
   384  //
   385  // It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
   386  // Algorithm 14.
   387  func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) {
   388  	if len(ekPKE) != EncapsulationKeySize768 {
   389  		return nil, errors.New("mlkem: invalid encapsulation key length")
   390  	}
   391  
   392  	h := sha3.New256()
   393  	h.Write(ekPKE)
   394  	h.Sum(ek.h[:0])
   395  
   396  	for i := range ek.t {
   397  		var err error
   398  		ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
   399  		if err != nil {
   400  			return nil, err
   401  		}
   402  		ekPKE = ekPKE[encodingSize12:]
   403  	}
   404  	copy(ek.ρ[:], ekPKE)
   405  
   406  	for i := byte(0); i < k; i++ {
   407  		for j := byte(0); j < k; j++ {
   408  			ek.a[i*k+j] = sampleNTT(ek.ρ[:], j, i)
   409  		}
   410  	}
   411  
   412  	return ek, nil
   413  }
   414  
   415  // pkeEncrypt encrypt a plaintext message.
   416  //
   417  // It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the
   418  // computation of t and AT is done in parseEK.
   419  func pkeEncrypt(cc *[CiphertextSize768]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte {
   420  	var N byte
   421  	r, e1 := make([]nttElement, k), make([]ringElement, k)
   422  	for i := range r {
   423  		r[i] = ntt(samplePolyCBD(rnd, N))
   424  		N++
   425  	}
   426  	for i := range e1 {
   427  		e1[i] = samplePolyCBD(rnd, N)
   428  		N++
   429  	}
   430  	e2 := samplePolyCBD(rnd, N)
   431  
   432  	u := make([]ringElement, k) // NTT⁻¹(AT ◦ r) + e1
   433  	for i := range u {
   434  		u[i] = e1[i]
   435  		for j := range r {
   436  			// Note that i and j are inverted, as we need the transposed of A.
   437  			u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k+i], r[j])))
   438  		}
   439  	}
   440  
   441  	μ := ringDecodeAndDecompress1(m)
   442  
   443  	var vNTT nttElement // t⊺ ◦ r
   444  	for i := range ex.t {
   445  		vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
   446  	}
   447  	v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
   448  
   449  	c := cc[:0]
   450  	for _, f := range u {
   451  		c = ringCompressAndEncode10(c, f)
   452  	}
   453  	c = ringCompressAndEncode4(c, v)
   454  
   455  	return c
   456  }
   457  
   458  // Decapsulate generates a shared key from a ciphertext and a decapsulation key.
   459  // If the ciphertext is not valid, Decapsulate returns an error.
   460  //
   461  // The shared key must be kept secret.
   462  func (dk *DecapsulationKey768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
   463  	if len(ciphertext) != CiphertextSize768 {
   464  		return nil, errors.New("mlkem: invalid ciphertext length")
   465  	}
   466  	c := (*[CiphertextSize768]byte)(ciphertext)
   467  	// Note that the hash check (step 3 of the decapsulation input check from
   468  	// FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always
   469  	// validly generated by ML-KEM.KeyGen_internal.
   470  	return kemDecaps(dk, c), nil
   471  }
   472  
   473  // kemDecaps produces a shared key from a ciphertext.
   474  //
   475  // It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18.
   476  func kemDecaps(dk *DecapsulationKey768, c *[CiphertextSize768]byte) (K []byte) {
   477  	fips140.RecordApproved()
   478  	m := pkeDecrypt(&dk.decryptionKey, c)
   479  	g := sha3.New512()
   480  	g.Write(m[:])
   481  	g.Write(dk.h[:])
   482  	G := g.Sum(make([]byte, 0, 64))
   483  	Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
   484  	J := sha3.NewShake256()
   485  	J.Write(dk.z[:])
   486  	J.Write(c[:])
   487  	Kout := make([]byte, SharedKeySize)
   488  	J.Read(Kout)
   489  	var cc [CiphertextSize768]byte
   490  	c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r)
   491  
   492  	subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
   493  	return Kout
   494  }
   495  
   496  // pkeDecrypt decrypts a ciphertext.
   497  //
   498  // It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15,
   499  // although s is retained from kemKeyGen.
   500  func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize768]byte) []byte {
   501  	u := make([]ringElement, k)
   502  	for i := range u {
   503  		b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)])
   504  		u[i] = ringDecodeAndDecompress10(b)
   505  	}
   506  
   507  	b := (*[encodingSize4]byte)(c[encodingSize10*k:])
   508  	v := ringDecodeAndDecompress4(b)
   509  
   510  	var mask nttElement // s⊺ ◦ NTT(u)
   511  	for i := range dx.s {
   512  		mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
   513  	}
   514  	w := polySub(v, inverseNTT(mask))
   515  
   516  	return ringCompressAndEncode1(nil, w)
   517  }
   518  

View as plain text