1
2
3
4
5 package quic
6
7 import (
8 "bytes"
9 "crypto/aes"
10 "crypto/cipher"
11 "crypto/rand"
12 "encoding/binary"
13 "net/netip"
14 "time"
15
16 "golang.org/x/crypto/chacha20poly1305"
17 "golang.org/x/net/internal/quic/quicwire"
18 )
19
20
21
22 var (
23 retrySecret = []byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}
24 retryNonce = []byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}
25 retryAEAD = func() cipher.AEAD {
26 c, err := aes.NewCipher(retrySecret)
27 if err != nil {
28 panic(err)
29 }
30 aead, err := cipher.NewGCM(c)
31 if err != nil {
32 panic(err)
33 }
34 return aead
35 }()
36 )
37
38
39 const retryTokenValidityPeriod = 5 * time.Second
40
41
42 type retryState struct {
43 aead cipher.AEAD
44 }
45
46 func (rs *retryState) init() error {
47
48
49 secret := make([]byte, chacha20poly1305.KeySize)
50 if _, err := rand.Read(secret); err != nil {
51 return err
52 }
53 aead, err := chacha20poly1305.NewX(secret)
54 if err != nil {
55 panic(err)
56 }
57 rs.aead = aead
58 return nil
59 }
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90 func (rs *retryState) makeToken(now time.Time, srcConnID, origDstConnID []byte, addr netip.AddrPort) (token, newDstConnID []byte, err error) {
91 nonce := make([]byte, rs.aead.NonceSize())
92 if _, err := rand.Read(nonce); err != nil {
93 return nil, nil, err
94 }
95
96 var plaintext []byte
97 plaintext = binary.BigEndian.AppendUint64(plaintext, uint64(now.Unix()))
98 plaintext = append(plaintext, origDstConnID...)
99
100 token = append(token, nonce[maxConnIDLen:]...)
101 token = rs.aead.Seal(token, nonce, plaintext, rs.additionalData(srcConnID, addr))
102 return token, nonce[:maxConnIDLen], nil
103 }
104
105 func (rs *retryState) validateToken(now time.Time, token, srcConnID, dstConnID []byte, addr netip.AddrPort) (origDstConnID []byte, ok bool) {
106 tokenNonceLen := rs.aead.NonceSize() - maxConnIDLen
107 if len(token) < tokenNonceLen {
108 return nil, false
109 }
110 nonce := append([]byte{}, dstConnID...)
111 nonce = append(nonce, token[:tokenNonceLen]...)
112 ciphertext := token[tokenNonceLen:]
113 if len(nonce) != rs.aead.NonceSize() {
114 return nil, false
115 }
116
117 plaintext, err := rs.aead.Open(nil, nonce, ciphertext, rs.additionalData(srcConnID, addr))
118 if err != nil {
119 return nil, false
120 }
121 if len(plaintext) < 8 {
122 return nil, false
123 }
124 when := time.Unix(int64(binary.BigEndian.Uint64(plaintext)), 0)
125 origDstConnID = plaintext[8:]
126
127
128
129 if d := abs(now.Sub(when)); d > retryTokenValidityPeriod {
130 return nil, false
131 }
132
133 return origDstConnID, true
134 }
135
136 func (rs *retryState) additionalData(srcConnID []byte, addr netip.AddrPort) []byte {
137 var additional []byte
138 additional = quicwire.AppendUint8Bytes(additional, srcConnID)
139 additional = append(additional, addr.Addr().AsSlice()...)
140 additional = binary.BigEndian.AppendUint16(additional, addr.Port())
141 return additional
142 }
143
144 func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) (origDstConnID []byte, ok bool) {
145
146 token, n := quicwire.ConsumeUint8Bytes(p.data)
147 if n < 0 {
148
149
150
151 return nil, false
152 }
153 if len(token) == 0 {
154
155
156 e.sendRetry(now, p, peerAddr)
157 return nil, false
158 }
159 origDstConnID, ok = e.retry.validateToken(now, token, p.srcConnID, p.dstConnID, peerAddr)
160 if !ok {
161
162
163
164 e.sendConnectionClose(p, peerAddr, errInvalidToken)
165 return nil, false
166 }
167 return origDstConnID, true
168 }
169
170 func (e *Endpoint) sendRetry(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) {
171 token, srcConnID, err := e.retry.makeToken(now, p.srcConnID, p.dstConnID, peerAddr)
172 if err != nil {
173 return
174 }
175 b := encodeRetryPacket(p.dstConnID, retryPacket{
176 dstConnID: p.srcConnID,
177 srcConnID: srcConnID,
178 token: token,
179 })
180 e.sendDatagram(datagram{
181 b: b,
182 peerAddr: peerAddr,
183 })
184 }
185
186 type retryPacket struct {
187 dstConnID []byte
188 srcConnID []byte
189 token []byte
190 }
191
192 func encodeRetryPacket(originalDstConnID []byte, p retryPacket) []byte {
193
194
195
196
197
198
199
200 var b []byte
201 b = quicwire.AppendUint8Bytes(b, originalDstConnID)
202 start := len(b)
203 b = append(b, headerFormLong|fixedBit|longPacketTypeRetry)
204 b = binary.BigEndian.AppendUint32(b, quicVersion1)
205 b = quicwire.AppendUint8Bytes(b, p.dstConnID)
206 b = quicwire.AppendUint8Bytes(b, p.srcConnID)
207 b = append(b, p.token...)
208 b = retryAEAD.Seal(b, retryNonce, nil, b)
209 return b[start:]
210 }
211
212 func parseRetryPacket(b, origDstConnID []byte) (p retryPacket, ok bool) {
213 const retryIntegrityTagLength = 128 / 8
214
215 lp, ok := parseGenericLongHeaderPacket(b)
216 if !ok {
217 return retryPacket{}, false
218 }
219 if len(lp.data) < retryIntegrityTagLength {
220 return retryPacket{}, false
221 }
222 gotTag := lp.data[len(lp.data)-retryIntegrityTagLength:]
223
224
225
226
227 pseudo := quicwire.AppendUint8Bytes(nil, origDstConnID)
228 pseudo = append(pseudo, b[:len(b)-retryIntegrityTagLength]...)
229 wantTag := retryAEAD.Seal(nil, retryNonce, nil, pseudo)
230 if !bytes.Equal(gotTag, wantTag) {
231 return retryPacket{}, false
232 }
233
234 token := lp.data[:len(lp.data)-retryIntegrityTagLength]
235 return retryPacket{
236 dstConnID: lp.dstConnID,
237 srcConnID: lp.srcConnID,
238 token: token,
239 }, true
240 }
241
View as plain text