1
2
3
4
5 package quic
6
7 import (
8 "context"
9 "crypto/rand"
10 "errors"
11 "net"
12 "net/netip"
13 "sync"
14 "sync/atomic"
15 "time"
16 )
17
18
19
20
21
22 type Endpoint struct {
23 listenConfig *Config
24 packetConn packetConn
25 testHooks endpointTestHooks
26 resetGen statelessResetTokenGenerator
27 retry retryState
28
29 acceptQueue queue[*Conn]
30 connsMap connsMap
31
32 connsMu sync.Mutex
33 conns map[*Conn]struct{}
34 closing bool
35 closec chan struct{}
36 }
37
38 type endpointTestHooks interface {
39 newConn(c *Conn, cids newServerConnIDs)
40 }
41
42
43 type packetConn interface {
44 Close() error
45 LocalAddr() netip.AddrPort
46 Read(f func(*datagram))
47 Write(datagram) error
48 }
49
50
51
52
53
54 func Listen(network, address string, listenConfig *Config) (*Endpoint, error) {
55 if listenConfig != nil && listenConfig.TLSConfig == nil {
56 return nil, errors.New("TLSConfig is not set")
57 }
58 a, err := net.ResolveUDPAddr(network, address)
59 if err != nil {
60 return nil, err
61 }
62 udpConn, err := net.ListenUDP(network, a)
63 if err != nil {
64 return nil, err
65 }
66 pc, err := newNetUDPConn(udpConn)
67 if err != nil {
68 return nil, err
69 }
70 return newEndpoint(pc, listenConfig, nil)
71 }
72
73
74
75
76
77 func NewEndpoint(conn net.PacketConn, config *Config) (*Endpoint, error) {
78 var pc packetConn
79 var err error
80 switch conn := conn.(type) {
81 case *net.UDPConn:
82 pc, err = newNetUDPConn(conn)
83 default:
84 pc, err = newNetPacketConn(conn)
85 }
86 if err != nil {
87 return nil, err
88 }
89 return newEndpoint(pc, config, nil)
90 }
91
92 func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) {
93 e := &Endpoint{
94 listenConfig: config,
95 packetConn: pc,
96 testHooks: hooks,
97 conns: make(map[*Conn]struct{}),
98 acceptQueue: newQueue[*Conn](),
99 closec: make(chan struct{}),
100 }
101 var statelessResetKey [32]byte
102 if config != nil {
103 statelessResetKey = config.StatelessResetKey
104 }
105 e.resetGen.init(statelessResetKey)
106 e.connsMap.init()
107 if config != nil && config.RequireAddressValidation {
108 if err := e.retry.init(); err != nil {
109 return nil, err
110 }
111 }
112 go e.listen()
113 return e, nil
114 }
115
116
117 func (e *Endpoint) LocalAddr() netip.AddrPort {
118 return e.packetConn.LocalAddr()
119 }
120
121
122
123
124
125
126
127
128 func (e *Endpoint) Close(ctx context.Context) error {
129 e.acceptQueue.close(errors.New("endpoint closed"))
130
131
132
133 var conns []*Conn
134 e.connsMu.Lock()
135 if !e.closing {
136 e.closing = true
137 for c := range e.conns {
138 conns = append(conns, c)
139 }
140 if len(e.conns) == 0 {
141 e.packetConn.Close()
142 }
143 }
144 e.connsMu.Unlock()
145
146 for _, c := range conns {
147 c.Abort(localTransportError{code: errNo})
148 }
149 select {
150 case <-e.closec:
151 case <-ctx.Done():
152 for _, c := range conns {
153 c.exit()
154 }
155 return ctx.Err()
156 }
157 return nil
158 }
159
160
161 func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) {
162 return e.acceptQueue.get(ctx)
163 }
164
165
166
167 func (e *Endpoint) Dial(ctx context.Context, network, address string, config *Config) (*Conn, error) {
168 u, err := net.ResolveUDPAddr(network, address)
169 if err != nil {
170 return nil, err
171 }
172 addr := u.AddrPort()
173 addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
174 c, err := e.newConn(time.Now(), config, clientSide, newServerConnIDs{}, address, addr)
175 if err != nil {
176 return nil, err
177 }
178 if err := c.waitReady(ctx); err != nil {
179 c.Abort(nil)
180 return nil, err
181 }
182 return c, nil
183 }
184
185 func (e *Endpoint) newConn(now time.Time, config *Config, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort) (*Conn, error) {
186 e.connsMu.Lock()
187 defer e.connsMu.Unlock()
188 if e.closing {
189 return nil, errors.New("endpoint closed")
190 }
191 c, err := newConn(now, side, cids, peerHostname, peerAddr, config, e)
192 if err != nil {
193 return nil, err
194 }
195 e.conns[c] = struct{}{}
196 return c, nil
197 }
198
199
200
201 func (e *Endpoint) serverConnEstablished(c *Conn) {
202 e.acceptQueue.put(c)
203 }
204
205
206
207 func (e *Endpoint) connDrained(c *Conn) {
208 var cids [][]byte
209 for i := range c.connIDState.local {
210 cids = append(cids, c.connIDState.local[i].cid)
211 }
212 var tokens []statelessResetToken
213 for i := range c.connIDState.remote {
214 tokens = append(tokens, c.connIDState.remote[i].resetToken)
215 }
216 e.connsMap.updateConnIDs(func(conns *connsMap) {
217 for _, cid := range cids {
218 conns.retireConnID(c, cid)
219 }
220 for _, token := range tokens {
221 conns.retireResetToken(c, token)
222 }
223 })
224 e.connsMu.Lock()
225 defer e.connsMu.Unlock()
226 delete(e.conns, c)
227 if e.closing && len(e.conns) == 0 {
228 e.packetConn.Close()
229 }
230 }
231
232 func (e *Endpoint) listen() {
233 defer close(e.closec)
234 e.packetConn.Read(func(m *datagram) {
235 if e.connsMap.updateNeeded.Load() {
236 e.connsMap.applyUpdates()
237 }
238 e.handleDatagram(m)
239 })
240 }
241
242 func (e *Endpoint) handleDatagram(m *datagram) {
243 dstConnID, ok := dstConnIDForDatagram(m.b)
244 if !ok {
245 m.recycle()
246 return
247 }
248 c := e.connsMap.byConnID[string(dstConnID)]
249 if c == nil {
250
251
252 e.handleUnknownDestinationDatagram(m)
253 return
254 }
255
256
257
258 c.sendMsg(m)
259 }
260
261 func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) {
262 defer func() {
263 if m != nil {
264 m.recycle()
265 }
266 }()
267 const minimumValidPacketSize = 21
268 if len(m.b) < minimumValidPacketSize {
269 return
270 }
271 now := time.Now()
272
273 var token statelessResetToken
274 copy(token[:], m.b[len(m.b)-len(token):])
275 if c := e.connsMap.byResetToken[token]; c != nil {
276 c.sendMsg(func(now time.Time, c *Conn) {
277 c.handleStatelessReset(now, token)
278 })
279 return
280 }
281
282
283 if !isLongHeader(m.b[0]) {
284 e.maybeSendStatelessReset(m.b, m.peerAddr)
285 return
286 }
287 p, ok := parseGenericLongHeaderPacket(m.b)
288 if !ok || len(m.b) < paddedInitialDatagramSize {
289 return
290 }
291 switch p.version {
292 case quicVersion1:
293 case 0:
294
295 return
296 default:
297
298 e.sendVersionNegotiation(p, m.peerAddr)
299 return
300 }
301 if getPacketType(m.b) != packetTypeInitial {
302
303
304
305
306
307 return
308 }
309 if e.listenConfig == nil {
310
311 return
312 }
313 if len(p.srcConnID) > maxConnIDLen || len(p.dstConnID) > maxConnIDLen {
314
315
316
317 return
318 }
319 cids := newServerConnIDs{
320 srcConnID: p.srcConnID,
321 dstConnID: p.dstConnID,
322 }
323 if e.listenConfig.RequireAddressValidation {
324 var ok bool
325 cids.retrySrcConnID = p.dstConnID
326 cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.peerAddr)
327 if !ok {
328 return
329 }
330 } else {
331 cids.originalDstConnID = p.dstConnID
332 }
333 var err error
334 c, err := e.newConn(now, e.listenConfig, serverSide, cids, "", m.peerAddr)
335 if err != nil {
336
337
338
339
340 return
341 }
342 c.sendMsg(m)
343 m = nil
344 }
345
346 func (e *Endpoint) maybeSendStatelessReset(b []byte, peerAddr netip.AddrPort) {
347 if !e.resetGen.canReset {
348
349 return
350 }
351
352
353
354
355
356
357 if len(b) < 1+connIDLen+1+1+16 {
358 return
359 }
360
361 cid := b[1:][:connIDLen]
362 token := e.resetGen.tokenForConnID(cid)
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380 size := min(len(b)-1, 42)
381
382 b = b[:size]
383 rand.Read(b[:len(b)-statelessResetTokenLen])
384 b[0] &^= headerFormLong
385 b[0] |= fixedBit
386 copy(b[len(b)-statelessResetTokenLen:], token[:])
387 e.sendDatagram(datagram{
388 b: b,
389 peerAddr: peerAddr,
390 })
391 }
392
393 func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, peerAddr netip.AddrPort) {
394 m := newDatagram()
395 m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1)
396 m.peerAddr = peerAddr
397 e.sendDatagram(*m)
398 m.recycle()
399 }
400
401 func (e *Endpoint) sendConnectionClose(in genericLongPacket, peerAddr netip.AddrPort, code transportError) {
402 keys := initialKeys(in.dstConnID, serverSide)
403 var w packetWriter
404 p := longPacket{
405 ptype: packetTypeInitial,
406 version: quicVersion1,
407 num: 0,
408 dstConnID: in.srcConnID,
409 srcConnID: in.dstConnID,
410 }
411 const pnumMaxAcked = 0
412 w.reset(paddedInitialDatagramSize)
413 w.startProtectedLongHeaderPacket(pnumMaxAcked, p)
414 w.appendConnectionCloseTransportFrame(code, 0, "")
415 w.finishProtectedLongHeaderPacket(pnumMaxAcked, keys.w, p)
416 buf := w.datagram()
417 if len(buf) == 0 {
418 return
419 }
420 e.sendDatagram(datagram{
421 b: buf,
422 peerAddr: peerAddr,
423 })
424 }
425
426 func (e *Endpoint) sendDatagram(dgram datagram) error {
427 return e.packetConn.Write(dgram)
428 }
429
430
431 type connsMap struct {
432 byConnID map[string]*Conn
433 byResetToken map[statelessResetToken]*Conn
434
435 updateMu sync.Mutex
436 updateNeeded atomic.Bool
437 updates []func(*connsMap)
438 }
439
440 func (m *connsMap) init() {
441 m.byConnID = map[string]*Conn{}
442 m.byResetToken = map[statelessResetToken]*Conn{}
443 }
444
445 func (m *connsMap) addConnID(c *Conn, cid []byte) {
446 m.byConnID[string(cid)] = c
447 }
448
449 func (m *connsMap) retireConnID(c *Conn, cid []byte) {
450 delete(m.byConnID, string(cid))
451 }
452
453 func (m *connsMap) addResetToken(c *Conn, token statelessResetToken) {
454 m.byResetToken[token] = c
455 }
456
457 func (m *connsMap) retireResetToken(c *Conn, token statelessResetToken) {
458 delete(m.byResetToken, token)
459 }
460
461 func (m *connsMap) updateConnIDs(f func(*connsMap)) {
462 m.updateMu.Lock()
463 defer m.updateMu.Unlock()
464 m.updates = append(m.updates, f)
465 m.updateNeeded.Store(true)
466 }
467
468
469 func (m *connsMap) applyUpdates() {
470 m.updateMu.Lock()
471 defer m.updateMu.Unlock()
472 for _, f := range m.updates {
473 f(m)
474 }
475 clear(m.updates)
476 m.updates = m.updates[:0]
477 m.updateNeeded.Store(false)
478 }
479
View as plain text