// Copyright 2024 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package hpke import ( "bytes" "encoding/hex" "encoding/json" "os" "strconv" "strings" "testing" "crypto/ecdh" _ "crypto/sha256" _ "crypto/sha512" ) func mustDecodeHex(t *testing.T, in string) []byte { b, err := hex.DecodeString(in) if err != nil { t.Fatal(err) } return b } func parseVectorSetup(vector string) map[string]string { vals := map[string]string{} for _, l := range strings.Split(vector, "\n") { fields := strings.Split(l, ": ") vals[fields[0]] = fields[1] } return vals } func parseVectorEncryptions(vector string) []map[string]string { vals := []map[string]string{} for _, section := range strings.Split(vector, "\n\n") { e := map[string]string{} for _, l := range strings.Split(section, "\n") { fields := strings.Split(l, ": ") e[fields[0]] = fields[1] } vals = append(vals, e) } return vals } func TestRFC9180Vectors(t *testing.T) { vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json") if err != nil { t.Fatal(err) } var vectors []struct { Name string Setup string Encryptions string } if err := json.Unmarshal(vectorsJSON, &vectors); err != nil { t.Fatal(err) } for _, vector := range vectors { t.Run(vector.Name, func(t *testing.T) { setup := parseVectorSetup(vector.Setup) kemID, err := strconv.Atoi(setup["kem_id"]) if err != nil { t.Fatal(err) } if _, ok := SupportedKEMs[uint16(kemID)]; !ok { t.Skip("unsupported KEM") } kdfID, err := strconv.Atoi(setup["kdf_id"]) if err != nil { t.Fatal(err) } if _, ok := SupportedKDFs[uint16(kdfID)]; !ok { t.Skip("unsupported KDF") } aeadID, err := strconv.Atoi(setup["aead_id"]) if err != nil { t.Fatal(err) } if _, ok := SupportedAEADs[uint16(aeadID)]; !ok { t.Skip("unsupported AEAD") } info := mustDecodeHex(t, setup["info"]) pubKeyBytes := mustDecodeHex(t, setup["pkRm"]) pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes) if err != nil { t.Fatal(err) } ephemeralPrivKey := mustDecodeHex(t, setup["skEm"]) testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) { return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey) } t.Cleanup(func() { testingOnlyGenerateKey = nil }) encap, context, err := SetupSender( uint16(kemID), uint16(kdfID), uint16(aeadID), pub, info, ) if err != nil { t.Fatal(err) } expectedEncap := mustDecodeHex(t, setup["enc"]) if !bytes.Equal(encap, expectedEncap) { t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap) } expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"]) if !bytes.Equal(context.sharedSecret, expectedSharedSecret) { t.Errorf("unexpected shared secret, got: %x, want %x", context.sharedSecret, expectedSharedSecret) } expectedKey := mustDecodeHex(t, setup["key"]) if !bytes.Equal(context.key, expectedKey) { t.Errorf("unexpected key, got: %x, want %x", context.key, expectedKey) } expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"]) if !bytes.Equal(context.baseNonce, expectedBaseNonce) { t.Errorf("unexpected base nonce, got: %x, want %x", context.baseNonce, expectedBaseNonce) } expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"]) if !bytes.Equal(context.exporterSecret, expectedExporterSecret) { t.Errorf("unexpected exporter secret, got: %x, want %x", context.exporterSecret, expectedExporterSecret) } for _, enc := range parseVectorEncryptions(vector.Encryptions) { t.Run("seq num "+enc["sequence number"], func(t *testing.T) { seqNum, err := strconv.Atoi(enc["sequence number"]) if err != nil { t.Fatal(err) } context.seqNum = uint128{lo: uint64(seqNum)} expectedNonce := mustDecodeHex(t, enc["nonce"]) // We can't call nextNonce, because it increments the sequence number, // so just compute it directly. computedNonce := context.seqNum.bytes()[16-context.aead.NonceSize():] for i := range context.baseNonce { computedNonce[i] ^= context.baseNonce[i] } if !bytes.Equal(computedNonce, expectedNonce) { t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce) } expectedCiphertext := mustDecodeHex(t, enc["ct"]) ciphertext, err := context.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"])) if err != nil { t.Fatal(err) } if !bytes.Equal(ciphertext, expectedCiphertext) { t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext) } }) } }) } }