Source file
src/crypto/tls/handshake_messages_test.go
1
2
3
4
5 package tls
6
7 import (
8 "bytes"
9 "crypto/x509"
10 "encoding/hex"
11 "math"
12 "math/rand"
13 "reflect"
14 "strings"
15 "testing"
16 "testing/quick"
17 "time"
18 )
19
20 var tests = []handshakeMessage{
21 &clientHelloMsg{},
22 &serverHelloMsg{},
23 &finishedMsg{},
24
25 &certificateMsg{},
26 &certificateRequestMsg{},
27 &certificateVerifyMsg{
28 hasSignatureAlgorithm: true,
29 },
30 &certificateStatusMsg{},
31 &clientKeyExchangeMsg{},
32 &newSessionTicketMsg{},
33 &encryptedExtensionsMsg{},
34 &endOfEarlyDataMsg{},
35 &keyUpdateMsg{},
36 &newSessionTicketMsgTLS13{},
37 &certificateRequestMsgTLS13{},
38 &certificateMsgTLS13{},
39 &SessionState{},
40 }
41
42 func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
43 t.Helper()
44 b, err := msg.marshal()
45 if err != nil {
46 t.Fatal(err)
47 }
48 return b
49 }
50
51 func TestMarshalUnmarshal(t *testing.T) {
52 rand := rand.New(rand.NewSource(time.Now().UnixNano()))
53
54 for i, m := range tests {
55 ty := reflect.ValueOf(m).Type()
56 t.Run(ty.String(), func(t *testing.T) {
57 n := 100
58 if testing.Short() {
59 n = 5
60 }
61 for j := 0; j < n; j++ {
62 v, ok := quick.Value(ty, rand)
63 if !ok {
64 t.Errorf("#%d: failed to create value", i)
65 break
66 }
67
68 m1 := v.Interface().(handshakeMessage)
69 marshaled := mustMarshal(t, m1)
70 if !m.unmarshal(marshaled) {
71 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
72 break
73 }
74
75 if m, ok := m.(*SessionState); ok {
76 m.activeCertHandles = nil
77 }
78
79
80
81
82
83
84 switch t := m.(type) {
85 case *clientHelloMsg:
86 t.original = nil
87 case *serverHelloMsg:
88 t.original = nil
89 }
90
91 if !reflect.DeepEqual(m1, m) {
92 t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
93 break
94 }
95
96 if i >= 3 {
97
98
99
100
101
102 for j := 0; j < len(marshaled); j++ {
103 if m.unmarshal(marshaled[0:j]) {
104 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
105 break
106 }
107 }
108 }
109 }
110 })
111 }
112 }
113
114 func TestFuzz(t *testing.T) {
115 rand := rand.New(rand.NewSource(0))
116 for _, m := range tests {
117 for j := 0; j < 1000; j++ {
118 len := rand.Intn(1000)
119 bytes := randomBytes(len, rand)
120
121 m.unmarshal(bytes)
122 }
123 }
124 }
125
126 func randomBytes(n int, rand *rand.Rand) []byte {
127 r := make([]byte, n)
128 if _, err := rand.Read(r); err != nil {
129 panic("rand.Read failed: " + err.Error())
130 }
131 return r
132 }
133
134 func randomString(n int, rand *rand.Rand) string {
135 b := randomBytes(n, rand)
136 return string(b)
137 }
138
139 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
140 m := &clientHelloMsg{}
141 m.vers = uint16(rand.Intn(65536))
142 m.random = randomBytes(32, rand)
143 m.sessionId = randomBytes(rand.Intn(32), rand)
144 m.cipherSuites = make([]uint16, rand.Intn(63)+1)
145 for i := 0; i < len(m.cipherSuites); i++ {
146 cs := uint16(rand.Int31())
147 if cs == scsvRenegotiation {
148 cs += 1
149 }
150 m.cipherSuites[i] = cs
151 }
152 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
153 if rand.Intn(10) > 5 {
154 m.serverName = randomString(rand.Intn(255), rand)
155 for strings.HasSuffix(m.serverName, ".") {
156 m.serverName = m.serverName[:len(m.serverName)-1]
157 }
158 }
159 m.ocspStapling = rand.Intn(10) > 5
160 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
161 m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
162 for i := range m.supportedCurves {
163 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
164 }
165 if rand.Intn(10) > 5 {
166 m.ticketSupported = true
167 if rand.Intn(10) > 5 {
168 m.sessionTicket = randomBytes(rand.Intn(300), rand)
169 } else {
170 m.sessionTicket = make([]byte, 0)
171 }
172 }
173 if rand.Intn(10) > 5 {
174 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
175 }
176 if rand.Intn(10) > 5 {
177 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
178 }
179 for i := 0; i < rand.Intn(5); i++ {
180 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
181 }
182 if rand.Intn(10) > 5 {
183 m.scts = true
184 }
185 if rand.Intn(10) > 5 {
186 m.secureRenegotiationSupported = true
187 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
188 }
189 if rand.Intn(10) > 5 {
190 m.extendedMasterSecret = true
191 }
192 for i := 0; i < rand.Intn(5); i++ {
193 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
194 }
195 if rand.Intn(10) > 5 {
196 m.cookie = randomBytes(rand.Intn(500)+1, rand)
197 }
198 for i := 0; i < rand.Intn(5); i++ {
199 var ks keyShare
200 ks.group = CurveID(rand.Intn(30000) + 1)
201 ks.data = randomBytes(rand.Intn(200)+1, rand)
202 m.keyShares = append(m.keyShares, ks)
203 }
204 switch rand.Intn(3) {
205 case 1:
206 m.pskModes = []uint8{pskModeDHE}
207 case 2:
208 m.pskModes = []uint8{pskModeDHE, pskModePlain}
209 }
210 for i := 0; i < rand.Intn(5); i++ {
211 var psk pskIdentity
212 psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
213 psk.label = randomBytes(rand.Intn(500)+1, rand)
214 m.pskIdentities = append(m.pskIdentities, psk)
215 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
216 }
217 if rand.Intn(10) > 5 {
218 m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
219 }
220 if rand.Intn(10) > 5 {
221 m.earlyData = true
222 }
223
224 return reflect.ValueOf(m)
225 }
226
227 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
228 m := &serverHelloMsg{}
229 m.vers = uint16(rand.Intn(65536))
230 m.random = randomBytes(32, rand)
231 m.sessionId = randomBytes(rand.Intn(32), rand)
232 m.cipherSuite = uint16(rand.Int31())
233 m.compressionMethod = uint8(rand.Intn(256))
234 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
235
236 if rand.Intn(10) > 5 {
237 m.ocspStapling = true
238 }
239 if rand.Intn(10) > 5 {
240 m.ticketSupported = true
241 }
242 if rand.Intn(10) > 5 {
243 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
244 }
245
246 for i := 0; i < rand.Intn(4); i++ {
247 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
248 }
249
250 if rand.Intn(10) > 5 {
251 m.secureRenegotiationSupported = true
252 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
253 }
254 if rand.Intn(10) > 5 {
255 m.extendedMasterSecret = true
256 }
257 if rand.Intn(10) > 5 {
258 m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
259 }
260 if rand.Intn(10) > 5 {
261 m.cookie = randomBytes(rand.Intn(500)+1, rand)
262 }
263 if rand.Intn(10) > 5 {
264 for i := 0; i < rand.Intn(5); i++ {
265 m.serverShare.group = CurveID(rand.Intn(30000) + 1)
266 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
267 }
268 } else if rand.Intn(10) > 5 {
269 m.selectedGroup = CurveID(rand.Intn(30000) + 1)
270 }
271 if rand.Intn(10) > 5 {
272 m.selectedIdentityPresent = true
273 m.selectedIdentity = uint16(rand.Intn(0xffff))
274 }
275 if rand.Intn(10) > 5 {
276 m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand)
277 }
278 if rand.Intn(10) > 5 {
279 m.serverNameAck = rand.Intn(2) == 1
280 }
281
282 return reflect.ValueOf(m)
283 }
284
285 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
286 m := &encryptedExtensionsMsg{}
287
288 if rand.Intn(10) > 5 {
289 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
290 }
291 if rand.Intn(10) > 5 {
292 m.earlyData = true
293 }
294
295 return reflect.ValueOf(m)
296 }
297
298 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
299 m := &certificateMsg{}
300 numCerts := rand.Intn(20)
301 m.certificates = make([][]byte, numCerts)
302 for i := 0; i < numCerts; i++ {
303 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
304 }
305 return reflect.ValueOf(m)
306 }
307
308 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
309 m := &certificateRequestMsg{}
310 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
311 for i := 0; i < rand.Intn(100); i++ {
312 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
313 }
314 return reflect.ValueOf(m)
315 }
316
317 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
318 m := &certificateVerifyMsg{}
319 m.hasSignatureAlgorithm = true
320 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
321 m.signature = randomBytes(rand.Intn(15)+1, rand)
322 return reflect.ValueOf(m)
323 }
324
325 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
326 m := &certificateStatusMsg{}
327 m.response = randomBytes(rand.Intn(10)+1, rand)
328 return reflect.ValueOf(m)
329 }
330
331 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
332 m := &clientKeyExchangeMsg{}
333 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
334 return reflect.ValueOf(m)
335 }
336
337 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
338 m := &finishedMsg{}
339 m.verifyData = randomBytes(12, rand)
340 return reflect.ValueOf(m)
341 }
342
343 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
344 m := &newSessionTicketMsg{}
345 m.ticket = randomBytes(rand.Intn(4), rand)
346 return reflect.ValueOf(m)
347 }
348
349 var sessionTestCerts []*x509.Certificate
350
351 func init() {
352 cert, err := x509.ParseCertificate(testRSACertificate)
353 if err != nil {
354 panic(err)
355 }
356 sessionTestCerts = append(sessionTestCerts, cert)
357 cert, err = x509.ParseCertificate(testRSACertificateIssuer)
358 if err != nil {
359 panic(err)
360 }
361 sessionTestCerts = append(sessionTestCerts, cert)
362 }
363
364 func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
365 s := &SessionState{}
366 isTLS13 := rand.Intn(10) > 5
367 if isTLS13 {
368 s.version = VersionTLS13
369 } else {
370 s.version = uint16(rand.Intn(VersionTLS13))
371 }
372 s.isClient = rand.Intn(10) > 5
373 s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
374 s.createdAt = uint64(rand.Int63())
375 s.secret = randomBytes(rand.Intn(100)+1, rand)
376 for n, i := rand.Intn(3), 0; i < n; i++ {
377 s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand))
378 }
379 if rand.Intn(10) > 5 {
380 s.EarlyData = true
381 }
382 if rand.Intn(10) > 5 {
383 s.extMasterSecret = true
384 }
385 if s.isClient || rand.Intn(10) > 5 {
386 if rand.Intn(10) > 5 {
387 s.peerCertificates = sessionTestCerts
388 } else {
389 s.peerCertificates = sessionTestCerts[:1]
390 }
391 }
392 if rand.Intn(10) > 5 && s.peerCertificates != nil {
393 s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
394 }
395 if rand.Intn(10) > 5 && s.peerCertificates != nil {
396 for i := 0; i < rand.Intn(2)+1; i++ {
397 s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
398 }
399 }
400 if len(s.peerCertificates) > 0 {
401 for i := 0; i < rand.Intn(3); i++ {
402 if rand.Intn(10) > 5 {
403 s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
404 } else {
405 s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
406 }
407 }
408 }
409 if rand.Intn(10) > 5 && s.EarlyData {
410 s.alpnProtocol = string(randomBytes(rand.Intn(10), rand))
411 }
412 if s.isClient {
413 if isTLS13 {
414 s.useBy = uint64(rand.Int63())
415 s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
416 }
417 }
418 return reflect.ValueOf(s)
419 }
420
421 func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
422 func (s *SessionState) unmarshal(b []byte) bool {
423 ss, err := ParseSessionState(b)
424 if err != nil {
425 return false
426 }
427 *s = *ss
428 return true
429 }
430
431 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
432 m := &endOfEarlyDataMsg{}
433 return reflect.ValueOf(m)
434 }
435
436 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
437 m := &keyUpdateMsg{}
438 m.updateRequested = rand.Intn(10) > 5
439 return reflect.ValueOf(m)
440 }
441
442 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
443 m := &newSessionTicketMsgTLS13{}
444 m.lifetime = uint32(rand.Intn(500000))
445 m.ageAdd = uint32(rand.Intn(500000))
446 m.nonce = randomBytes(rand.Intn(100), rand)
447 m.label = randomBytes(rand.Intn(1000), rand)
448 if rand.Intn(10) > 5 {
449 m.maxEarlyData = uint32(rand.Intn(500000))
450 }
451 return reflect.ValueOf(m)
452 }
453
454 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
455 m := &certificateRequestMsgTLS13{}
456 if rand.Intn(10) > 5 {
457 m.ocspStapling = true
458 }
459 if rand.Intn(10) > 5 {
460 m.scts = true
461 }
462 if rand.Intn(10) > 5 {
463 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
464 }
465 if rand.Intn(10) > 5 {
466 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
467 }
468 if rand.Intn(10) > 5 {
469 m.certificateAuthorities = make([][]byte, 3)
470 for i := 0; i < 3; i++ {
471 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
472 }
473 }
474 return reflect.ValueOf(m)
475 }
476
477 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
478 m := &certificateMsgTLS13{}
479 for i := 0; i < rand.Intn(2)+1; i++ {
480 m.certificate.Certificate = append(
481 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
482 }
483 if rand.Intn(10) > 5 {
484 m.ocspStapling = true
485 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
486 }
487 if rand.Intn(10) > 5 {
488 m.scts = true
489 for i := 0; i < rand.Intn(2)+1; i++ {
490 m.certificate.SignedCertificateTimestamps = append(
491 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
492 }
493 }
494 return reflect.ValueOf(m)
495 }
496
497 func TestRejectEmptySCTList(t *testing.T) {
498
499
500 var random [32]byte
501 sct := []byte{0x42, 0x42, 0x42, 0x42}
502 serverHello := &serverHelloMsg{
503 vers: VersionTLS12,
504 random: random[:],
505 scts: [][]byte{sct},
506 }
507 serverHelloBytes := mustMarshal(t, serverHello)
508
509 var serverHelloCopy serverHelloMsg
510 if !serverHelloCopy.unmarshal(serverHelloBytes) {
511 t.Fatal("Failed to unmarshal initial message")
512 }
513
514
515 i := bytes.Index(serverHelloBytes, sct)
516 if i < 0 {
517 t.Fatal("Cannot find SCT in ServerHello")
518 }
519
520 var serverHelloEmptySCT []byte
521 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
522
523 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
524 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
525
526
527 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
528 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
529 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
530
531
532 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
533 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
534
535 if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
536 t.Fatal("Unmarshaled ServerHello with empty SCT list")
537 }
538 }
539
540 func TestRejectEmptySCT(t *testing.T) {
541
542
543
544 var random [32]byte
545 serverHello := &serverHelloMsg{
546 vers: VersionTLS12,
547 random: random[:],
548 scts: [][]byte{nil},
549 }
550 serverHelloBytes := mustMarshal(t, serverHello)
551
552 var serverHelloCopy serverHelloMsg
553 if serverHelloCopy.unmarshal(serverHelloBytes) {
554 t.Fatal("Unmarshaled ServerHello with zero-length SCT")
555 }
556 }
557
558 func TestRejectDuplicateExtensions(t *testing.T) {
559 clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
560 if err != nil {
561 t.Fatalf("failed to decode test ClientHello: %s", err)
562 }
563 var clientHelloCopy clientHelloMsg
564 if clientHelloCopy.unmarshal(clientHelloBytes) {
565 t.Error("Unmarshaled ClientHello with duplicate extensions")
566 }
567
568 serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
569 if err != nil {
570 t.Fatalf("failed to decode test ServerHello: %s", err)
571 }
572 var serverHelloCopy serverHelloMsg
573 if serverHelloCopy.unmarshal(serverHelloBytes) {
574 t.Fatal("Unmarshaled ServerHello with duplicate extensions")
575 }
576 }
577
View as plain text