Source file src/vendor/golang.org/x/net/quic/conn.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package quic
     6  
     7  import (
     8  	"context"
     9  	cryptorand "crypto/rand"
    10  	"crypto/tls"
    11  	"errors"
    12  	"fmt"
    13  	"log/slog"
    14  	"math/rand/v2"
    15  	"net/netip"
    16  	"time"
    17  )
    18  
    19  // A Conn is a QUIC connection.
    20  //
    21  // Multiple goroutines may invoke methods on a Conn simultaneously.
    22  type Conn struct {
    23  	side      connSide
    24  	endpoint  *Endpoint
    25  	config    *Config
    26  	testHooks connTestHooks
    27  	peerAddr  netip.AddrPort
    28  	localAddr netip.AddrPort
    29  	prng      *rand.Rand
    30  
    31  	msgc  chan any
    32  	donec chan struct{} // closed when conn loop exits
    33  
    34  	w           packetWriter
    35  	acks        [numberSpaceCount]ackState // indexed by number space
    36  	lifetime    lifetimeState
    37  	idle        idleState
    38  	connIDState connIDState
    39  	loss        lossState
    40  	streams     streamsState
    41  	path        pathState
    42  	skip        skipState
    43  
    44  	// Packet protection keys, CRYPTO streams, and TLS state.
    45  	keysInitial   fixedKeyPair
    46  	keysHandshake fixedKeyPair
    47  	keysAppData   updatingKeyPair
    48  	crypto        [numberSpaceCount]cryptoStream
    49  	tls           *tls.QUICConn
    50  
    51  	// retryToken is the token provided by the peer in a Retry packet.
    52  	retryToken []byte
    53  
    54  	// handshakeConfirmed is set when the handshake is confirmed.
    55  	// For server connections, it tracks sending HANDSHAKE_DONE.
    56  	handshakeConfirmed sentVal
    57  
    58  	peerAckDelayExponent int8 // -1 when unknown
    59  
    60  	// Tests only: Send a PING in a specific number space.
    61  	testSendPingSpace numberSpace
    62  	testSendPing      sentVal
    63  
    64  	log *slog.Logger
    65  }
    66  
    67  // connTestHooks override conn behavior in tests.
    68  type connTestHooks interface {
    69  	// init is called after a conn is created.
    70  	init(first bool)
    71  
    72  	// handleTLSEvent is called with each TLS event.
    73  	handleTLSEvent(tls.QUICEvent)
    74  
    75  	// newConnID is called to generate a new connection ID.
    76  	// Permits tests to generate consistent connection IDs rather than random ones.
    77  	newConnID(seq int64) ([]byte, error)
    78  }
    79  
    80  // newServerConnIDs is connection IDs associated with a new server connection.
    81  type newServerConnIDs struct {
    82  	srcConnID         []byte // source from client's current Initial
    83  	dstConnID         []byte // destination from client's current Initial
    84  	originalDstConnID []byte // destination from client's first Initial
    85  	retrySrcConnID    []byte // source from server's Retry
    86  }
    87  
    88  func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) {
    89  	c := &Conn{
    90  		side:                 side,
    91  		endpoint:             e,
    92  		config:               config,
    93  		peerAddr:             unmapAddrPort(peerAddr),
    94  		donec:                make(chan struct{}),
    95  		peerAckDelayExponent: -1,
    96  	}
    97  	defer func() {
    98  		// If we hit an error in newConn, close donec so tests don't get stuck waiting for it.
    99  		// This is only relevant if we've got a bug, but it makes tracking that bug down
   100  		// much easier.
   101  		if conn == nil {
   102  			close(c.donec)
   103  		}
   104  	}()
   105  
   106  	// A one-element buffer allows us to wake a Conn's event loop as a
   107  	// non-blocking operation.
   108  	c.msgc = make(chan any, 1)
   109  
   110  	if e.testHooks != nil {
   111  		e.testHooks.newConn(c, cids)
   112  	}
   113  
   114  	// initialConnID is the connection ID used to generate Initial packet protection keys.
   115  	var initialConnID []byte
   116  	if c.side == clientSide {
   117  		if err := c.connIDState.initClient(c); err != nil {
   118  			return nil, err
   119  		}
   120  		initialConnID, _ = c.connIDState.dstConnID()
   121  	} else {
   122  		initialConnID = cids.originalDstConnID
   123  		if cids.retrySrcConnID != nil {
   124  			initialConnID = cids.retrySrcConnID
   125  		}
   126  		if err := c.connIDState.initServer(c, cids); err != nil {
   127  			return nil, err
   128  		}
   129  	}
   130  
   131  	// A per-conn ChaCha8 PRNG is probably more than we need,
   132  	// but at least it's fairly small.
   133  	var seed [32]byte
   134  	if _, err := cryptorand.Read(seed[:]); err != nil {
   135  		panic(err)
   136  	}
   137  	c.prng = rand.New(rand.NewChaCha8(seed))
   138  
   139  	// TODO: PMTU discovery.
   140  	c.logConnectionStarted(cids.originalDstConnID, peerAddr)
   141  	c.keysAppData.init()
   142  	c.loss.init(c.side, smallestMaxDatagramSize, now)
   143  	c.streamsInit()
   144  	c.lifetimeInit()
   145  	c.restartIdleTimer(now)
   146  	c.skip.init(c)
   147  
   148  	if err := c.startTLS(now, initialConnID, peerHostname, transportParameters{
   149  		initialSrcConnID:               c.connIDState.srcConnID(),
   150  		originalDstConnID:              cids.originalDstConnID,
   151  		retrySrcConnID:                 cids.retrySrcConnID,
   152  		ackDelayExponent:               ackDelayExponent,
   153  		maxUDPPayloadSize:              maxUDPPayloadSize,
   154  		maxAckDelay:                    maxAckDelay,
   155  		disableActiveMigration:         true,
   156  		initialMaxData:                 config.maxConnReadBufferSize(),
   157  		initialMaxStreamDataBidiLocal:  config.maxStreamReadBufferSize(),
   158  		initialMaxStreamDataBidiRemote: config.maxStreamReadBufferSize(),
   159  		initialMaxStreamDataUni:        config.maxStreamReadBufferSize(),
   160  		initialMaxStreamsBidi:          c.streams.remoteLimit[bidiStream].max,
   161  		initialMaxStreamsUni:           c.streams.remoteLimit[uniStream].max,
   162  		activeConnIDLimit:              activeConnIDLimit,
   163  	}); err != nil {
   164  		return nil, err
   165  	}
   166  
   167  	if c.testHooks != nil {
   168  		c.testHooks.init(true)
   169  	}
   170  	go c.loop(now)
   171  	return c, nil
   172  }
   173  
   174  func (c *Conn) String() string {
   175  	return fmt.Sprintf("quic.Conn(%v,->%v)", c.side, c.peerAddr)
   176  }
   177  
   178  // LocalAddr returns the local network address, if known.
   179  func (c *Conn) LocalAddr() netip.AddrPort {
   180  	return c.localAddr
   181  }
   182  
   183  // RemoteAddr returns the remote network address, if known.
   184  func (c *Conn) RemoteAddr() netip.AddrPort {
   185  	return c.peerAddr
   186  }
   187  
   188  // ConnectionState returns basic TLS details about the connection.
   189  func (c *Conn) ConnectionState() tls.ConnectionState {
   190  	return c.tls.ConnectionState()
   191  }
   192  
   193  // confirmHandshake is called when the handshake is confirmed.
   194  // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2
   195  func (c *Conn) confirmHandshake(now time.Time) {
   196  	// If handshakeConfirmed is unset, the handshake is not confirmed.
   197  	// If it is unsent, the handshake is confirmed and we need to send a HANDSHAKE_DONE.
   198  	// If it is sent, we have sent a HANDSHAKE_DONE.
   199  	// If it is received, the handshake is confirmed and we do not need to send anything.
   200  	if c.handshakeConfirmed.isSet() {
   201  		return // already confirmed
   202  	}
   203  	if c.side == serverSide {
   204  		// When the server confirms the handshake, it sends a HANDSHAKE_DONE.
   205  		c.handshakeConfirmed.setUnsent()
   206  		c.endpoint.serverConnEstablished(c)
   207  	} else {
   208  		// The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed
   209  		// to the received state, indicating that the handshake is confirmed and we
   210  		// don't need to send anything.
   211  		c.handshakeConfirmed.setReceived()
   212  	}
   213  	c.restartIdleTimer(now)
   214  	c.loss.confirmHandshake()
   215  	// "An endpoint MUST discard its Handshake keys when the TLS handshake is confirmed"
   216  	// https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2-1
   217  	c.discardKeys(now, handshakeSpace)
   218  }
   219  
   220  // discardKeys discards unused packet protection keys.
   221  // https://www.rfc-editor.org/rfc/rfc9001#section-4.9
   222  func (c *Conn) discardKeys(now time.Time, space numberSpace) {
   223  	if err := c.crypto[space].discardKeys(); err != nil {
   224  		c.abort(now, err)
   225  	}
   226  	switch space {
   227  	case initialSpace:
   228  		c.keysInitial.discard()
   229  	case handshakeSpace:
   230  		c.keysHandshake.discard()
   231  	}
   232  	c.loss.discardKeys(now, c.log, space)
   233  }
   234  
   235  // receiveTransportParameters applies transport parameters sent by the peer.
   236  func (c *Conn) receiveTransportParameters(p transportParameters) error {
   237  	isRetry := c.retryToken != nil
   238  	if err := c.connIDState.validateTransportParameters(c, isRetry, p); err != nil {
   239  		return err
   240  	}
   241  	c.streams.outflow.setMaxData(p.initialMaxData)
   242  	c.streams.localLimit[bidiStream].setMax(p.initialMaxStreamsBidi)
   243  	c.streams.localLimit[uniStream].setMax(p.initialMaxStreamsUni)
   244  	c.streams.peerInitialMaxStreamDataBidiLocal = p.initialMaxStreamDataBidiLocal
   245  	c.streams.peerInitialMaxStreamDataRemote[bidiStream] = p.initialMaxStreamDataBidiRemote
   246  	c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni
   247  	c.receivePeerMaxIdleTimeout(p.maxIdleTimeout)
   248  	c.peerAckDelayExponent = p.ackDelayExponent
   249  	c.loss.setMaxAckDelay(p.maxAckDelay)
   250  	if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil {
   251  		return err
   252  	}
   253  	if p.preferredAddrConnID != nil {
   254  		var (
   255  			seq           int64 = 1 // sequence number of this conn id is 1
   256  			retirePriorTo int64 = 0 // retire nothing
   257  			resetToken    [16]byte
   258  		)
   259  		copy(resetToken[:], p.preferredAddrResetToken)
   260  		if err := c.connIDState.handleNewConnID(c, seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil {
   261  			return err
   262  		}
   263  	}
   264  	// TODO: stateless_reset_token
   265  	// TODO: max_udp_payload_size
   266  	// TODO: disable_active_migration
   267  	// TODO: preferred_address
   268  	return nil
   269  }
   270  
   271  type (
   272  	timerEvent struct{}
   273  	wakeEvent  struct{}
   274  )
   275  
   276  var errIdleTimeout = errors.New("idle timeout")
   277  
   278  // loop is the connection main loop.
   279  //
   280  // Except where otherwise noted, all connection state is owned by the loop goroutine.
   281  //
   282  // The loop processes messages from c.msgc and timer events.
   283  // Other goroutines may examine or modify conn state by sending the loop funcs to execute.
   284  func (c *Conn) loop(now time.Time) {
   285  	defer c.cleanup()
   286  
   287  	// The connection timer sends a message to the connection loop on expiry.
   288  	// We need to give it an expiry when creating it, so set the initial timeout to
   289  	// an arbitrary large value. The timer will be reset before this expires (and it
   290  	// isn't a problem if it does anyway).
   291  	var lastTimeout time.Time
   292  	timer := time.AfterFunc(1*time.Hour, func() {
   293  		c.sendMsg(timerEvent{})
   294  	})
   295  	defer timer.Stop()
   296  
   297  	for c.lifetime.state != connStateDone {
   298  		sendTimeout := c.maybeSend(now) // try sending
   299  
   300  		// Note that we only need to consider the ack timer for the App Data space,
   301  		// since the Initial and Handshake spaces always ack immediately.
   302  		nextTimeout := sendTimeout
   303  		nextTimeout = firstTime(nextTimeout, c.idle.nextTimeout)
   304  		if c.isAlive() {
   305  			nextTimeout = firstTime(nextTimeout, c.loss.timer)
   306  			nextTimeout = firstTime(nextTimeout, c.acks[appDataSpace].nextAck)
   307  		} else {
   308  			nextTimeout = firstTime(nextTimeout, c.lifetime.drainEndTime)
   309  		}
   310  
   311  		var m any
   312  		if !nextTimeout.IsZero() && nextTimeout.Before(now) {
   313  			// A connection timer has expired.
   314  			now = time.Now()
   315  			m = timerEvent{}
   316  		} else {
   317  			// Reschedule the connection timer if necessary
   318  			// and wait for the next event.
   319  			if !nextTimeout.Equal(lastTimeout) && !nextTimeout.IsZero() {
   320  				// Resetting a timer created with time.AfterFunc guarantees
   321  				// that the timer will run again. We might generate a spurious
   322  				// timer event under some circumstances, but that's okay.
   323  				timer.Reset(nextTimeout.Sub(now))
   324  				lastTimeout = nextTimeout
   325  			}
   326  			m = <-c.msgc
   327  			now = time.Now()
   328  		}
   329  		switch m := m.(type) {
   330  		case *datagram:
   331  			if !c.handleDatagram(now, m) {
   332  				if c.logEnabled(QLogLevelPacket) {
   333  					c.logPacketDropped(m)
   334  				}
   335  			}
   336  			m.recycle()
   337  		case timerEvent:
   338  			// A connection timer has expired.
   339  			if c.idleAdvance(now) {
   340  				// The connection idle timer has expired.
   341  				c.abortImmediately(now, errIdleTimeout)
   342  				return
   343  			}
   344  			c.loss.advance(now, c.handleAckOrLoss)
   345  			if c.lifetimeAdvance(now) {
   346  				// The connection has completed the draining period,
   347  				// and may be shut down.
   348  				return
   349  			}
   350  		case wakeEvent:
   351  			// We're being woken up to try sending some frames.
   352  		case func(time.Time, *Conn):
   353  			// Send a func to msgc to run it on the main Conn goroutine
   354  			m(now, c)
   355  		case func(now, next time.Time, _ *Conn):
   356  			// Send a func to msgc to run it on the main Conn goroutine
   357  			m(now, nextTimeout, c)
   358  		default:
   359  			panic(fmt.Sprintf("quic: unrecognized conn message %T", m))
   360  		}
   361  	}
   362  }
   363  
   364  func (c *Conn) cleanup() {
   365  	c.logConnectionClosed()
   366  	c.endpoint.connDrained(c)
   367  	c.tls.Close()
   368  	close(c.donec)
   369  }
   370  
   371  // sendMsg sends a message to the conn's loop.
   372  // It does not wait for the message to be processed.
   373  // The conn may close before processing the message, in which case it is lost.
   374  func (c *Conn) sendMsg(m any) {
   375  	select {
   376  	case c.msgc <- m:
   377  	case <-c.donec:
   378  	}
   379  }
   380  
   381  // wake wakes up the conn's loop.
   382  func (c *Conn) wake() {
   383  	select {
   384  	case c.msgc <- wakeEvent{}:
   385  	default:
   386  	}
   387  }
   388  
   389  // runOnLoop executes a function within the conn's loop goroutine.
   390  func (c *Conn) runOnLoop(ctx context.Context, f func(now time.Time, c *Conn)) error {
   391  	donec := make(chan struct{})
   392  	msg := func(now time.Time, c *Conn) {
   393  		defer close(donec)
   394  		f(now, c)
   395  	}
   396  	c.sendMsg(msg)
   397  	select {
   398  	case <-donec:
   399  	case <-c.donec:
   400  		return errors.New("quic: connection closed")
   401  	}
   402  	return nil
   403  }
   404  
   405  func (c *Conn) waitOnDone(ctx context.Context, ch <-chan struct{}) error {
   406  	// Check the channel before the context.
   407  	// We always prefer to return results when available,
   408  	// even when provided with an already-canceled context.
   409  	select {
   410  	case <-ch:
   411  		return nil
   412  	default:
   413  	}
   414  	select {
   415  	case <-ch:
   416  	case <-ctx.Done():
   417  		return ctx.Err()
   418  	}
   419  	return nil
   420  }
   421  
   422  // firstTime returns the earliest non-zero time, or zero if both times are zero.
   423  func firstTime(a, b time.Time) time.Time {
   424  	switch {
   425  	case a.IsZero():
   426  		return b
   427  	case b.IsZero():
   428  		return a
   429  	case a.Before(b):
   430  		return a
   431  	default:
   432  		return b
   433  	}
   434  }
   435  

View as plain text