1
2
3
4
5 package hpke
6
7 import (
8 "bytes"
9 "encoding/hex"
10 "encoding/json"
11 "os"
12 "strconv"
13 "strings"
14 "testing"
15
16 "crypto/ecdh"
17 _ "crypto/sha256"
18 _ "crypto/sha512"
19 )
20
21 func mustDecodeHex(t *testing.T, in string) []byte {
22 t.Helper()
23 b, err := hex.DecodeString(in)
24 if err != nil {
25 t.Fatal(err)
26 }
27 return b
28 }
29
30 func parseVectorSetup(vector string) map[string]string {
31 vals := map[string]string{}
32 for _, l := range strings.Split(vector, "\n") {
33 fields := strings.Split(l, ": ")
34 vals[fields[0]] = fields[1]
35 }
36 return vals
37 }
38
39 func parseVectorEncryptions(vector string) []map[string]string {
40 vals := []map[string]string{}
41 for _, section := range strings.Split(vector, "\n\n") {
42 e := map[string]string{}
43 for _, l := range strings.Split(section, "\n") {
44 fields := strings.Split(l, ": ")
45 e[fields[0]] = fields[1]
46 }
47 vals = append(vals, e)
48 }
49 return vals
50 }
51
52 func TestRFC9180Vectors(t *testing.T) {
53 vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json")
54 if err != nil {
55 t.Fatal(err)
56 }
57
58 var vectors []struct {
59 Name string
60 Setup string
61 Encryptions string
62 }
63 if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
64 t.Fatal(err)
65 }
66
67 for _, vector := range vectors {
68 t.Run(vector.Name, func(t *testing.T) {
69 setup := parseVectorSetup(vector.Setup)
70
71 kemID, err := strconv.Atoi(setup["kem_id"])
72 if err != nil {
73 t.Fatal(err)
74 }
75 if _, ok := SupportedKEMs[uint16(kemID)]; !ok {
76 t.Skip("unsupported KEM")
77 }
78 kdfID, err := strconv.Atoi(setup["kdf_id"])
79 if err != nil {
80 t.Fatal(err)
81 }
82 if _, ok := SupportedKDFs[uint16(kdfID)]; !ok {
83 t.Skip("unsupported KDF")
84 }
85 aeadID, err := strconv.Atoi(setup["aead_id"])
86 if err != nil {
87 t.Fatal(err)
88 }
89 if _, ok := SupportedAEADs[uint16(aeadID)]; !ok {
90 t.Skip("unsupported AEAD")
91 }
92
93 info := mustDecodeHex(t, setup["info"])
94 pubKeyBytes := mustDecodeHex(t, setup["pkRm"])
95 pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes)
96 if err != nil {
97 t.Fatal(err)
98 }
99
100 ephemeralPrivKey := mustDecodeHex(t, setup["skEm"])
101
102 testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) {
103 return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey)
104 }
105 t.Cleanup(func() { testingOnlyGenerateKey = nil })
106
107 encap, context, err := SetupSender(
108 uint16(kemID),
109 uint16(kdfID),
110 uint16(aeadID),
111 pub,
112 info,
113 )
114 if err != nil {
115 t.Fatal(err)
116 }
117
118 expectedEncap := mustDecodeHex(t, setup["enc"])
119 if !bytes.Equal(encap, expectedEncap) {
120 t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
121 }
122 expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
123 if !bytes.Equal(context.sharedSecret, expectedSharedSecret) {
124 t.Errorf("unexpected shared secret, got: %x, want %x", context.sharedSecret, expectedSharedSecret)
125 }
126 expectedKey := mustDecodeHex(t, setup["key"])
127 if !bytes.Equal(context.key, expectedKey) {
128 t.Errorf("unexpected key, got: %x, want %x", context.key, expectedKey)
129 }
130 expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
131 if !bytes.Equal(context.baseNonce, expectedBaseNonce) {
132 t.Errorf("unexpected base nonce, got: %x, want %x", context.baseNonce, expectedBaseNonce)
133 }
134 expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
135 if !bytes.Equal(context.exporterSecret, expectedExporterSecret) {
136 t.Errorf("unexpected exporter secret, got: %x, want %x", context.exporterSecret, expectedExporterSecret)
137 }
138
139 for _, enc := range parseVectorEncryptions(vector.Encryptions) {
140 t.Run("seq num "+enc["sequence number"], func(t *testing.T) {
141 seqNum, err := strconv.Atoi(enc["sequence number"])
142 if err != nil {
143 t.Fatal(err)
144 }
145 context.seqNum = uint128{lo: uint64(seqNum)}
146 expectedNonce := mustDecodeHex(t, enc["nonce"])
147
148
149 computedNonce := context.seqNum.bytes()[16-context.aead.NonceSize():]
150 for i := range context.baseNonce {
151 computedNonce[i] ^= context.baseNonce[i]
152 }
153 if !bytes.Equal(computedNonce, expectedNonce) {
154 t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
155 }
156
157 expectedCiphertext := mustDecodeHex(t, enc["ct"])
158 ciphertext, err := context.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
159 if err != nil {
160 t.Fatal(err)
161 }
162 if !bytes.Equal(ciphertext, expectedCiphertext) {
163 t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext)
164 }
165 })
166 }
167 })
168 }
169 }
170
View as plain text