Source file
src/crypto/tls/handshake_messages_test.go
1
2
3
4
5 package tls
6
7 import (
8 "bytes"
9 "crypto/x509"
10 "encoding/hex"
11 "math"
12 "math/rand"
13 "reflect"
14 "strings"
15 "testing"
16 "testing/quick"
17 "time"
18 )
19
20 var tests = []handshakeMessage{
21 &clientHelloMsg{},
22 &serverHelloMsg{},
23 &finishedMsg{},
24
25 &certificateMsg{},
26 &certificateRequestMsg{},
27 &certificateVerifyMsg{
28 hasSignatureAlgorithm: true,
29 },
30 &certificateStatusMsg{},
31 &clientKeyExchangeMsg{},
32 &newSessionTicketMsg{},
33 &encryptedExtensionsMsg{},
34 &endOfEarlyDataMsg{},
35 &keyUpdateMsg{},
36 &newSessionTicketMsgTLS13{},
37 &certificateRequestMsgTLS13{},
38 &certificateMsgTLS13{},
39 &SessionState{},
40 }
41
42 func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
43 t.Helper()
44 b, err := msg.marshal()
45 if err != nil {
46 t.Fatal(err)
47 }
48 return b
49 }
50
51 func TestMarshalUnmarshal(t *testing.T) {
52 rand := rand.New(rand.NewSource(time.Now().UnixNano()))
53
54 for i, m := range tests {
55 ty := reflect.ValueOf(m).Type()
56 t.Run(ty.String(), func(t *testing.T) {
57 n := 100
58 if testing.Short() {
59 n = 5
60 }
61 for j := 0; j < n; j++ {
62 v, ok := quick.Value(ty, rand)
63 if !ok {
64 t.Errorf("#%d: failed to create value", i)
65 break
66 }
67
68 m1 := v.Interface().(handshakeMessage)
69 marshaled := mustMarshal(t, m1)
70 if !m.unmarshal(marshaled) {
71 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
72 break
73 }
74
75 if m, ok := m.(*SessionState); ok {
76 m.activeCertHandles = nil
77 }
78
79 if ch, ok := m.(*clientHelloMsg); ok {
80
81
82
83
84
85 if len(ch.extensions) == 0 {
86 t.Errorf("expected ch.extensions to be populated on unmarshal")
87 }
88 ch.extensions = nil
89 }
90
91
92
93
94
95
96 switch t := m.(type) {
97 case *clientHelloMsg:
98 t.original = nil
99 case *serverHelloMsg:
100 t.original = nil
101 }
102
103 if !reflect.DeepEqual(m1, m) {
104 t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
105 break
106 }
107
108 if i >= 3 {
109
110
111
112
113
114 for j := 0; j < len(marshaled); j++ {
115 if m.unmarshal(marshaled[0:j]) {
116 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
117 break
118 }
119 }
120 }
121 }
122 })
123 }
124 }
125
126 func TestFuzz(t *testing.T) {
127 rand := rand.New(rand.NewSource(0))
128 for _, m := range tests {
129 for j := 0; j < 1000; j++ {
130 len := rand.Intn(1000)
131 bytes := randomBytes(len, rand)
132
133 m.unmarshal(bytes)
134 }
135 }
136 }
137
138 func randomBytes(n int, rand *rand.Rand) []byte {
139 r := make([]byte, n)
140 if _, err := rand.Read(r); err != nil {
141 panic("rand.Read failed: " + err.Error())
142 }
143 return r
144 }
145
146 func randomString(n int, rand *rand.Rand) string {
147 b := randomBytes(n, rand)
148 return string(b)
149 }
150
151 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
152 m := &clientHelloMsg{}
153 m.vers = uint16(rand.Intn(65536))
154 m.random = randomBytes(32, rand)
155 m.sessionId = randomBytes(rand.Intn(32), rand)
156 m.cipherSuites = make([]uint16, rand.Intn(63)+1)
157 for i := 0; i < len(m.cipherSuites); i++ {
158 cs := uint16(rand.Int31())
159 if cs == scsvRenegotiation {
160 cs += 1
161 }
162 m.cipherSuites[i] = cs
163 }
164 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
165 if rand.Intn(10) > 5 {
166 m.serverName = randomString(rand.Intn(255), rand)
167 for strings.HasSuffix(m.serverName, ".") {
168 m.serverName = m.serverName[:len(m.serverName)-1]
169 }
170 }
171 m.ocspStapling = rand.Intn(10) > 5
172 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
173 m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
174 for i := range m.supportedCurves {
175 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
176 }
177 if rand.Intn(10) > 5 {
178 m.ticketSupported = true
179 if rand.Intn(10) > 5 {
180 m.sessionTicket = randomBytes(rand.Intn(300), rand)
181 } else {
182 m.sessionTicket = make([]byte, 0)
183 }
184 }
185 if rand.Intn(10) > 5 {
186 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
187 }
188 if rand.Intn(10) > 5 {
189 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
190 }
191 for i := 0; i < rand.Intn(5); i++ {
192 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
193 }
194 if rand.Intn(10) > 5 {
195 m.scts = true
196 }
197 if rand.Intn(10) > 5 {
198 m.secureRenegotiationSupported = true
199 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
200 }
201 if rand.Intn(10) > 5 {
202 m.extendedMasterSecret = true
203 }
204 for i := 0; i < rand.Intn(5); i++ {
205 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
206 }
207 if rand.Intn(10) > 5 {
208 m.cookie = randomBytes(rand.Intn(500)+1, rand)
209 }
210 for i := 0; i < rand.Intn(5); i++ {
211 var ks keyShare
212 ks.group = CurveID(rand.Intn(30000) + 1)
213 ks.data = randomBytes(rand.Intn(200)+1, rand)
214 m.keyShares = append(m.keyShares, ks)
215 }
216 switch rand.Intn(3) {
217 case 1:
218 m.pskModes = []uint8{pskModeDHE}
219 case 2:
220 m.pskModes = []uint8{pskModeDHE, pskModePlain}
221 }
222 for i := 0; i < rand.Intn(5); i++ {
223 var psk pskIdentity
224 psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
225 psk.label = randomBytes(rand.Intn(500)+1, rand)
226 m.pskIdentities = append(m.pskIdentities, psk)
227 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
228 }
229 if rand.Intn(10) > 5 {
230 m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
231 }
232 if rand.Intn(10) > 5 {
233 m.earlyData = true
234 }
235
236 return reflect.ValueOf(m)
237 }
238
239 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
240 m := &serverHelloMsg{}
241 m.vers = uint16(rand.Intn(65536))
242 m.random = randomBytes(32, rand)
243 m.sessionId = randomBytes(rand.Intn(32), rand)
244 m.cipherSuite = uint16(rand.Int31())
245 m.compressionMethod = uint8(rand.Intn(256))
246 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
247
248 if rand.Intn(10) > 5 {
249 m.ocspStapling = true
250 }
251 if rand.Intn(10) > 5 {
252 m.ticketSupported = true
253 }
254 if rand.Intn(10) > 5 {
255 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
256 }
257
258 for i := 0; i < rand.Intn(4); i++ {
259 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
260 }
261
262 if rand.Intn(10) > 5 {
263 m.secureRenegotiationSupported = true
264 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
265 }
266 if rand.Intn(10) > 5 {
267 m.extendedMasterSecret = true
268 }
269 if rand.Intn(10) > 5 {
270 m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
271 }
272 if rand.Intn(10) > 5 {
273 m.cookie = randomBytes(rand.Intn(500)+1, rand)
274 }
275 if rand.Intn(10) > 5 {
276 for i := 0; i < rand.Intn(5); i++ {
277 m.serverShare.group = CurveID(rand.Intn(30000) + 1)
278 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
279 }
280 } else if rand.Intn(10) > 5 {
281 m.selectedGroup = CurveID(rand.Intn(30000) + 1)
282 }
283 if rand.Intn(10) > 5 {
284 m.selectedIdentityPresent = true
285 m.selectedIdentity = uint16(rand.Intn(0xffff))
286 }
287 if rand.Intn(10) > 5 {
288 m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand)
289 }
290 if rand.Intn(10) > 5 {
291 m.serverNameAck = rand.Intn(2) == 1
292 }
293
294 return reflect.ValueOf(m)
295 }
296
297 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
298 m := &encryptedExtensionsMsg{}
299
300 if rand.Intn(10) > 5 {
301 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
302 }
303 if rand.Intn(10) > 5 {
304 m.earlyData = true
305 }
306
307 return reflect.ValueOf(m)
308 }
309
310 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
311 m := &certificateMsg{}
312 numCerts := rand.Intn(20)
313 m.certificates = make([][]byte, numCerts)
314 for i := 0; i < numCerts; i++ {
315 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
316 }
317 return reflect.ValueOf(m)
318 }
319
320 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
321 m := &certificateRequestMsg{}
322 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
323 for i := 0; i < rand.Intn(100); i++ {
324 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
325 }
326 return reflect.ValueOf(m)
327 }
328
329 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
330 m := &certificateVerifyMsg{}
331 m.hasSignatureAlgorithm = true
332 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
333 m.signature = randomBytes(rand.Intn(15)+1, rand)
334 return reflect.ValueOf(m)
335 }
336
337 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
338 m := &certificateStatusMsg{}
339 m.response = randomBytes(rand.Intn(10)+1, rand)
340 return reflect.ValueOf(m)
341 }
342
343 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
344 m := &clientKeyExchangeMsg{}
345 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
346 return reflect.ValueOf(m)
347 }
348
349 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
350 m := &finishedMsg{}
351 m.verifyData = randomBytes(12, rand)
352 return reflect.ValueOf(m)
353 }
354
355 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
356 m := &newSessionTicketMsg{}
357 m.ticket = randomBytes(rand.Intn(4), rand)
358 return reflect.ValueOf(m)
359 }
360
361 var sessionTestCerts []*x509.Certificate
362
363 func init() {
364 cert, err := x509.ParseCertificate(testRSACertificate)
365 if err != nil {
366 panic(err)
367 }
368 sessionTestCerts = append(sessionTestCerts, cert)
369 cert, err = x509.ParseCertificate(testRSACertificateIssuer)
370 if err != nil {
371 panic(err)
372 }
373 sessionTestCerts = append(sessionTestCerts, cert)
374 }
375
376 func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
377 s := &SessionState{}
378 isTLS13 := rand.Intn(10) > 5
379 if isTLS13 {
380 s.version = VersionTLS13
381 } else {
382 s.version = uint16(rand.Intn(VersionTLS13))
383 }
384 s.isClient = rand.Intn(10) > 5
385 s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
386 s.createdAt = uint64(rand.Int63())
387 s.secret = randomBytes(rand.Intn(100)+1, rand)
388 for n, i := rand.Intn(3), 0; i < n; i++ {
389 s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand))
390 }
391 if rand.Intn(10) > 5 {
392 s.EarlyData = true
393 }
394 if rand.Intn(10) > 5 {
395 s.extMasterSecret = true
396 }
397 if s.isClient || rand.Intn(10) > 5 {
398 if rand.Intn(10) > 5 {
399 s.peerCertificates = sessionTestCerts
400 } else {
401 s.peerCertificates = sessionTestCerts[:1]
402 }
403 }
404 if rand.Intn(10) > 5 && s.peerCertificates != nil {
405 s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
406 }
407 if rand.Intn(10) > 5 && s.peerCertificates != nil {
408 for i := 0; i < rand.Intn(2)+1; i++ {
409 s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
410 }
411 }
412 if len(s.peerCertificates) > 0 {
413 for i := 0; i < rand.Intn(3); i++ {
414 if rand.Intn(10) > 5 {
415 s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
416 } else {
417 s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
418 }
419 }
420 }
421 if rand.Intn(10) > 5 && s.EarlyData {
422 s.alpnProtocol = string(randomBytes(rand.Intn(10), rand))
423 }
424 if s.isClient {
425 if isTLS13 {
426 s.useBy = uint64(rand.Int63())
427 s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
428 }
429 }
430 return reflect.ValueOf(s)
431 }
432
433 func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
434 func (s *SessionState) unmarshal(b []byte) bool {
435 ss, err := ParseSessionState(b)
436 if err != nil {
437 return false
438 }
439 *s = *ss
440 return true
441 }
442
443 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
444 m := &endOfEarlyDataMsg{}
445 return reflect.ValueOf(m)
446 }
447
448 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
449 m := &keyUpdateMsg{}
450 m.updateRequested = rand.Intn(10) > 5
451 return reflect.ValueOf(m)
452 }
453
454 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
455 m := &newSessionTicketMsgTLS13{}
456 m.lifetime = uint32(rand.Intn(500000))
457 m.ageAdd = uint32(rand.Intn(500000))
458 m.nonce = randomBytes(rand.Intn(100), rand)
459 m.label = randomBytes(rand.Intn(1000), rand)
460 if rand.Intn(10) > 5 {
461 m.maxEarlyData = uint32(rand.Intn(500000))
462 }
463 return reflect.ValueOf(m)
464 }
465
466 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
467 m := &certificateRequestMsgTLS13{}
468 if rand.Intn(10) > 5 {
469 m.ocspStapling = true
470 }
471 if rand.Intn(10) > 5 {
472 m.scts = true
473 }
474 if rand.Intn(10) > 5 {
475 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
476 }
477 if rand.Intn(10) > 5 {
478 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
479 }
480 if rand.Intn(10) > 5 {
481 m.certificateAuthorities = make([][]byte, 3)
482 for i := 0; i < 3; i++ {
483 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
484 }
485 }
486 return reflect.ValueOf(m)
487 }
488
489 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
490 m := &certificateMsgTLS13{}
491 for i := 0; i < rand.Intn(2)+1; i++ {
492 m.certificate.Certificate = append(
493 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
494 }
495 if rand.Intn(10) > 5 {
496 m.ocspStapling = true
497 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
498 }
499 if rand.Intn(10) > 5 {
500 m.scts = true
501 for i := 0; i < rand.Intn(2)+1; i++ {
502 m.certificate.SignedCertificateTimestamps = append(
503 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
504 }
505 }
506 return reflect.ValueOf(m)
507 }
508
509 func TestRejectEmptySCTList(t *testing.T) {
510
511
512 var random [32]byte
513 sct := []byte{0x42, 0x42, 0x42, 0x42}
514 serverHello := &serverHelloMsg{
515 vers: VersionTLS12,
516 random: random[:],
517 scts: [][]byte{sct},
518 }
519 serverHelloBytes := mustMarshal(t, serverHello)
520
521 var serverHelloCopy serverHelloMsg
522 if !serverHelloCopy.unmarshal(serverHelloBytes) {
523 t.Fatal("Failed to unmarshal initial message")
524 }
525
526
527 i := bytes.Index(serverHelloBytes, sct)
528 if i < 0 {
529 t.Fatal("Cannot find SCT in ServerHello")
530 }
531
532 var serverHelloEmptySCT []byte
533 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
534
535 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
536 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
537
538
539 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
540 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
541 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
542
543
544 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
545 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
546
547 if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
548 t.Fatal("Unmarshaled ServerHello with empty SCT list")
549 }
550 }
551
552 func TestRejectEmptySCT(t *testing.T) {
553
554
555
556 var random [32]byte
557 serverHello := &serverHelloMsg{
558 vers: VersionTLS12,
559 random: random[:],
560 scts: [][]byte{nil},
561 }
562 serverHelloBytes := mustMarshal(t, serverHello)
563
564 var serverHelloCopy serverHelloMsg
565 if serverHelloCopy.unmarshal(serverHelloBytes) {
566 t.Fatal("Unmarshaled ServerHello with zero-length SCT")
567 }
568 }
569
570 func TestRejectDuplicateExtensions(t *testing.T) {
571 clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
572 if err != nil {
573 t.Fatalf("failed to decode test ClientHello: %s", err)
574 }
575 var clientHelloCopy clientHelloMsg
576 if clientHelloCopy.unmarshal(clientHelloBytes) {
577 t.Error("Unmarshaled ClientHello with duplicate extensions")
578 }
579
580 serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
581 if err != nil {
582 t.Fatalf("failed to decode test ServerHello: %s", err)
583 }
584 var serverHelloCopy serverHelloMsg
585 if serverHelloCopy.unmarshal(serverHelloBytes) {
586 t.Fatal("Unmarshaled ServerHello with duplicate extensions")
587 }
588 }
589
View as plain text