Source file
src/crypto/tls/tls_test.go
1
2
3
4
5 package tls
6
7 import (
8 "bytes"
9 "context"
10 "crypto"
11 "crypto/ecdsa"
12 "crypto/elliptic"
13 "crypto/rand"
14 "crypto/x509"
15 "crypto/x509/pkix"
16 "encoding/asn1"
17 "encoding/json"
18 "encoding/pem"
19 "errors"
20 "fmt"
21 "internal/testenv"
22 "io"
23 "math"
24 "math/big"
25 "net"
26 "os"
27 "reflect"
28 "slices"
29 "strings"
30 "testing"
31 "time"
32 )
33
34 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
35 MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
36 BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
37 aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
38 MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
39 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ
40 hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa
41 rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv
42 zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF
43 MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW
44 r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
45 -----END CERTIFICATE-----
46 `
47
48 var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY-----
49 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
50 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
51 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
52 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
53 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
54 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
55 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
56 -----END RSA TESTING KEY-----
57 `)
58
59
60
61 var keyPEM = testingKey(`-----BEGIN TESTING KEY-----
62 MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
63 k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
64 6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
65 MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
66 SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
67 xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
68 D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
69 -----END TESTING KEY-----
70 `)
71
72 var ecdsaCertPEM = `-----BEGIN CERTIFICATE-----
73 MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
74 EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0
75 eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG
76 EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk
77 Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR
78 lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl
79 01xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8
80 XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo
81 A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb
82 H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1
83 +jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA==
84 -----END CERTIFICATE-----
85 `
86
87 var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS-----
88 BgUrgQQAIw==
89 -----END EC PARAMETERS-----
90 -----BEGIN EC TESTING KEY-----
91 MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0
92 NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL
93 06h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz
94 VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q
95 kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ==
96 -----END EC TESTING KEY-----
97 `)
98
99 var keyPairTests = []struct {
100 algo string
101 cert string
102 key string
103 }{
104 {"ECDSA", ecdsaCertPEM, ecdsaKeyPEM},
105 {"RSA", rsaCertPEM, rsaKeyPEM},
106 {"RSA-untyped", rsaCertPEM, keyPEM},
107 }
108
109 func TestX509KeyPair(t *testing.T) {
110 t.Parallel()
111 var pem []byte
112 for _, test := range keyPairTests {
113 pem = []byte(test.cert + test.key)
114 if _, err := X509KeyPair(pem, pem); err != nil {
115 t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err)
116 }
117 pem = []byte(test.key + test.cert)
118 if _, err := X509KeyPair(pem, pem); err != nil {
119 t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err)
120 }
121 }
122 }
123
124 func TestX509KeyPairErrors(t *testing.T) {
125 _, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM))
126 if err == nil {
127 t.Fatalf("X509KeyPair didn't return an error when arguments were switched")
128 }
129 if subStr := "been switched"; !strings.Contains(err.Error(), subStr) {
130 t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err)
131 }
132
133 _, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM))
134 if err == nil {
135 t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates")
136 }
137 if subStr := "certificate"; !strings.Contains(err.Error(), subStr) {
138 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err)
139 }
140
141 const nonsensePEM = `
142 -----BEGIN NONSENSE-----
143 Zm9vZm9vZm9v
144 -----END NONSENSE-----
145 `
146
147 _, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM))
148 if err == nil {
149 t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense")
150 }
151 if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) {
152 t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err)
153 }
154 }
155
156 func TestX509MixedKeyPair(t *testing.T) {
157 if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil {
158 t.Error("Load of RSA certificate succeeded with ECDSA private key")
159 }
160 if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil {
161 t.Error("Load of ECDSA certificate succeeded with RSA private key")
162 }
163 }
164
165 func newLocalListener(t testing.TB) net.Listener {
166 t.Helper()
167 ln, err := net.Listen("tcp", "127.0.0.1:0")
168 if err != nil {
169 ln, err = net.Listen("tcp6", "[::1]:0")
170 }
171 if err != nil {
172 t.Fatal(err)
173 }
174 return ln
175 }
176
177 func TestDialTimeout(t *testing.T) {
178 if testing.Short() {
179 t.Skip("skipping in short mode")
180 }
181
182 timeout := 100 * time.Microsecond
183 for !t.Failed() {
184 acceptc := make(chan net.Conn)
185 listener := newLocalListener(t)
186 go func() {
187 for {
188 conn, err := listener.Accept()
189 if err != nil {
190 close(acceptc)
191 return
192 }
193 acceptc <- conn
194 }
195 }()
196
197 addr := listener.Addr().String()
198 dialer := &net.Dialer{
199 Timeout: timeout,
200 }
201 if conn, err := DialWithDialer(dialer, "tcp", addr, nil); err == nil {
202 conn.Close()
203 t.Errorf("DialWithTimeout unexpectedly completed successfully")
204 } else if !isTimeoutError(err) {
205 t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
206 }
207
208 listener.Close()
209
210
211
212
213
214
215 lconn, ok := <-acceptc
216 if ok {
217
218
219 t.Logf("Listener accepted a connection from %s", lconn.RemoteAddr())
220 lconn.Close()
221 }
222
223
224 for extraConn := range acceptc {
225 t.Logf("spurious extra connection from %s", extraConn.RemoteAddr())
226 extraConn.Close()
227 }
228 if ok {
229 break
230 }
231
232 t.Logf("with timeout %v, DialWithDialer returned before listener accepted any connections; retrying", timeout)
233 timeout *= 2
234 }
235 }
236
237 func TestDeadlineOnWrite(t *testing.T) {
238 if testing.Short() {
239 t.Skip("skipping in short mode")
240 }
241
242 ln := newLocalListener(t)
243 defer ln.Close()
244
245 srvCh := make(chan *Conn, 1)
246
247 go func() {
248 sconn, err := ln.Accept()
249 if err != nil {
250 srvCh <- nil
251 return
252 }
253 srv := Server(sconn, testConfig.Clone())
254 if err := srv.Handshake(); err != nil {
255 srvCh <- nil
256 return
257 }
258 srvCh <- srv
259 }()
260
261 clientConfig := testConfig.Clone()
262 clientConfig.MaxVersion = VersionTLS12
263 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
264 if err != nil {
265 t.Fatal(err)
266 }
267 defer conn.Close()
268
269 srv := <-srvCh
270 if srv == nil {
271 t.Error(err)
272 }
273
274
275 buf := make([]byte, 6)
276 if _, err := srv.Write([]byte("foobar")); err != nil {
277 t.Errorf("Write err: %v", err)
278 }
279 if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" {
280 t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
281 }
282
283
284 if err = srv.SetDeadline(time.Now()); err != nil {
285 t.Fatalf("SetDeadline(time.Now()) err: %v", err)
286 }
287 if _, err = srv.Write([]byte("should fail")); err == nil {
288 t.Fatal("Write should have timed out")
289 }
290
291
292 if err = srv.SetDeadline(time.Time{}); err != nil {
293 t.Fatalf("SetDeadline(time.Time{}) err: %v", err)
294 }
295 if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil {
296 t.Fatal("Write which previously failed should still time out")
297 }
298
299
300 if ne := err.(net.Error); ne.Temporary() != false {
301 t.Error("Write timed out but incorrectly classified the error as Temporary")
302 }
303 if !isTimeoutError(err) {
304 t.Error("Write timed out but did not classify the error as a Timeout")
305 }
306 }
307
308 type readerFunc func([]byte) (int, error)
309
310 func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
311
312
313
314
315 func TestDialer(t *testing.T) {
316 ln := newLocalListener(t)
317 defer ln.Close()
318
319 unblockServer := make(chan struct{})
320 defer close(unblockServer)
321 go func() {
322 conn, err := ln.Accept()
323 if err != nil {
324 return
325 }
326 defer conn.Close()
327 <-unblockServer
328 }()
329
330 ctx, cancel := context.WithCancel(context.Background())
331 d := Dialer{Config: &Config{
332 Rand: readerFunc(func(b []byte) (n int, err error) {
333
334
335
336
337
338 cancel()
339 return len(b), nil
340 }),
341 ServerName: "foo",
342 }}
343 _, err := d.DialContext(ctx, "tcp", ln.Addr().String())
344 if err != context.Canceled {
345 t.Errorf("err = %v; want context.Canceled", err)
346 }
347 }
348
349 func isTimeoutError(err error) bool {
350 if ne, ok := err.(net.Error); ok {
351 return ne.Timeout()
352 }
353 return false
354 }
355
356
357
358
359 func TestConnReadNonzeroAndEOF(t *testing.T) {
360
361
362
363
364
365
366 if testing.Short() {
367 t.Skip("skipping in short mode")
368 }
369 var err error
370 for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 {
371 if err = testConnReadNonzeroAndEOF(t, delay); err == nil {
372 return
373 }
374 }
375 t.Error(err)
376 }
377
378 func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error {
379 ln := newLocalListener(t)
380 defer ln.Close()
381
382 srvCh := make(chan *Conn, 1)
383 var serr error
384 go func() {
385 sconn, err := ln.Accept()
386 if err != nil {
387 serr = err
388 srvCh <- nil
389 return
390 }
391 serverConfig := testConfig.Clone()
392 srv := Server(sconn, serverConfig)
393 if err := srv.Handshake(); err != nil {
394 serr = fmt.Errorf("handshake: %v", err)
395 srvCh <- nil
396 return
397 }
398 srvCh <- srv
399 }()
400
401 clientConfig := testConfig.Clone()
402
403
404 clientConfig.MaxVersion = VersionTLS12
405 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
406 if err != nil {
407 t.Fatal(err)
408 }
409 defer conn.Close()
410
411 srv := <-srvCh
412 if srv == nil {
413 return serr
414 }
415
416 buf := make([]byte, 6)
417
418 srv.Write([]byte("foobar"))
419 n, err := conn.Read(buf)
420 if n != 6 || err != nil || string(buf) != "foobar" {
421 return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
422 }
423
424 srv.Write([]byte("abcdef"))
425 srv.Close()
426 time.Sleep(delay)
427 n, err = conn.Read(buf)
428 if n != 6 || string(buf) != "abcdef" {
429 return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf)
430 }
431 if err != io.EOF {
432 return fmt.Errorf("Second Read error = %v; want io.EOF", err)
433 }
434 return nil
435 }
436
437 func TestTLSUniqueMatches(t *testing.T) {
438 ln := newLocalListener(t)
439 defer ln.Close()
440
441 serverTLSUniques := make(chan []byte)
442 parentDone := make(chan struct{})
443 childDone := make(chan struct{})
444 defer close(parentDone)
445 go func() {
446 defer close(childDone)
447 for i := 0; i < 2; i++ {
448 sconn, err := ln.Accept()
449 if err != nil {
450 t.Error(err)
451 return
452 }
453 serverConfig := testConfig.Clone()
454 serverConfig.MaxVersion = VersionTLS12
455 srv := Server(sconn, serverConfig)
456 if err := srv.Handshake(); err != nil {
457 t.Error(err)
458 return
459 }
460 select {
461 case <-parentDone:
462 return
463 case serverTLSUniques <- srv.ConnectionState().TLSUnique:
464 }
465 }
466 }()
467
468 clientConfig := testConfig.Clone()
469 clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
470 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
471 if err != nil {
472 t.Fatal(err)
473 }
474
475 var serverTLSUniquesValue []byte
476 select {
477 case <-childDone:
478 return
479 case serverTLSUniquesValue = <-serverTLSUniques:
480 }
481
482 if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
483 t.Error("client and server channel bindings differ")
484 }
485 if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
486 t.Error("tls-unique is empty or zero")
487 }
488 conn.Close()
489
490 conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
491 if err != nil {
492 t.Fatal(err)
493 }
494 defer conn.Close()
495 if !conn.ConnectionState().DidResume {
496 t.Error("second session did not use resumption")
497 }
498
499 select {
500 case <-childDone:
501 return
502 case serverTLSUniquesValue = <-serverTLSUniques:
503 }
504
505 if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
506 t.Error("client and server channel bindings differ when session resumption is used")
507 }
508 if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
509 t.Error("resumption tls-unique is empty or zero")
510 }
511 }
512
513 func TestVerifyHostname(t *testing.T) {
514 testenv.MustHaveExternalNetwork(t)
515
516 c, err := Dial("tcp", "www.google.com:https", nil)
517 if err != nil {
518 t.Fatal(err)
519 }
520 if err := c.VerifyHostname("www.google.com"); err != nil {
521 t.Fatalf("verify www.google.com: %v", err)
522 }
523 if err := c.VerifyHostname("www.yahoo.com"); err == nil {
524 t.Fatalf("verify www.yahoo.com succeeded")
525 }
526
527 c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true})
528 if err != nil {
529 t.Fatal(err)
530 }
531 if err := c.VerifyHostname("www.google.com"); err == nil {
532 t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true")
533 }
534 }
535
536 func TestConnCloseBreakingWrite(t *testing.T) {
537 ln := newLocalListener(t)
538 defer ln.Close()
539
540 srvCh := make(chan *Conn, 1)
541 var serr error
542 var sconn net.Conn
543 go func() {
544 var err error
545 sconn, err = ln.Accept()
546 if err != nil {
547 serr = err
548 srvCh <- nil
549 return
550 }
551 serverConfig := testConfig.Clone()
552 srv := Server(sconn, serverConfig)
553 if err := srv.Handshake(); err != nil {
554 serr = fmt.Errorf("handshake: %v", err)
555 srvCh <- nil
556 return
557 }
558 srvCh <- srv
559 }()
560
561 cconn, err := net.Dial("tcp", ln.Addr().String())
562 if err != nil {
563 t.Fatal(err)
564 }
565 defer cconn.Close()
566
567 conn := &changeImplConn{
568 Conn: cconn,
569 }
570
571 clientConfig := testConfig.Clone()
572 tconn := Client(conn, clientConfig)
573 if err := tconn.Handshake(); err != nil {
574 t.Fatal(err)
575 }
576
577 srv := <-srvCh
578 if srv == nil {
579 t.Fatal(serr)
580 }
581 defer sconn.Close()
582
583 connClosed := make(chan struct{})
584 conn.closeFunc = func() error {
585 close(connClosed)
586 return nil
587 }
588
589 inWrite := make(chan bool, 1)
590 var errConnClosed = errors.New("conn closed for test")
591 conn.writeFunc = func(p []byte) (n int, err error) {
592 inWrite <- true
593 <-connClosed
594 return 0, errConnClosed
595 }
596
597 closeReturned := make(chan bool, 1)
598 go func() {
599 <-inWrite
600 tconn.Close()
601 closeReturned <- true
602 }()
603
604 _, err = tconn.Write([]byte("foo"))
605 if err != errConnClosed {
606 t.Errorf("Write error = %v; want errConnClosed", err)
607 }
608
609 <-closeReturned
610 if err := tconn.Close(); err != net.ErrClosed {
611 t.Errorf("Close error = %v; want net.ErrClosed", err)
612 }
613 }
614
615 func TestConnCloseWrite(t *testing.T) {
616 ln := newLocalListener(t)
617 defer ln.Close()
618
619 clientDoneChan := make(chan struct{})
620
621 serverCloseWrite := func() error {
622 sconn, err := ln.Accept()
623 if err != nil {
624 return fmt.Errorf("accept: %v", err)
625 }
626 defer sconn.Close()
627
628 serverConfig := testConfig.Clone()
629 srv := Server(sconn, serverConfig)
630 if err := srv.Handshake(); err != nil {
631 return fmt.Errorf("handshake: %v", err)
632 }
633 defer srv.Close()
634
635 data, err := io.ReadAll(srv)
636 if err != nil {
637 return err
638 }
639 if len(data) > 0 {
640 return fmt.Errorf("Read data = %q; want nothing", data)
641 }
642
643 if err := srv.CloseWrite(); err != nil {
644 return fmt.Errorf("server CloseWrite: %v", err)
645 }
646
647
648
649
650
651 <-clientDoneChan
652 return nil
653 }
654
655 clientCloseWrite := func() error {
656 defer close(clientDoneChan)
657
658 clientConfig := testConfig.Clone()
659 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
660 if err != nil {
661 return err
662 }
663 if err := conn.Handshake(); err != nil {
664 return err
665 }
666 defer conn.Close()
667
668 if err := conn.CloseWrite(); err != nil {
669 return fmt.Errorf("client CloseWrite: %v", err)
670 }
671
672 if _, err := conn.Write([]byte{0}); err != errShutdown {
673 return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
674 }
675
676 data, err := io.ReadAll(conn)
677 if err != nil {
678 return err
679 }
680 if len(data) > 0 {
681 return fmt.Errorf("Read data = %q; want nothing", data)
682 }
683 return nil
684 }
685
686 errChan := make(chan error, 2)
687
688 go func() { errChan <- serverCloseWrite() }()
689 go func() { errChan <- clientCloseWrite() }()
690
691 for i := 0; i < 2; i++ {
692 select {
693 case err := <-errChan:
694 if err != nil {
695 t.Fatal(err)
696 }
697 case <-time.After(10 * time.Second):
698 t.Fatal("deadlock")
699 }
700 }
701
702
703
704 {
705 ln2 := newLocalListener(t)
706 defer ln2.Close()
707
708 netConn, err := net.Dial("tcp", ln2.Addr().String())
709 if err != nil {
710 t.Fatal(err)
711 }
712 defer netConn.Close()
713 conn := Client(netConn, testConfig.Clone())
714
715 if err := conn.CloseWrite(); err != errEarlyCloseWrite {
716 t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
717 }
718 }
719 }
720
721 func TestWarningAlertFlood(t *testing.T) {
722 ln := newLocalListener(t)
723 defer ln.Close()
724
725 server := func() error {
726 sconn, err := ln.Accept()
727 if err != nil {
728 return fmt.Errorf("accept: %v", err)
729 }
730 defer sconn.Close()
731
732 serverConfig := testConfig.Clone()
733 srv := Server(sconn, serverConfig)
734 if err := srv.Handshake(); err != nil {
735 return fmt.Errorf("handshake: %v", err)
736 }
737 defer srv.Close()
738
739 _, err = io.ReadAll(srv)
740 if err == nil {
741 return errors.New("unexpected lack of error from server")
742 }
743 const expected = "too many ignored"
744 if str := err.Error(); !strings.Contains(str, expected) {
745 return fmt.Errorf("expected error containing %q, but saw: %s", expected, str)
746 }
747
748 return nil
749 }
750
751 errChan := make(chan error, 1)
752 go func() { errChan <- server() }()
753
754 clientConfig := testConfig.Clone()
755 clientConfig.MaxVersion = VersionTLS12
756 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
757 if err != nil {
758 t.Fatal(err)
759 }
760 defer conn.Close()
761 if err := conn.Handshake(); err != nil {
762 t.Fatal(err)
763 }
764
765 for i := 0; i < maxUselessRecords+1; i++ {
766 conn.sendAlert(alertNoRenegotiation)
767 }
768
769 if err := <-errChan; err != nil {
770 t.Fatal(err)
771 }
772 }
773
774 func TestCloneFuncFields(t *testing.T) {
775 const expectedCount = 9
776 called := 0
777
778 c1 := Config{
779 Time: func() time.Time {
780 called |= 1 << 0
781 return time.Time{}
782 },
783 GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
784 called |= 1 << 1
785 return nil, nil
786 },
787 GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) {
788 called |= 1 << 2
789 return nil, nil
790 },
791 GetConfigForClient: func(*ClientHelloInfo) (*Config, error) {
792 called |= 1 << 3
793 return nil, nil
794 },
795 VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
796 called |= 1 << 4
797 return nil
798 },
799 VerifyConnection: func(ConnectionState) error {
800 called |= 1 << 5
801 return nil
802 },
803 UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) {
804 called |= 1 << 6
805 return nil, nil
806 },
807 WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) {
808 called |= 1 << 7
809 return nil, nil
810 },
811 EncryptedClientHelloRejectionVerify: func(ConnectionState) error {
812 called |= 1 << 8
813 return nil
814 },
815 }
816
817 c2 := c1.Clone()
818
819 c2.Time()
820 c2.GetCertificate(nil)
821 c2.GetClientCertificate(nil)
822 c2.GetConfigForClient(nil)
823 c2.VerifyPeerCertificate(nil, nil)
824 c2.VerifyConnection(ConnectionState{})
825 c2.UnwrapSession(nil, ConnectionState{})
826 c2.WrapSession(ConnectionState{}, nil)
827 c2.EncryptedClientHelloRejectionVerify(ConnectionState{})
828
829 if called != (1<<expectedCount)-1 {
830 t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
831 }
832 }
833
834 func TestCloneNonFuncFields(t *testing.T) {
835 var c1 Config
836 v := reflect.ValueOf(&c1).Elem()
837
838 typ := v.Type()
839 for i := 0; i < typ.NumField(); i++ {
840 f := v.Field(i)
841
842
843 switch fn := typ.Field(i).Name; fn {
844 case "Rand":
845 f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
846 case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "WrapSession", "UnwrapSession", "EncryptedClientHelloRejectionVerify":
847
848
849
850
851 case "Certificates":
852 f.Set(reflect.ValueOf([]Certificate{
853 {Certificate: [][]byte{{'b'}}},
854 }))
855 case "NameToCertificate":
856 f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil}))
857 case "RootCAs", "ClientCAs":
858 f.Set(reflect.ValueOf(x509.NewCertPool()))
859 case "ClientSessionCache":
860 f.Set(reflect.ValueOf(NewLRUClientSessionCache(10)))
861 case "KeyLogWriter":
862 f.Set(reflect.ValueOf(io.Writer(os.Stdout)))
863 case "NextProtos":
864 f.Set(reflect.ValueOf([]string{"a", "b"}))
865 case "ServerName":
866 f.Set(reflect.ValueOf("b"))
867 case "ClientAuth":
868 f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
869 case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
870 f.Set(reflect.ValueOf(true))
871 case "MinVersion", "MaxVersion":
872 f.Set(reflect.ValueOf(uint16(VersionTLS12)))
873 case "SessionTicketKey":
874 f.Set(reflect.ValueOf([32]byte{}))
875 case "CipherSuites":
876 f.Set(reflect.ValueOf([]uint16{1, 2}))
877 case "CurvePreferences":
878 f.Set(reflect.ValueOf([]CurveID{CurveP256}))
879 case "Renegotiation":
880 f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
881 case "EncryptedClientHelloConfigList":
882 f.Set(reflect.ValueOf([]byte{'x'}))
883 case "mutex", "autoSessionTicketKeys", "sessionTicketKeys":
884 continue
885 default:
886 t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
887 }
888 }
889
890 c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
891 c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
892
893 c2 := c1.Clone()
894 if !reflect.DeepEqual(&c1, c2) {
895 t.Errorf("clone failed to copy a field")
896 }
897 }
898
899 func TestCloneNilConfig(t *testing.T) {
900 var config *Config
901 if cc := config.Clone(); cc != nil {
902 t.Fatalf("Clone with nil should return nil, got: %+v", cc)
903 }
904 }
905
906
907
908 type changeImplConn struct {
909 net.Conn
910 writeFunc func([]byte) (int, error)
911 closeFunc func() error
912 }
913
914 func (w *changeImplConn) Write(p []byte) (n int, err error) {
915 if w.writeFunc != nil {
916 return w.writeFunc(p)
917 }
918 return w.Conn.Write(p)
919 }
920
921 func (w *changeImplConn) Close() error {
922 if w.closeFunc != nil {
923 return w.closeFunc()
924 }
925 return w.Conn.Close()
926 }
927
928 func throughput(b *testing.B, version uint16, totalBytes int64, dynamicRecordSizingDisabled bool) {
929 ln := newLocalListener(b)
930 defer ln.Close()
931
932 N := b.N
933
934
935
936 const bufsize = 32 << 10
937
938 go func() {
939 buf := make([]byte, bufsize)
940 for i := 0; i < N; i++ {
941 sconn, err := ln.Accept()
942 if err != nil {
943
944
945 panic(fmt.Errorf("accept: %v", err))
946 }
947 serverConfig := testConfig.Clone()
948 serverConfig.CipherSuites = nil
949 serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
950 srv := Server(sconn, serverConfig)
951 if err := srv.Handshake(); err != nil {
952 panic(fmt.Errorf("handshake: %v", err))
953 }
954 if _, err := io.CopyBuffer(srv, srv, buf); err != nil {
955 panic(fmt.Errorf("copy buffer: %v", err))
956 }
957 }
958 }()
959
960 b.SetBytes(totalBytes)
961 clientConfig := testConfig.Clone()
962 clientConfig.CipherSuites = nil
963 clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
964 clientConfig.MaxVersion = version
965
966 buf := make([]byte, bufsize)
967 chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf))))
968 for i := 0; i < N; i++ {
969 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
970 if err != nil {
971 b.Fatal(err)
972 }
973 for j := 0; j < chunks; j++ {
974 _, err := conn.Write(buf)
975 if err != nil {
976 b.Fatal(err)
977 }
978 _, err = io.ReadFull(conn, buf)
979 if err != nil {
980 b.Fatal(err)
981 }
982 }
983 conn.Close()
984 }
985 }
986
987 func BenchmarkThroughput(b *testing.B) {
988 for _, mode := range []string{"Max", "Dynamic"} {
989 for size := 1; size <= 64; size <<= 1 {
990 name := fmt.Sprintf("%sPacket/%dMB", mode, size)
991 b.Run(name, func(b *testing.B) {
992 b.Run("TLSv12", func(b *testing.B) {
993 throughput(b, VersionTLS12, int64(size<<20), mode == "Max")
994 })
995 b.Run("TLSv13", func(b *testing.B) {
996 throughput(b, VersionTLS13, int64(size<<20), mode == "Max")
997 })
998 })
999 }
1000 }
1001 }
1002
1003 type slowConn struct {
1004 net.Conn
1005 bps int
1006 }
1007
1008 func (c *slowConn) Write(p []byte) (int, error) {
1009 if c.bps == 0 {
1010 panic("too slow")
1011 }
1012 t0 := time.Now()
1013 wrote := 0
1014 for wrote < len(p) {
1015 time.Sleep(100 * time.Microsecond)
1016 allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8
1017 if allowed > len(p) {
1018 allowed = len(p)
1019 }
1020 if wrote < allowed {
1021 n, err := c.Conn.Write(p[wrote:allowed])
1022 wrote += n
1023 if err != nil {
1024 return wrote, err
1025 }
1026 }
1027 }
1028 return len(p), nil
1029 }
1030
1031 func latency(b *testing.B, version uint16, bps int, dynamicRecordSizingDisabled bool) {
1032 ln := newLocalListener(b)
1033 defer ln.Close()
1034
1035 N := b.N
1036
1037 go func() {
1038 for i := 0; i < N; i++ {
1039 sconn, err := ln.Accept()
1040 if err != nil {
1041
1042
1043 panic(fmt.Errorf("accept: %v", err))
1044 }
1045 serverConfig := testConfig.Clone()
1046 serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
1047 srv := Server(&slowConn{sconn, bps}, serverConfig)
1048 if err := srv.Handshake(); err != nil {
1049 panic(fmt.Errorf("handshake: %v", err))
1050 }
1051 io.Copy(srv, srv)
1052 }
1053 }()
1054
1055 clientConfig := testConfig.Clone()
1056 clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
1057 clientConfig.MaxVersion = version
1058
1059 buf := make([]byte, 16384)
1060 peek := make([]byte, 1)
1061
1062 for i := 0; i < N; i++ {
1063 conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
1064 if err != nil {
1065 b.Fatal(err)
1066 }
1067
1068 if _, err := conn.Write(buf[:1]); err != nil {
1069 b.Fatal(err)
1070 }
1071 if _, err := io.ReadFull(conn, peek); err != nil {
1072 b.Fatal(err)
1073 }
1074 if _, err := conn.Write(buf); err != nil {
1075 b.Fatal(err)
1076 }
1077 if _, err = io.ReadFull(conn, peek); err != nil {
1078 b.Fatal(err)
1079 }
1080 conn.Close()
1081 }
1082 }
1083
1084 func BenchmarkLatency(b *testing.B) {
1085 for _, mode := range []string{"Max", "Dynamic"} {
1086 for _, kbps := range []int{200, 500, 1000, 2000, 5000} {
1087 name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps)
1088 b.Run(name, func(b *testing.B) {
1089 b.Run("TLSv12", func(b *testing.B) {
1090 latency(b, VersionTLS12, kbps*1000, mode == "Max")
1091 })
1092 b.Run("TLSv13", func(b *testing.B) {
1093 latency(b, VersionTLS13, kbps*1000, mode == "Max")
1094 })
1095 })
1096 }
1097 }
1098 }
1099
1100 func TestConnectionStateMarshal(t *testing.T) {
1101 cs := &ConnectionState{}
1102 _, err := json.Marshal(cs)
1103 if err != nil {
1104 t.Errorf("json.Marshal failed on ConnectionState: %v", err)
1105 }
1106 }
1107
1108 func TestConnectionState(t *testing.T) {
1109 issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1110 if err != nil {
1111 panic(err)
1112 }
1113 rootCAs := x509.NewCertPool()
1114 rootCAs.AddCert(issuer)
1115
1116 now := func() time.Time { return time.Unix(1476984729, 0) }
1117
1118 const alpnProtocol = "golang"
1119 const serverName = "example.golang"
1120 var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
1121 var ocsp = []byte("dummy ocsp")
1122
1123 for _, v := range []uint16{VersionTLS12, VersionTLS13} {
1124 var name string
1125 switch v {
1126 case VersionTLS12:
1127 name = "TLSv12"
1128 case VersionTLS13:
1129 name = "TLSv13"
1130 }
1131 t.Run(name, func(t *testing.T) {
1132 config := &Config{
1133 Time: now,
1134 Rand: zeroSource{},
1135 Certificates: make([]Certificate, 1),
1136 MaxVersion: v,
1137 RootCAs: rootCAs,
1138 ClientCAs: rootCAs,
1139 ClientAuth: RequireAndVerifyClientCert,
1140 NextProtos: []string{alpnProtocol},
1141 ServerName: serverName,
1142 }
1143 config.Certificates[0].Certificate = [][]byte{testRSACertificate}
1144 config.Certificates[0].PrivateKey = testRSAPrivateKey
1145 config.Certificates[0].SignedCertificateTimestamps = scts
1146 config.Certificates[0].OCSPStaple = ocsp
1147
1148 ss, cs, err := testHandshake(t, config, config)
1149 if err != nil {
1150 t.Fatalf("Handshake failed: %v", err)
1151 }
1152
1153 if ss.Version != v || cs.Version != v {
1154 t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v)
1155 }
1156
1157 if !ss.HandshakeComplete || !cs.HandshakeComplete {
1158 t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete)
1159 }
1160
1161 if ss.DidResume || cs.DidResume {
1162 t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume)
1163 }
1164
1165 if ss.CipherSuite == 0 || cs.CipherSuite == 0 {
1166 t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite)
1167 }
1168
1169 if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol {
1170 t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol)
1171 }
1172
1173 if !cs.NegotiatedProtocolIsMutual {
1174 t.Errorf("Got false NegotiatedProtocolIsMutual on the client side")
1175 }
1176
1177
1178 if ss.ServerName != serverName {
1179 t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
1180 }
1181 if cs.ServerName != serverName {
1182 t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
1183 }
1184
1185 if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
1186 t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1)
1187 }
1188
1189 if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 {
1190 t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1)
1191 } else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 {
1192 t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2)
1193 }
1194
1195 if len(cs.SignedCertificateTimestamps) != 2 {
1196 t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2)
1197 }
1198 if !bytes.Equal(cs.OCSPResponse, ocsp) {
1199 t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp)
1200 }
1201
1202 if v == VersionTLS13 {
1203 if len(ss.SignedCertificateTimestamps) != 2 {
1204 t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2)
1205 }
1206 if !bytes.Equal(ss.OCSPResponse, ocsp) {
1207 t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp)
1208 }
1209 }
1210
1211 if v == VersionTLS13 {
1212 if ss.TLSUnique != nil || cs.TLSUnique != nil {
1213 t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique)
1214 }
1215 } else {
1216 if ss.TLSUnique == nil || cs.TLSUnique == nil {
1217 t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique)
1218 }
1219 }
1220 })
1221 }
1222 }
1223
1224
1225
1226 func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) {
1227 c0 := Certificate{
1228 Certificate: [][]byte{testRSACertificate},
1229 PrivateKey: testRSAPrivateKey,
1230 }
1231 c1 := Certificate{
1232 Certificate: [][]byte{testSNICertificate},
1233 PrivateKey: testRSAPrivateKey,
1234 }
1235 config := testConfig.Clone()
1236 config.Certificates = []Certificate{c0, c1}
1237
1238 config.BuildNameToCertificate()
1239 got := config.Certificates
1240 want := []Certificate{c0, c1}
1241 if !reflect.DeepEqual(got, want) {
1242 t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want)
1243 }
1244 }
1245
1246 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
1247
1248 func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
1249 rsaCert := &Certificate{
1250 Certificate: [][]byte{testRSACertificate},
1251 PrivateKey: testRSAPrivateKey,
1252 }
1253 pkcs1Cert := &Certificate{
1254 Certificate: [][]byte{testRSACertificate},
1255 PrivateKey: testRSAPrivateKey,
1256 SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256},
1257 }
1258 ecdsaCert := &Certificate{
1259
1260 Certificate: [][]byte{testP256Certificate},
1261 PrivateKey: testP256PrivateKey,
1262 }
1263 ed25519Cert := &Certificate{
1264 Certificate: [][]byte{testEd25519Certificate},
1265 PrivateKey: testEd25519PrivateKey,
1266 }
1267
1268 tests := []struct {
1269 c *Certificate
1270 chi *ClientHelloInfo
1271 wantErr string
1272 }{
1273 {rsaCert, &ClientHelloInfo{
1274 ServerName: "example.golang",
1275 SignatureSchemes: []SignatureScheme{PSSWithSHA256},
1276 SupportedVersions: []uint16{VersionTLS13},
1277 }, ""},
1278 {ecdsaCert, &ClientHelloInfo{
1279 SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1280 SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1281 }, ""},
1282 {rsaCert, &ClientHelloInfo{
1283 ServerName: "example.com",
1284 SignatureSchemes: []SignatureScheme{PSSWithSHA256},
1285 SupportedVersions: []uint16{VersionTLS13},
1286 }, "not valid for requested server name"},
1287 {ecdsaCert, &ClientHelloInfo{
1288 SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384},
1289 SupportedVersions: []uint16{VersionTLS13},
1290 }, "signature algorithms"},
1291 {pkcs1Cert, &ClientHelloInfo{
1292 SignatureSchemes: []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1293 SupportedVersions: []uint16{VersionTLS13},
1294 }, "signature algorithms"},
1295
1296 {rsaCert, &ClientHelloInfo{
1297 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1298 SignatureSchemes: []SignatureScheme{PKCS1WithSHA1},
1299 SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1300 }, "signature algorithms"},
1301 {rsaCert, &ClientHelloInfo{
1302 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1303 SignatureSchemes: []SignatureScheme{PKCS1WithSHA1},
1304 SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1305 config: &Config{
1306 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1307 MaxVersion: VersionTLS12,
1308 },
1309 }, ""},
1310
1311 {ecdsaCert, &ClientHelloInfo{
1312 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1313 SupportedCurves: []CurveID{CurveP256},
1314 SupportedPoints: []uint8{pointFormatUncompressed},
1315 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256},
1316 SupportedVersions: []uint16{VersionTLS12},
1317 }, ""},
1318 {ecdsaCert, &ClientHelloInfo{
1319 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1320 SupportedCurves: []CurveID{CurveP256},
1321 SupportedPoints: []uint8{pointFormatUncompressed},
1322 SignatureSchemes: []SignatureScheme{ECDSAWithP384AndSHA384},
1323 SupportedVersions: []uint16{VersionTLS12},
1324 }, ""},
1325 {ecdsaCert, &ClientHelloInfo{
1326 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1327 SupportedCurves: []CurveID{CurveP256},
1328 SupportedPoints: []uint8{pointFormatUncompressed},
1329 SignatureSchemes: nil,
1330 SupportedVersions: []uint16{VersionTLS12},
1331 }, ""},
1332 {ecdsaCert, &ClientHelloInfo{
1333 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1334 SupportedCurves: []CurveID{CurveP256},
1335 SupportedPoints: []uint8{pointFormatUncompressed},
1336 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256},
1337 SupportedVersions: []uint16{VersionTLS12},
1338 }, "cipher suite"},
1339 {ecdsaCert, &ClientHelloInfo{
1340 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1341 SupportedCurves: []CurveID{CurveP256},
1342 SupportedPoints: []uint8{pointFormatUncompressed},
1343 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256},
1344 SupportedVersions: []uint16{VersionTLS12},
1345 config: &Config{
1346 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1347 },
1348 }, "cipher suite"},
1349 {ecdsaCert, &ClientHelloInfo{
1350 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1351 SupportedCurves: []CurveID{CurveP384},
1352 SupportedPoints: []uint8{pointFormatUncompressed},
1353 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256},
1354 SupportedVersions: []uint16{VersionTLS12},
1355 }, "certificate curve"},
1356 {ecdsaCert, &ClientHelloInfo{
1357 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1358 SupportedCurves: []CurveID{CurveP256},
1359 SupportedPoints: []uint8{1},
1360 SignatureSchemes: []SignatureScheme{ECDSAWithP256AndSHA256},
1361 SupportedVersions: []uint16{VersionTLS12},
1362 }, "doesn't support ECDHE"},
1363 {ecdsaCert, &ClientHelloInfo{
1364 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1365 SupportedCurves: []CurveID{CurveP256},
1366 SupportedPoints: []uint8{pointFormatUncompressed},
1367 SignatureSchemes: []SignatureScheme{PSSWithSHA256},
1368 SupportedVersions: []uint16{VersionTLS12},
1369 }, "signature algorithms"},
1370
1371 {ed25519Cert, &ClientHelloInfo{
1372 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1373 SupportedCurves: []CurveID{CurveP256},
1374 SupportedPoints: []uint8{pointFormatUncompressed},
1375 SignatureSchemes: []SignatureScheme{Ed25519},
1376 SupportedVersions: []uint16{VersionTLS12},
1377 }, ""},
1378 {ed25519Cert, &ClientHelloInfo{
1379 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1380 SupportedCurves: []CurveID{CurveP256},
1381 SupportedPoints: []uint8{pointFormatUncompressed},
1382 SignatureSchemes: []SignatureScheme{Ed25519},
1383 SupportedVersions: []uint16{VersionTLS10},
1384 config: &Config{MinVersion: VersionTLS10},
1385 }, "doesn't support Ed25519"},
1386 {ed25519Cert, &ClientHelloInfo{
1387 CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1388 SupportedCurves: []CurveID{},
1389 SupportedPoints: []uint8{pointFormatUncompressed},
1390 SignatureSchemes: []SignatureScheme{Ed25519},
1391 SupportedVersions: []uint16{VersionTLS12},
1392 }, "doesn't support ECDHE"},
1393
1394 {rsaCert, &ClientHelloInfo{
1395 CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
1396 SupportedCurves: []CurveID{CurveP256},
1397 SupportedPoints: []uint8{pointFormatUncompressed},
1398 SupportedVersions: []uint16{VersionTLS10},
1399 config: &Config{MinVersion: VersionTLS10},
1400 }, ""},
1401 {rsaCert, &ClientHelloInfo{
1402 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1403 SupportedVersions: []uint16{VersionTLS12},
1404 config: &Config{
1405 CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1406 },
1407 }, ""},
1408 }
1409 for i, tt := range tests {
1410 err := tt.chi.SupportsCertificate(tt.c)
1411 switch {
1412 case tt.wantErr == "" && err != nil:
1413 t.Errorf("%d: unexpected error: %v", i, err)
1414 case tt.wantErr != "" && err == nil:
1415 t.Errorf("%d: unexpected success", i)
1416 case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr):
1417 t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr)
1418 }
1419 }
1420 }
1421
1422 func TestCipherSuites(t *testing.T) {
1423 var lastID uint16
1424 for _, c := range CipherSuites() {
1425 if lastID > c.ID {
1426 t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1427 } else {
1428 lastID = c.ID
1429 }
1430
1431 if c.Insecure {
1432 t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID)
1433 }
1434 }
1435 lastID = 0
1436 for _, c := range InsecureCipherSuites() {
1437 if lastID > c.ID {
1438 t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1439 } else {
1440 lastID = c.ID
1441 }
1442
1443 if !c.Insecure {
1444 t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID)
1445 }
1446 }
1447
1448 CipherSuiteByID := func(id uint16) *CipherSuite {
1449 for _, c := range CipherSuites() {
1450 if c.ID == id {
1451 return c
1452 }
1453 }
1454 for _, c := range InsecureCipherSuites() {
1455 if c.ID == id {
1456 return c
1457 }
1458 }
1459 return nil
1460 }
1461
1462 for _, c := range cipherSuites {
1463 cc := CipherSuiteByID(c.id)
1464 if cc == nil {
1465 t.Errorf("%#04x: no CipherSuite entry", c.id)
1466 continue
1467 }
1468
1469 if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 {
1470 t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1471 } else if !tls12Only && len(cc.SupportedVersions) != 3 {
1472 t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1473 }
1474
1475 if cc.Insecure {
1476 if slices.Contains(defaultCipherSuites(), c.id) {
1477 t.Errorf("%#04x: insecure suite in default list", c.id)
1478 }
1479 } else {
1480 if !slices.Contains(defaultCipherSuites(), c.id) {
1481 t.Errorf("%#04x: secure suite not in default list", c.id)
1482 }
1483 }
1484
1485 if got := CipherSuiteName(c.id); got != cc.Name {
1486 t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1487 }
1488 }
1489 for _, c := range cipherSuitesTLS13 {
1490 cc := CipherSuiteByID(c.id)
1491 if cc == nil {
1492 t.Errorf("%#04x: no CipherSuite entry", c.id)
1493 continue
1494 }
1495
1496 if cc.Insecure {
1497 t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure)
1498 }
1499 if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 {
1500 t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1501 }
1502
1503 if got := CipherSuiteName(c.id); got != cc.Name {
1504 t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1505 }
1506 }
1507
1508 if got := CipherSuiteName(0xabc); got != "0x0ABC" {
1509 t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got)
1510 }
1511
1512 if len(cipherSuitesPreferenceOrder) != len(cipherSuites) {
1513 t.Errorf("cipherSuitesPreferenceOrder is not the same size as cipherSuites")
1514 }
1515 if len(cipherSuitesPreferenceOrderNoAES) != len(cipherSuitesPreferenceOrder) {
1516 t.Errorf("cipherSuitesPreferenceOrderNoAES is not the same size as cipherSuitesPreferenceOrder")
1517 }
1518
1519
1520 for _, badSuites := range []map[uint16]bool{disabledCipherSuites, rsaKexCiphers} {
1521 for id := range badSuites {
1522 c := CipherSuiteByID(id)
1523 if c == nil {
1524 t.Errorf("%#04x: no CipherSuite entry", id)
1525 continue
1526 }
1527 if !c.Insecure {
1528 t.Errorf("%#04x: disabled by default but not marked insecure", id)
1529 }
1530 }
1531 }
1532
1533 for i, prefOrder := range [][]uint16{cipherSuitesPreferenceOrder, cipherSuitesPreferenceOrderNoAES} {
1534
1535
1536 var sawInsecure, sawBad bool
1537 for _, id := range prefOrder {
1538 c := CipherSuiteByID(id)
1539 if c == nil {
1540 t.Errorf("%#04x: no CipherSuite entry", id)
1541 continue
1542 }
1543
1544 if c.Insecure {
1545 sawInsecure = true
1546 } else if sawInsecure {
1547 t.Errorf("%#04x: secure suite after insecure one(s)", id)
1548 }
1549
1550 if http2isBadCipher(id) {
1551 sawBad = true
1552 } else if sawBad {
1553 t.Errorf("%#04x: non-bad suite after bad HTTP/2 one(s)", id)
1554 }
1555 }
1556
1557
1558 isBetter := func(a, b uint16) int {
1559 aSuite, bSuite := cipherSuiteByID(a), cipherSuiteByID(b)
1560 aName, bName := CipherSuiteName(a), CipherSuiteName(b)
1561
1562 if !strings.Contains(aName, "RC4") && strings.Contains(bName, "RC4") {
1563 return -1
1564 } else if strings.Contains(aName, "RC4") && !strings.Contains(bName, "RC4") {
1565 return +1
1566 }
1567
1568 if !strings.Contains(aName, "CBC_SHA256") && strings.Contains(bName, "CBC_SHA256") {
1569 return -1
1570 } else if strings.Contains(aName, "CBC_SHA256") && !strings.Contains(bName, "CBC_SHA256") {
1571 return +1
1572 }
1573
1574 if !strings.Contains(aName, "3DES") && strings.Contains(bName, "3DES") {
1575 return -1
1576 } else if strings.Contains(aName, "3DES") && !strings.Contains(bName, "3DES") {
1577 return +1
1578 }
1579
1580 if aSuite.flags&suiteECDHE != 0 && bSuite.flags&suiteECDHE == 0 {
1581 return -1
1582 } else if aSuite.flags&suiteECDHE == 0 && bSuite.flags&suiteECDHE != 0 {
1583 return +1
1584 }
1585
1586 if aSuite.aead != nil && bSuite.aead == nil {
1587 return -1
1588 } else if aSuite.aead == nil && bSuite.aead != nil {
1589 return +1
1590 }
1591
1592 if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") {
1593
1594 if i == 0 {
1595 return -1
1596 } else {
1597 return +1
1598 }
1599 } else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "AES") {
1600
1601 if i != 0 {
1602 return -1
1603 } else {
1604 return +1
1605 }
1606 }
1607
1608 if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") {
1609 return -1
1610 } else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") {
1611 return +1
1612 }
1613
1614 if aSuite.flags&suiteECSign != 0 && bSuite.flags&suiteECSign == 0 {
1615 return -1
1616 } else if aSuite.flags&suiteECSign == 0 && bSuite.flags&suiteECSign != 0 {
1617 return +1
1618 }
1619 t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName)
1620 panic("unreachable")
1621 }
1622 if !slices.IsSortedFunc(prefOrder, isBetter) {
1623 t.Error("preference order is not sorted according to the rules")
1624 }
1625 }
1626 }
1627
1628 func TestVersionName(t *testing.T) {
1629 if got, exp := VersionName(VersionTLS13), "TLS 1.3"; got != exp {
1630 t.Errorf("unexpected VersionName: got %q, expected %q", got, exp)
1631 }
1632 if got, exp := VersionName(0x12a), "0x012A"; got != exp {
1633 t.Errorf("unexpected fallback VersionName: got %q, expected %q", got, exp)
1634 }
1635 }
1636
1637
1638
1639 func http2isBadCipher(cipher uint16) bool {
1640 switch cipher {
1641 case TLS_RSA_WITH_RC4_128_SHA,
1642 TLS_RSA_WITH_3DES_EDE_CBC_SHA,
1643 TLS_RSA_WITH_AES_128_CBC_SHA,
1644 TLS_RSA_WITH_AES_256_CBC_SHA,
1645 TLS_RSA_WITH_AES_128_CBC_SHA256,
1646 TLS_RSA_WITH_AES_128_GCM_SHA256,
1647 TLS_RSA_WITH_AES_256_GCM_SHA384,
1648 TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
1649 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
1650 TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
1651 TLS_ECDHE_RSA_WITH_RC4_128_SHA,
1652 TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
1653 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
1654 TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
1655 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
1656 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
1657 return true
1658 default:
1659 return false
1660 }
1661 }
1662
1663 type brokenSigner struct{ crypto.Signer }
1664
1665 func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
1666
1667 return s.Signer.Sign(rand, digest, opts.HashFunc())
1668 }
1669
1670
1671
1672 func TestPKCS1OnlyCert(t *testing.T) {
1673 clientConfig := testConfig.Clone()
1674 clientConfig.Certificates = []Certificate{{
1675 Certificate: [][]byte{testRSACertificate},
1676 PrivateKey: brokenSigner{testRSAPrivateKey},
1677 }}
1678 serverConfig := testConfig.Clone()
1679 serverConfig.MaxVersion = VersionTLS12
1680 serverConfig.ClientAuth = RequireAnyClientCert
1681
1682
1683 if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil {
1684 t.Fatal("expected broken certificate to cause connection to fail")
1685 }
1686
1687 clientConfig.Certificates[0].SupportedSignatureAlgorithms =
1688 []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}
1689
1690
1691
1692 if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
1693 t.Error(err)
1694 }
1695 }
1696
1697 func TestVerifyCertificates(t *testing.T) {
1698
1699 t.Run("TLSv12", func(t *testing.T) { testVerifyCertificates(t, VersionTLS12) })
1700 t.Run("TLSv13", func(t *testing.T) { testVerifyCertificates(t, VersionTLS13) })
1701 }
1702
1703 func testVerifyCertificates(t *testing.T, version uint16) {
1704 tests := []struct {
1705 name string
1706
1707 InsecureSkipVerify bool
1708 ClientAuth ClientAuthType
1709 ClientCertificates bool
1710 }{
1711 {
1712 name: "defaults",
1713 },
1714 {
1715 name: "InsecureSkipVerify",
1716 InsecureSkipVerify: true,
1717 },
1718 {
1719 name: "RequestClientCert with no certs",
1720 ClientAuth: RequestClientCert,
1721 },
1722 {
1723 name: "RequestClientCert with certs",
1724 ClientAuth: RequestClientCert,
1725 ClientCertificates: true,
1726 },
1727 {
1728 name: "RequireAnyClientCert",
1729 ClientAuth: RequireAnyClientCert,
1730 ClientCertificates: true,
1731 },
1732 {
1733 name: "VerifyClientCertIfGiven with no certs",
1734 ClientAuth: VerifyClientCertIfGiven,
1735 },
1736 {
1737 name: "VerifyClientCertIfGiven with certs",
1738 ClientAuth: VerifyClientCertIfGiven,
1739 ClientCertificates: true,
1740 },
1741 {
1742 name: "RequireAndVerifyClientCert",
1743 ClientAuth: RequireAndVerifyClientCert,
1744 ClientCertificates: true,
1745 },
1746 }
1747
1748 issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1749 if err != nil {
1750 t.Fatal(err)
1751 }
1752 rootCAs := x509.NewCertPool()
1753 rootCAs.AddCert(issuer)
1754
1755 for _, test := range tests {
1756 test := test
1757 t.Run(test.name, func(t *testing.T) {
1758 t.Parallel()
1759
1760 var serverVerifyConnection, clientVerifyConnection bool
1761 var serverVerifyPeerCertificates, clientVerifyPeerCertificates bool
1762
1763 clientConfig := testConfig.Clone()
1764 clientConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
1765 clientConfig.MaxVersion = version
1766 clientConfig.MinVersion = version
1767 clientConfig.RootCAs = rootCAs
1768 clientConfig.ServerName = "example.golang"
1769 clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
1770 serverConfig := clientConfig.Clone()
1771 serverConfig.ClientCAs = rootCAs
1772
1773 clientConfig.VerifyConnection = func(cs ConnectionState) error {
1774 clientVerifyConnection = true
1775 return nil
1776 }
1777 clientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
1778 clientVerifyPeerCertificates = true
1779 return nil
1780 }
1781 serverConfig.VerifyConnection = func(cs ConnectionState) error {
1782 serverVerifyConnection = true
1783 return nil
1784 }
1785 serverConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
1786 serverVerifyPeerCertificates = true
1787 return nil
1788 }
1789
1790 clientConfig.InsecureSkipVerify = test.InsecureSkipVerify
1791 serverConfig.ClientAuth = test.ClientAuth
1792 if !test.ClientCertificates {
1793 clientConfig.Certificates = nil
1794 }
1795
1796 if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
1797 t.Fatal(err)
1798 }
1799
1800 want := serverConfig.ClientAuth != NoClientCert
1801 if serverVerifyPeerCertificates != want {
1802 t.Errorf("VerifyPeerCertificates on the server: got %v, want %v",
1803 serverVerifyPeerCertificates, want)
1804 }
1805 if !clientVerifyPeerCertificates {
1806 t.Errorf("VerifyPeerCertificates not called on the client")
1807 }
1808 if !serverVerifyConnection {
1809 t.Error("VerifyConnection did not get called on the server")
1810 }
1811 if !clientVerifyConnection {
1812 t.Error("VerifyConnection did not get called on the client")
1813 }
1814
1815 serverVerifyPeerCertificates, clientVerifyPeerCertificates = false, false
1816 serverVerifyConnection, clientVerifyConnection = false, false
1817 cs, _, err := testHandshake(t, clientConfig, serverConfig)
1818 if err != nil {
1819 t.Fatal(err)
1820 }
1821 if !cs.DidResume {
1822 t.Error("expected resumption")
1823 }
1824
1825 if serverVerifyPeerCertificates {
1826 t.Error("VerifyPeerCertificates got called on the server on resumption")
1827 }
1828 if clientVerifyPeerCertificates {
1829 t.Error("VerifyPeerCertificates got called on the client on resumption")
1830 }
1831 if !serverVerifyConnection {
1832 t.Error("VerifyConnection did not get called on the server on resumption")
1833 }
1834 if !clientVerifyConnection {
1835 t.Error("VerifyConnection did not get called on the client on resumption")
1836 }
1837 })
1838 }
1839 }
1840
1841 func TestHandshakeKyber(t *testing.T) {
1842 if x25519Kyber768Draft00.String() != "X25519Kyber768Draft00" {
1843 t.Fatalf("unexpected CurveID string: %v", x25519Kyber768Draft00.String())
1844 }
1845
1846 var tests = []struct {
1847 name string
1848 clientConfig func(*Config)
1849 serverConfig func(*Config)
1850 preparation func(*testing.T)
1851 expectClientSupport bool
1852 expectKyber bool
1853 expectHRR bool
1854 }{
1855 {
1856 name: "Default",
1857 expectClientSupport: true,
1858 expectKyber: true,
1859 expectHRR: false,
1860 },
1861 {
1862 name: "ClientCurvePreferences",
1863 clientConfig: func(config *Config) {
1864 config.CurvePreferences = []CurveID{X25519}
1865 },
1866 expectClientSupport: false,
1867 },
1868 {
1869 name: "ServerCurvePreferencesX25519",
1870 serverConfig: func(config *Config) {
1871 config.CurvePreferences = []CurveID{X25519}
1872 },
1873 expectClientSupport: true,
1874 expectKyber: false,
1875 expectHRR: false,
1876 },
1877 {
1878 name: "ServerCurvePreferencesHRR",
1879 serverConfig: func(config *Config) {
1880 config.CurvePreferences = []CurveID{CurveP256}
1881 },
1882 expectClientSupport: true,
1883 expectKyber: false,
1884 expectHRR: true,
1885 },
1886 {
1887 name: "ClientTLSv12",
1888 clientConfig: func(config *Config) {
1889 config.MaxVersion = VersionTLS12
1890 },
1891 expectClientSupport: false,
1892 },
1893 {
1894 name: "ServerTLSv12",
1895 serverConfig: func(config *Config) {
1896 config.MaxVersion = VersionTLS12
1897 },
1898 expectClientSupport: true,
1899 expectKyber: false,
1900 },
1901 {
1902 name: "GODEBUG",
1903 preparation: func(t *testing.T) {
1904 t.Setenv("GODEBUG", "tlskyber=0")
1905 },
1906 expectClientSupport: false,
1907 },
1908 }
1909
1910 baseConfig := testConfig.Clone()
1911 baseConfig.CurvePreferences = nil
1912 for _, test := range tests {
1913 t.Run(test.name, func(t *testing.T) {
1914 if test.preparation != nil {
1915 test.preparation(t)
1916 } else {
1917 t.Parallel()
1918 }
1919 serverConfig := baseConfig.Clone()
1920 if test.serverConfig != nil {
1921 test.serverConfig(serverConfig)
1922 }
1923 serverConfig.GetConfigForClient = func(hello *ClientHelloInfo) (*Config, error) {
1924 if !test.expectClientSupport && slices.Contains(hello.SupportedCurves, x25519Kyber768Draft00) {
1925 return nil, errors.New("client supports Kyber768Draft00")
1926 } else if test.expectClientSupport && !slices.Contains(hello.SupportedCurves, x25519Kyber768Draft00) {
1927 return nil, errors.New("client does not support Kyber768Draft00")
1928 }
1929 return nil, nil
1930 }
1931 clientConfig := baseConfig.Clone()
1932 if test.clientConfig != nil {
1933 test.clientConfig(clientConfig)
1934 }
1935 ss, cs, err := testHandshake(t, clientConfig, serverConfig)
1936 if err != nil {
1937 t.Fatal(err)
1938 }
1939 if test.expectKyber {
1940 if ss.testingOnlyCurveID != x25519Kyber768Draft00 {
1941 t.Errorf("got CurveID %v (server), expected %v", ss.testingOnlyCurveID, x25519Kyber768Draft00)
1942 }
1943 if cs.testingOnlyCurveID != x25519Kyber768Draft00 {
1944 t.Errorf("got CurveID %v (client), expected %v", cs.testingOnlyCurveID, x25519Kyber768Draft00)
1945 }
1946 } else {
1947 if ss.testingOnlyCurveID == x25519Kyber768Draft00 {
1948 t.Errorf("got CurveID %v (server), expected not Kyber", ss.testingOnlyCurveID)
1949 }
1950 if cs.testingOnlyCurveID == x25519Kyber768Draft00 {
1951 t.Errorf("got CurveID %v (client), expected not Kyber", cs.testingOnlyCurveID)
1952 }
1953 }
1954 if test.expectHRR {
1955 if !ss.testingOnlyDidHRR {
1956 t.Error("server did not use HRR")
1957 }
1958 if !cs.testingOnlyDidHRR {
1959 t.Error("client did not use HRR")
1960 }
1961 } else {
1962 if ss.testingOnlyDidHRR {
1963 t.Error("server used HRR")
1964 }
1965 if cs.testingOnlyDidHRR {
1966 t.Error("client used HRR")
1967 }
1968 }
1969 })
1970 }
1971 }
1972
1973 func TestX509KeyPairPopulateCertificate(t *testing.T) {
1974 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
1975 if err != nil {
1976 t.Fatal(err)
1977 }
1978 keyDER, err := x509.MarshalPKCS8PrivateKey(key)
1979 if err != nil {
1980 t.Fatal(err)
1981 }
1982 keyPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyDER})
1983 tmpl := &x509.Certificate{
1984 SerialNumber: big.NewInt(1),
1985 Subject: pkix.Name{CommonName: "test"},
1986 }
1987 certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key)
1988 if err != nil {
1989 t.Fatal(err)
1990 }
1991 certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
1992
1993 t.Run("x509keypairleaf=0", func(t *testing.T) {
1994 t.Setenv("GODEBUG", "x509keypairleaf=0")
1995 cert, err := X509KeyPair(certPEM, keyPEM)
1996 if err != nil {
1997 t.Fatal(err)
1998 }
1999 if cert.Leaf != nil {
2000 t.Fatal("Leaf should not be populated")
2001 }
2002 })
2003 t.Run("x509keypairleaf=1", func(t *testing.T) {
2004 t.Setenv("GODEBUG", "x509keypairleaf=1")
2005 cert, err := X509KeyPair(certPEM, keyPEM)
2006 if err != nil {
2007 t.Fatal(err)
2008 }
2009 if cert.Leaf == nil {
2010 t.Fatal("Leaf should be populated")
2011 }
2012 })
2013 t.Run("GODEBUG unset", func(t *testing.T) {
2014 cert, err := X509KeyPair(certPEM, keyPEM)
2015 if err != nil {
2016 t.Fatal(err)
2017 }
2018 if cert.Leaf == nil {
2019 t.Fatal("Leaf should be populated")
2020 }
2021 })
2022 }
2023
2024 func TestEarlyLargeCertMsg(t *testing.T) {
2025 client, server := localPipe(t)
2026
2027 go func() {
2028 if _, err := client.Write([]byte{byte(recordTypeHandshake), 3, 4, 0, 4, typeCertificate, 1, 255, 255}); err != nil {
2029 t.Log(err)
2030 }
2031 }()
2032
2033 expectedErr := "tls: handshake message of length 131071 bytes exceeds maximum of 65536 bytes"
2034 servConn := Server(server, testConfig)
2035 err := servConn.Handshake()
2036 if err == nil {
2037 t.Fatal("unexpected success")
2038 }
2039 if err.Error() != expectedErr {
2040 t.Fatalf("unexpected error: got %q, want %q", err, expectedErr)
2041 }
2042 }
2043
2044 func TestLargeCertMsg(t *testing.T) {
2045 k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
2046 if err != nil {
2047 t.Fatal(err)
2048 }
2049 tmpl := &x509.Certificate{
2050 SerialNumber: big.NewInt(1),
2051 Subject: pkix.Name{CommonName: "test"},
2052 ExtraExtensions: []pkix.Extension{
2053 {
2054 Id: asn1.ObjectIdentifier{1, 2, 3},
2055
2056
2057 Value: make([]byte, 65536),
2058 },
2059 },
2060 }
2061 cert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, k.Public(), k)
2062 if err != nil {
2063 t.Fatal(err)
2064 }
2065
2066 clientConfig, serverConfig := testConfig.Clone(), testConfig.Clone()
2067 clientConfig.InsecureSkipVerify = true
2068 serverConfig.Certificates = []Certificate{
2069 {
2070 Certificate: [][]byte{cert},
2071 PrivateKey: k,
2072 },
2073 }
2074 if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
2075 t.Fatalf("unexpected failure :%s", err)
2076 }
2077 }
2078
View as plain text