Source file src/crypto/internal/fips140/mlkem/mlkem1024.go

     1  // Code generated by generate1024.go. DO NOT EDIT.
     2  
     3  package mlkem
     4  
     5  import (
     6  	"bytes"
     7  	"crypto/internal/fips140"
     8  	"crypto/internal/fips140/drbg"
     9  	"crypto/internal/fips140/sha3"
    10  	"crypto/internal/fips140/subtle"
    11  	"errors"
    12  )
    13  
    14  // A DecapsulationKey1024 is the secret key used to decapsulate a shared key from a
    15  // ciphertext. It includes various precomputed values.
    16  type DecapsulationKey1024 struct {
    17  	d [32]byte // decapsulation key seed
    18  	z [32]byte // implicit rejection sampling seed
    19  
    20  	ρ [32]byte // sampleNTT seed for A, stored for the encapsulation key
    21  	h [32]byte // H(ek), stored for ML-KEM.Decaps_internal
    22  
    23  	encryptionKey1024
    24  	decryptionKey1024
    25  }
    26  
    27  // Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
    28  //
    29  // The decapsulation key must be kept secret.
    30  func (dk *DecapsulationKey1024) Bytes() []byte {
    31  	var b [SeedSize]byte
    32  	copy(b[:], dk.d[:])
    33  	copy(b[32:], dk.z[:])
    34  	return b[:]
    35  }
    36  
    37  // TestingOnlyExpandedBytes1024 returns the decapsulation key as a byte slice
    38  // using the full expanded NIST encoding.
    39  //
    40  // This should only be used for ACVP testing. For all other purposes prefer
    41  // the Bytes method that returns the (much smaller) seed.
    42  func TestingOnlyExpandedBytes1024(dk *DecapsulationKey1024) []byte {
    43  	b := make([]byte, 0, decapsulationKeySize1024)
    44  
    45  	// ByteEncode₁₂(s)
    46  	for i := range dk.s {
    47  		b = polyByteEncode(b, dk.s[i])
    48  	}
    49  
    50  	// ByteEncode₁₂(t) || ρ
    51  	for i := range dk.t {
    52  		b = polyByteEncode(b, dk.t[i])
    53  	}
    54  	b = append(b, dk.ρ[:]...)
    55  
    56  	// H(ek) || z
    57  	b = append(b, dk.h[:]...)
    58  	b = append(b, dk.z[:]...)
    59  
    60  	return b
    61  }
    62  
    63  // EncapsulationKey returns the public encapsulation key necessary to produce
    64  // ciphertexts.
    65  func (dk *DecapsulationKey1024) EncapsulationKey() *EncapsulationKey1024 {
    66  	return &EncapsulationKey1024{
    67  		ρ:                 dk.ρ,
    68  		h:                 dk.h,
    69  		encryptionKey1024: dk.encryptionKey1024,
    70  	}
    71  }
    72  
    73  // An EncapsulationKey1024 is the public key used to produce ciphertexts to be
    74  // decapsulated by the corresponding [DecapsulationKey1024].
    75  type EncapsulationKey1024 struct {
    76  	ρ [32]byte // sampleNTT seed for A
    77  	h [32]byte // H(ek)
    78  	encryptionKey1024
    79  }
    80  
    81  // Bytes returns the encapsulation key as a byte slice.
    82  func (ek *EncapsulationKey1024) Bytes() []byte {
    83  	// The actual logic is in a separate function to outline this allocation.
    84  	b := make([]byte, 0, EncapsulationKeySize1024)
    85  	return ek.bytes(b)
    86  }
    87  
    88  func (ek *EncapsulationKey1024) bytes(b []byte) []byte {
    89  	for i := range ek.t {
    90  		b = polyByteEncode(b, ek.t[i])
    91  	}
    92  	b = append(b, ek.ρ[:]...)
    93  	return b
    94  }
    95  
    96  // encryptionKey1024 is the parsed and expanded form of a PKE encryption key.
    97  type encryptionKey1024 struct {
    98  	t [k1024]nttElement         // ByteDecode₁₂(ek[:384k])
    99  	a [k1024 * k1024]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
   100  }
   101  
   102  // decryptionKey1024 is the parsed and expanded form of a PKE decryption key.
   103  type decryptionKey1024 struct {
   104  	s [k1024]nttElement // ByteDecode₁₂(dk[:decryptionKey1024Size])
   105  }
   106  
   107  // GenerateKey1024 generates a new decapsulation key, drawing random bytes from
   108  // a DRBG. The decapsulation key must be kept secret.
   109  func GenerateKey1024() (*DecapsulationKey1024, error) {
   110  	// The actual logic is in a separate function to outline this allocation.
   111  	dk := &DecapsulationKey1024{}
   112  	return generateKey1024(dk)
   113  }
   114  
   115  func generateKey1024(dk *DecapsulationKey1024) (*DecapsulationKey1024, error) {
   116  	var d [32]byte
   117  	drbg.Read(d[:])
   118  	var z [32]byte
   119  	drbg.Read(z[:])
   120  	kemKeyGen1024(dk, &d, &z)
   121  	fips140.PCT("ML-KEM PCT", func() error { return kemPCT1024(dk) })
   122  	fips140.RecordApproved()
   123  	return dk, nil
   124  }
   125  
   126  // GenerateKeyInternal1024 is a derandomized version of GenerateKey1024,
   127  // exclusively for use in tests.
   128  func GenerateKeyInternal1024(d, z *[32]byte) *DecapsulationKey1024 {
   129  	dk := &DecapsulationKey1024{}
   130  	kemKeyGen1024(dk, d, z)
   131  	return dk
   132  }
   133  
   134  // NewDecapsulationKey1024 parses a decapsulation key from a 64-byte
   135  // seed in the "d || z" form. The seed must be uniformly random.
   136  func NewDecapsulationKey1024(seed []byte) (*DecapsulationKey1024, error) {
   137  	// The actual logic is in a separate function to outline this allocation.
   138  	dk := &DecapsulationKey1024{}
   139  	return newKeyFromSeed1024(dk, seed)
   140  }
   141  
   142  func newKeyFromSeed1024(dk *DecapsulationKey1024, seed []byte) (*DecapsulationKey1024, error) {
   143  	if len(seed) != SeedSize {
   144  		return nil, errors.New("mlkem: invalid seed length")
   145  	}
   146  	d := (*[32]byte)(seed[:32])
   147  	z := (*[32]byte)(seed[32:])
   148  	kemKeyGen1024(dk, d, z)
   149  	fips140.RecordApproved()
   150  	return dk, nil
   151  }
   152  
   153  // TestingOnlyNewDecapsulationKey1024 parses a decapsulation key from its expanded NIST format.
   154  //
   155  // Bytes() must not be called on the returned key, as it will not produce the
   156  // original seed.
   157  //
   158  // This function should only be used for ACVP testing. Prefer NewDecapsulationKey1024 for all
   159  // other purposes.
   160  func TestingOnlyNewDecapsulationKey1024(b []byte) (*DecapsulationKey1024, error) {
   161  	if len(b) != decapsulationKeySize1024 {
   162  		return nil, errors.New("mlkem: invalid NIST decapsulation key length")
   163  	}
   164  
   165  	dk := &DecapsulationKey1024{}
   166  	for i := range dk.s {
   167  		var err error
   168  		dk.s[i], err = polyByteDecode[nttElement](b[:encodingSize12])
   169  		if err != nil {
   170  			return nil, errors.New("mlkem: invalid secret key encoding")
   171  		}
   172  		b = b[encodingSize12:]
   173  	}
   174  
   175  	ek, err := NewEncapsulationKey1024(b[:EncapsulationKeySize1024])
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  	dk.ρ = ek.ρ
   180  	dk.h = ek.h
   181  	dk.encryptionKey1024 = ek.encryptionKey1024
   182  	b = b[EncapsulationKeySize1024:]
   183  
   184  	if !bytes.Equal(dk.h[:], b[:32]) {
   185  		return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
   186  	}
   187  	b = b[32:]
   188  
   189  	copy(dk.z[:], b)
   190  
   191  	// Generate a random d value for use in Bytes(). This is a safety mechanism
   192  	// that avoids returning a broken key vs a random key if this function is
   193  	// called in contravention of the TestingOnlyNewDecapsulationKey1024 function
   194  	// comment advising against it.
   195  	drbg.Read(dk.d[:])
   196  
   197  	return dk, nil
   198  }
   199  
   200  // kemKeyGen1024 generates a decapsulation key.
   201  //
   202  // It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and
   203  // K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save
   204  // copies and allocations.
   205  func kemKeyGen1024(dk *DecapsulationKey1024, d, z *[32]byte) {
   206  	dk.d = *d
   207  	dk.z = *z
   208  
   209  	g := sha3.New512()
   210  	g.Write(d[:])
   211  	g.Write([]byte{k1024}) // Module dimension as a domain separator.
   212  	G := g.Sum(make([]byte, 0, 64))
   213  	ρ, σ := G[:32], G[32:]
   214  	dk.ρ = [32]byte(ρ)
   215  
   216  	A := &dk.a
   217  	for i := byte(0); i < k1024; i++ {
   218  		for j := byte(0); j < k1024; j++ {
   219  			A[i*k1024+j] = sampleNTT(ρ, j, i)
   220  		}
   221  	}
   222  
   223  	var N byte
   224  	s := &dk.s
   225  	for i := range s {
   226  		s[i] = ntt(samplePolyCBD(σ, N))
   227  		N++
   228  	}
   229  	e := make([]nttElement, k1024)
   230  	for i := range e {
   231  		e[i] = ntt(samplePolyCBD(σ, N))
   232  		N++
   233  	}
   234  
   235  	t := &dk.t
   236  	for i := range t { // t = A ◦ s + e
   237  		t[i] = e[i]
   238  		for j := range s {
   239  			t[i] = polyAdd(t[i], nttMul(A[i*k1024+j], s[j]))
   240  		}
   241  	}
   242  
   243  	H := sha3.New256()
   244  	ek := dk.EncapsulationKey().Bytes()
   245  	H.Write(ek)
   246  	H.Sum(dk.h[:0])
   247  }
   248  
   249  // kemPCT1024 performs a Pairwise Consistency Test per FIPS 140-3 IG 10.3.A
   250  // Additional Comment 1: "For key pairs generated for use with approved KEMs in
   251  // FIPS 203, the PCT shall consist of applying the encapsulation key ek to
   252  // encapsulate a shared secret K leading to ciphertext c, and then applying
   253  // decapsulation key dk to retrieve the same shared secret K. The PCT passes if
   254  // the two shared secret K values are equal. The PCT shall be performed either
   255  // when keys are generated/imported, prior to the first exportation, or prior to
   256  // the first operational use (if not exported before the first use)."
   257  func kemPCT1024(dk *DecapsulationKey1024) error {
   258  	ek := dk.EncapsulationKey()
   259  	K, c := ek.Encapsulate()
   260  	K1, err := dk.Decapsulate(c)
   261  	if err != nil {
   262  		return err
   263  	}
   264  	if subtle.ConstantTimeCompare(K, K1) != 1 {
   265  		return errors.New("mlkem: PCT failed")
   266  	}
   267  	return nil
   268  }
   269  
   270  // Encapsulate generates a shared key and an associated ciphertext from an
   271  // encapsulation key, drawing random bytes from a DRBG.
   272  //
   273  // The shared key must be kept secret.
   274  func (ek *EncapsulationKey1024) Encapsulate() (sharedKey, ciphertext []byte) {
   275  	// The actual logic is in a separate function to outline this allocation.
   276  	var cc [CiphertextSize1024]byte
   277  	return ek.encapsulate(&cc)
   278  }
   279  
   280  func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (sharedKey, ciphertext []byte) {
   281  	var m [messageSize]byte
   282  	drbg.Read(m[:])
   283  	// Note that the modulus check (step 2 of the encapsulation key check from
   284  	// FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK1024.
   285  	fips140.RecordApproved()
   286  	return kemEncaps1024(cc, ek, &m)
   287  }
   288  
   289  // EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
   290  // use in tests.
   291  func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
   292  	cc := &[CiphertextSize1024]byte{}
   293  	return kemEncaps1024(cc, ek, m)
   294  }
   295  
   296  // kemEncaps1024 generates a shared key and an associated ciphertext.
   297  //
   298  // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
   299  func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (K, c []byte) {
   300  	g := sha3.New512()
   301  	g.Write(m[:])
   302  	g.Write(ek.h[:])
   303  	G := g.Sum(nil)
   304  	K, r := G[:SharedKeySize], G[SharedKeySize:]
   305  	c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r)
   306  	return K, c
   307  }
   308  
   309  // NewEncapsulationKey1024 parses an encapsulation key from its encoded form.
   310  // If the encapsulation key is not valid, NewEncapsulationKey1024 returns an error.
   311  func NewEncapsulationKey1024(encapsulationKey []byte) (*EncapsulationKey1024, error) {
   312  	// The actual logic is in a separate function to outline this allocation.
   313  	ek := &EncapsulationKey1024{}
   314  	return parseEK1024(ek, encapsulationKey)
   315  }
   316  
   317  // parseEK1024 parses an encryption key from its encoded form.
   318  //
   319  // It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
   320  // Algorithm 14.
   321  func parseEK1024(ek *EncapsulationKey1024, ekPKE []byte) (*EncapsulationKey1024, error) {
   322  	if len(ekPKE) != EncapsulationKeySize1024 {
   323  		return nil, errors.New("mlkem: invalid encapsulation key length")
   324  	}
   325  
   326  	h := sha3.New256()
   327  	h.Write(ekPKE)
   328  	h.Sum(ek.h[:0])
   329  
   330  	for i := range ek.t {
   331  		var err error
   332  		ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
   333  		if err != nil {
   334  			return nil, err
   335  		}
   336  		ekPKE = ekPKE[encodingSize12:]
   337  	}
   338  	copy(ek.ρ[:], ekPKE)
   339  
   340  	for i := byte(0); i < k1024; i++ {
   341  		for j := byte(0); j < k1024; j++ {
   342  			ek.a[i*k1024+j] = sampleNTT(ek.ρ[:], j, i)
   343  		}
   344  	}
   345  
   346  	return ek, nil
   347  }
   348  
   349  // pkeEncrypt1024 encrypt a plaintext message.
   350  //
   351  // It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the
   352  // computation of t and AT is done in parseEK1024.
   353  func pkeEncrypt1024(cc *[CiphertextSize1024]byte, ex *encryptionKey1024, m *[messageSize]byte, rnd []byte) []byte {
   354  	var N byte
   355  	r, e1 := make([]nttElement, k1024), make([]ringElement, k1024)
   356  	for i := range r {
   357  		r[i] = ntt(samplePolyCBD(rnd, N))
   358  		N++
   359  	}
   360  	for i := range e1 {
   361  		e1[i] = samplePolyCBD(rnd, N)
   362  		N++
   363  	}
   364  	e2 := samplePolyCBD(rnd, N)
   365  
   366  	u := make([]ringElement, k1024) // NTT⁻¹(AT ◦ r) + e1
   367  	for i := range u {
   368  		u[i] = e1[i]
   369  		for j := range r {
   370  			// Note that i and j are inverted, as we need the transposed of A.
   371  			u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k1024+i], r[j])))
   372  		}
   373  	}
   374  
   375  	μ := ringDecodeAndDecompress1(m)
   376  
   377  	var vNTT nttElement // t⊺ ◦ r
   378  	for i := range ex.t {
   379  		vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
   380  	}
   381  	v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
   382  
   383  	c := cc[:0]
   384  	for _, f := range u {
   385  		c = ringCompressAndEncode11(c, f)
   386  	}
   387  	c = ringCompressAndEncode5(c, v)
   388  
   389  	return c
   390  }
   391  
   392  // Decapsulate generates a shared key from a ciphertext and a decapsulation key.
   393  // If the ciphertext is not valid, Decapsulate returns an error.
   394  //
   395  // The shared key must be kept secret.
   396  func (dk *DecapsulationKey1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
   397  	if len(ciphertext) != CiphertextSize1024 {
   398  		return nil, errors.New("mlkem: invalid ciphertext length")
   399  	}
   400  	c := (*[CiphertextSize1024]byte)(ciphertext)
   401  	// Note that the hash check (step 3 of the decapsulation input check from
   402  	// FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always
   403  	// validly generated by ML-KEM.KeyGen_internal.
   404  	return kemDecaps1024(dk, c), nil
   405  }
   406  
   407  // kemDecaps1024 produces a shared key from a ciphertext.
   408  //
   409  // It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18.
   410  func kemDecaps1024(dk *DecapsulationKey1024, c *[CiphertextSize1024]byte) (K []byte) {
   411  	fips140.RecordApproved()
   412  	m := pkeDecrypt1024(&dk.decryptionKey1024, c)
   413  	g := sha3.New512()
   414  	g.Write(m[:])
   415  	g.Write(dk.h[:])
   416  	G := g.Sum(make([]byte, 0, 64))
   417  	Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
   418  	J := sha3.NewShake256()
   419  	J.Write(dk.z[:])
   420  	J.Write(c[:])
   421  	Kout := make([]byte, SharedKeySize)
   422  	J.Read(Kout)
   423  	var cc [CiphertextSize1024]byte
   424  	c1 := pkeEncrypt1024(&cc, &dk.encryptionKey1024, (*[32]byte)(m), r)
   425  
   426  	subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
   427  	return Kout
   428  }
   429  
   430  // pkeDecrypt1024 decrypts a ciphertext.
   431  //
   432  // It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15,
   433  // although s is retained from kemKeyGen1024.
   434  func pkeDecrypt1024(dx *decryptionKey1024, c *[CiphertextSize1024]byte) []byte {
   435  	u := make([]ringElement, k1024)
   436  	for i := range u {
   437  		b := (*[encodingSize11]byte)(c[encodingSize11*i : encodingSize11*(i+1)])
   438  		u[i] = ringDecodeAndDecompress11(b)
   439  	}
   440  
   441  	b := (*[encodingSize5]byte)(c[encodingSize11*k1024:])
   442  	v := ringDecodeAndDecompress5(b)
   443  
   444  	var mask nttElement // s⊺ ◦ NTT(u)
   445  	for i := range dx.s {
   446  		mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
   447  	}
   448  	w := polySub(v, inverseNTT(mask))
   449  
   450  	return ringCompressAndEncode1(nil, w)
   451  }
   452  

View as plain text