1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "context"
12 "crypto/cipher"
13 "crypto/subtle"
14 "crypto/x509"
15 "errors"
16 "fmt"
17 "hash"
18 "internal/godebug"
19 "io"
20 "net"
21 "sync"
22 "sync/atomic"
23 "time"
24 )
25
26
27
28 type Conn struct {
29
30 conn net.Conn
31 isClient bool
32 handshakeFn func(context.Context) error
33 quic *quicState
34
35
36
37
38 isHandshakeComplete atomic.Bool
39
40 handshakeMutex sync.Mutex
41 handshakeErr error
42 vers uint16
43 haveVers bool
44 config *Config
45
46
47
48 handshakes int
49 extMasterSecret bool
50 didResume bool
51 didHRR bool
52 cipherSuite uint16
53 curveID CurveID
54 ocspResponse []byte
55 scts [][]byte
56 peerCertificates []*x509.Certificate
57
58
59 activeCertHandles []*activeCert
60
61
62 verifiedChains [][]*x509.Certificate
63
64 serverName string
65
66
67
68 secureRenegotiation bool
69
70 ekm func(label string, context []byte, length int) ([]byte, error)
71
72
73 resumptionSecret []byte
74 echAccepted bool
75
76
77
78
79 ticketKeys []ticketKey
80
81
82
83
84
85 clientFinishedIsFirst bool
86
87
88 closeNotifyErr error
89
90
91 closeNotifySent bool
92
93
94
95
96
97 clientFinished [12]byte
98 serverFinished [12]byte
99
100
101 clientProtocol string
102
103
104 in, out halfConn
105 rawInput bytes.Buffer
106 input bytes.Reader
107 hand bytes.Buffer
108 buffering bool
109 sendBuf []byte
110
111
112
113 bytesSent int64
114 packetsSent int64
115
116
117
118
119 retryCount int
120
121
122
123 activeCall atomic.Int32
124
125 tmp [16]byte
126 }
127
128
129
130
131
132
133 func (c *Conn) LocalAddr() net.Addr {
134 return c.conn.LocalAddr()
135 }
136
137
138 func (c *Conn) RemoteAddr() net.Addr {
139 return c.conn.RemoteAddr()
140 }
141
142
143
144
145 func (c *Conn) SetDeadline(t time.Time) error {
146 return c.conn.SetDeadline(t)
147 }
148
149
150
151 func (c *Conn) SetReadDeadline(t time.Time) error {
152 return c.conn.SetReadDeadline(t)
153 }
154
155
156
157
158 func (c *Conn) SetWriteDeadline(t time.Time) error {
159 return c.conn.SetWriteDeadline(t)
160 }
161
162
163
164
165 func (c *Conn) NetConn() net.Conn {
166 return c.conn
167 }
168
169
170
171 type halfConn struct {
172 sync.Mutex
173
174 err error
175 version uint16
176 cipher any
177 mac hash.Hash
178 seq [8]byte
179
180 scratchBuf [13]byte
181
182 nextCipher any
183 nextMac hash.Hash
184
185 level QUICEncryptionLevel
186 trafficSecret []byte
187 }
188
189 type permanentError struct {
190 err net.Error
191 }
192
193 func (e *permanentError) Error() string { return e.err.Error() }
194 func (e *permanentError) Unwrap() error { return e.err }
195 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
196 func (e *permanentError) Temporary() bool { return false }
197
198 func (hc *halfConn) setErrorLocked(err error) error {
199 if e, ok := err.(net.Error); ok {
200 hc.err = &permanentError{err: e}
201 } else {
202 hc.err = err
203 }
204 return hc.err
205 }
206
207
208
209 func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
210 hc.version = version
211 hc.nextCipher = cipher
212 hc.nextMac = mac
213 }
214
215
216
217 func (hc *halfConn) changeCipherSpec() error {
218 if hc.nextCipher == nil || hc.version == VersionTLS13 {
219 return alertInternalError
220 }
221 hc.cipher = hc.nextCipher
222 hc.mac = hc.nextMac
223 hc.nextCipher = nil
224 hc.nextMac = nil
225 for i := range hc.seq {
226 hc.seq[i] = 0
227 }
228 return nil
229 }
230
231 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
232 hc.trafficSecret = secret
233 hc.level = level
234 key, iv := suite.trafficKey(secret)
235 hc.cipher = suite.aead(key, iv)
236 for i := range hc.seq {
237 hc.seq[i] = 0
238 }
239 }
240
241
242 func (hc *halfConn) incSeq() {
243 for i := 7; i >= 0; i-- {
244 hc.seq[i]++
245 if hc.seq[i] != 0 {
246 return
247 }
248 }
249
250
251
252
253 panic("TLS: sequence number wraparound")
254 }
255
256
257
258
259 func (hc *halfConn) explicitNonceLen() int {
260 if hc.cipher == nil {
261 return 0
262 }
263
264 switch c := hc.cipher.(type) {
265 case cipher.Stream:
266 return 0
267 case aead:
268 return c.explicitNonceLen()
269 case cbcMode:
270
271 if hc.version >= VersionTLS11 {
272 return c.BlockSize()
273 }
274 return 0
275 default:
276 panic("unknown cipher type")
277 }
278 }
279
280
281
282
283 func extractPadding(payload []byte) (toRemove int, good byte) {
284 if len(payload) < 1 {
285 return 0, 0
286 }
287
288 paddingLen := payload[len(payload)-1]
289 t := uint(len(payload)-1) - uint(paddingLen)
290
291 good = byte(int32(^t) >> 31)
292
293
294 toCheck := 256
295
296 if toCheck > len(payload) {
297 toCheck = len(payload)
298 }
299
300 for i := 0; i < toCheck; i++ {
301 t := uint(paddingLen) - uint(i)
302
303 mask := byte(int32(^t) >> 31)
304 b := payload[len(payload)-1-i]
305 good &^= mask&paddingLen ^ mask&b
306 }
307
308
309
310 good &= good << 4
311 good &= good << 2
312 good &= good << 1
313 good = uint8(int8(good) >> 7)
314
315
316
317
318
319
320
321
322
323
324 paddingLen &= good
325
326 toRemove = int(paddingLen) + 1
327 return
328 }
329
330 func roundUp(a, b int) int {
331 return a + (b-a%b)%b
332 }
333
334
335 type cbcMode interface {
336 cipher.BlockMode
337 SetIV([]byte)
338 }
339
340
341
342 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
343 var plaintext []byte
344 typ := recordType(record[0])
345 payload := record[recordHeaderLen:]
346
347
348
349 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
350 return payload, typ, nil
351 }
352
353 paddingGood := byte(255)
354 paddingLen := 0
355
356 explicitNonceLen := hc.explicitNonceLen()
357
358 if hc.cipher != nil {
359 switch c := hc.cipher.(type) {
360 case cipher.Stream:
361 c.XORKeyStream(payload, payload)
362 case aead:
363 if len(payload) < explicitNonceLen {
364 return nil, 0, alertBadRecordMAC
365 }
366 nonce := payload[:explicitNonceLen]
367 if len(nonce) == 0 {
368 nonce = hc.seq[:]
369 }
370 payload = payload[explicitNonceLen:]
371
372 var additionalData []byte
373 if hc.version == VersionTLS13 {
374 additionalData = record[:recordHeaderLen]
375 } else {
376 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
377 additionalData = append(additionalData, record[:3]...)
378 n := len(payload) - c.Overhead()
379 additionalData = append(additionalData, byte(n>>8), byte(n))
380 }
381
382 var err error
383 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
384 if err != nil {
385 return nil, 0, alertBadRecordMAC
386 }
387 case cbcMode:
388 blockSize := c.BlockSize()
389 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
390 if len(payload)%blockSize != 0 || len(payload) < minPayload {
391 return nil, 0, alertBadRecordMAC
392 }
393
394 if explicitNonceLen > 0 {
395 c.SetIV(payload[:explicitNonceLen])
396 payload = payload[explicitNonceLen:]
397 }
398 c.CryptBlocks(payload, payload)
399
400
401
402
403
404
405
406 paddingLen, paddingGood = extractPadding(payload)
407 default:
408 panic("unknown cipher type")
409 }
410
411 if hc.version == VersionTLS13 {
412 if typ != recordTypeApplicationData {
413 return nil, 0, alertUnexpectedMessage
414 }
415 if len(plaintext) > maxPlaintext+1 {
416 return nil, 0, alertRecordOverflow
417 }
418
419 for i := len(plaintext) - 1; i >= 0; i-- {
420 if plaintext[i] != 0 {
421 typ = recordType(plaintext[i])
422 plaintext = plaintext[:i]
423 break
424 }
425 if i == 0 {
426 return nil, 0, alertUnexpectedMessage
427 }
428 }
429 }
430 } else {
431 plaintext = payload
432 }
433
434 if hc.mac != nil {
435 macSize := hc.mac.Size()
436 if len(payload) < macSize {
437 return nil, 0, alertBadRecordMAC
438 }
439
440 n := len(payload) - macSize - paddingLen
441 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
442 record[3] = byte(n >> 8)
443 record[4] = byte(n)
444 remoteMAC := payload[n : n+macSize]
445 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
446
447
448
449
450
451
452
453
454 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
455 if macAndPaddingGood != 1 {
456 return nil, 0, alertBadRecordMAC
457 }
458
459 plaintext = payload[:n]
460 }
461
462 hc.incSeq()
463 return plaintext, typ, nil
464 }
465
466
467
468
469 func sliceForAppend(in []byte, n int) (head, tail []byte) {
470 if total := len(in) + n; cap(in) >= total {
471 head = in[:total]
472 } else {
473 head = make([]byte, total)
474 copy(head, in)
475 }
476 tail = head[len(in):]
477 return
478 }
479
480
481
482 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
483 if hc.cipher == nil {
484 return append(record, payload...), nil
485 }
486
487 var explicitNonce []byte
488 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
489 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
490 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
491
492
493
494
495
496
497
498
499
500 copy(explicitNonce, hc.seq[:])
501 } else {
502 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
503 return nil, err
504 }
505 }
506 }
507
508 var dst []byte
509 switch c := hc.cipher.(type) {
510 case cipher.Stream:
511 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
512 record, dst = sliceForAppend(record, len(payload)+len(mac))
513 c.XORKeyStream(dst[:len(payload)], payload)
514 c.XORKeyStream(dst[len(payload):], mac)
515 case aead:
516 nonce := explicitNonce
517 if len(nonce) == 0 {
518 nonce = hc.seq[:]
519 }
520
521 if hc.version == VersionTLS13 {
522 record = append(record, payload...)
523
524
525 record = append(record, record[0])
526 record[0] = byte(recordTypeApplicationData)
527
528 n := len(payload) + 1 + c.Overhead()
529 record[3] = byte(n >> 8)
530 record[4] = byte(n)
531
532 record = c.Seal(record[:recordHeaderLen],
533 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
534 } else {
535 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
536 additionalData = append(additionalData, record[:recordHeaderLen]...)
537 record = c.Seal(record, nonce, payload, additionalData)
538 }
539 case cbcMode:
540 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
541 blockSize := c.BlockSize()
542 plaintextLen := len(payload) + len(mac)
543 paddingLen := blockSize - plaintextLen%blockSize
544 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
545 copy(dst, payload)
546 copy(dst[len(payload):], mac)
547 for i := plaintextLen; i < len(dst); i++ {
548 dst[i] = byte(paddingLen - 1)
549 }
550 if len(explicitNonce) > 0 {
551 c.SetIV(explicitNonce)
552 }
553 c.CryptBlocks(dst, dst)
554 default:
555 panic("unknown cipher type")
556 }
557
558
559 n := len(record) - recordHeaderLen
560 record[3] = byte(n >> 8)
561 record[4] = byte(n)
562 hc.incSeq()
563
564 return record, nil
565 }
566
567
568 type RecordHeaderError struct {
569
570 Msg string
571
572
573 RecordHeader [5]byte
574
575
576
577
578 Conn net.Conn
579 }
580
581 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
582
583 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
584 err.Msg = msg
585 err.Conn = conn
586 copy(err.RecordHeader[:], c.rawInput.Bytes())
587 return err
588 }
589
590 func (c *Conn) readRecord() error {
591 return c.readRecordOrCCS(false)
592 }
593
594 func (c *Conn) readChangeCipherSpec() error {
595 return c.readRecordOrCCS(true)
596 }
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
613 if c.in.err != nil {
614 return c.in.err
615 }
616 handshakeComplete := c.isHandshakeComplete.Load()
617
618
619 if c.input.Len() != 0 {
620 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
621 }
622 c.input.Reset(nil)
623
624 if c.quic != nil {
625 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
626 }
627
628
629 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
630
631
632
633 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
634 err = io.EOF
635 }
636 if e, ok := err.(net.Error); !ok || !e.Temporary() {
637 c.in.setErrorLocked(err)
638 }
639 return err
640 }
641 hdr := c.rawInput.Bytes()[:recordHeaderLen]
642 typ := recordType(hdr[0])
643
644
645
646
647
648 if !handshakeComplete && typ == 0x80 {
649 c.sendAlert(alertProtocolVersion)
650 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
651 }
652
653 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
654 expectedVers := c.vers
655 if expectedVers == VersionTLS13 {
656
657
658 expectedVers = VersionTLS12
659 }
660 n := int(hdr[3])<<8 | int(hdr[4])
661 if c.haveVers && vers != expectedVers {
662 c.sendAlert(alertProtocolVersion)
663 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
664 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
665 }
666 if !c.haveVers {
667
668
669
670
671 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
672 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
673 }
674 }
675 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
676 c.sendAlert(alertRecordOverflow)
677 msg := fmt.Sprintf("oversized record received with length %d", n)
678 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
679 }
680 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
681 if e, ok := err.(net.Error); !ok || !e.Temporary() {
682 c.in.setErrorLocked(err)
683 }
684 return err
685 }
686
687
688 record := c.rawInput.Next(recordHeaderLen + n)
689 data, typ, err := c.in.decrypt(record)
690 if err != nil {
691 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
692 }
693 if len(data) > maxPlaintext {
694 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
695 }
696
697
698 if c.in.cipher == nil && typ == recordTypeApplicationData {
699 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
700 }
701
702 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
703
704 c.retryCount = 0
705 }
706
707
708 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
709 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
710 }
711
712 switch typ {
713 default:
714 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
715
716 case recordTypeAlert:
717 if c.quic != nil {
718 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
719 }
720 if len(data) != 2 {
721 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
722 }
723 if alert(data[1]) == alertCloseNotify {
724 return c.in.setErrorLocked(io.EOF)
725 }
726 if c.vers == VersionTLS13 {
727 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
728 }
729 switch data[0] {
730 case alertLevelWarning:
731
732 return c.retryReadRecord(expectChangeCipherSpec)
733 case alertLevelError:
734 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
735 default:
736 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
737 }
738
739 case recordTypeChangeCipherSpec:
740 if len(data) != 1 || data[0] != 1 {
741 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
742 }
743
744 if c.hand.Len() > 0 {
745 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
746 }
747
748
749
750
751
752 if c.vers == VersionTLS13 {
753 return c.retryReadRecord(expectChangeCipherSpec)
754 }
755 if !expectChangeCipherSpec {
756 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
757 }
758 if err := c.in.changeCipherSpec(); err != nil {
759 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
760 }
761
762 case recordTypeApplicationData:
763 if !handshakeComplete || expectChangeCipherSpec {
764 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
765 }
766
767
768 if len(data) == 0 {
769 return c.retryReadRecord(expectChangeCipherSpec)
770 }
771
772
773
774 c.input.Reset(data)
775
776 case recordTypeHandshake:
777 if len(data) == 0 || expectChangeCipherSpec {
778 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
779 }
780 c.hand.Write(data)
781 }
782
783 return nil
784 }
785
786
787
788 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
789 c.retryCount++
790 if c.retryCount > maxUselessRecords {
791 c.sendAlert(alertUnexpectedMessage)
792 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
793 }
794 return c.readRecordOrCCS(expectChangeCipherSpec)
795 }
796
797
798
799
800 type atLeastReader struct {
801 R io.Reader
802 N int64
803 }
804
805 func (r *atLeastReader) Read(p []byte) (int, error) {
806 if r.N <= 0 {
807 return 0, io.EOF
808 }
809 n, err := r.R.Read(p)
810 r.N -= int64(n)
811 if r.N > 0 && err == io.EOF {
812 return n, io.ErrUnexpectedEOF
813 }
814 if r.N <= 0 && err == nil {
815 return n, io.EOF
816 }
817 return n, err
818 }
819
820
821
822 func (c *Conn) readFromUntil(r io.Reader, n int) error {
823 if c.rawInput.Len() >= n {
824 return nil
825 }
826 needs := n - c.rawInput.Len()
827
828
829
830 c.rawInput.Grow(needs + bytes.MinRead)
831 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
832 return err
833 }
834
835
836 func (c *Conn) sendAlertLocked(err alert) error {
837 if c.quic != nil {
838 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
839 }
840
841 switch err {
842 case alertNoRenegotiation, alertCloseNotify:
843 c.tmp[0] = alertLevelWarning
844 default:
845 c.tmp[0] = alertLevelError
846 }
847 c.tmp[1] = byte(err)
848
849 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
850 if err == alertCloseNotify {
851
852 return writeErr
853 }
854
855 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
856 }
857
858
859 func (c *Conn) sendAlert(err alert) error {
860 c.out.Lock()
861 defer c.out.Unlock()
862 return c.sendAlertLocked(err)
863 }
864
865 const (
866
867
868
869
870
871 tcpMSSEstimate = 1208
872
873
874
875
876 recordSizeBoostThreshold = 128 * 1024
877 )
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
896 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
897 return maxPlaintext
898 }
899
900 if c.bytesSent >= recordSizeBoostThreshold {
901 return maxPlaintext
902 }
903
904
905 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
906 if c.out.cipher != nil {
907 switch ciph := c.out.cipher.(type) {
908 case cipher.Stream:
909 payloadBytes -= c.out.mac.Size()
910 case cipher.AEAD:
911 payloadBytes -= ciph.Overhead()
912 case cbcMode:
913 blockSize := ciph.BlockSize()
914
915
916 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
917
918
919 payloadBytes -= c.out.mac.Size()
920 default:
921 panic("unknown cipher type")
922 }
923 }
924 if c.vers == VersionTLS13 {
925 payloadBytes--
926 }
927
928
929 pkt := c.packetsSent
930 c.packetsSent++
931 if pkt > 1000 {
932 return maxPlaintext
933 }
934
935 n := payloadBytes * int(pkt+1)
936 if n > maxPlaintext {
937 n = maxPlaintext
938 }
939 return n
940 }
941
942 func (c *Conn) write(data []byte) (int, error) {
943 if c.buffering {
944 c.sendBuf = append(c.sendBuf, data...)
945 return len(data), nil
946 }
947
948 n, err := c.conn.Write(data)
949 c.bytesSent += int64(n)
950 return n, err
951 }
952
953 func (c *Conn) flush() (int, error) {
954 if len(c.sendBuf) == 0 {
955 return 0, nil
956 }
957
958 n, err := c.conn.Write(c.sendBuf)
959 c.bytesSent += int64(n)
960 c.sendBuf = nil
961 c.buffering = false
962 return n, err
963 }
964
965
966 var outBufPool = sync.Pool{
967 New: func() any {
968 return new([]byte)
969 },
970 }
971
972
973
974 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
975 if c.quic != nil {
976 if typ != recordTypeHandshake {
977 return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
978 }
979 c.quicWriteCryptoData(c.out.level, data)
980 if !c.buffering {
981 if _, err := c.flush(); err != nil {
982 return 0, err
983 }
984 }
985 return len(data), nil
986 }
987
988 outBufPtr := outBufPool.Get().(*[]byte)
989 outBuf := *outBufPtr
990 defer func() {
991
992
993
994
995
996 *outBufPtr = outBuf
997 outBufPool.Put(outBufPtr)
998 }()
999
1000 var n int
1001 for len(data) > 0 {
1002 m := len(data)
1003 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
1004 m = maxPayload
1005 }
1006
1007 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
1008 outBuf[0] = byte(typ)
1009 vers := c.vers
1010 if vers == 0 {
1011
1012
1013 vers = VersionTLS10
1014 } else if vers == VersionTLS13 {
1015
1016
1017 vers = VersionTLS12
1018 }
1019 outBuf[1] = byte(vers >> 8)
1020 outBuf[2] = byte(vers)
1021 outBuf[3] = byte(m >> 8)
1022 outBuf[4] = byte(m)
1023
1024 var err error
1025 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
1026 if err != nil {
1027 return n, err
1028 }
1029 if _, err := c.write(outBuf); err != nil {
1030 return n, err
1031 }
1032 n += m
1033 data = data[m:]
1034 }
1035
1036 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
1037 if err := c.out.changeCipherSpec(); err != nil {
1038 return n, c.sendAlertLocked(err.(alert))
1039 }
1040 }
1041
1042 return n, nil
1043 }
1044
1045
1046
1047
1048 func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
1049 c.out.Lock()
1050 defer c.out.Unlock()
1051
1052 data, err := msg.marshal()
1053 if err != nil {
1054 return 0, err
1055 }
1056 if transcript != nil {
1057 transcript.Write(data)
1058 }
1059
1060 return c.writeRecordLocked(recordTypeHandshake, data)
1061 }
1062
1063
1064
1065 func (c *Conn) writeChangeCipherRecord() error {
1066 c.out.Lock()
1067 defer c.out.Unlock()
1068 _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
1069 return err
1070 }
1071
1072
1073 func (c *Conn) readHandshakeBytes(n int) error {
1074 if c.quic != nil {
1075 return c.quicReadHandshakeBytes(n)
1076 }
1077 for c.hand.Len() < n {
1078 if err := c.readRecord(); err != nil {
1079 return err
1080 }
1081 }
1082 return nil
1083 }
1084
1085
1086
1087
1088 func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
1089 if err := c.readHandshakeBytes(4); err != nil {
1090 return nil, err
1091 }
1092 data := c.hand.Bytes()
1093
1094 maxHandshakeSize := maxHandshake
1095
1096
1097
1098 if c.haveVers && data[0] == typeCertificate {
1099
1100
1101
1102 maxHandshakeSize = maxHandshakeCertificateMsg
1103 }
1104
1105 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1106 if n > maxHandshakeSize {
1107 c.sendAlertLocked(alertInternalError)
1108 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize))
1109 }
1110 if err := c.readHandshakeBytes(4 + n); err != nil {
1111 return nil, err
1112 }
1113 data = c.hand.Next(4 + n)
1114 return c.unmarshalHandshakeMessage(data, transcript)
1115 }
1116
1117 func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
1118 var m handshakeMessage
1119 switch data[0] {
1120 case typeHelloRequest:
1121 m = new(helloRequestMsg)
1122 case typeClientHello:
1123 m = new(clientHelloMsg)
1124 case typeServerHello:
1125 m = new(serverHelloMsg)
1126 case typeNewSessionTicket:
1127 if c.vers == VersionTLS13 {
1128 m = new(newSessionTicketMsgTLS13)
1129 } else {
1130 m = new(newSessionTicketMsg)
1131 }
1132 case typeCertificate:
1133 if c.vers == VersionTLS13 {
1134 m = new(certificateMsgTLS13)
1135 } else {
1136 m = new(certificateMsg)
1137 }
1138 case typeCertificateRequest:
1139 if c.vers == VersionTLS13 {
1140 m = new(certificateRequestMsgTLS13)
1141 } else {
1142 m = &certificateRequestMsg{
1143 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1144 }
1145 }
1146 case typeCertificateStatus:
1147 m = new(certificateStatusMsg)
1148 case typeServerKeyExchange:
1149 m = new(serverKeyExchangeMsg)
1150 case typeServerHelloDone:
1151 m = new(serverHelloDoneMsg)
1152 case typeClientKeyExchange:
1153 m = new(clientKeyExchangeMsg)
1154 case typeCertificateVerify:
1155 m = &certificateVerifyMsg{
1156 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1157 }
1158 case typeFinished:
1159 m = new(finishedMsg)
1160 case typeEncryptedExtensions:
1161 m = new(encryptedExtensionsMsg)
1162 case typeEndOfEarlyData:
1163 m = new(endOfEarlyDataMsg)
1164 case typeKeyUpdate:
1165 m = new(keyUpdateMsg)
1166 default:
1167 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1168 }
1169
1170
1171
1172
1173 data = append([]byte(nil), data...)
1174
1175 if !m.unmarshal(data) {
1176 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1177 }
1178
1179 if transcript != nil {
1180 transcript.Write(data)
1181 }
1182
1183 return m, nil
1184 }
1185
1186 var (
1187 errShutdown = errors.New("tls: protocol is shutdown")
1188 )
1189
1190
1191
1192
1193
1194
1195
1196 func (c *Conn) Write(b []byte) (int, error) {
1197
1198 for {
1199 x := c.activeCall.Load()
1200 if x&1 != 0 {
1201 return 0, net.ErrClosed
1202 }
1203 if c.activeCall.CompareAndSwap(x, x+2) {
1204 break
1205 }
1206 }
1207 defer c.activeCall.Add(-2)
1208
1209 if err := c.Handshake(); err != nil {
1210 return 0, err
1211 }
1212
1213 c.out.Lock()
1214 defer c.out.Unlock()
1215
1216 if err := c.out.err; err != nil {
1217 return 0, err
1218 }
1219
1220 if !c.isHandshakeComplete.Load() {
1221 return 0, alertInternalError
1222 }
1223
1224 if c.closeNotifySent {
1225 return 0, errShutdown
1226 }
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237 var m int
1238 if len(b) > 1 && c.vers == VersionTLS10 {
1239 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1240 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1241 if err != nil {
1242 return n, c.out.setErrorLocked(err)
1243 }
1244 m, b = 1, b[1:]
1245 }
1246 }
1247
1248 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1249 return n + m, c.out.setErrorLocked(err)
1250 }
1251
1252
1253 func (c *Conn) handleRenegotiation() error {
1254 if c.vers == VersionTLS13 {
1255 return errors.New("tls: internal error: unexpected renegotiation")
1256 }
1257
1258 msg, err := c.readHandshake(nil)
1259 if err != nil {
1260 return err
1261 }
1262
1263 helloReq, ok := msg.(*helloRequestMsg)
1264 if !ok {
1265 c.sendAlert(alertUnexpectedMessage)
1266 return unexpectedMessageError(helloReq, msg)
1267 }
1268
1269 if !c.isClient {
1270 return c.sendAlert(alertNoRenegotiation)
1271 }
1272
1273 switch c.config.Renegotiation {
1274 case RenegotiateNever:
1275 return c.sendAlert(alertNoRenegotiation)
1276 case RenegotiateOnceAsClient:
1277 if c.handshakes > 1 {
1278 return c.sendAlert(alertNoRenegotiation)
1279 }
1280 case RenegotiateFreelyAsClient:
1281
1282 default:
1283 c.sendAlert(alertInternalError)
1284 return errors.New("tls: unknown Renegotiation value")
1285 }
1286
1287 c.handshakeMutex.Lock()
1288 defer c.handshakeMutex.Unlock()
1289
1290 c.isHandshakeComplete.Store(false)
1291 if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1292 c.handshakes++
1293 }
1294 return c.handshakeErr
1295 }
1296
1297
1298
1299 func (c *Conn) handlePostHandshakeMessage() error {
1300 if c.vers != VersionTLS13 {
1301 return c.handleRenegotiation()
1302 }
1303
1304 msg, err := c.readHandshake(nil)
1305 if err != nil {
1306 return err
1307 }
1308 c.retryCount++
1309 if c.retryCount > maxUselessRecords {
1310 c.sendAlert(alertUnexpectedMessage)
1311 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1312 }
1313
1314 switch msg := msg.(type) {
1315 case *newSessionTicketMsgTLS13:
1316 return c.handleNewSessionTicket(msg)
1317 case *keyUpdateMsg:
1318 return c.handleKeyUpdate(msg)
1319 }
1320
1321
1322
1323
1324 c.sendAlert(alertUnexpectedMessage)
1325 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1326 }
1327
1328 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1329 if c.quic != nil {
1330 c.sendAlert(alertUnexpectedMessage)
1331 return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
1332 }
1333
1334 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1335 if cipherSuite == nil {
1336 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1337 }
1338
1339 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1340 c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1341
1342 if keyUpdate.updateRequested {
1343 c.out.Lock()
1344 defer c.out.Unlock()
1345
1346 msg := &keyUpdateMsg{}
1347 msgBytes, err := msg.marshal()
1348 if err != nil {
1349 return err
1350 }
1351 _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
1352 if err != nil {
1353
1354 c.out.setErrorLocked(err)
1355 return nil
1356 }
1357
1358 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1359 c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1360 }
1361
1362 return nil
1363 }
1364
1365
1366
1367
1368
1369
1370
1371 func (c *Conn) Read(b []byte) (int, error) {
1372 if err := c.Handshake(); err != nil {
1373 return 0, err
1374 }
1375 if len(b) == 0 {
1376
1377
1378 return 0, nil
1379 }
1380
1381 c.in.Lock()
1382 defer c.in.Unlock()
1383
1384 for c.input.Len() == 0 {
1385 if err := c.readRecord(); err != nil {
1386 return 0, err
1387 }
1388 for c.hand.Len() > 0 {
1389 if err := c.handlePostHandshakeMessage(); err != nil {
1390 return 0, err
1391 }
1392 }
1393 }
1394
1395 n, _ := c.input.Read(b)
1396
1397
1398
1399
1400
1401
1402
1403
1404 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1405 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1406 if err := c.readRecord(); err != nil {
1407 return n, err
1408 }
1409 }
1410
1411 return n, nil
1412 }
1413
1414
1415 func (c *Conn) Close() error {
1416
1417 var x int32
1418 for {
1419 x = c.activeCall.Load()
1420 if x&1 != 0 {
1421 return net.ErrClosed
1422 }
1423 if c.activeCall.CompareAndSwap(x, x|1) {
1424 break
1425 }
1426 }
1427 if x != 0 {
1428
1429
1430
1431
1432
1433
1434 return c.conn.Close()
1435 }
1436
1437 var alertErr error
1438 if c.isHandshakeComplete.Load() {
1439 if err := c.closeNotify(); err != nil {
1440 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1441 }
1442 }
1443
1444 if err := c.conn.Close(); err != nil {
1445 return err
1446 }
1447 return alertErr
1448 }
1449
1450 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1451
1452
1453
1454
1455 func (c *Conn) CloseWrite() error {
1456 if !c.isHandshakeComplete.Load() {
1457 return errEarlyCloseWrite
1458 }
1459
1460 return c.closeNotify()
1461 }
1462
1463 func (c *Conn) closeNotify() error {
1464 c.out.Lock()
1465 defer c.out.Unlock()
1466
1467 if !c.closeNotifySent {
1468
1469 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1470 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1471 c.closeNotifySent = true
1472
1473 c.SetWriteDeadline(time.Now())
1474 }
1475 return c.closeNotifyErr
1476 }
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491 func (c *Conn) Handshake() error {
1492 return c.HandshakeContext(context.Background())
1493 }
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505 func (c *Conn) HandshakeContext(ctx context.Context) error {
1506
1507
1508 return c.handshakeContext(ctx)
1509 }
1510
1511 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1512
1513
1514
1515 if c.isHandshakeComplete.Load() {
1516 return nil
1517 }
1518
1519 handshakeCtx, cancel := context.WithCancel(ctx)
1520
1521
1522
1523 defer cancel()
1524
1525 if c.quic != nil {
1526 c.quic.cancelc = handshakeCtx.Done()
1527 c.quic.cancel = cancel
1528 } else if ctx.Done() != nil {
1529
1530
1531
1532
1533
1534 done := make(chan struct{})
1535 interruptRes := make(chan error, 1)
1536 defer func() {
1537 close(done)
1538 if ctxErr := <-interruptRes; ctxErr != nil {
1539
1540 ret = ctxErr
1541 }
1542 }()
1543 go func() {
1544 select {
1545 case <-handshakeCtx.Done():
1546
1547 _ = c.conn.Close()
1548 interruptRes <- handshakeCtx.Err()
1549 case <-done:
1550 interruptRes <- nil
1551 }
1552 }()
1553 }
1554
1555 c.handshakeMutex.Lock()
1556 defer c.handshakeMutex.Unlock()
1557
1558 if err := c.handshakeErr; err != nil {
1559 return err
1560 }
1561 if c.isHandshakeComplete.Load() {
1562 return nil
1563 }
1564
1565 c.in.Lock()
1566 defer c.in.Unlock()
1567
1568 c.handshakeErr = c.handshakeFn(handshakeCtx)
1569 if c.handshakeErr == nil {
1570 c.handshakes++
1571 } else {
1572
1573
1574 c.flush()
1575 }
1576
1577 if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
1578 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1579 }
1580 if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
1581 panic("tls: internal error: handshake returned an error but is marked successful")
1582 }
1583
1584 if c.quic != nil {
1585 if c.handshakeErr == nil {
1586 c.quicHandshakeComplete()
1587
1588
1589
1590 c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
1591 } else {
1592 var a alert
1593 c.out.Lock()
1594 if !errors.As(c.out.err, &a) {
1595 a = alertInternalError
1596 }
1597 c.out.Unlock()
1598
1599
1600
1601
1602 c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
1603 }
1604 close(c.quic.blockedc)
1605 close(c.quic.signalc)
1606 }
1607
1608 return c.handshakeErr
1609 }
1610
1611
1612 func (c *Conn) ConnectionState() ConnectionState {
1613 c.handshakeMutex.Lock()
1614 defer c.handshakeMutex.Unlock()
1615 return c.connectionStateLocked()
1616 }
1617
1618 var tlsunsafeekm = godebug.New("tlsunsafeekm")
1619
1620 func (c *Conn) connectionStateLocked() ConnectionState {
1621 var state ConnectionState
1622 state.HandshakeComplete = c.isHandshakeComplete.Load()
1623 state.Version = c.vers
1624 state.NegotiatedProtocol = c.clientProtocol
1625 state.DidResume = c.didResume
1626 state.testingOnlyDidHRR = c.didHRR
1627
1628 state.testingOnlyCurveID = c.curveID
1629 state.NegotiatedProtocolIsMutual = true
1630 state.ServerName = c.serverName
1631 state.CipherSuite = c.cipherSuite
1632 state.PeerCertificates = c.peerCertificates
1633 state.VerifiedChains = c.verifiedChains
1634 state.SignedCertificateTimestamps = c.scts
1635 state.OCSPResponse = c.ocspResponse
1636 if (!c.didResume || c.extMasterSecret) && c.vers != VersionTLS13 {
1637 if c.clientFinishedIsFirst {
1638 state.TLSUnique = c.clientFinished[:]
1639 } else {
1640 state.TLSUnique = c.serverFinished[:]
1641 }
1642 }
1643 if c.config.Renegotiation != RenegotiateNever {
1644 state.ekm = noEKMBecauseRenegotiation
1645 } else if c.vers != VersionTLS13 && !c.extMasterSecret {
1646 state.ekm = func(label string, context []byte, length int) ([]byte, error) {
1647 if tlsunsafeekm.Value() == "1" {
1648 tlsunsafeekm.IncNonDefault()
1649 return c.ekm(label, context, length)
1650 }
1651 return noEKMBecauseNoEMS(label, context, length)
1652 }
1653 } else {
1654 state.ekm = c.ekm
1655 }
1656 state.ECHAccepted = c.echAccepted
1657 return state
1658 }
1659
1660
1661
1662 func (c *Conn) OCSPResponse() []byte {
1663 c.handshakeMutex.Lock()
1664 defer c.handshakeMutex.Unlock()
1665
1666 return c.ocspResponse
1667 }
1668
1669
1670
1671
1672 func (c *Conn) VerifyHostname(host string) error {
1673 c.handshakeMutex.Lock()
1674 defer c.handshakeMutex.Unlock()
1675 if !c.isClient {
1676 return errors.New("tls: VerifyHostname called on TLS server connection")
1677 }
1678 if !c.isHandshakeComplete.Load() {
1679 return errors.New("tls: handshake has not yet been performed")
1680 }
1681 if len(c.verifiedChains) == 0 {
1682 return errors.New("tls: handshake did not verify certificate chain")
1683 }
1684 return c.peerCertificates[0].VerifyHostname(host)
1685 }
1686
View as plain text