Source file
src/crypto/tls/bogo_shim_test.go
1 package tls
2
3 import (
4 "bytes"
5 "crypto/x509"
6 "encoding/base64"
7 "encoding/json"
8 "encoding/pem"
9 "flag"
10 "fmt"
11 "internal/byteorder"
12 "internal/testenv"
13 "io"
14 "log"
15 "net"
16 "os"
17 "os/exec"
18 "path/filepath"
19 "runtime"
20 "slices"
21 "strconv"
22 "strings"
23 "testing"
24
25 "golang.org/x/crypto/cryptobyte"
26 )
27
28 var (
29 port = flag.String("port", "", "")
30 server = flag.Bool("server", false, "")
31
32 isHandshakerSupported = flag.Bool("is-handshaker-supported", false, "")
33
34 keyfile = flag.String("key-file", "", "")
35 certfile = flag.String("cert-file", "", "")
36
37 trustCert = flag.String("trust-cert", "", "")
38
39 minVersion = flag.Int("min-version", VersionSSL30, "")
40 maxVersion = flag.Int("max-version", VersionTLS13, "")
41 expectVersion = flag.Int("expect-version", 0, "")
42
43 noTLS1 = flag.Bool("no-tls1", false, "")
44 noTLS11 = flag.Bool("no-tls11", false, "")
45 noTLS12 = flag.Bool("no-tls12", false, "")
46 noTLS13 = flag.Bool("no-tls13", false, "")
47
48 requireAnyClientCertificate = flag.Bool("require-any-client-certificate", false, "")
49
50 shimWritesFirst = flag.Bool("shim-writes-first", false, "")
51
52 resumeCount = flag.Int("resume-count", 0, "")
53
54 curves = flagStringSlice("curves", "")
55 expectedCurve = flag.String("expect-curve-id", "", "")
56
57 shimID = flag.Uint64("shim-id", 0, "")
58 _ = flag.Bool("ipv6", false, "")
59
60 echConfigListB64 = flag.String("ech-config-list", "", "")
61 expectECHAccepted = flag.Bool("expect-ech-accept", false, "")
62 expectHRR = flag.Bool("expect-hrr", false, "")
63 expectNoHRR = flag.Bool("expect-no-hrr", false, "")
64 expectedECHRetryConfigs = flag.String("expect-ech-retry-configs", "", "")
65 expectNoECHRetryConfigs = flag.Bool("expect-no-ech-retry-configs", false, "")
66 onInitialExpectECHAccepted = flag.Bool("on-initial-expect-ech-accept", false, "")
67 _ = flag.Bool("expect-no-ech-name-override", false, "")
68 _ = flag.String("expect-ech-name-override", "", "")
69 _ = flag.Bool("reverify-on-resume", false, "")
70 onResumeECHConfigListB64 = flag.String("on-resume-ech-config-list", "", "")
71 _ = flag.Bool("on-resume-expect-reject-early-data", false, "")
72 onResumeExpectECHAccepted = flag.Bool("on-resume-expect-ech-accept", false, "")
73 _ = flag.Bool("on-resume-expect-no-ech-name-override", false, "")
74 expectedServerName = flag.String("expect-server-name", "", "")
75
76 expectSessionMiss = flag.Bool("expect-session-miss", false, "")
77
78 _ = flag.Bool("enable-early-data", false, "")
79 _ = flag.Bool("on-resume-expect-accept-early-data", false, "")
80 _ = flag.Bool("expect-ticket-supports-early-data", false, "")
81 onResumeShimWritesFirst = flag.Bool("on-resume-shim-writes-first", false, "")
82
83 advertiseALPN = flag.String("advertise-alpn", "", "")
84 expectALPN = flag.String("expect-alpn", "", "")
85 rejectALPN = flag.Bool("reject-alpn", false, "")
86 declineALPN = flag.Bool("decline-alpn", false, "")
87 expectAdvertisedALPN = flag.String("expect-advertised-alpn", "", "")
88 selectALPN = flag.String("select-alpn", "", "")
89
90 hostName = flag.String("host-name", "", "")
91
92 verifyPeer = flag.Bool("verify-peer", false, "")
93 _ = flag.Bool("use-custom-verify-callback", false, "")
94 )
95
96 type stringSlice []string
97
98 func flagStringSlice(name, usage string) *stringSlice {
99 f := &stringSlice{}
100 flag.Var(f, name, usage)
101 return f
102 }
103
104 func (saf stringSlice) String() string {
105 return strings.Join(saf, ",")
106 }
107
108 func (saf stringSlice) Set(s string) error {
109 saf = append(saf, s)
110 return nil
111 }
112
113 func bogoShim() {
114 if *isHandshakerSupported {
115 fmt.Println("No")
116 return
117 }
118
119 cfg := &Config{
120 ServerName: "test",
121
122 MinVersion: uint16(*minVersion),
123 MaxVersion: uint16(*maxVersion),
124
125 ClientSessionCache: NewLRUClientSessionCache(0),
126
127 GetConfigForClient: func(chi *ClientHelloInfo) (*Config, error) {
128
129 if *expectAdvertisedALPN != "" {
130
131 s := cryptobyte.String(*expectAdvertisedALPN)
132
133 var expectedALPNs []string
134
135 for !s.Empty() {
136 var alpn cryptobyte.String
137 if !s.ReadUint8LengthPrefixed(&alpn) {
138 return nil, fmt.Errorf("unexpected error while parsing arguments for -expect-advertised-alpn")
139 }
140 expectedALPNs = append(expectedALPNs, string(alpn))
141 }
142
143 if !slices.Equal(chi.SupportedProtos, expectedALPNs) {
144 return nil, fmt.Errorf("unexpected ALPN: got %q, want %q", chi.SupportedProtos, expectedALPNs)
145 }
146 }
147 return nil, nil
148 },
149 }
150
151 if *noTLS1 {
152 cfg.MinVersion = VersionTLS11
153 if *noTLS11 {
154 cfg.MinVersion = VersionTLS12
155 if *noTLS12 {
156 cfg.MinVersion = VersionTLS13
157 if *noTLS13 {
158 log.Fatalf("no supported versions enabled")
159 }
160 }
161 }
162 } else if *noTLS13 {
163 cfg.MaxVersion = VersionTLS12
164 if *noTLS12 {
165 cfg.MaxVersion = VersionTLS11
166 if *noTLS11 {
167 cfg.MaxVersion = VersionTLS10
168 if *noTLS1 {
169 log.Fatalf("no supported versions enabled")
170 }
171 }
172 }
173 }
174
175 if *advertiseALPN != "" {
176 alpns := *advertiseALPN
177 for len(alpns) > 0 {
178 alpnLen := int(alpns[0])
179 cfg.NextProtos = append(cfg.NextProtos, alpns[1:1+alpnLen])
180 alpns = alpns[alpnLen+1:]
181 }
182 }
183
184 if *rejectALPN {
185 cfg.NextProtos = []string{"unnegotiableprotocol"}
186 }
187
188 if *declineALPN {
189 cfg.NextProtos = []string{}
190 }
191 if *selectALPN != "" {
192 cfg.NextProtos = []string{*selectALPN}
193 }
194
195 if *hostName != "" {
196 cfg.ServerName = *hostName
197 }
198
199 if *keyfile != "" || *certfile != "" {
200 pair, err := LoadX509KeyPair(*certfile, *keyfile)
201 if err != nil {
202 log.Fatalf("load key-file err: %s", err)
203 }
204 cfg.Certificates = []Certificate{pair}
205 }
206 if *trustCert != "" {
207 pool := x509.NewCertPool()
208 certFile, err := os.ReadFile(*trustCert)
209 if err != nil {
210 log.Fatalf("load trust-cert err: %s", err)
211 }
212 block, _ := pem.Decode(certFile)
213 cert, err := x509.ParseCertificate(block.Bytes)
214 if err != nil {
215 log.Fatalf("parse trust-cert err: %s", err)
216 }
217 pool.AddCert(cert)
218 cfg.RootCAs = pool
219 }
220
221 if *requireAnyClientCertificate {
222 cfg.ClientAuth = RequireAnyClientCert
223 }
224 if *verifyPeer {
225 cfg.ClientAuth = VerifyClientCertIfGiven
226 }
227
228 if *echConfigListB64 != "" {
229 echConfigList, err := base64.StdEncoding.DecodeString(*echConfigListB64)
230 if err != nil {
231 log.Fatalf("parse ech-config-list err: %s", err)
232 }
233 cfg.EncryptedClientHelloConfigList = echConfigList
234 cfg.MinVersion = VersionTLS13
235 }
236
237 if len(*curves) != 0 {
238 for _, curveStr := range *curves {
239 id, err := strconv.Atoi(curveStr)
240 if err != nil {
241 log.Fatalf("failed to parse curve id %q: %s", curveStr, err)
242 }
243 cfg.CurvePreferences = append(cfg.CurvePreferences, CurveID(id))
244 }
245 }
246
247 for i := 0; i < *resumeCount+1; i++ {
248 if i > 0 && (*onResumeECHConfigListB64 != "") {
249 echConfigList, err := base64.StdEncoding.DecodeString(*onResumeECHConfigListB64)
250 if err != nil {
251 log.Fatalf("parse ech-config-list err: %s", err)
252 }
253 cfg.EncryptedClientHelloConfigList = echConfigList
254 }
255
256 conn, err := net.Dial("tcp", net.JoinHostPort("localhost", *port))
257 if err != nil {
258 log.Fatalf("dial err: %s", err)
259 }
260 defer conn.Close()
261
262
263 shimIDBytes := make([]byte, 8)
264 byteorder.LePutUint64(shimIDBytes, *shimID)
265 if _, err := conn.Write(shimIDBytes); err != nil {
266 log.Fatalf("failed to write shim id: %s", err)
267 }
268
269 var tlsConn *Conn
270 if *server {
271 tlsConn = Server(conn, cfg)
272 } else {
273 tlsConn = Client(conn, cfg)
274 }
275
276 if i == 0 && *shimWritesFirst {
277 if _, err := tlsConn.Write([]byte("hello")); err != nil {
278 log.Fatalf("write err: %s", err)
279 }
280 }
281
282 for {
283 buf := make([]byte, 500)
284 var n int
285 n, err = tlsConn.Read(buf)
286 if err != nil {
287 break
288 }
289 buf = buf[:n]
290 for i := range buf {
291 buf[i] ^= 0xff
292 }
293 if _, err = tlsConn.Write(buf); err != nil {
294 break
295 }
296 }
297 if err != nil && err != io.EOF {
298 retryErr, ok := err.(*ECHRejectionError)
299 if !ok {
300 log.Fatalf("unexpected error type returned: %v", err)
301 }
302 if *expectNoECHRetryConfigs && len(retryErr.RetryConfigList) > 0 {
303 log.Fatalf("expected no ECH retry configs, got some")
304 }
305 if *expectedECHRetryConfigs != "" {
306 expectedRetryConfigs, err := base64.StdEncoding.DecodeString(*expectedECHRetryConfigs)
307 if err != nil {
308 log.Fatalf("failed to decode expected retry configs: %s", err)
309 }
310 if !bytes.Equal(retryErr.RetryConfigList, expectedRetryConfigs) {
311 log.Fatalf("unexpected retry list returned: got %x, want %x", retryErr.RetryConfigList, expectedRetryConfigs)
312 }
313 }
314 log.Fatalf("conn error: %s", err)
315 }
316
317 cs := tlsConn.ConnectionState()
318 if cs.HandshakeComplete {
319 if *expectALPN != "" && cs.NegotiatedProtocol != *expectALPN {
320 log.Fatalf("unexpected protocol negotiated: want %q, got %q", *expectALPN, cs.NegotiatedProtocol)
321 }
322
323 if *selectALPN != "" && cs.NegotiatedProtocol != *selectALPN {
324 log.Fatalf("unexpected protocol negotiated: want %q, got %q", *selectALPN, cs.NegotiatedProtocol)
325 }
326
327 if *expectVersion != 0 && cs.Version != uint16(*expectVersion) {
328 log.Fatalf("expected ssl version %q, got %q", uint16(*expectVersion), cs.Version)
329 }
330 if *declineALPN && cs.NegotiatedProtocol != "" {
331 log.Fatal("unexpected ALPN protocol")
332 }
333 if *expectECHAccepted && !cs.ECHAccepted {
334 log.Fatal("expected ECH to be accepted, but connection state shows it was not")
335 } else if i == 0 && *onInitialExpectECHAccepted && !cs.ECHAccepted {
336 log.Fatal("expected ECH to be accepted, but connection state shows it was not")
337 } else if i > 0 && *onResumeExpectECHAccepted && !cs.ECHAccepted {
338 log.Fatal("expected ECH to be accepted on resumption, but connection state shows it was not")
339 } else if i == 0 && !*expectECHAccepted && cs.ECHAccepted {
340 log.Fatal("did not expect ECH, but it was accepted")
341 }
342
343 if *expectHRR && !cs.testingOnlyDidHRR {
344 log.Fatal("expected HRR but did not do it")
345 }
346
347 if *expectNoHRR && cs.testingOnlyDidHRR {
348 log.Fatal("expected no HRR but did do it")
349 }
350
351 if *expectSessionMiss && cs.DidResume {
352 log.Fatal("unexpected session resumption")
353 }
354
355 if *expectedServerName != "" && cs.ServerName != *expectedServerName {
356 log.Fatalf("unexpected server name: got %q, want %q", cs.ServerName, *expectedServerName)
357 }
358 }
359
360 if *expectedCurve != "" {
361 expectedCurveID, err := strconv.Atoi(*expectedCurve)
362 if err != nil {
363 log.Fatalf("failed to parse -expect-curve-id: %s", err)
364 }
365 if tlsConn.curveID != CurveID(expectedCurveID) {
366 log.Fatalf("unexpected curve id: want %d, got %d", expectedCurveID, tlsConn.curveID)
367 }
368 }
369 }
370 }
371
372 func TestBogoSuite(t *testing.T) {
373 testenv.SkipIfShortAndSlow(t)
374 testenv.MustHaveExternalNetwork(t)
375 testenv.MustHaveGoRun(t)
376 testenv.MustHaveExec(t)
377
378 if testing.Short() {
379 t.Skip("skipping in short mode")
380 }
381 if testenv.Builder() != "" && runtime.GOOS == "windows" {
382 t.Skip("#66913: windows network connections are flakey on builders")
383 }
384
385
386
387
388
389 if _, err := os.Stat("bogo_config.json"); err != nil {
390 t.Fatal(err)
391 }
392
393 var bogoDir string
394 if *bogoLocalDir != "" {
395 bogoDir = *bogoLocalDir
396 } else {
397 const boringsslModVer = "v0.0.0-20240523173554-273a920f84e8"
398 output, err := exec.Command("go", "mod", "download", "-json", "boringssl.googlesource.com/boringssl.git@"+boringsslModVer).CombinedOutput()
399 if err != nil {
400 t.Fatalf("failed to download boringssl: %s", err)
401 }
402 var j struct {
403 Dir string
404 }
405 if err := json.Unmarshal(output, &j); err != nil {
406 t.Fatalf("failed to parse 'go mod download' output: %s", err)
407 }
408 bogoDir = j.Dir
409 }
410
411 cwd, err := os.Getwd()
412 if err != nil {
413 t.Fatal(err)
414 }
415
416 resultsFile := filepath.Join(t.TempDir(), "results.json")
417
418 args := []string{
419 "test",
420 ".",
421 fmt.Sprintf("-shim-config=%s", filepath.Join(cwd, "bogo_config.json")),
422 fmt.Sprintf("-shim-path=%s", os.Args[0]),
423 "-shim-extra-flags=-bogo-mode",
424 "-allow-unimplemented",
425 "-loose-errors",
426 fmt.Sprintf("-json-output=%s", resultsFile),
427 }
428 if *bogoFilter != "" {
429 args = append(args, fmt.Sprintf("-test=%s", *bogoFilter))
430 }
431
432 goCmd, err := testenv.GoTool()
433 if err != nil {
434 t.Fatal(err)
435 }
436 cmd := exec.Command(goCmd, args...)
437 out := &strings.Builder{}
438 cmd.Stderr = out
439 cmd.Dir = filepath.Join(bogoDir, "ssl/test/runner")
440 err = cmd.Run()
441
442
443
444
445
446
447 resultsJSON, jsonErr := os.ReadFile(resultsFile)
448 if jsonErr != nil {
449 if err != nil {
450 t.Fatalf("bogo failed: %s\n%s", err, out)
451 }
452 t.Fatalf("failed to read results JSON file: %s", jsonErr)
453 }
454
455 var results bogoResults
456 if err := json.Unmarshal(resultsJSON, &results); err != nil {
457 t.Fatalf("failed to parse results JSON: %s", err)
458 }
459
460
461
462
463 assertResults := map[string]string{
464 "CurveTest-Client-Kyber-TLS13": "PASS",
465 "CurveTest-Server-Kyber-TLS13": "PASS",
466 }
467
468 for name, result := range results.Tests {
469
470 t.Run(name, func(t *testing.T) {
471 if result.Actual == "FAIL" && result.IsUnexpected {
472 t.Fatal(result.Error)
473 }
474 if expectedResult, ok := assertResults[name]; ok && expectedResult != result.Actual {
475 t.Fatalf("unexpected result: got %s, want %s", result.Actual, assertResults[name])
476 }
477 delete(assertResults, name)
478 if result.Actual == "SKIP" {
479 t.Skip()
480 }
481 })
482 }
483 if *bogoFilter == "" {
484
485 for name, expectedResult := range assertResults {
486 t.Run(name, func(t *testing.T) {
487 t.Fatalf("expected test to run with result %s, but it was not present in the test results", expectedResult)
488 })
489 }
490 }
491 }
492
493
494 type bogoResults struct {
495 Version int `json:"version"`
496 Interrupted bool `json:"interrupted"`
497 PathDelimiter string `json:"path_delimiter"`
498 SecondsSinceEpoch float64 `json:"seconds_since_epoch"`
499 NumFailuresByType map[string]int `json:"num_failures_by_type"`
500 Tests map[string]struct {
501 Actual string `json:"actual"`
502 Expected string `json:"expected"`
503 IsUnexpected bool `json:"is_unexpected"`
504 Error string `json:"error,omitempty"`
505 } `json:"tests"`
506 }
507
View as plain text