Source file src/crypto/internal/hpke/hpke_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package hpke
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/hex"
    10  	"encoding/json"
    11  	"os"
    12  	"strconv"
    13  	"strings"
    14  	"testing"
    15  
    16  	"crypto/ecdh"
    17  	_ "crypto/sha256"
    18  	_ "crypto/sha512"
    19  )
    20  
    21  func mustDecodeHex(t *testing.T, in string) []byte {
    22  	b, err := hex.DecodeString(in)
    23  	if err != nil {
    24  		t.Fatal(err)
    25  	}
    26  	return b
    27  }
    28  
    29  func parseVectorSetup(vector string) map[string]string {
    30  	vals := map[string]string{}
    31  	for _, l := range strings.Split(vector, "\n") {
    32  		fields := strings.Split(l, ": ")
    33  		vals[fields[0]] = fields[1]
    34  	}
    35  	return vals
    36  }
    37  
    38  func parseVectorEncryptions(vector string) []map[string]string {
    39  	vals := []map[string]string{}
    40  	for _, section := range strings.Split(vector, "\n\n") {
    41  		e := map[string]string{}
    42  		for _, l := range strings.Split(section, "\n") {
    43  			fields := strings.Split(l, ": ")
    44  			e[fields[0]] = fields[1]
    45  		}
    46  		vals = append(vals, e)
    47  	}
    48  	return vals
    49  }
    50  
    51  func TestRFC9180Vectors(t *testing.T) {
    52  	vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json")
    53  	if err != nil {
    54  		t.Fatal(err)
    55  	}
    56  
    57  	var vectors []struct {
    58  		Name        string
    59  		Setup       string
    60  		Encryptions string
    61  	}
    62  	if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
    63  		t.Fatal(err)
    64  	}
    65  
    66  	for _, vector := range vectors {
    67  		t.Run(vector.Name, func(t *testing.T) {
    68  			setup := parseVectorSetup(vector.Setup)
    69  
    70  			kemID, err := strconv.Atoi(setup["kem_id"])
    71  			if err != nil {
    72  				t.Fatal(err)
    73  			}
    74  			if _, ok := SupportedKEMs[uint16(kemID)]; !ok {
    75  				t.Skip("unsupported KEM")
    76  			}
    77  			kdfID, err := strconv.Atoi(setup["kdf_id"])
    78  			if err != nil {
    79  				t.Fatal(err)
    80  			}
    81  			if _, ok := SupportedKDFs[uint16(kdfID)]; !ok {
    82  				t.Skip("unsupported KDF")
    83  			}
    84  			aeadID, err := strconv.Atoi(setup["aead_id"])
    85  			if err != nil {
    86  				t.Fatal(err)
    87  			}
    88  			if _, ok := SupportedAEADs[uint16(aeadID)]; !ok {
    89  				t.Skip("unsupported AEAD")
    90  			}
    91  
    92  			info := mustDecodeHex(t, setup["info"])
    93  			pubKeyBytes := mustDecodeHex(t, setup["pkRm"])
    94  			pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes)
    95  			if err != nil {
    96  				t.Fatal(err)
    97  			}
    98  
    99  			ephemeralPrivKey := mustDecodeHex(t, setup["skEm"])
   100  
   101  			testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) {
   102  				return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey)
   103  			}
   104  			t.Cleanup(func() { testingOnlyGenerateKey = nil })
   105  
   106  			encap, context, err := SetupSender(
   107  				uint16(kemID),
   108  				uint16(kdfID),
   109  				uint16(aeadID),
   110  				pub,
   111  				info,
   112  			)
   113  			if err != nil {
   114  				t.Fatal(err)
   115  			}
   116  
   117  			expectedEncap := mustDecodeHex(t, setup["enc"])
   118  			if !bytes.Equal(encap, expectedEncap) {
   119  				t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
   120  			}
   121  			expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
   122  			if !bytes.Equal(context.sharedSecret, expectedSharedSecret) {
   123  				t.Errorf("unexpected shared secret, got: %x, want %x", context.sharedSecret, expectedSharedSecret)
   124  			}
   125  			expectedKey := mustDecodeHex(t, setup["key"])
   126  			if !bytes.Equal(context.key, expectedKey) {
   127  				t.Errorf("unexpected key, got: %x, want %x", context.key, expectedKey)
   128  			}
   129  			expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
   130  			if !bytes.Equal(context.baseNonce, expectedBaseNonce) {
   131  				t.Errorf("unexpected base nonce, got: %x, want %x", context.baseNonce, expectedBaseNonce)
   132  			}
   133  			expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
   134  			if !bytes.Equal(context.exporterSecret, expectedExporterSecret) {
   135  				t.Errorf("unexpected exporter secret, got: %x, want %x", context.exporterSecret, expectedExporterSecret)
   136  			}
   137  
   138  			for _, enc := range parseVectorEncryptions(vector.Encryptions) {
   139  				t.Run("seq num "+enc["sequence number"], func(t *testing.T) {
   140  					seqNum, err := strconv.Atoi(enc["sequence number"])
   141  					if err != nil {
   142  						t.Fatal(err)
   143  					}
   144  					context.seqNum = uint128{lo: uint64(seqNum)}
   145  					expectedNonce := mustDecodeHex(t, enc["nonce"])
   146  					// We can't call nextNonce, because it increments the sequence number,
   147  					// so just compute it directly.
   148  					computedNonce := context.seqNum.bytes()[16-context.aead.NonceSize():]
   149  					for i := range context.baseNonce {
   150  						computedNonce[i] ^= context.baseNonce[i]
   151  					}
   152  					if !bytes.Equal(computedNonce, expectedNonce) {
   153  						t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
   154  					}
   155  
   156  					expectedCiphertext := mustDecodeHex(t, enc["ct"])
   157  					ciphertext, err := context.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
   158  					if err != nil {
   159  						t.Fatal(err)
   160  					}
   161  					if !bytes.Equal(ciphertext, expectedCiphertext) {
   162  						t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext)
   163  					}
   164  				})
   165  			}
   166  		})
   167  	}
   168  }
   169  

View as plain text