Source file src/vendor/golang.org/x/net/quic/endpoint.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package quic
     6  
     7  import (
     8  	"context"
     9  	"crypto/rand"
    10  	"errors"
    11  	"net"
    12  	"net/netip"
    13  	"sync"
    14  	"sync/atomic"
    15  	"time"
    16  )
    17  
    18  // An Endpoint handles QUIC traffic on a network address.
    19  // It can accept inbound connections or create outbound ones.
    20  //
    21  // Multiple goroutines may invoke methods on an Endpoint simultaneously.
    22  type Endpoint struct {
    23  	listenConfig *Config
    24  	packetConn   packetConn
    25  	testHooks    endpointTestHooks
    26  	resetGen     statelessResetTokenGenerator
    27  	retry        retryState
    28  
    29  	acceptQueue queue[*Conn] // new inbound connections
    30  	connsMap    connsMap     // only accessed by the listen loop
    31  
    32  	connsMu sync.Mutex
    33  	conns   map[*Conn]struct{}
    34  	closing bool          // set when Close is called
    35  	closec  chan struct{} // closed when the listen loop exits
    36  }
    37  
    38  type endpointTestHooks interface {
    39  	newConn(c *Conn, cids newServerConnIDs)
    40  }
    41  
    42  // A packetConn is the interface to sending and receiving UDP packets.
    43  type packetConn interface {
    44  	Close() error
    45  	LocalAddr() netip.AddrPort
    46  	Read(f func(*datagram))
    47  	Write(datagram) error
    48  }
    49  
    50  // Listen listens on a local network address.
    51  //
    52  // The config is used to for connections accepted by the endpoint.
    53  // If the config is nil, the endpoint will not accept connections.
    54  func Listen(network, address string, listenConfig *Config) (*Endpoint, error) {
    55  	if listenConfig != nil && listenConfig.TLSConfig == nil {
    56  		return nil, errors.New("TLSConfig is not set")
    57  	}
    58  	a, err := net.ResolveUDPAddr(network, address)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	udpConn, err := net.ListenUDP(network, a)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	pc, err := newNetUDPConn(udpConn)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  	return newEndpoint(pc, listenConfig, nil)
    71  }
    72  
    73  // NewEndpoint creates an endpoint using a net.PacketConn as the underlying transport.
    74  //
    75  // If the PacketConn is not a *net.UDPConn, the endpoint may be slower and lack
    76  // access to some features of the network.
    77  func NewEndpoint(conn net.PacketConn, config *Config) (*Endpoint, error) {
    78  	var pc packetConn
    79  	var err error
    80  	switch conn := conn.(type) {
    81  	case *net.UDPConn:
    82  		pc, err = newNetUDPConn(conn)
    83  	default:
    84  		pc, err = newNetPacketConn(conn)
    85  	}
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	return newEndpoint(pc, config, nil)
    90  }
    91  
    92  func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) {
    93  	e := &Endpoint{
    94  		listenConfig: config,
    95  		packetConn:   pc,
    96  		testHooks:    hooks,
    97  		conns:        make(map[*Conn]struct{}),
    98  		acceptQueue:  newQueue[*Conn](),
    99  		closec:       make(chan struct{}),
   100  	}
   101  	var statelessResetKey [32]byte
   102  	if config != nil {
   103  		statelessResetKey = config.StatelessResetKey
   104  	}
   105  	e.resetGen.init(statelessResetKey)
   106  	e.connsMap.init()
   107  	if config != nil && config.RequireAddressValidation {
   108  		if err := e.retry.init(); err != nil {
   109  			return nil, err
   110  		}
   111  	}
   112  	go e.listen()
   113  	return e, nil
   114  }
   115  
   116  // LocalAddr returns the local network address.
   117  func (e *Endpoint) LocalAddr() netip.AddrPort {
   118  	return e.packetConn.LocalAddr()
   119  }
   120  
   121  // Close closes the Endpoint.
   122  // Any blocked operations on the Endpoint or associated Conns and Stream will be unblocked
   123  // and return errors.
   124  //
   125  // Close aborts every open connection.
   126  // Data in stream read and write buffers is discarded.
   127  // It waits for the peers of any open connection to acknowledge the connection has been closed.
   128  func (e *Endpoint) Close(ctx context.Context) error {
   129  	e.acceptQueue.close(errors.New("endpoint closed"))
   130  
   131  	// It isn't safe to call Conn.Abort or conn.exit with connsMu held,
   132  	// so copy the list of conns.
   133  	var conns []*Conn
   134  	e.connsMu.Lock()
   135  	if !e.closing {
   136  		e.closing = true // setting e.closing prevents new conns from being created
   137  		for c := range e.conns {
   138  			conns = append(conns, c)
   139  		}
   140  		if len(e.conns) == 0 {
   141  			e.packetConn.Close()
   142  		}
   143  	}
   144  	e.connsMu.Unlock()
   145  
   146  	for _, c := range conns {
   147  		c.Abort(localTransportError{code: errNo})
   148  	}
   149  	select {
   150  	case <-e.closec:
   151  	case <-ctx.Done():
   152  		for _, c := range conns {
   153  			c.exit()
   154  		}
   155  		return ctx.Err()
   156  	}
   157  	return nil
   158  }
   159  
   160  // Accept waits for and returns the next connection.
   161  func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) {
   162  	return e.acceptQueue.get(ctx)
   163  }
   164  
   165  // Dial creates and returns a connection to a network address.
   166  // The config cannot be nil.
   167  func (e *Endpoint) Dial(ctx context.Context, network, address string, config *Config) (*Conn, error) {
   168  	u, err := net.ResolveUDPAddr(network, address)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  	addr := u.AddrPort()
   173  	addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
   174  	c, err := e.newConn(time.Now(), config, clientSide, newServerConnIDs{}, address, addr)
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	if err := c.waitReady(ctx); err != nil {
   179  		c.Abort(nil)
   180  		return nil, err
   181  	}
   182  	return c, nil
   183  }
   184  
   185  func (e *Endpoint) newConn(now time.Time, config *Config, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort) (*Conn, error) {
   186  	e.connsMu.Lock()
   187  	defer e.connsMu.Unlock()
   188  	if e.closing {
   189  		return nil, errors.New("endpoint closed")
   190  	}
   191  	c, err := newConn(now, side, cids, peerHostname, peerAddr, config, e)
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  	e.conns[c] = struct{}{}
   196  	return c, nil
   197  }
   198  
   199  // serverConnEstablished is called by a conn when the handshake completes
   200  // for an inbound (serverSide) connection.
   201  func (e *Endpoint) serverConnEstablished(c *Conn) {
   202  	e.acceptQueue.put(c)
   203  }
   204  
   205  // connDrained is called by a conn when it leaves the draining state,
   206  // either when the peer acknowledges connection closure or the drain timeout expires.
   207  func (e *Endpoint) connDrained(c *Conn) {
   208  	var cids [][]byte
   209  	for i := range c.connIDState.local {
   210  		cids = append(cids, c.connIDState.local[i].cid)
   211  	}
   212  	var tokens []statelessResetToken
   213  	for i := range c.connIDState.remote {
   214  		tokens = append(tokens, c.connIDState.remote[i].resetToken)
   215  	}
   216  	e.connsMap.updateConnIDs(func(conns *connsMap) {
   217  		for _, cid := range cids {
   218  			conns.retireConnID(c, cid)
   219  		}
   220  		for _, token := range tokens {
   221  			conns.retireResetToken(c, token)
   222  		}
   223  	})
   224  	e.connsMu.Lock()
   225  	defer e.connsMu.Unlock()
   226  	delete(e.conns, c)
   227  	if e.closing && len(e.conns) == 0 {
   228  		e.packetConn.Close()
   229  	}
   230  }
   231  
   232  func (e *Endpoint) listen() {
   233  	defer close(e.closec)
   234  	e.packetConn.Read(func(m *datagram) {
   235  		if e.connsMap.updateNeeded.Load() {
   236  			e.connsMap.applyUpdates()
   237  		}
   238  		e.handleDatagram(m)
   239  	})
   240  }
   241  
   242  func (e *Endpoint) handleDatagram(m *datagram) {
   243  	dstConnID, ok := dstConnIDForDatagram(m.b)
   244  	if !ok {
   245  		m.recycle()
   246  		return
   247  	}
   248  	c := e.connsMap.byConnID[string(dstConnID)]
   249  	if c == nil {
   250  		// TODO: Move this branch into a separate goroutine to avoid blocking
   251  		// the endpoint while processing packets.
   252  		e.handleUnknownDestinationDatagram(m)
   253  		return
   254  	}
   255  
   256  	// TODO: This can block the endpoint while waiting for the conn to accept the dgram.
   257  	// Think about buffering between the receive loop and the conn.
   258  	c.sendMsg(m)
   259  }
   260  
   261  func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) {
   262  	defer func() {
   263  		if m != nil {
   264  			m.recycle()
   265  		}
   266  	}()
   267  	const minimumValidPacketSize = 21
   268  	if len(m.b) < minimumValidPacketSize {
   269  		return
   270  	}
   271  	now := time.Now()
   272  	// Check to see if this is a stateless reset.
   273  	var token statelessResetToken
   274  	copy(token[:], m.b[len(m.b)-len(token):])
   275  	if c := e.connsMap.byResetToken[token]; c != nil {
   276  		c.sendMsg(func(now time.Time, c *Conn) {
   277  			c.handleStatelessReset(now, token)
   278  		})
   279  		return
   280  	}
   281  	// If this is a 1-RTT packet, there's nothing productive we can do with it.
   282  	// Send a stateless reset if possible.
   283  	if !isLongHeader(m.b[0]) {
   284  		e.maybeSendStatelessReset(m.b, m.peerAddr)
   285  		return
   286  	}
   287  	p, ok := parseGenericLongHeaderPacket(m.b)
   288  	if !ok || len(m.b) < paddedInitialDatagramSize {
   289  		return
   290  	}
   291  	switch p.version {
   292  	case quicVersion1:
   293  	case 0:
   294  		// Version Negotiation for an unknown connection.
   295  		return
   296  	default:
   297  		// Unknown version.
   298  		e.sendVersionNegotiation(p, m.peerAddr)
   299  		return
   300  	}
   301  	if getPacketType(m.b) != packetTypeInitial {
   302  		// This packet isn't trying to create a new connection.
   303  		// It might be associated with some connection we've lost state for.
   304  		// We are technically permitted to send a stateless reset for
   305  		// a long-header packet, but this isn't generally useful. See:
   306  		// https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16
   307  		return
   308  	}
   309  	if e.listenConfig == nil {
   310  		// We are not configured to accept connections.
   311  		return
   312  	}
   313  	if len(p.srcConnID) > maxConnIDLen || len(p.dstConnID) > maxConnIDLen {
   314  		// Enforce QUICv1 connection ID length limits.
   315  		// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.12.1
   316  		// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.16.1
   317  		return
   318  	}
   319  	cids := newServerConnIDs{
   320  		srcConnID: p.srcConnID,
   321  		dstConnID: p.dstConnID,
   322  	}
   323  	if e.listenConfig.RequireAddressValidation {
   324  		var ok bool
   325  		cids.retrySrcConnID = p.dstConnID
   326  		cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.peerAddr)
   327  		if !ok {
   328  			return
   329  		}
   330  	} else {
   331  		cids.originalDstConnID = p.dstConnID
   332  	}
   333  	var err error
   334  	c, err := e.newConn(now, e.listenConfig, serverSide, cids, "", m.peerAddr)
   335  	if err != nil {
   336  		// The accept queue is probably full.
   337  		// We could send a CONNECTION_CLOSE to the peer to reject the connection.
   338  		// Currently, we just drop the datagram.
   339  		// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5
   340  		return
   341  	}
   342  	c.sendMsg(m)
   343  	m = nil // don't recycle, sendMsg takes ownership
   344  }
   345  
   346  func (e *Endpoint) maybeSendStatelessReset(b []byte, peerAddr netip.AddrPort) {
   347  	if !e.resetGen.canReset {
   348  		// Config.StatelessResetKey isn't set, so we don't send stateless resets.
   349  		return
   350  	}
   351  	// The smallest possible valid packet a peer can send us is:
   352  	//   1 byte of header
   353  	//   connIDLen bytes of destination connection ID
   354  	//   1 byte of packet number
   355  	//   1 byte of payload
   356  	//   16 bytes AEAD expansion
   357  	if len(b) < 1+connIDLen+1+1+16 {
   358  		return
   359  	}
   360  	// TODO: Rate limit stateless resets.
   361  	cid := b[1:][:connIDLen]
   362  	token := e.resetGen.tokenForConnID(cid)
   363  	// We want to generate a stateless reset that is as short as possible,
   364  	// but long enough to be difficult to distinguish from a 1-RTT packet.
   365  	//
   366  	// The minimal 1-RTT packet is:
   367  	//   1 byte of header
   368  	//   0-20 bytes of destination connection ID
   369  	//   1-4 bytes of packet number
   370  	//   1 byte of payload
   371  	//   16 bytes AEAD expansion
   372  	//
   373  	// Assuming the maximum possible connection ID and packet number size,
   374  	// this gives 1 + 20 + 4 + 1 + 16 = 42 bytes.
   375  	//
   376  	// We also must generate a stateless reset that is shorter than the datagram
   377  	// we are responding to, in order to ensure that reset loops terminate.
   378  	//
   379  	// See: https://www.rfc-editor.org/rfc/rfc9000#section-10.3
   380  	size := min(len(b)-1, 42)
   381  	// Reuse the input buffer for generating the stateless reset.
   382  	b = b[:size]
   383  	rand.Read(b[:len(b)-statelessResetTokenLen])
   384  	b[0] &^= headerFormLong // clear long header bit
   385  	b[0] |= fixedBit        // set fixed bit
   386  	copy(b[len(b)-statelessResetTokenLen:], token[:])
   387  	e.sendDatagram(datagram{
   388  		b:        b,
   389  		peerAddr: peerAddr,
   390  	})
   391  }
   392  
   393  func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, peerAddr netip.AddrPort) {
   394  	m := newDatagram()
   395  	m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1)
   396  	m.peerAddr = peerAddr
   397  	e.sendDatagram(*m)
   398  	m.recycle()
   399  }
   400  
   401  func (e *Endpoint) sendConnectionClose(in genericLongPacket, peerAddr netip.AddrPort, code transportError) {
   402  	keys := initialKeys(in.dstConnID, serverSide)
   403  	var w packetWriter
   404  	p := longPacket{
   405  		ptype:     packetTypeInitial,
   406  		version:   quicVersion1,
   407  		num:       0,
   408  		dstConnID: in.srcConnID,
   409  		srcConnID: in.dstConnID,
   410  	}
   411  	const pnumMaxAcked = 0
   412  	w.reset(paddedInitialDatagramSize)
   413  	w.startProtectedLongHeaderPacket(pnumMaxAcked, p)
   414  	w.appendConnectionCloseTransportFrame(code, 0, "")
   415  	w.finishProtectedLongHeaderPacket(pnumMaxAcked, keys.w, p)
   416  	buf := w.datagram()
   417  	if len(buf) == 0 {
   418  		return
   419  	}
   420  	e.sendDatagram(datagram{
   421  		b:        buf,
   422  		peerAddr: peerAddr,
   423  	})
   424  }
   425  
   426  func (e *Endpoint) sendDatagram(dgram datagram) error {
   427  	return e.packetConn.Write(dgram)
   428  }
   429  
   430  // A connsMap is an endpoint's mapping of conn ids and reset tokens to conns.
   431  type connsMap struct {
   432  	byConnID     map[string]*Conn
   433  	byResetToken map[statelessResetToken]*Conn
   434  
   435  	updateMu     sync.Mutex
   436  	updateNeeded atomic.Bool
   437  	updates      []func(*connsMap)
   438  }
   439  
   440  func (m *connsMap) init() {
   441  	m.byConnID = map[string]*Conn{}
   442  	m.byResetToken = map[statelessResetToken]*Conn{}
   443  }
   444  
   445  func (m *connsMap) addConnID(c *Conn, cid []byte) {
   446  	m.byConnID[string(cid)] = c
   447  }
   448  
   449  func (m *connsMap) retireConnID(c *Conn, cid []byte) {
   450  	delete(m.byConnID, string(cid))
   451  }
   452  
   453  func (m *connsMap) addResetToken(c *Conn, token statelessResetToken) {
   454  	m.byResetToken[token] = c
   455  }
   456  
   457  func (m *connsMap) retireResetToken(c *Conn, token statelessResetToken) {
   458  	delete(m.byResetToken, token)
   459  }
   460  
   461  func (m *connsMap) updateConnIDs(f func(*connsMap)) {
   462  	m.updateMu.Lock()
   463  	defer m.updateMu.Unlock()
   464  	m.updates = append(m.updates, f)
   465  	m.updateNeeded.Store(true)
   466  }
   467  
   468  // applyUpdates is called by the datagram receive loop to update its connection ID map.
   469  func (m *connsMap) applyUpdates() {
   470  	m.updateMu.Lock()
   471  	defer m.updateMu.Unlock()
   472  	for _, f := range m.updates {
   473  		f(m)
   474  	}
   475  	clear(m.updates)
   476  	m.updates = m.updates[:0]
   477  	m.updateNeeded.Store(false)
   478  }
   479  

View as plain text