Source file src/crypto/internal/hpke/hpke.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package hpke
     6  
     7  import (
     8  	"crypto"
     9  	"crypto/aes"
    10  	"crypto/cipher"
    11  	"crypto/ecdh"
    12  	"crypto/rand"
    13  	"errors"
    14  	"internal/byteorder"
    15  	"math/bits"
    16  
    17  	"golang.org/x/crypto/chacha20poly1305"
    18  	"golang.org/x/crypto/hkdf"
    19  )
    20  
    21  // testingOnlyGenerateKey is only used during testing, to provide
    22  // a fixed test key to use when checking the RFC 9180 vectors.
    23  var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)
    24  
    25  type hkdfKDF struct {
    26  	hash crypto.Hash
    27  }
    28  
    29  func (kdf *hkdfKDF) LabeledExtract(suiteID []byte, salt []byte, label string, inputKey []byte) []byte {
    30  	labeledIKM := make([]byte, 0, 7+len(suiteID)+len(label)+len(inputKey))
    31  	labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
    32  	labeledIKM = append(labeledIKM, suiteID...)
    33  	labeledIKM = append(labeledIKM, label...)
    34  	labeledIKM = append(labeledIKM, inputKey...)
    35  	return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
    36  }
    37  
    38  func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) []byte {
    39  	labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
    40  	labeledInfo = byteorder.BeAppendUint16(labeledInfo, length)
    41  	labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
    42  	labeledInfo = append(labeledInfo, suiteID...)
    43  	labeledInfo = append(labeledInfo, label...)
    44  	labeledInfo = append(labeledInfo, info...)
    45  	out := make([]byte, length)
    46  	n, err := hkdf.Expand(kdf.hash.New, randomKey, labeledInfo).Read(out)
    47  	if err != nil || n != int(length) {
    48  		panic("hpke: LabeledExpand failed unexpectedly")
    49  	}
    50  	return out
    51  }
    52  
    53  // dhKEM implements the KEM specified in RFC 9180, Section 4.1.
    54  type dhKEM struct {
    55  	dh  ecdh.Curve
    56  	kdf hkdfKDF
    57  
    58  	suiteID []byte
    59  	nSecret uint16
    60  }
    61  
    62  var SupportedKEMs = map[uint16]struct {
    63  	curve   ecdh.Curve
    64  	hash    crypto.Hash
    65  	nSecret uint16
    66  }{
    67  	// RFC 9180 Section 7.1
    68  	0x0020: {ecdh.X25519(), crypto.SHA256, 32},
    69  }
    70  
    71  func newDHKem(kemID uint16) (*dhKEM, error) {
    72  	suite, ok := SupportedKEMs[kemID]
    73  	if !ok {
    74  		return nil, errors.New("unsupported suite ID")
    75  	}
    76  	return &dhKEM{
    77  		dh:      suite.curve,
    78  		kdf:     hkdfKDF{suite.hash},
    79  		suiteID: byteorder.BeAppendUint16([]byte("KEM"), kemID),
    80  		nSecret: suite.nSecret,
    81  	}, nil
    82  }
    83  
    84  func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) []byte {
    85  	eaePRK := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
    86  	return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
    87  }
    88  
    89  func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) {
    90  	var privEph *ecdh.PrivateKey
    91  	if testingOnlyGenerateKey != nil {
    92  		privEph, err = testingOnlyGenerateKey()
    93  	} else {
    94  		privEph, err = dh.dh.GenerateKey(rand.Reader)
    95  	}
    96  	if err != nil {
    97  		return nil, nil, err
    98  	}
    99  	dhVal, err := privEph.ECDH(pubRecipient)
   100  	if err != nil {
   101  		return nil, nil, err
   102  	}
   103  	encPubEph := privEph.PublicKey().Bytes()
   104  
   105  	encPubRecip := pubRecipient.Bytes()
   106  	kemContext := append(encPubEph, encPubRecip...)
   107  
   108  	return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil
   109  }
   110  
   111  type Sender struct {
   112  	aead cipher.AEAD
   113  	kem  *dhKEM
   114  
   115  	sharedSecret []byte
   116  
   117  	suiteID []byte
   118  
   119  	key            []byte
   120  	baseNonce      []byte
   121  	exporterSecret []byte
   122  
   123  	seqNum uint128
   124  }
   125  
   126  var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
   127  	block, err := aes.NewCipher(key)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  	return cipher.NewGCM(block)
   132  }
   133  
   134  var SupportedAEADs = map[uint16]struct {
   135  	keySize   int
   136  	nonceSize int
   137  	aead      func([]byte) (cipher.AEAD, error)
   138  }{
   139  	// RFC 9180, Section 7.3
   140  	0x0001: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
   141  	0x0002: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
   142  	0x0003: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
   143  }
   144  
   145  var SupportedKDFs = map[uint16]func() *hkdfKDF{
   146  	// RFC 9180, Section 7.2
   147  	0x0001: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
   148  }
   149  
   150  func SetupSender(kemID, kdfID, aeadID uint16, pub crypto.PublicKey, info []byte) ([]byte, *Sender, error) {
   151  	suiteID := SuiteID(kemID, kdfID, aeadID)
   152  
   153  	kem, err := newDHKem(kemID)
   154  	if err != nil {
   155  		return nil, nil, err
   156  	}
   157  	pubRecipient, ok := pub.(*ecdh.PublicKey)
   158  	if !ok {
   159  		return nil, nil, errors.New("incorrect public key type")
   160  	}
   161  	sharedSecret, encapsulatedKey, err := kem.Encap(pubRecipient)
   162  	if err != nil {
   163  		return nil, nil, err
   164  	}
   165  
   166  	kdfInit, ok := SupportedKDFs[kdfID]
   167  	if !ok {
   168  		return nil, nil, errors.New("unsupported KDF id")
   169  	}
   170  	kdf := kdfInit()
   171  
   172  	aeadInfo, ok := SupportedAEADs[aeadID]
   173  	if !ok {
   174  		return nil, nil, errors.New("unsupported AEAD id")
   175  	}
   176  
   177  	pskIDHash := kdf.LabeledExtract(suiteID, nil, "psk_id_hash", nil)
   178  	infoHash := kdf.LabeledExtract(suiteID, nil, "info_hash", info)
   179  	ksContext := append([]byte{0}, pskIDHash...)
   180  	ksContext = append(ksContext, infoHash...)
   181  
   182  	secret := kdf.LabeledExtract(suiteID, sharedSecret, "secret", nil)
   183  
   184  	key := kdf.LabeledExpand(suiteID, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
   185  	baseNonce := kdf.LabeledExpand(suiteID, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
   186  	exporterSecret := kdf.LabeledExpand(suiteID, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
   187  
   188  	aead, err := aeadInfo.aead(key)
   189  	if err != nil {
   190  		return nil, nil, err
   191  	}
   192  
   193  	return encapsulatedKey, &Sender{
   194  		kem:            kem,
   195  		aead:           aead,
   196  		sharedSecret:   sharedSecret,
   197  		suiteID:        suiteID,
   198  		key:            key,
   199  		baseNonce:      baseNonce,
   200  		exporterSecret: exporterSecret,
   201  	}, nil
   202  }
   203  
   204  func (s *Sender) nextNonce() []byte {
   205  	nonce := s.seqNum.bytes()[16-s.aead.NonceSize():]
   206  	for i := range s.baseNonce {
   207  		nonce[i] ^= s.baseNonce[i]
   208  	}
   209  	// Message limit is, according to the RFC, 2^95+1, which
   210  	// is somewhat confusing, but we do as we're told.
   211  	if s.seqNum.bitLen() >= (s.aead.NonceSize()*8)-1 {
   212  		panic("message limit reached")
   213  	}
   214  	s.seqNum = s.seqNum.addOne()
   215  	return nonce
   216  }
   217  
   218  func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
   219  
   220  	ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
   221  	return ciphertext, nil
   222  }
   223  
   224  func SuiteID(kemID, kdfID, aeadID uint16) []byte {
   225  	suiteID := make([]byte, 0, 4+2+2+2)
   226  	suiteID = append(suiteID, []byte("HPKE")...)
   227  	suiteID = byteorder.BeAppendUint16(suiteID, kemID)
   228  	suiteID = byteorder.BeAppendUint16(suiteID, kdfID)
   229  	suiteID = byteorder.BeAppendUint16(suiteID, aeadID)
   230  	return suiteID
   231  }
   232  
   233  func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
   234  	kemInfo, ok := SupportedKEMs[kemID]
   235  	if !ok {
   236  		return nil, errors.New("unsupported KEM id")
   237  	}
   238  	return kemInfo.curve.NewPublicKey(bytes)
   239  }
   240  
   241  type uint128 struct {
   242  	hi, lo uint64
   243  }
   244  
   245  func (u uint128) addOne() uint128 {
   246  	lo, carry := bits.Add64(u.lo, 1, 0)
   247  	return uint128{u.hi + carry, lo}
   248  }
   249  
   250  func (u uint128) bitLen() int {
   251  	return bits.Len64(u.hi) + bits.Len64(u.lo)
   252  }
   253  
   254  func (u uint128) bytes() []byte {
   255  	b := make([]byte, 16)
   256  	byteorder.BePutUint64(b[0:], u.hi)
   257  	byteorder.BePutUint64(b[8:], u.lo)
   258  	return b
   259  }
   260  

View as plain text