1
2
3
4
5 package quic
6
7 import (
8 "crypto"
9 "crypto/aes"
10 "crypto/cipher"
11 "crypto/sha256"
12 "crypto/tls"
13 "errors"
14 "hash"
15
16 "golang.org/x/crypto/chacha20"
17 "golang.org/x/crypto/chacha20poly1305"
18 "golang.org/x/crypto/cryptobyte"
19 "golang.org/x/crypto/hkdf"
20 )
21
22 var errInvalidPacket = errors.New("quic: invalid packet")
23
24
25
26 const headerProtectionSampleSize = 16
27
28
29
30 const aeadOverhead = 16
31
32
33
34 type headerKey struct {
35 hp headerProtection
36 }
37
38 func (k headerKey) isSet() bool {
39 return k.hp != nil
40 }
41
42 func (k *headerKey) init(suite uint16, secret []byte) {
43 h, keySize := hashForSuite(suite)
44 hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keySize)
45 switch suite {
46 case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
47 c, err := aes.NewCipher(hpKey)
48 if err != nil {
49 panic(err)
50 }
51 k.hp = &aesHeaderProtection{cipher: c}
52 case tls.TLS_CHACHA20_POLY1305_SHA256:
53 k.hp = chaCha20HeaderProtection{hpKey}
54 default:
55 panic("BUG: unknown cipher suite")
56 }
57 }
58
59
60
61 func (k headerKey) protect(hdr []byte, pnumOff int) {
62
63 pnumSize := int(hdr[0]&0x03) + 1
64 sample := hdr[pnumOff+4:][:headerProtectionSampleSize]
65 mask := k.hp.headerProtection(sample)
66 if isLongHeader(hdr[0]) {
67 hdr[0] ^= mask[0] & 0x0f
68 } else {
69 hdr[0] ^= mask[0] & 0x1f
70 }
71 for i := 0; i < pnumSize; i++ {
72 hdr[pnumOff+i] ^= mask[1+i]
73 }
74 }
75
76
77
78
79 func (k headerKey) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (hdr, pay []byte, pnum packetNumber, _ error) {
80 if len(pkt) < pnumOff+4+headerProtectionSampleSize {
81 return nil, nil, 0, errInvalidPacket
82 }
83 numpay := pkt[pnumOff:]
84 sample := numpay[4:][:headerProtectionSampleSize]
85 mask := k.hp.headerProtection(sample)
86 if isLongHeader(pkt[0]) {
87 pkt[0] ^= mask[0] & 0x0f
88 } else {
89 pkt[0] ^= mask[0] & 0x1f
90 }
91 pnumLen := int(pkt[0]&0x03) + 1
92 pnum = packetNumber(0)
93 for i := 0; i < pnumLen; i++ {
94 numpay[i] ^= mask[1+i]
95 pnum = (pnum << 8) | packetNumber(numpay[i])
96 }
97 pnum = decodePacketNumber(pnumMax, pnum, pnumLen)
98 hdr = pkt[:pnumOff+pnumLen]
99 pay = numpay[pnumLen:]
100 return hdr, pay, pnum, nil
101 }
102
103
104
105
106
107
108
109 type headerProtection interface {
110 headerProtection(sample []byte) (mask [5]byte)
111 }
112
113
114
115 type aesHeaderProtection struct {
116 cipher cipher.Block
117 scratch [aes.BlockSize]byte
118 }
119
120 func (hp *aesHeaderProtection) headerProtection(sample []byte) (mask [5]byte) {
121 hp.cipher.Encrypt(hp.scratch[:], sample)
122 copy(mask[:], hp.scratch[:])
123 return mask
124 }
125
126
127
128 type chaCha20HeaderProtection struct {
129 key []byte
130 }
131
132 func (hp chaCha20HeaderProtection) headerProtection(sample []byte) (mask [5]byte) {
133 counter := uint32(sample[3])<<24 | uint32(sample[2])<<16 | uint32(sample[1])<<8 | uint32(sample[0])
134 nonce := sample[4:16]
135 c, err := chacha20.NewUnauthenticatedCipher(hp.key, nonce)
136 if err != nil {
137 panic(err)
138 }
139 c.SetCounter(counter)
140 c.XORKeyStream(mask[:], mask[:])
141 return mask
142 }
143
144
145
146 type packetKey struct {
147 aead cipher.AEAD
148 iv []byte
149 }
150
151 func (k *packetKey) init(suite uint16, secret []byte) {
152
153 h, keySize := hashForSuite(suite)
154 key := hkdfExpandLabel(h.New, secret, "quic key", nil, keySize)
155 switch suite {
156 case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
157 k.aead = newAESAEAD(key)
158 case tls.TLS_CHACHA20_POLY1305_SHA256:
159 k.aead = newChaCha20AEAD(key)
160 default:
161 panic("BUG: unknown cipher suite")
162 }
163 k.iv = hkdfExpandLabel(h.New, secret, "quic iv", nil, k.aead.NonceSize())
164 }
165
166 func newAESAEAD(key []byte) cipher.AEAD {
167 c, err := aes.NewCipher(key)
168 if err != nil {
169 panic(err)
170 }
171 aead, err := cipher.NewGCM(c)
172 if err != nil {
173 panic(err)
174 }
175 return aead
176 }
177
178 func newChaCha20AEAD(key []byte) cipher.AEAD {
179 var err error
180 aead, err := chacha20poly1305.New(key)
181 if err != nil {
182 panic(err)
183 }
184 return aead
185 }
186
187 func (k packetKey) protect(hdr, pay []byte, pnum packetNumber) []byte {
188 k.xorIV(pnum)
189 defer k.xorIV(pnum)
190 return k.aead.Seal(hdr, k.iv, pay, hdr)
191 }
192
193 func (k packetKey) unprotect(hdr, pay []byte, pnum packetNumber) (dec []byte, err error) {
194 k.xorIV(pnum)
195 defer k.xorIV(pnum)
196 return k.aead.Open(pay[:0], k.iv, pay, hdr)
197 }
198
199
200 func (k packetKey) xorIV(pnum packetNumber) {
201 k.iv[len(k.iv)-8] ^= uint8(pnum >> 56)
202 k.iv[len(k.iv)-7] ^= uint8(pnum >> 48)
203 k.iv[len(k.iv)-6] ^= uint8(pnum >> 40)
204 k.iv[len(k.iv)-5] ^= uint8(pnum >> 32)
205 k.iv[len(k.iv)-4] ^= uint8(pnum >> 24)
206 k.iv[len(k.iv)-3] ^= uint8(pnum >> 16)
207 k.iv[len(k.iv)-2] ^= uint8(pnum >> 8)
208 k.iv[len(k.iv)-1] ^= uint8(pnum)
209 }
210
211
212
213
214
215 type fixedKeys struct {
216 hdr headerKey
217 pkt packetKey
218 }
219
220 func (k *fixedKeys) init(suite uint16, secret []byte) {
221 k.hdr.init(suite, secret)
222 k.pkt.init(suite, secret)
223 }
224
225 func (k fixedKeys) isSet() bool {
226 return k.hdr.hp != nil
227 }
228
229
230
231
232
233
234
235
236
237 func (k fixedKeys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte {
238 pkt := k.pkt.protect(hdr, pay, pnum)
239 k.hdr.protect(pkt, pnumOff)
240 return pkt
241 }
242
243
244
245
246
247
248
249
250
251 func (k fixedKeys) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, num packetNumber, err error) {
252 hdr, pay, pnum, err := k.hdr.unprotect(pkt, pnumOff, pnumMax)
253 if err != nil {
254 return nil, 0, err
255 }
256 pay, err = k.pkt.unprotect(hdr, pay, pnum)
257 if err != nil {
258 return nil, 0, err
259 }
260 return pay, pnum, nil
261 }
262
263
264 type fixedKeyPair struct {
265 r, w fixedKeys
266 }
267
268 func (k *fixedKeyPair) discard() {
269 *k = fixedKeyPair{}
270 }
271
272 func (k *fixedKeyPair) canRead() bool {
273 return k.r.isSet()
274 }
275
276 func (k *fixedKeyPair) canWrite() bool {
277 return k.w.isSet()
278 }
279
280
281
282
283
284 type updatingKeys struct {
285 suite uint16
286 hdr headerKey
287 pkt [2]packetKey
288 nextSecret []byte
289 }
290
291 func (k *updatingKeys) init(suite uint16, secret []byte) {
292 k.suite = suite
293 k.hdr.init(suite, secret)
294
295 k.pkt[1].init(suite, secret)
296 k.nextSecret = secret
297 k.update()
298 }
299
300
301
302
303
304 func (k *updatingKeys) update() {
305 k.nextSecret = updateSecret(k.suite, k.nextSecret)
306 k.pkt[0] = k.pkt[1]
307 k.pkt[1].init(k.suite, k.nextSecret)
308 }
309
310 func updateSecret(suite uint16, secret []byte) (nextSecret []byte) {
311 h, _ := hashForSuite(suite)
312 return hkdfExpandLabel(h.New, secret, "quic ku", nil, len(secret))
313 }
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336 type updatingKeyPair struct {
337 phase uint8
338 updating bool
339 authFailures int64
340 minSent packetNumber
341 minReceived packetNumber
342 updateAfter packetNumber
343 r, w updatingKeys
344 }
345
346 func (k *updatingKeyPair) init() {
347
348
349
350
351
352
353
354
355
356
357
358 k.updateAfter = 100
359 }
360
361 func (k *updatingKeyPair) canRead() bool {
362 return k.r.hdr.hp != nil
363 }
364
365 func (k *updatingKeyPair) canWrite() bool {
366 return k.w.hdr.hp != nil
367 }
368
369
370 func (k *updatingKeyPair) handleAckFor(pnum packetNumber) {
371 if k.updating && pnum >= k.minSent {
372 k.updating = false
373 k.phase ^= keyPhaseBit
374 k.r.update()
375 k.w.update()
376 }
377 }
378
379
380
381
382 func (k *updatingKeyPair) needAckEliciting() bool {
383 return k.updating && k.minSent == maxPacketNumber
384 }
385
386
387
388 func (k *updatingKeyPair) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte {
389 var pkt []byte
390 if k.updating {
391 hdr[0] |= k.phase ^ keyPhaseBit
392 pkt = k.w.pkt[1].protect(hdr, pay, pnum)
393 k.minSent = min(pnum, k.minSent)
394 } else {
395 hdr[0] |= k.phase
396 pkt = k.w.pkt[0].protect(hdr, pay, pnum)
397 if pnum >= k.updateAfter {
398
399
400
401
402
403 k.updating = true
404 k.minSent = maxPacketNumber
405 k.minReceived = maxPacketNumber
406
407
408
409
410 k.updateAfter += (1 << 22)
411 }
412 }
413 k.w.hdr.protect(pkt, pnumOff)
414 return pkt
415 }
416
417
418
419 func (k *updatingKeyPair) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, pnum packetNumber, err error) {
420 hdr, pay, pnum, err := k.r.hdr.unprotect(pkt, pnumOff, pnumMax)
421 if err != nil {
422 return nil, 0, err
423 }
424
425
426
427
428
429
430 if hdr[0]&keyPhaseBit == k.phase && (!k.updating || pnum < k.minReceived) {
431 pay, err = k.r.pkt[0].unprotect(hdr, pay, pnum)
432 } else {
433 pay, err = k.r.pkt[1].unprotect(hdr, pay, pnum)
434 if err == nil {
435 if !k.updating {
436
437 k.updating = true
438 k.minSent = maxPacketNumber
439 k.minReceived = pnum
440 } else {
441 k.minReceived = min(pnum, k.minReceived)
442 }
443 }
444 }
445 if err != nil {
446 k.authFailures++
447 if k.authFailures >= aeadIntegrityLimit(k.r.suite) {
448 return nil, 0, localTransportError{code: errAEADLimitReached}
449 }
450 return nil, 0, err
451 }
452 return pay, pnum, nil
453 }
454
455
456
457
458
459
460 func aeadIntegrityLimit(suite uint16) int64 {
461 switch suite {
462 case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
463 return 1 << 52
464 case tls.TLS_CHACHA20_POLY1305_SHA256:
465 return 1 << 36
466 default:
467 panic("BUG: unknown cipher suite")
468 }
469 }
470
471
472 var initialSalt = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
473
474
475
476
477
478
479
480 func initialKeys(cid []byte, side connSide) fixedKeyPair {
481 initialSecret := hkdf.Extract(sha256.New, cid, initialSalt)
482 var clientKeys fixedKeys
483 clientSecret := hkdfExpandLabel(sha256.New, initialSecret, "client in", nil, sha256.Size)
484 clientKeys.init(tls.TLS_AES_128_GCM_SHA256, clientSecret)
485 var serverKeys fixedKeys
486 serverSecret := hkdfExpandLabel(sha256.New, initialSecret, "server in", nil, sha256.Size)
487 serverKeys.init(tls.TLS_AES_128_GCM_SHA256, serverSecret)
488 if side == clientSide {
489 return fixedKeyPair{r: serverKeys, w: clientKeys}
490 } else {
491 return fixedKeyPair{w: serverKeys, r: clientKeys}
492 }
493 }
494
495
496 func checkCipherSuite(suite uint16) error {
497 switch suite {
498 case tls.TLS_AES_128_GCM_SHA256:
499 case tls.TLS_AES_256_GCM_SHA384:
500 case tls.TLS_CHACHA20_POLY1305_SHA256:
501 default:
502 return errors.New("invalid cipher suite")
503 }
504 return nil
505 }
506
507 func hashForSuite(suite uint16) (h crypto.Hash, keySize int) {
508 switch suite {
509 case tls.TLS_AES_128_GCM_SHA256:
510 return crypto.SHA256, 128 / 8
511 case tls.TLS_AES_256_GCM_SHA384:
512 return crypto.SHA384, 256 / 8
513 case tls.TLS_CHACHA20_POLY1305_SHA256:
514 return crypto.SHA256, chacha20.KeySize
515 default:
516 panic("BUG: unknown cipher suite")
517 }
518 }
519
520
521
522
523 func hkdfExpandLabel(hash func() hash.Hash, secret []byte, label string, context []byte, length int) []byte {
524 var hkdfLabel cryptobyte.Builder
525 hkdfLabel.AddUint16(uint16(length))
526 hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
527 b.AddBytes([]byte("tls13 "))
528 b.AddBytes([]byte(label))
529 })
530 hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
531 b.AddBytes(context)
532 })
533 out := make([]byte, length)
534 n, err := hkdf.Expand(hash, secret, hkdfLabel.BytesOrPanic()).Read(out)
535 if err != nil || n != length {
536 panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
537 }
538 return out
539 }
540
View as plain text