1
2
3
4
5 package hpke
6
7 import (
8 "bytes"
9 "encoding/hex"
10 "encoding/json"
11 "os"
12 "strconv"
13 "strings"
14 "testing"
15
16 "crypto/ecdh"
17 _ "crypto/sha256"
18 _ "crypto/sha512"
19 )
20
21 func mustDecodeHex(t *testing.T, in string) []byte {
22 b, err := hex.DecodeString(in)
23 if err != nil {
24 t.Fatal(err)
25 }
26 return b
27 }
28
29 func parseVectorSetup(vector string) map[string]string {
30 vals := map[string]string{}
31 for _, l := range strings.Split(vector, "\n") {
32 fields := strings.Split(l, ": ")
33 vals[fields[0]] = fields[1]
34 }
35 return vals
36 }
37
38 func parseVectorEncryptions(vector string) []map[string]string {
39 vals := []map[string]string{}
40 for _, section := range strings.Split(vector, "\n\n") {
41 e := map[string]string{}
42 for _, l := range strings.Split(section, "\n") {
43 fields := strings.Split(l, ": ")
44 e[fields[0]] = fields[1]
45 }
46 vals = append(vals, e)
47 }
48 return vals
49 }
50
51 func TestRFC9180Vectors(t *testing.T) {
52 vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json")
53 if err != nil {
54 t.Fatal(err)
55 }
56
57 var vectors []struct {
58 Name string
59 Setup string
60 Encryptions string
61 }
62 if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
63 t.Fatal(err)
64 }
65
66 for _, vector := range vectors {
67 t.Run(vector.Name, func(t *testing.T) {
68 setup := parseVectorSetup(vector.Setup)
69
70 kemID, err := strconv.Atoi(setup["kem_id"])
71 if err != nil {
72 t.Fatal(err)
73 }
74 if _, ok := SupportedKEMs[uint16(kemID)]; !ok {
75 t.Skip("unsupported KEM")
76 }
77 kdfID, err := strconv.Atoi(setup["kdf_id"])
78 if err != nil {
79 t.Fatal(err)
80 }
81 if _, ok := SupportedKDFs[uint16(kdfID)]; !ok {
82 t.Skip("unsupported KDF")
83 }
84 aeadID, err := strconv.Atoi(setup["aead_id"])
85 if err != nil {
86 t.Fatal(err)
87 }
88 if _, ok := SupportedAEADs[uint16(aeadID)]; !ok {
89 t.Skip("unsupported AEAD")
90 }
91
92 info := mustDecodeHex(t, setup["info"])
93 pubKeyBytes := mustDecodeHex(t, setup["pkRm"])
94 pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes)
95 if err != nil {
96 t.Fatal(err)
97 }
98
99 ephemeralPrivKey := mustDecodeHex(t, setup["skEm"])
100
101 testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) {
102 return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey)
103 }
104 t.Cleanup(func() { testingOnlyGenerateKey = nil })
105
106 encap, context, err := SetupSender(
107 uint16(kemID),
108 uint16(kdfID),
109 uint16(aeadID),
110 pub,
111 info,
112 )
113 if err != nil {
114 t.Fatal(err)
115 }
116
117 expectedEncap := mustDecodeHex(t, setup["enc"])
118 if !bytes.Equal(encap, expectedEncap) {
119 t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
120 }
121 expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
122 if !bytes.Equal(context.sharedSecret, expectedSharedSecret) {
123 t.Errorf("unexpected shared secret, got: %x, want %x", context.sharedSecret, expectedSharedSecret)
124 }
125 expectedKey := mustDecodeHex(t, setup["key"])
126 if !bytes.Equal(context.key, expectedKey) {
127 t.Errorf("unexpected key, got: %x, want %x", context.key, expectedKey)
128 }
129 expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
130 if !bytes.Equal(context.baseNonce, expectedBaseNonce) {
131 t.Errorf("unexpected base nonce, got: %x, want %x", context.baseNonce, expectedBaseNonce)
132 }
133 expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
134 if !bytes.Equal(context.exporterSecret, expectedExporterSecret) {
135 t.Errorf("unexpected exporter secret, got: %x, want %x", context.exporterSecret, expectedExporterSecret)
136 }
137
138 for _, enc := range parseVectorEncryptions(vector.Encryptions) {
139 t.Run("seq num "+enc["sequence number"], func(t *testing.T) {
140 seqNum, err := strconv.Atoi(enc["sequence number"])
141 if err != nil {
142 t.Fatal(err)
143 }
144 context.seqNum = uint128{lo: uint64(seqNum)}
145 expectedNonce := mustDecodeHex(t, enc["nonce"])
146
147
148 computedNonce := context.seqNum.bytes()[16-context.aead.NonceSize():]
149 for i := range context.baseNonce {
150 computedNonce[i] ^= context.baseNonce[i]
151 }
152 if !bytes.Equal(computedNonce, expectedNonce) {
153 t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
154 }
155
156 expectedCiphertext := mustDecodeHex(t, enc["ct"])
157 ciphertext, err := context.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
158 if err != nil {
159 t.Fatal(err)
160 }
161 if !bytes.Equal(ciphertext, expectedCiphertext) {
162 t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext)
163 }
164 })
165 }
166 })
167 }
168 }
169
View as plain text