Source file src/crypto/rsa/rsa_wycheproof_test.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  package rsa_test
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/internal/cryptotest/wycheproof"
     9  	"crypto/rsa"
    10  	"crypto/x509"
    11  	"fmt"
    12  	"slices"
    13  	"testing"
    14  )
    15  
    16  func TestRSAOAEPDecryptWycheproof(t *testing.T) {
    17  	flagsShouldPass := map[string]bool{
    18  		"Constructed":         true,
    19  		"EncryptionWithLabel": true,
    20  		// rsa.DecryptOAEP happily supports small key sizes
    21  		"SmallIntegerCiphertext": true,
    22  	}
    23  
    24  	// TODO(XXX): support test files with different hashes for MGF/label
    25  	for _, file := range []string{
    26  		"rsa_oaep_2048_sha1_mgf1sha1_test.json",
    27  		"rsa_oaep_2048_sha224_mgf1sha224_test.json",
    28  		"rsa_oaep_2048_sha256_mgf1sha256_test.json",
    29  		"rsa_oaep_2048_sha384_mgf1sha384_test.json",
    30  		"rsa_oaep_2048_sha512_mgf1sha512_test.json",
    31  		"rsa_oaep_3072_sha256_mgf1sha256_test.json",
    32  		"rsa_oaep_3072_sha512_mgf1sha512_test.json",
    33  		"rsa_oaep_4096_sha256_mgf1sha256_test.json",
    34  		"rsa_oaep_4096_sha512_mgf1sha512_test.json",
    35  		"rsa_oaep_misc_test.json",
    36  	} {
    37  		var testdata wycheproof.RsaesOaepDecryptSchemaV1Json
    38  		wycheproof.LoadVectorFile(t, file, &testdata)
    39  
    40  		for _, tg := range testdata.TestGroups {
    41  			// TODO(XXX): support rsa_oaep_misc_test test cases with different hashes for MGF/label
    42  			if tg.MgfSha != tg.Sha {
    43  				t.Skip("test cases with different hashes for MGF/label not yet supported")
    44  			}
    45  
    46  			rawPriv, err := x509.ParsePKCS8PrivateKey(wycheproof.MustDecodeHex(tg.PrivateKeyPkcs8))
    47  			if err != nil {
    48  				t.Fatalf("%s failed to parse PKCS #8 private key: %s", file, err)
    49  			}
    50  			priv := rawPriv.(*rsa.PrivateKey)
    51  			hash := wycheproof.ParseHash(tg.Sha)
    52  
    53  			for _, tv := range tg.Tests {
    54  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
    55  					t.Parallel()
    56  
    57  					ct := wycheproof.MustDecodeHex(tv.Ct)
    58  					label := wycheproof.MustDecodeHex(tv.Label)
    59  					wantPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, flagsShouldPass)
    60  					plaintext, err := rsa.DecryptOAEP(hash.New(), nil, priv, ct, label)
    61  					if wantPass {
    62  						if err != nil {
    63  							t.Fatalf("expected success: %s", err)
    64  						}
    65  						if !bytes.Equal(plaintext, wycheproof.MustDecodeHex(tv.Msg)) {
    66  							t.Errorf("unexpected plaintext: got %x, want %s", plaintext, tv.Msg)
    67  						}
    68  					} else if err == nil {
    69  						t.Errorf("expected failure")
    70  					}
    71  				})
    72  			}
    73  		}
    74  	}
    75  }
    76  
    77  func TestRSAPKCS1SignaturesWycheproof(t *testing.T) {
    78  	// A map of supported modulus sizes to the list of hashes that Wycheproof has
    79  	// test vector coverage for.
    80  	modsAndHashes := map[int][]string{
    81  		2048: {
    82  			"sha224",
    83  			"sha256",
    84  			"sha384",
    85  			"sha512",
    86  			"sha512_224",
    87  			"sha512_256",
    88  			"sha3_224",
    89  			"sha3_256",
    90  			"sha3_384",
    91  			"sha3_512",
    92  		},
    93  		3072: {
    94  			"sha256",
    95  			"sha384",
    96  			"sha512",
    97  			"sha512_256",
    98  			"sha3_256",
    99  			"sha3_384",
   100  			"sha3_512",
   101  		},
   102  		4096: {
   103  			"sha256",
   104  			"sha384",
   105  			"sha512",
   106  			"sha512_256",
   107  		},
   108  		8192: {
   109  			"sha256",
   110  			"sha384",
   111  			"sha512",
   112  		},
   113  	}
   114  
   115  	var files []string
   116  	for m, hashes := range modsAndHashes {
   117  		for _, h := range hashes {
   118  			files = append(files, fmt.Sprintf("rsa_signature_%d_%s_test.json", m, h))
   119  		}
   120  	}
   121  
   122  	flagsShouldPass := map[string]bool{
   123  		// Omitting the parameter field in an ASN encoded integer is a legacy behavior.
   124  		"MissingNull": false,
   125  	}
   126  
   127  	for _, file := range files {
   128  		var testdata wycheproof.RsassaPkcs1VerifySchemaV1Json
   129  		wycheproof.LoadVectorFile(t, file, &testdata)
   130  
   131  		for _, tg := range testdata.TestGroups {
   132  			hash := wycheproof.ParseHash(tg.Sha)
   133  
   134  			pub, err := x509.ParsePKCS1PublicKey(wycheproof.MustDecodeHex(tg.PublicKeyAsn))
   135  			if err != nil {
   136  				t.Fatalf("failed to decode pubkey: %v", err)
   137  			}
   138  
   139  			for _, tv := range tg.Tests {
   140  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
   141  					t.Parallel()
   142  
   143  					sig := wycheproof.MustDecodeHex(tv.Sig)
   144  					h := hash.New()
   145  					h.Write(wycheproof.MustDecodeHex(tv.Msg))
   146  					err := rsa.VerifyPKCS1v15(pub, hash, h.Sum(nil), sig)
   147  					want := wycheproof.ShouldPass(t, tv.Result, tv.Flags, flagsShouldPass)
   148  					if (err == nil) != want {
   149  						t.Errorf("wanted success: %t err: %v", want, err)
   150  					}
   151  				})
   152  			}
   153  		}
   154  	}
   155  }
   156  
   157  func TestRSAPSSSignaturesWycheproof(t *testing.T) {
   158  	// filesOverrideToPassZeroSLen is a map of all test files
   159  	// and which TcIds that should be overridden to pass if the
   160  	// rsa.PSSOptions.SaltLength is zero.
   161  	// These tests expect a failure with a PSSOptions.SaltLength: 0
   162  	// and a signature that uses a different salt length. However,
   163  	// a salt length of 0 is defined as rsa.PSSSaltLengthAuto which
   164  	// works deterministically to auto-detect the length when
   165  	// verifying, so these tests actually pass as they should.
   166  	filesOverrideToPassZeroSLen := map[string][]int{
   167  		"rsa_pss_2048_sha1_mgf1_20_test.json":   {46, 47, 48, 49, 50, 51},
   168  		"rsa_pss_2048_sha256_mgf1_0_test.json":  {67, 68, 69, 70},
   169  		"rsa_pss_2048_sha256_mgf1_32_test.json": {67, 68, 69, 70, 71, 72},
   170  		"rsa_pss_3072_sha256_mgf1_32_test.json": {67, 68, 69, 70, 71, 72},
   171  		"rsa_pss_4096_sha256_mgf1_32_test.json": {67, 68, 69, 70, 71, 72},
   172  		"rsa_pss_4096_sha512_mgf1_32_test.json": {136, 137, 138, 139, 140, 141},
   173  		// "rsa_pss_misc_test.json": nil,  // TODO: This ones seems to be broken right now, but can enable later on.
   174  	}
   175  
   176  	for file, overrideIDs := range filesOverrideToPassZeroSLen {
   177  		var testdata wycheproof.RsassaPssVerifySchemaV1Json
   178  		wycheproof.LoadVectorFile(t, file, &testdata)
   179  
   180  		for _, tg := range testdata.TestGroups {
   181  			hash := wycheproof.ParseHash(tg.Sha)
   182  
   183  			pub, err := x509.ParsePKCS1PublicKey(wycheproof.MustDecodeHex(tg.PublicKeyAsn))
   184  			if err != nil {
   185  				t.Fatalf("failed to decode pubkey: %v", err)
   186  			}
   187  
   188  			// Run all the tests twice: the first time with the salt length
   189  			// as PSSSaltLengthAuto, and the second time with the salt length
   190  			// explicitly set to tg.SLen.
   191  			for i := 0; i < 2; i++ {
   192  				saltLabel := "autoSalt"
   193  				if i == 1 {
   194  					saltLabel = "vecSalt"
   195  				}
   196  				opts := &rsa.PSSOptions{
   197  					Hash:       hash,
   198  					SaltLength: rsa.PSSSaltLengthAuto,
   199  				}
   200  
   201  				for _, tv := range tg.Tests {
   202  					t.Run(wycheproof.TestName(file, tv)+" "+saltLabel, func(t *testing.T) {
   203  						h := hash.New()
   204  						h.Write(wycheproof.MustDecodeHex(tv.Msg))
   205  						sig := wycheproof.MustDecodeHex(tv.Sig)
   206  						err = rsa.VerifyPSS(pub, hash, h.Sum(nil), sig, opts)
   207  						want := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
   208  						if opts.SaltLength == 0 && slices.Contains(overrideIDs, tv.TcId) {
   209  							want = true
   210  						}
   211  						if (err == nil) != want {
   212  							t.Errorf("wanted success: %t err: %v", want, err)
   213  						}
   214  					})
   215  				}
   216  
   217  				// Update opts.SaltLength for the second run of the tests.
   218  				opts.SaltLength = tg.SLen
   219  			}
   220  		}
   221  	}
   222  }
   223  

View as plain text