Source file src/crypto/tls/bogo_shim_test.go

     1  package tls
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/x509"
     6  	"encoding/base64"
     7  	"encoding/json"
     8  	"encoding/pem"
     9  	"flag"
    10  	"fmt"
    11  	"internal/byteorder"
    12  	"internal/testenv"
    13  	"io"
    14  	"log"
    15  	"net"
    16  	"os"
    17  	"os/exec"
    18  	"path/filepath"
    19  	"runtime"
    20  	"slices"
    21  	"strconv"
    22  	"strings"
    23  	"testing"
    24  
    25  	"golang.org/x/crypto/cryptobyte"
    26  )
    27  
    28  var (
    29  	port   = flag.String("port", "", "")
    30  	server = flag.Bool("server", false, "")
    31  
    32  	isHandshakerSupported = flag.Bool("is-handshaker-supported", false, "")
    33  
    34  	keyfile  = flag.String("key-file", "", "")
    35  	certfile = flag.String("cert-file", "", "")
    36  
    37  	trustCert = flag.String("trust-cert", "", "")
    38  
    39  	minVersion    = flag.Int("min-version", VersionSSL30, "")
    40  	maxVersion    = flag.Int("max-version", VersionTLS13, "")
    41  	expectVersion = flag.Int("expect-version", 0, "")
    42  
    43  	noTLS1  = flag.Bool("no-tls1", false, "")
    44  	noTLS11 = flag.Bool("no-tls11", false, "")
    45  	noTLS12 = flag.Bool("no-tls12", false, "")
    46  	noTLS13 = flag.Bool("no-tls13", false, "")
    47  
    48  	requireAnyClientCertificate = flag.Bool("require-any-client-certificate", false, "")
    49  
    50  	shimWritesFirst = flag.Bool("shim-writes-first", false, "")
    51  
    52  	resumeCount = flag.Int("resume-count", 0, "")
    53  
    54  	curves        = flagStringSlice("curves", "")
    55  	expectedCurve = flag.String("expect-curve-id", "", "")
    56  
    57  	shimID = flag.Uint64("shim-id", 0, "")
    58  	_      = flag.Bool("ipv6", false, "")
    59  
    60  	echConfigListB64           = flag.String("ech-config-list", "", "")
    61  	expectECHAccepted          = flag.Bool("expect-ech-accept", false, "")
    62  	expectHRR                  = flag.Bool("expect-hrr", false, "")
    63  	expectNoHRR                = flag.Bool("expect-no-hrr", false, "")
    64  	expectedECHRetryConfigs    = flag.String("expect-ech-retry-configs", "", "")
    65  	expectNoECHRetryConfigs    = flag.Bool("expect-no-ech-retry-configs", false, "")
    66  	onInitialExpectECHAccepted = flag.Bool("on-initial-expect-ech-accept", false, "")
    67  	_                          = flag.Bool("expect-no-ech-name-override", false, "")
    68  	_                          = flag.String("expect-ech-name-override", "", "")
    69  	_                          = flag.Bool("reverify-on-resume", false, "")
    70  	onResumeECHConfigListB64   = flag.String("on-resume-ech-config-list", "", "")
    71  	_                          = flag.Bool("on-resume-expect-reject-early-data", false, "")
    72  	onResumeExpectECHAccepted  = flag.Bool("on-resume-expect-ech-accept", false, "")
    73  	_                          = flag.Bool("on-resume-expect-no-ech-name-override", false, "")
    74  	expectedServerName         = flag.String("expect-server-name", "", "")
    75  
    76  	expectSessionMiss = flag.Bool("expect-session-miss", false, "")
    77  
    78  	_                       = flag.Bool("enable-early-data", false, "")
    79  	_                       = flag.Bool("on-resume-expect-accept-early-data", false, "")
    80  	_                       = flag.Bool("expect-ticket-supports-early-data", false, "")
    81  	onResumeShimWritesFirst = flag.Bool("on-resume-shim-writes-first", false, "")
    82  
    83  	advertiseALPN        = flag.String("advertise-alpn", "", "")
    84  	expectALPN           = flag.String("expect-alpn", "", "")
    85  	rejectALPN           = flag.Bool("reject-alpn", false, "")
    86  	declineALPN          = flag.Bool("decline-alpn", false, "")
    87  	expectAdvertisedALPN = flag.String("expect-advertised-alpn", "", "")
    88  	selectALPN           = flag.String("select-alpn", "", "")
    89  
    90  	hostName = flag.String("host-name", "", "")
    91  
    92  	verifyPeer = flag.Bool("verify-peer", false, "")
    93  	_          = flag.Bool("use-custom-verify-callback", false, "")
    94  )
    95  
    96  type stringSlice []string
    97  
    98  func flagStringSlice(name, usage string) *stringSlice {
    99  	f := &stringSlice{}
   100  	flag.Var(f, name, usage)
   101  	return f
   102  }
   103  
   104  func (saf stringSlice) String() string {
   105  	return strings.Join(saf, ",")
   106  }
   107  
   108  func (saf stringSlice) Set(s string) error {
   109  	saf = append(saf, s)
   110  	return nil
   111  }
   112  
   113  func bogoShim() {
   114  	if *isHandshakerSupported {
   115  		fmt.Println("No")
   116  		return
   117  	}
   118  
   119  	cfg := &Config{
   120  		ServerName: "test",
   121  
   122  		MinVersion: uint16(*minVersion),
   123  		MaxVersion: uint16(*maxVersion),
   124  
   125  		ClientSessionCache: NewLRUClientSessionCache(0),
   126  
   127  		GetConfigForClient: func(chi *ClientHelloInfo) (*Config, error) {
   128  
   129  			if *expectAdvertisedALPN != "" {
   130  
   131  				s := cryptobyte.String(*expectAdvertisedALPN)
   132  
   133  				var expectedALPNs []string
   134  
   135  				for !s.Empty() {
   136  					var alpn cryptobyte.String
   137  					if !s.ReadUint8LengthPrefixed(&alpn) {
   138  						return nil, fmt.Errorf("unexpected error while parsing arguments for -expect-advertised-alpn")
   139  					}
   140  					expectedALPNs = append(expectedALPNs, string(alpn))
   141  				}
   142  
   143  				if !slices.Equal(chi.SupportedProtos, expectedALPNs) {
   144  					return nil, fmt.Errorf("unexpected ALPN: got %q, want %q", chi.SupportedProtos, expectedALPNs)
   145  				}
   146  			}
   147  			return nil, nil
   148  		},
   149  	}
   150  
   151  	if *noTLS1 {
   152  		cfg.MinVersion = VersionTLS11
   153  		if *noTLS11 {
   154  			cfg.MinVersion = VersionTLS12
   155  			if *noTLS12 {
   156  				cfg.MinVersion = VersionTLS13
   157  				if *noTLS13 {
   158  					log.Fatalf("no supported versions enabled")
   159  				}
   160  			}
   161  		}
   162  	} else if *noTLS13 {
   163  		cfg.MaxVersion = VersionTLS12
   164  		if *noTLS12 {
   165  			cfg.MaxVersion = VersionTLS11
   166  			if *noTLS11 {
   167  				cfg.MaxVersion = VersionTLS10
   168  				if *noTLS1 {
   169  					log.Fatalf("no supported versions enabled")
   170  				}
   171  			}
   172  		}
   173  	}
   174  
   175  	if *advertiseALPN != "" {
   176  		alpns := *advertiseALPN
   177  		for len(alpns) > 0 {
   178  			alpnLen := int(alpns[0])
   179  			cfg.NextProtos = append(cfg.NextProtos, alpns[1:1+alpnLen])
   180  			alpns = alpns[alpnLen+1:]
   181  		}
   182  	}
   183  
   184  	if *rejectALPN {
   185  		cfg.NextProtos = []string{"unnegotiableprotocol"}
   186  	}
   187  
   188  	if *declineALPN {
   189  		cfg.NextProtos = []string{}
   190  	}
   191  	if *selectALPN != "" {
   192  		cfg.NextProtos = []string{*selectALPN}
   193  	}
   194  
   195  	if *hostName != "" {
   196  		cfg.ServerName = *hostName
   197  	}
   198  
   199  	if *keyfile != "" || *certfile != "" {
   200  		pair, err := LoadX509KeyPair(*certfile, *keyfile)
   201  		if err != nil {
   202  			log.Fatalf("load key-file err: %s", err)
   203  		}
   204  		cfg.Certificates = []Certificate{pair}
   205  	}
   206  	if *trustCert != "" {
   207  		pool := x509.NewCertPool()
   208  		certFile, err := os.ReadFile(*trustCert)
   209  		if err != nil {
   210  			log.Fatalf("load trust-cert err: %s", err)
   211  		}
   212  		block, _ := pem.Decode(certFile)
   213  		cert, err := x509.ParseCertificate(block.Bytes)
   214  		if err != nil {
   215  			log.Fatalf("parse trust-cert err: %s", err)
   216  		}
   217  		pool.AddCert(cert)
   218  		cfg.RootCAs = pool
   219  	}
   220  
   221  	if *requireAnyClientCertificate {
   222  		cfg.ClientAuth = RequireAnyClientCert
   223  	}
   224  	if *verifyPeer {
   225  		cfg.ClientAuth = VerifyClientCertIfGiven
   226  	}
   227  
   228  	if *echConfigListB64 != "" {
   229  		echConfigList, err := base64.StdEncoding.DecodeString(*echConfigListB64)
   230  		if err != nil {
   231  			log.Fatalf("parse ech-config-list err: %s", err)
   232  		}
   233  		cfg.EncryptedClientHelloConfigList = echConfigList
   234  		cfg.MinVersion = VersionTLS13
   235  	}
   236  
   237  	if len(*curves) != 0 {
   238  		for _, curveStr := range *curves {
   239  			id, err := strconv.Atoi(curveStr)
   240  			if err != nil {
   241  				log.Fatalf("failed to parse curve id %q: %s", curveStr, err)
   242  			}
   243  			cfg.CurvePreferences = append(cfg.CurvePreferences, CurveID(id))
   244  		}
   245  	}
   246  
   247  	for i := 0; i < *resumeCount+1; i++ {
   248  		if i > 0 && (*onResumeECHConfigListB64 != "") {
   249  			echConfigList, err := base64.StdEncoding.DecodeString(*onResumeECHConfigListB64)
   250  			if err != nil {
   251  				log.Fatalf("parse ech-config-list err: %s", err)
   252  			}
   253  			cfg.EncryptedClientHelloConfigList = echConfigList
   254  		}
   255  
   256  		conn, err := net.Dial("tcp", net.JoinHostPort("localhost", *port))
   257  		if err != nil {
   258  			log.Fatalf("dial err: %s", err)
   259  		}
   260  		defer conn.Close()
   261  
   262  		// Write the shim ID we were passed as a little endian uint64
   263  		shimIDBytes := make([]byte, 8)
   264  		byteorder.LePutUint64(shimIDBytes, *shimID)
   265  		if _, err := conn.Write(shimIDBytes); err != nil {
   266  			log.Fatalf("failed to write shim id: %s", err)
   267  		}
   268  
   269  		var tlsConn *Conn
   270  		if *server {
   271  			tlsConn = Server(conn, cfg)
   272  		} else {
   273  			tlsConn = Client(conn, cfg)
   274  		}
   275  
   276  		if i == 0 && *shimWritesFirst {
   277  			if _, err := tlsConn.Write([]byte("hello")); err != nil {
   278  				log.Fatalf("write err: %s", err)
   279  			}
   280  		}
   281  
   282  		for {
   283  			buf := make([]byte, 500)
   284  			var n int
   285  			n, err = tlsConn.Read(buf)
   286  			if err != nil {
   287  				break
   288  			}
   289  			buf = buf[:n]
   290  			for i := range buf {
   291  				buf[i] ^= 0xff
   292  			}
   293  			if _, err = tlsConn.Write(buf); err != nil {
   294  				break
   295  			}
   296  		}
   297  		if err != nil && err != io.EOF {
   298  			retryErr, ok := err.(*ECHRejectionError)
   299  			if !ok {
   300  				log.Fatalf("unexpected error type returned: %v", err)
   301  			}
   302  			if *expectNoECHRetryConfigs && len(retryErr.RetryConfigList) > 0 {
   303  				log.Fatalf("expected no ECH retry configs, got some")
   304  			}
   305  			if *expectedECHRetryConfigs != "" {
   306  				expectedRetryConfigs, err := base64.StdEncoding.DecodeString(*expectedECHRetryConfigs)
   307  				if err != nil {
   308  					log.Fatalf("failed to decode expected retry configs: %s", err)
   309  				}
   310  				if !bytes.Equal(retryErr.RetryConfigList, expectedRetryConfigs) {
   311  					log.Fatalf("unexpected retry list returned: got %x, want %x", retryErr.RetryConfigList, expectedRetryConfigs)
   312  				}
   313  			}
   314  			log.Fatalf("conn error: %s", err)
   315  		}
   316  
   317  		cs := tlsConn.ConnectionState()
   318  		if cs.HandshakeComplete {
   319  			if *expectALPN != "" && cs.NegotiatedProtocol != *expectALPN {
   320  				log.Fatalf("unexpected protocol negotiated: want %q, got %q", *expectALPN, cs.NegotiatedProtocol)
   321  			}
   322  
   323  			if *selectALPN != "" && cs.NegotiatedProtocol != *selectALPN {
   324  				log.Fatalf("unexpected protocol negotiated: want %q, got %q", *selectALPN, cs.NegotiatedProtocol)
   325  			}
   326  
   327  			if *expectVersion != 0 && cs.Version != uint16(*expectVersion) {
   328  				log.Fatalf("expected ssl version %q, got %q", uint16(*expectVersion), cs.Version)
   329  			}
   330  			if *declineALPN && cs.NegotiatedProtocol != "" {
   331  				log.Fatal("unexpected ALPN protocol")
   332  			}
   333  			if *expectECHAccepted && !cs.ECHAccepted {
   334  				log.Fatal("expected ECH to be accepted, but connection state shows it was not")
   335  			} else if i == 0 && *onInitialExpectECHAccepted && !cs.ECHAccepted {
   336  				log.Fatal("expected ECH to be accepted, but connection state shows it was not")
   337  			} else if i > 0 && *onResumeExpectECHAccepted && !cs.ECHAccepted {
   338  				log.Fatal("expected ECH to be accepted on resumption, but connection state shows it was not")
   339  			} else if i == 0 && !*expectECHAccepted && cs.ECHAccepted {
   340  				log.Fatal("did not expect ECH, but it was accepted")
   341  			}
   342  
   343  			if *expectHRR && !cs.testingOnlyDidHRR {
   344  				log.Fatal("expected HRR but did not do it")
   345  			}
   346  
   347  			if *expectNoHRR && cs.testingOnlyDidHRR {
   348  				log.Fatal("expected no HRR but did do it")
   349  			}
   350  
   351  			if *expectSessionMiss && cs.DidResume {
   352  				log.Fatal("unexpected session resumption")
   353  			}
   354  
   355  			if *expectedServerName != "" && cs.ServerName != *expectedServerName {
   356  				log.Fatalf("unexpected server name: got %q, want %q", cs.ServerName, *expectedServerName)
   357  			}
   358  		}
   359  
   360  		if *expectedCurve != "" {
   361  			expectedCurveID, err := strconv.Atoi(*expectedCurve)
   362  			if err != nil {
   363  				log.Fatalf("failed to parse -expect-curve-id: %s", err)
   364  			}
   365  			if tlsConn.curveID != CurveID(expectedCurveID) {
   366  				log.Fatalf("unexpected curve id: want %d, got %d", expectedCurveID, tlsConn.curveID)
   367  			}
   368  		}
   369  	}
   370  }
   371  
   372  func TestBogoSuite(t *testing.T) {
   373  	testenv.SkipIfShortAndSlow(t)
   374  	testenv.MustHaveExternalNetwork(t)
   375  	testenv.MustHaveGoRun(t)
   376  	testenv.MustHaveExec(t)
   377  
   378  	if testing.Short() {
   379  		t.Skip("skipping in short mode")
   380  	}
   381  	if testenv.Builder() != "" && runtime.GOOS == "windows" {
   382  		t.Skip("#66913: windows network connections are flakey on builders")
   383  	}
   384  
   385  	// In order to make Go test caching work as expected, we stat the
   386  	// bogo_config.json file, so that the Go testing hooks know that it is
   387  	// important for this test and will invalidate a cached test result if the
   388  	// file changes.
   389  	if _, err := os.Stat("bogo_config.json"); err != nil {
   390  		t.Fatal(err)
   391  	}
   392  
   393  	var bogoDir string
   394  	if *bogoLocalDir != "" {
   395  		bogoDir = *bogoLocalDir
   396  	} else {
   397  		const boringsslModVer = "v0.0.0-20240523173554-273a920f84e8"
   398  		output, err := exec.Command("go", "mod", "download", "-json", "boringssl.googlesource.com/boringssl.git@"+boringsslModVer).CombinedOutput()
   399  		if err != nil {
   400  			t.Fatalf("failed to download boringssl: %s", err)
   401  		}
   402  		var j struct {
   403  			Dir string
   404  		}
   405  		if err := json.Unmarshal(output, &j); err != nil {
   406  			t.Fatalf("failed to parse 'go mod download' output: %s", err)
   407  		}
   408  		bogoDir = j.Dir
   409  	}
   410  
   411  	cwd, err := os.Getwd()
   412  	if err != nil {
   413  		t.Fatal(err)
   414  	}
   415  
   416  	resultsFile := filepath.Join(t.TempDir(), "results.json")
   417  
   418  	args := []string{
   419  		"test",
   420  		".",
   421  		fmt.Sprintf("-shim-config=%s", filepath.Join(cwd, "bogo_config.json")),
   422  		fmt.Sprintf("-shim-path=%s", os.Args[0]),
   423  		"-shim-extra-flags=-bogo-mode",
   424  		"-allow-unimplemented",
   425  		"-loose-errors", // TODO(roland): this should be removed eventually
   426  		fmt.Sprintf("-json-output=%s", resultsFile),
   427  	}
   428  	if *bogoFilter != "" {
   429  		args = append(args, fmt.Sprintf("-test=%s", *bogoFilter))
   430  	}
   431  
   432  	goCmd, err := testenv.GoTool()
   433  	if err != nil {
   434  		t.Fatal(err)
   435  	}
   436  	cmd := exec.Command(goCmd, args...)
   437  	out := &strings.Builder{}
   438  	cmd.Stderr = out
   439  	cmd.Dir = filepath.Join(bogoDir, "ssl/test/runner")
   440  	err = cmd.Run()
   441  	// NOTE: we don't immediately check the error, because the failure could be either because
   442  	// the runner failed for some unexpected reason, or because a test case failed, and we
   443  	// cannot easily differentiate these cases. We check if the JSON results file was written,
   444  	// which should only happen if the failure was because of a test failure, and use that
   445  	// to determine the failure mode.
   446  
   447  	resultsJSON, jsonErr := os.ReadFile(resultsFile)
   448  	if jsonErr != nil {
   449  		if err != nil {
   450  			t.Fatalf("bogo failed: %s\n%s", err, out)
   451  		}
   452  		t.Fatalf("failed to read results JSON file: %s", jsonErr)
   453  	}
   454  
   455  	var results bogoResults
   456  	if err := json.Unmarshal(resultsJSON, &results); err != nil {
   457  		t.Fatalf("failed to parse results JSON: %s", err)
   458  	}
   459  
   460  	// assertResults contains test results we want to make sure
   461  	// are present in the output. They are only checked if -bogo-filter
   462  	// was not passed.
   463  	assertResults := map[string]string{
   464  		"CurveTest-Client-Kyber-TLS13": "PASS",
   465  		"CurveTest-Server-Kyber-TLS13": "PASS",
   466  	}
   467  
   468  	for name, result := range results.Tests {
   469  		// This is not really the intended way to do this... but... it works?
   470  		t.Run(name, func(t *testing.T) {
   471  			if result.Actual == "FAIL" && result.IsUnexpected {
   472  				t.Fatal(result.Error)
   473  			}
   474  			if expectedResult, ok := assertResults[name]; ok && expectedResult != result.Actual {
   475  				t.Fatalf("unexpected result: got %s, want %s", result.Actual, assertResults[name])
   476  			}
   477  			delete(assertResults, name)
   478  			if result.Actual == "SKIP" {
   479  				t.Skip()
   480  			}
   481  		})
   482  	}
   483  	if *bogoFilter == "" {
   484  		// Anything still in assertResults did not show up in the results, so we should fail
   485  		for name, expectedResult := range assertResults {
   486  			t.Run(name, func(t *testing.T) {
   487  				t.Fatalf("expected test to run with result %s, but it was not present in the test results", expectedResult)
   488  			})
   489  		}
   490  	}
   491  }
   492  
   493  // bogoResults is a copy of boringssl.googlesource.com/boringssl/testresults.Results
   494  type bogoResults struct {
   495  	Version           int            `json:"version"`
   496  	Interrupted       bool           `json:"interrupted"`
   497  	PathDelimiter     string         `json:"path_delimiter"`
   498  	SecondsSinceEpoch float64        `json:"seconds_since_epoch"`
   499  	NumFailuresByType map[string]int `json:"num_failures_by_type"`
   500  	Tests             map[string]struct {
   501  		Actual       string `json:"actual"`
   502  		Expected     string `json:"expected"`
   503  		IsUnexpected bool   `json:"is_unexpected"`
   504  		Error        string `json:"error,omitempty"`
   505  	} `json:"tests"`
   506  }
   507  

View as plain text