Source file src/crypto/internal/hpke/hpke_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package hpke
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/hex"
    10  	"encoding/json"
    11  	"os"
    12  	"strconv"
    13  	"strings"
    14  	"testing"
    15  
    16  	"crypto/ecdh"
    17  	_ "crypto/sha256"
    18  	_ "crypto/sha512"
    19  )
    20  
    21  func mustDecodeHex(t *testing.T, in string) []byte {
    22  	t.Helper()
    23  	b, err := hex.DecodeString(in)
    24  	if err != nil {
    25  		t.Fatal(err)
    26  	}
    27  	return b
    28  }
    29  
    30  func parseVectorSetup(vector string) map[string]string {
    31  	vals := map[string]string{}
    32  	for _, l := range strings.Split(vector, "\n") {
    33  		fields := strings.Split(l, ": ")
    34  		vals[fields[0]] = fields[1]
    35  	}
    36  	return vals
    37  }
    38  
    39  func parseVectorEncryptions(vector string) []map[string]string {
    40  	vals := []map[string]string{}
    41  	for _, section := range strings.Split(vector, "\n\n") {
    42  		e := map[string]string{}
    43  		for _, l := range strings.Split(section, "\n") {
    44  			fields := strings.Split(l, ": ")
    45  			e[fields[0]] = fields[1]
    46  		}
    47  		vals = append(vals, e)
    48  	}
    49  	return vals
    50  }
    51  
    52  func TestRFC9180Vectors(t *testing.T) {
    53  	vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json")
    54  	if err != nil {
    55  		t.Fatal(err)
    56  	}
    57  
    58  	var vectors []struct {
    59  		Name        string
    60  		Setup       string
    61  		Encryptions string
    62  	}
    63  	if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
    64  		t.Fatal(err)
    65  	}
    66  
    67  	for _, vector := range vectors {
    68  		t.Run(vector.Name, func(t *testing.T) {
    69  			setup := parseVectorSetup(vector.Setup)
    70  
    71  			kemID, err := strconv.Atoi(setup["kem_id"])
    72  			if err != nil {
    73  				t.Fatal(err)
    74  			}
    75  			if _, ok := SupportedKEMs[uint16(kemID)]; !ok {
    76  				t.Skip("unsupported KEM")
    77  			}
    78  			kdfID, err := strconv.Atoi(setup["kdf_id"])
    79  			if err != nil {
    80  				t.Fatal(err)
    81  			}
    82  			if _, ok := SupportedKDFs[uint16(kdfID)]; !ok {
    83  				t.Skip("unsupported KDF")
    84  			}
    85  			aeadID, err := strconv.Atoi(setup["aead_id"])
    86  			if err != nil {
    87  				t.Fatal(err)
    88  			}
    89  			if _, ok := SupportedAEADs[uint16(aeadID)]; !ok {
    90  				t.Skip("unsupported AEAD")
    91  			}
    92  
    93  			info := mustDecodeHex(t, setup["info"])
    94  			pubKeyBytes := mustDecodeHex(t, setup["pkRm"])
    95  			pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes)
    96  			if err != nil {
    97  				t.Fatal(err)
    98  			}
    99  
   100  			ephemeralPrivKey := mustDecodeHex(t, setup["skEm"])
   101  
   102  			testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) {
   103  				return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey)
   104  			}
   105  			t.Cleanup(func() { testingOnlyGenerateKey = nil })
   106  
   107  			encap, context, err := SetupSender(
   108  				uint16(kemID),
   109  				uint16(kdfID),
   110  				uint16(aeadID),
   111  				pub,
   112  				info,
   113  			)
   114  			if err != nil {
   115  				t.Fatal(err)
   116  			}
   117  
   118  			expectedEncap := mustDecodeHex(t, setup["enc"])
   119  			if !bytes.Equal(encap, expectedEncap) {
   120  				t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
   121  			}
   122  			expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
   123  			if !bytes.Equal(context.sharedSecret, expectedSharedSecret) {
   124  				t.Errorf("unexpected shared secret, got: %x, want %x", context.sharedSecret, expectedSharedSecret)
   125  			}
   126  			expectedKey := mustDecodeHex(t, setup["key"])
   127  			if !bytes.Equal(context.key, expectedKey) {
   128  				t.Errorf("unexpected key, got: %x, want %x", context.key, expectedKey)
   129  			}
   130  			expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
   131  			if !bytes.Equal(context.baseNonce, expectedBaseNonce) {
   132  				t.Errorf("unexpected base nonce, got: %x, want %x", context.baseNonce, expectedBaseNonce)
   133  			}
   134  			expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
   135  			if !bytes.Equal(context.exporterSecret, expectedExporterSecret) {
   136  				t.Errorf("unexpected exporter secret, got: %x, want %x", context.exporterSecret, expectedExporterSecret)
   137  			}
   138  
   139  			for _, enc := range parseVectorEncryptions(vector.Encryptions) {
   140  				t.Run("seq num "+enc["sequence number"], func(t *testing.T) {
   141  					seqNum, err := strconv.Atoi(enc["sequence number"])
   142  					if err != nil {
   143  						t.Fatal(err)
   144  					}
   145  					context.seqNum = uint128{lo: uint64(seqNum)}
   146  					expectedNonce := mustDecodeHex(t, enc["nonce"])
   147  					// We can't call nextNonce, because it increments the sequence number,
   148  					// so just compute it directly.
   149  					computedNonce := context.seqNum.bytes()[16-context.aead.NonceSize():]
   150  					for i := range context.baseNonce {
   151  						computedNonce[i] ^= context.baseNonce[i]
   152  					}
   153  					if !bytes.Equal(computedNonce, expectedNonce) {
   154  						t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
   155  					}
   156  
   157  					expectedCiphertext := mustDecodeHex(t, enc["ct"])
   158  					ciphertext, err := context.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
   159  					if err != nil {
   160  						t.Fatal(err)
   161  					}
   162  					if !bytes.Equal(ciphertext, expectedCiphertext) {
   163  						t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext)
   164  					}
   165  				})
   166  			}
   167  		})
   168  	}
   169  }
   170  

View as plain text