1
2
3
4
5 package quic
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "slices"
11 )
12
13
14 type connIDState struct {
15
16
17
18
19
20
21
22
23 local []connID
24 remote []remoteConnID
25
26 nextLocalSeq int64
27 peerActiveConnIDLimit int64
28
29
30
31
32
33
34 retireRemotePriorTo int64
35 remoteRetiring rangeset[int64]
36 remoteRetiringSent rangeset[int64]
37
38 originalDstConnID []byte
39 retrySrcConnID []byte
40
41 needSend bool
42 }
43
44
45 type connID struct {
46
47 cid []byte
48
49
50
51
52
53 seq int64
54
55
56
57
58
59
60
61
62 send sentVal
63 }
64
65
66 type remoteConnID struct {
67 connID
68 resetToken statelessResetToken
69 }
70
71 func (s *connIDState) initClient(c *Conn) error {
72
73
74 locid, err := c.newConnID(0)
75 if err != nil {
76 return err
77 }
78 s.local = append(s.local, connID{
79 seq: 0,
80 cid: locid,
81 })
82 s.nextLocalSeq = 1
83 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
84 conns.addConnID(c, locid)
85 })
86
87
88
89 remid, err := c.newConnID(-1)
90 if err != nil {
91 return err
92 }
93 s.remote = append(s.remote, remoteConnID{
94 connID: connID{
95 seq: -1,
96 cid: remid,
97 },
98 })
99 s.originalDstConnID = remid
100 return nil
101 }
102
103 func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error {
104 dstConnID := cloneBytes(cids.dstConnID)
105
106
107
108 s.local = append(s.local, connID{
109 seq: -1,
110 cid: dstConnID,
111 })
112
113
114
115 locid, err := c.newConnID(0)
116 if err != nil {
117 return err
118 }
119 s.local = append(s.local, connID{
120 seq: 0,
121 cid: locid,
122 })
123 s.nextLocalSeq = 1
124 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
125 conns.addConnID(c, dstConnID)
126 conns.addConnID(c, locid)
127 })
128
129
130 s.remote = append(s.remote, remoteConnID{
131 connID: connID{
132 seq: 0,
133 cid: cloneBytes(cids.srcConnID),
134 },
135 })
136 return nil
137 }
138
139
140 func (s *connIDState) srcConnID() []byte {
141 if s.local[0].seq == -1 && len(s.local) > 1 {
142
143 return s.local[1].cid
144 }
145 return s.local[0].cid
146 }
147
148
149 func (s *connIDState) dstConnID() (cid []byte, ok bool) {
150 for i := range s.remote {
151 return s.remote[i].cid, true
152 }
153 return nil, false
154 }
155
156
157
158 func (s *connIDState) isValidStatelessResetToken(resetToken statelessResetToken) bool {
159 if len(s.remote) == 0 {
160 return false
161 }
162
163
164 return s.remote[0].resetToken == resetToken
165 }
166
167
168
169 func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error {
170 s.peerActiveConnIDLimit = lim
171 return s.issueLocalIDs(c)
172 }
173
174 func (s *connIDState) issueLocalIDs(c *Conn) error {
175 toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit)
176 for i := range s.local {
177 if s.local[i].seq != -1 {
178 toIssue--
179 }
180 }
181 var newIDs [][]byte
182 for toIssue > 0 {
183 cid, err := c.newConnID(s.nextLocalSeq)
184 if err != nil {
185 return err
186 }
187 newIDs = append(newIDs, cid)
188 s.local = append(s.local, connID{
189 seq: s.nextLocalSeq,
190 cid: cid,
191 })
192 s.local[len(s.local)-1].send.setUnsent()
193 s.nextLocalSeq++
194 s.needSend = true
195 toIssue--
196 }
197 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
198 for _, cid := range newIDs {
199 conns.addConnID(c, cid)
200 }
201 })
202 return nil
203 }
204
205
206
207 func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p transportParameters) error {
208
209
210
211
212 if !bytes.Equal(s.originalDstConnID, p.originalDstConnID) {
213 return localTransportError{
214 code: errTransportParameter,
215 reason: "original_destination_connection_id mismatch",
216 }
217 }
218 s.originalDstConnID = nil
219
220
221 if !bytes.Equal(p.retrySrcConnID, s.retrySrcConnID) {
222 return localTransportError{
223 code: errTransportParameter,
224 reason: "retry_source_connection_id mismatch",
225 }
226 }
227 s.retrySrcConnID = nil
228
229 if len(s.remote) == 0 || s.remote[0].seq != 0 {
230 return localTransportError{
231 code: errInternal,
232 reason: "remote connection id missing",
233 }
234 }
235 if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) {
236 return localTransportError{
237 code: errTransportParameter,
238 reason: "initial_source_connection_id mismatch",
239 }
240 }
241 if len(p.statelessResetToken) > 0 {
242 if c.side == serverSide {
243 return localTransportError{
244 code: errTransportParameter,
245 reason: "client sent stateless_reset_token",
246 }
247 }
248 token := statelessResetToken(p.statelessResetToken)
249 s.remote[0].resetToken = token
250 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
251 conns.addResetToken(c, token)
252 })
253 }
254 return nil
255 }
256
257
258
259 func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) {
260 switch {
261 case ptype == packetTypeInitial && c.side == clientSide:
262 if len(s.remote) == 1 && s.remote[0].seq == -1 {
263
264
265
266 s.remote[0] = remoteConnID{
267 connID: connID{
268 seq: 0,
269 cid: cloneBytes(srcConnID),
270 },
271 }
272 }
273 case ptype == packetTypeHandshake && c.side == serverSide:
274 if len(s.local) > 0 && s.local[0].seq == -1 {
275
276
277
278 cid := s.local[0].cid
279 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
280 conns.retireConnID(c, cid)
281 })
282 s.local = append(s.local[:0], s.local[1:]...)
283 }
284 }
285 }
286
287 func (s *connIDState) handleRetryPacket(srcConnID []byte) {
288 if len(s.remote) != 1 || s.remote[0].seq != -1 {
289 panic("BUG: handling retry with non-transient remote conn id")
290 }
291 s.retrySrcConnID = cloneBytes(srcConnID)
292 s.remote[0].cid = s.retrySrcConnID
293 }
294
295 func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, resetToken statelessResetToken) error {
296 if len(s.remote[0].cid) == 0 {
297
298
299
300
301 return localTransportError{
302 code: errProtocolViolation,
303 reason: "NEW_CONNECTION_ID from peer with zero-length DCID",
304 }
305 }
306
307 if seq < s.retireRemotePriorTo {
308
309
310 return nil
311 }
312
313 if retire > s.retireRemotePriorTo {
314
315
316
317
318
319 s.remoteRetiring.add(s.retireRemotePriorTo, retire)
320 s.retireRemotePriorTo = retire
321 s.needSend = true
322 s.remote = slices.DeleteFunc(s.remote, func(rcid remoteConnID) bool {
323 return rcid.seq < s.retireRemotePriorTo
324 })
325 }
326
327 have := false
328 for i := range s.remote {
329 rcid := &s.remote[i]
330 if rcid.seq == seq {
331 if !bytes.Equal(rcid.cid, cid) {
332 return localTransportError{
333 code: errProtocolViolation,
334 reason: "NEW_CONNECTION_ID does not match prior id",
335 }
336 }
337 have = true
338 break
339 }
340 }
341
342 if !have {
343
344
345
346
347
348 s.remote = append(s.remote, remoteConnID{
349 connID: connID{
350 seq: seq,
351 cid: cloneBytes(cid),
352 },
353 resetToken: resetToken,
354 })
355 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
356 conns.addResetToken(c, resetToken)
357 })
358 }
359
360 if len(s.remote) > activeConnIDLimit {
361
362
363
364 return localTransportError{
365 code: errConnectionIDLimit,
366 reason: "active_connection_id_limit exceeded",
367 }
368 }
369
370
371
372
373
374
375
376 if s.remoteRetiring.size()+s.remoteRetiringSent.size() > 3*activeConnIDLimit {
377 return localTransportError{
378 code: errConnectionIDLimit,
379 reason: "too many unacknowledged retired connection ids",
380 }
381 }
382
383 return nil
384 }
385
386 func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error {
387 if seq >= s.nextLocalSeq {
388 return localTransportError{
389 code: errProtocolViolation,
390 reason: "RETIRE_CONNECTION_ID for unissued sequence number",
391 }
392 }
393 for i := range s.local {
394 if s.local[i].seq == seq {
395 cid := s.local[i].cid
396 c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
397 conns.retireConnID(c, cid)
398 })
399 s.local = append(s.local[:i], s.local[i+1:]...)
400 break
401 }
402 }
403 s.issueLocalIDs(c)
404 return nil
405 }
406
407 func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fate packetFate) {
408 for i := range s.local {
409 if s.local[i].seq != seq {
410 continue
411 }
412 s.local[i].send.ackOrLoss(pnum, fate)
413 if fate != packetAcked {
414 s.needSend = true
415 }
416 return
417 }
418 }
419
420 func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) {
421 s.remoteRetiringSent.sub(seq, seq+1)
422 if fate == packetLost {
423
424 s.remoteRetiring.add(seq, seq+1)
425 s.needSend = true
426 }
427 }
428
429
430
431
432
433
434 func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool {
435 if !s.needSend && !pto {
436
437 return true
438 }
439 retireBefore := int64(0)
440 if s.local[0].seq != -1 {
441 retireBefore = s.local[0].seq
442 }
443 for i := range s.local {
444 if !s.local[i].send.shouldSendPTO(pto) {
445 continue
446 }
447 if !c.w.appendNewConnectionIDFrame(
448 s.local[i].seq,
449 retireBefore,
450 s.local[i].cid,
451 c.endpoint.resetGen.tokenForConnID(s.local[i].cid),
452 ) {
453 return false
454 }
455 s.local[i].send.setSent(pnum)
456 }
457 if pto {
458 for _, r := range s.remoteRetiringSent {
459 for cid := r.start; cid < r.end; cid++ {
460 if !c.w.appendRetireConnectionIDFrame(cid) {
461 return false
462 }
463 }
464 }
465 }
466 for s.remoteRetiring.numRanges() > 0 {
467 cid := s.remoteRetiring.min()
468 if !c.w.appendRetireConnectionIDFrame(cid) {
469 return false
470 }
471 s.remoteRetiring.sub(cid, cid+1)
472 s.remoteRetiringSent.add(cid, cid+1)
473 }
474 s.needSend = false
475 return true
476 }
477
478 func cloneBytes(b []byte) []byte {
479 n := make([]byte, len(b))
480 copy(n, b)
481 return n
482 }
483
484 func (c *Conn) newConnID(seq int64) ([]byte, error) {
485 if c.testHooks != nil {
486 return c.testHooks.newConnID(seq)
487 }
488 return newRandomConnID(seq)
489 }
490
491 func newRandomConnID(_ int64) ([]byte, error) {
492
493
494 id := make([]byte, connIDLen)
495 if _, err := rand.Read(id); err != nil {
496
497
498
499 return nil, err
500 }
501 return id, nil
502 }
503
View as plain text