Source file
src/crypto/tls/bogo_shim_test.go
1 package tls
2
3 import (
4 "bytes"
5 "crypto/x509"
6 "encoding/base64"
7 "encoding/json"
8 "encoding/pem"
9 "flag"
10 "fmt"
11 "internal/byteorder"
12 "internal/testenv"
13 "io"
14 "log"
15 "net"
16 "os"
17 "os/exec"
18 "path/filepath"
19 "runtime"
20 "strconv"
21 "strings"
22 "testing"
23 )
24
25 var (
26 port = flag.String("port", "", "")
27 server = flag.Bool("server", false, "")
28
29 isHandshakerSupported = flag.Bool("is-handshaker-supported", false, "")
30
31 keyfile = flag.String("key-file", "", "")
32 certfile = flag.String("cert-file", "", "")
33
34 trustCert = flag.String("trust-cert", "", "")
35
36 minVersion = flag.Int("min-version", VersionSSL30, "")
37 maxVersion = flag.Int("max-version", VersionTLS13, "")
38 expectVersion = flag.Int("expect-version", 0, "")
39
40 noTLS1 = flag.Bool("no-tls1", false, "")
41 noTLS11 = flag.Bool("no-tls11", false, "")
42 noTLS12 = flag.Bool("no-tls12", false, "")
43 noTLS13 = flag.Bool("no-tls13", false, "")
44
45 requireAnyClientCertificate = flag.Bool("require-any-client-certificate", false, "")
46
47 shimWritesFirst = flag.Bool("shim-writes-first", false, "")
48
49 resumeCount = flag.Int("resume-count", 0, "")
50
51 curves = flagStringSlice("curves", "")
52 expectedCurve = flag.String("expect-curve-id", "", "")
53
54 shimID = flag.Uint64("shim-id", 0, "")
55 _ = flag.Bool("ipv6", false, "")
56
57 echConfigListB64 = flag.String("ech-config-list", "", "")
58 expectECHAccepted = flag.Bool("expect-ech-accept", false, "")
59 expectHRR = flag.Bool("expect-hrr", false, "")
60 expectNoHRR = flag.Bool("expect-no-hrr", false, "")
61 expectedECHRetryConfigs = flag.String("expect-ech-retry-configs", "", "")
62 expectNoECHRetryConfigs = flag.Bool("expect-no-ech-retry-configs", false, "")
63 onInitialExpectECHAccepted = flag.Bool("on-initial-expect-ech-accept", false, "")
64 _ = flag.Bool("expect-no-ech-name-override", false, "")
65 _ = flag.String("expect-ech-name-override", "", "")
66 _ = flag.Bool("reverify-on-resume", false, "")
67 onResumeECHConfigListB64 = flag.String("on-resume-ech-config-list", "", "")
68 _ = flag.Bool("on-resume-expect-reject-early-data", false, "")
69 onResumeExpectECHAccepted = flag.Bool("on-resume-expect-ech-accept", false, "")
70 _ = flag.Bool("on-resume-expect-no-ech-name-override", false, "")
71 expectedServerName = flag.String("expect-server-name", "", "")
72
73 expectSessionMiss = flag.Bool("expect-session-miss", false, "")
74
75 _ = flag.Bool("enable-early-data", false, "")
76 _ = flag.Bool("on-resume-expect-accept-early-data", false, "")
77 _ = flag.Bool("expect-ticket-supports-early-data", false, "")
78 onResumeShimWritesFirst = flag.Bool("on-resume-shim-writes-first", false, "")
79
80 advertiseALPN = flag.String("advertise-alpn", "", "")
81 expectALPN = flag.String("expect-alpn", "", "")
82 rejectALPN = flag.Bool("reject-alpn", false, "")
83 declineALPN = flag.Bool("decline-alpn", false, "")
84
85 hostName = flag.String("host-name", "", "")
86
87 verifyPeer = flag.Bool("verify-peer", false, "")
88 _ = flag.Bool("use-custom-verify-callback", false, "")
89 )
90
91 type stringSlice []string
92
93 func flagStringSlice(name, usage string) *stringSlice {
94 f := &stringSlice{}
95 flag.Var(f, name, usage)
96 return f
97 }
98
99 func (saf stringSlice) String() string {
100 return strings.Join(saf, ",")
101 }
102
103 func (saf stringSlice) Set(s string) error {
104 saf = append(saf, s)
105 return nil
106 }
107
108 func bogoShim() {
109 if *isHandshakerSupported {
110 fmt.Println("No")
111 return
112 }
113
114 cfg := &Config{
115 ServerName: "test",
116
117 MinVersion: uint16(*minVersion),
118 MaxVersion: uint16(*maxVersion),
119
120 ClientSessionCache: NewLRUClientSessionCache(0),
121 }
122
123 if *noTLS1 {
124 cfg.MinVersion = VersionTLS11
125 if *noTLS11 {
126 cfg.MinVersion = VersionTLS12
127 if *noTLS12 {
128 cfg.MinVersion = VersionTLS13
129 if *noTLS13 {
130 log.Fatalf("no supported versions enabled")
131 }
132 }
133 }
134 } else if *noTLS13 {
135 cfg.MaxVersion = VersionTLS12
136 if *noTLS12 {
137 cfg.MaxVersion = VersionTLS11
138 if *noTLS11 {
139 cfg.MaxVersion = VersionTLS10
140 if *noTLS1 {
141 log.Fatalf("no supported versions enabled")
142 }
143 }
144 }
145 }
146
147 if *advertiseALPN != "" {
148 alpns := *advertiseALPN
149 for len(alpns) > 0 {
150 alpnLen := int(alpns[0])
151 cfg.NextProtos = append(cfg.NextProtos, alpns[1:1+alpnLen])
152 alpns = alpns[alpnLen+1:]
153 }
154 }
155
156 if *rejectALPN {
157 cfg.NextProtos = []string{"unnegotiableprotocol"}
158 }
159
160 if *declineALPN {
161 cfg.NextProtos = []string{}
162 }
163
164 if *hostName != "" {
165 cfg.ServerName = *hostName
166 }
167
168 if *keyfile != "" || *certfile != "" {
169 pair, err := LoadX509KeyPair(*certfile, *keyfile)
170 if err != nil {
171 log.Fatalf("load key-file err: %s", err)
172 }
173 cfg.Certificates = []Certificate{pair}
174 }
175 if *trustCert != "" {
176 pool := x509.NewCertPool()
177 certFile, err := os.ReadFile(*trustCert)
178 if err != nil {
179 log.Fatalf("load trust-cert err: %s", err)
180 }
181 block, _ := pem.Decode(certFile)
182 cert, err := x509.ParseCertificate(block.Bytes)
183 if err != nil {
184 log.Fatalf("parse trust-cert err: %s", err)
185 }
186 pool.AddCert(cert)
187 cfg.RootCAs = pool
188 }
189
190 if *requireAnyClientCertificate {
191 cfg.ClientAuth = RequireAnyClientCert
192 }
193 if *verifyPeer {
194 cfg.ClientAuth = VerifyClientCertIfGiven
195 }
196
197 if *echConfigListB64 != "" {
198 echConfigList, err := base64.StdEncoding.DecodeString(*echConfigListB64)
199 if err != nil {
200 log.Fatalf("parse ech-config-list err: %s", err)
201 }
202 cfg.EncryptedClientHelloConfigList = echConfigList
203 cfg.MinVersion = VersionTLS13
204 }
205
206 if len(*curves) != 0 {
207 for _, curveStr := range *curves {
208 id, err := strconv.Atoi(curveStr)
209 if err != nil {
210 log.Fatalf("failed to parse curve id %q: %s", curveStr, err)
211 }
212 cfg.CurvePreferences = append(cfg.CurvePreferences, CurveID(id))
213 }
214 }
215
216 for i := 0; i < *resumeCount+1; i++ {
217 if i > 0 && (*onResumeECHConfigListB64 != "") {
218 echConfigList, err := base64.StdEncoding.DecodeString(*onResumeECHConfigListB64)
219 if err != nil {
220 log.Fatalf("parse ech-config-list err: %s", err)
221 }
222 cfg.EncryptedClientHelloConfigList = echConfigList
223 }
224
225 conn, err := net.Dial("tcp", net.JoinHostPort("localhost", *port))
226 if err != nil {
227 log.Fatalf("dial err: %s", err)
228 }
229 defer conn.Close()
230
231
232 shimIDBytes := make([]byte, 8)
233 byteorder.LePutUint64(shimIDBytes, *shimID)
234 if _, err := conn.Write(shimIDBytes); err != nil {
235 log.Fatalf("failed to write shim id: %s", err)
236 }
237
238 var tlsConn *Conn
239 if *server {
240 tlsConn = Server(conn, cfg)
241 } else {
242 tlsConn = Client(conn, cfg)
243 }
244
245 if i == 0 && *shimWritesFirst {
246 if _, err := tlsConn.Write([]byte("hello")); err != nil {
247 log.Fatalf("write err: %s", err)
248 }
249 }
250
251 for {
252 buf := make([]byte, 500)
253 var n int
254 n, err = tlsConn.Read(buf)
255 if err != nil {
256 break
257 }
258 buf = buf[:n]
259 for i := range buf {
260 buf[i] ^= 0xff
261 }
262 if _, err = tlsConn.Write(buf); err != nil {
263 break
264 }
265 }
266 if err != nil && err != io.EOF {
267 retryErr, ok := err.(*ECHRejectionError)
268 if !ok {
269 log.Fatalf("unexpected error type returned: %v", err)
270 }
271 if *expectNoECHRetryConfigs && len(retryErr.RetryConfigList) > 0 {
272 log.Fatalf("expected no ECH retry configs, got some")
273 }
274 if *expectedECHRetryConfigs != "" {
275 expectedRetryConfigs, err := base64.StdEncoding.DecodeString(*expectedECHRetryConfigs)
276 if err != nil {
277 log.Fatalf("failed to decode expected retry configs: %s", err)
278 }
279 if !bytes.Equal(retryErr.RetryConfigList, expectedRetryConfigs) {
280 log.Fatalf("unexpected retry list returned: got %x, want %x", retryErr.RetryConfigList, expectedRetryConfigs)
281 }
282 }
283 log.Fatalf("conn error: %s", err)
284 }
285
286 cs := tlsConn.ConnectionState()
287 if cs.HandshakeComplete {
288 if *expectALPN != "" && cs.NegotiatedProtocol != *expectALPN {
289 log.Fatalf("unexpected protocol negotiated: want %q, got %q", *expectALPN, cs.NegotiatedProtocol)
290 }
291 if *expectVersion != 0 && cs.Version != uint16(*expectVersion) {
292 log.Fatalf("expected ssl version %q, got %q", uint16(*expectVersion), cs.Version)
293 }
294 if *declineALPN && cs.NegotiatedProtocol != "" {
295 log.Fatal("unexpected ALPN protocol")
296 }
297 if *expectECHAccepted && !cs.ECHAccepted {
298 log.Fatal("expected ECH to be accepted, but connection state shows it was not")
299 } else if i == 0 && *onInitialExpectECHAccepted && !cs.ECHAccepted {
300 log.Fatal("expected ECH to be accepted, but connection state shows it was not")
301 } else if i > 0 && *onResumeExpectECHAccepted && !cs.ECHAccepted {
302 log.Fatal("expected ECH to be accepted on resumption, but connection state shows it was not")
303 } else if i == 0 && !*expectECHAccepted && cs.ECHAccepted {
304 log.Fatal("did not expect ECH, but it was accepted")
305 }
306
307 if *expectHRR && !cs.testingOnlyDidHRR {
308 log.Fatal("expected HRR but did not do it")
309 }
310
311 if *expectNoHRR && cs.testingOnlyDidHRR {
312 log.Fatal("expected no HRR but did do it")
313 }
314
315 if *expectSessionMiss && cs.DidResume {
316 log.Fatal("unexpected session resumption")
317 }
318
319 if *expectedServerName != "" && cs.ServerName != *expectedServerName {
320 log.Fatalf("unexpected server name: got %q, want %q", cs.ServerName, *expectedServerName)
321 }
322 }
323
324 if *expectedCurve != "" {
325 expectedCurveID, err := strconv.Atoi(*expectedCurve)
326 if err != nil {
327 log.Fatalf("failed to parse -expect-curve-id: %s", err)
328 }
329 if tlsConn.curveID != CurveID(expectedCurveID) {
330 log.Fatalf("unexpected curve id: want %d, got %d", expectedCurveID, tlsConn.curveID)
331 }
332 }
333 }
334 }
335
336 func TestBogoSuite(t *testing.T) {
337 testenv.SkipIfShortAndSlow(t)
338 testenv.MustHaveExternalNetwork(t)
339 testenv.MustHaveGoRun(t)
340 testenv.MustHaveExec(t)
341
342 if testing.Short() {
343 t.Skip("skipping in short mode")
344 }
345 if testenv.Builder() != "" && runtime.GOOS == "windows" {
346 t.Skip("#66913: windows network connections are flakey on builders")
347 }
348
349
350
351
352
353 if _, err := os.Stat("bogo_config.json"); err != nil {
354 t.Fatal(err)
355 }
356
357 var bogoDir string
358 if *bogoLocalDir != "" {
359 bogoDir = *bogoLocalDir
360 } else {
361 const boringsslModVer = "v0.0.0-20240523173554-273a920f84e8"
362 output, err := exec.Command("go", "mod", "download", "-json", "boringssl.googlesource.com/boringssl.git@"+boringsslModVer).CombinedOutput()
363 if err != nil {
364 t.Fatalf("failed to download boringssl: %s", err)
365 }
366 var j struct {
367 Dir string
368 }
369 if err := json.Unmarshal(output, &j); err != nil {
370 t.Fatalf("failed to parse 'go mod download' output: %s", err)
371 }
372 bogoDir = j.Dir
373 }
374
375 cwd, err := os.Getwd()
376 if err != nil {
377 t.Fatal(err)
378 }
379
380 resultsFile := filepath.Join(t.TempDir(), "results.json")
381
382 args := []string{
383 "test",
384 ".",
385 fmt.Sprintf("-shim-config=%s", filepath.Join(cwd, "bogo_config.json")),
386 fmt.Sprintf("-shim-path=%s", os.Args[0]),
387 "-shim-extra-flags=-bogo-mode",
388 "-allow-unimplemented",
389 "-loose-errors",
390 fmt.Sprintf("-json-output=%s", resultsFile),
391 }
392 if *bogoFilter != "" {
393 args = append(args, fmt.Sprintf("-test=%s", *bogoFilter))
394 }
395
396 goCmd, err := testenv.GoTool()
397 if err != nil {
398 t.Fatal(err)
399 }
400 cmd := exec.Command(goCmd, args...)
401 out := &strings.Builder{}
402 cmd.Stderr = out
403 cmd.Dir = filepath.Join(bogoDir, "ssl/test/runner")
404 err = cmd.Run()
405
406
407
408
409
410
411 resultsJSON, jsonErr := os.ReadFile(resultsFile)
412 if jsonErr != nil {
413 if err != nil {
414 t.Fatalf("bogo failed: %s\n%s", err, out)
415 }
416 t.Fatalf("failed to read results JSON file: %s", jsonErr)
417 }
418
419 var results bogoResults
420 if err := json.Unmarshal(resultsJSON, &results); err != nil {
421 t.Fatalf("failed to parse results JSON: %s", err)
422 }
423
424
425
426
427 assertResults := map[string]string{
428 "CurveTest-Client-Kyber-TLS13": "PASS",
429 "CurveTest-Server-Kyber-TLS13": "PASS",
430 }
431
432 for name, result := range results.Tests {
433
434 t.Run(name, func(t *testing.T) {
435 if result.Actual == "FAIL" && result.IsUnexpected {
436 t.Fatal(result.Error)
437 }
438 if expectedResult, ok := assertResults[name]; ok && expectedResult != result.Actual {
439 t.Fatalf("unexpected result: got %s, want %s", result.Actual, assertResults[name])
440 }
441 delete(assertResults, name)
442 if result.Actual == "SKIP" {
443 t.Skip()
444 }
445 })
446 }
447 if *bogoFilter == "" {
448
449 for name, expectedResult := range assertResults {
450 t.Run(name, func(t *testing.T) {
451 t.Fatalf("expected test to run with result %s, but it was not present in the test results", expectedResult)
452 })
453 }
454 }
455 }
456
457
458 type bogoResults struct {
459 Version int `json:"version"`
460 Interrupted bool `json:"interrupted"`
461 PathDelimiter string `json:"path_delimiter"`
462 SecondsSinceEpoch float64 `json:"seconds_since_epoch"`
463 NumFailuresByType map[string]int `json:"num_failures_by_type"`
464 Tests map[string]struct {
465 Actual string `json:"actual"`
466 Expected string `json:"expected"`
467 IsUnexpected bool `json:"is_unexpected"`
468 Error string `json:"error,omitempty"`
469 } `json:"tests"`
470 }
471
View as plain text