Source file src/vendor/golang.org/x/net/quic/conn_id.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package quic
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"slices"
    11  )
    12  
    13  // connIDState is a conn's connection IDs.
    14  type connIDState struct {
    15  	// The destination connection IDs of packets we receive are local.
    16  	// The destination connection IDs of packets we send are remote.
    17  	//
    18  	// Local IDs are usually issued by us, and remote IDs by the peer.
    19  	// The exception is the transient destination connection ID sent in
    20  	// a client's Initial packets, which is chosen by the client.
    21  	//
    22  	// These are []connID rather than []*connID to minimize allocations.
    23  	local  []connID
    24  	remote []remoteConnID
    25  
    26  	nextLocalSeq          int64
    27  	peerActiveConnIDLimit int64 // peer's active_connection_id_limit
    28  
    29  	// Handling of retirement of remote connection IDs.
    30  	// The rangesets track ID sequence numbers.
    31  	// IDs in need of retirement are added to remoteRetiring,
    32  	// moved to remoteRetiringSent once we send a RETIRE_CONECTION_ID frame,
    33  	// and removed from the set once retirement completes.
    34  	retireRemotePriorTo int64           // largest Retire Prior To value sent by the peer
    35  	remoteRetiring      rangeset[int64] // remote IDs in need of retirement
    36  	remoteRetiringSent  rangeset[int64] // remote IDs waiting for ack of retirement
    37  
    38  	originalDstConnID []byte // expected original_destination_connection_id param
    39  	retrySrcConnID    []byte // expected retry_source_connection_id param
    40  
    41  	needSend bool
    42  }
    43  
    44  // A connID is a connection ID and associated metadata.
    45  type connID struct {
    46  	// cid is the connection ID itself.
    47  	cid []byte
    48  
    49  	// seq is the connection ID's sequence number:
    50  	// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-1
    51  	//
    52  	// For the transient destination ID in a client's Initial packet, this is -1.
    53  	seq int64
    54  
    55  	// send is set when the connection ID's state needs to be sent to the peer.
    56  	//
    57  	// For local IDs, this indicates a new ID that should be sent
    58  	// in a NEW_CONNECTION_ID frame.
    59  	//
    60  	// For remote IDs, this indicates a retired ID that should be sent
    61  	// in a RETIRE_CONNECTION_ID frame.
    62  	send sentVal
    63  }
    64  
    65  // A remoteConnID is a connection ID and stateless reset token.
    66  type remoteConnID struct {
    67  	connID
    68  	resetToken statelessResetToken
    69  }
    70  
    71  func (s *connIDState) initClient(c *Conn) error {
    72  	// Client chooses its initial connection ID, and sends it
    73  	// in the Source Connection ID field of the first Initial packet.
    74  	locid, err := c.newConnID(0)
    75  	if err != nil {
    76  		return err
    77  	}
    78  	s.local = append(s.local, connID{
    79  		seq: 0,
    80  		cid: locid,
    81  	})
    82  	s.nextLocalSeq = 1
    83  	c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
    84  		conns.addConnID(c, locid)
    85  	})
    86  
    87  	// Client chooses an initial, transient connection ID for the server,
    88  	// and sends it in the Destination Connection ID field of the first Initial packet.
    89  	remid, err := c.newConnID(-1)
    90  	if err != nil {
    91  		return err
    92  	}
    93  	s.remote = append(s.remote, remoteConnID{
    94  		connID: connID{
    95  			seq: -1,
    96  			cid: remid,
    97  		},
    98  	})
    99  	s.originalDstConnID = remid
   100  	return nil
   101  }
   102  
   103  func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error {
   104  	dstConnID := cloneBytes(cids.dstConnID)
   105  	// Client-chosen, transient connection ID received in the first Initial packet.
   106  	// The server will not use this as the Source Connection ID of packets it sends,
   107  	// but remembers it because it may receive packets sent to this destination.
   108  	s.local = append(s.local, connID{
   109  		seq: -1,
   110  		cid: dstConnID,
   111  	})
   112  
   113  	// Server chooses a connection ID, and sends it in the Source Connection ID of
   114  	// the response to the clent.
   115  	locid, err := c.newConnID(0)
   116  	if err != nil {
   117  		return err
   118  	}
   119  	s.local = append(s.local, connID{
   120  		seq: 0,
   121  		cid: locid,
   122  	})
   123  	s.nextLocalSeq = 1
   124  	c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   125  		conns.addConnID(c, dstConnID)
   126  		conns.addConnID(c, locid)
   127  	})
   128  
   129  	// Client chose its own connection ID.
   130  	s.remote = append(s.remote, remoteConnID{
   131  		connID: connID{
   132  			seq: 0,
   133  			cid: cloneBytes(cids.srcConnID),
   134  		},
   135  	})
   136  	return nil
   137  }
   138  
   139  // srcConnID is the Source Connection ID to use in a sent packet.
   140  func (s *connIDState) srcConnID() []byte {
   141  	if s.local[0].seq == -1 && len(s.local) > 1 {
   142  		// Don't use the transient connection ID if another is available.
   143  		return s.local[1].cid
   144  	}
   145  	return s.local[0].cid
   146  }
   147  
   148  // dstConnID is the Destination Connection ID to use in a sent packet.
   149  func (s *connIDState) dstConnID() (cid []byte, ok bool) {
   150  	for i := range s.remote {
   151  		return s.remote[i].cid, true
   152  	}
   153  	return nil, false
   154  }
   155  
   156  // isValidStatelessResetToken reports whether the given reset token is
   157  // associated with a non-retired connection ID which we have used.
   158  func (s *connIDState) isValidStatelessResetToken(resetToken statelessResetToken) bool {
   159  	if len(s.remote) == 0 {
   160  		return false
   161  	}
   162  	// We currently only use the first available remote connection ID,
   163  	// so any other reset token is not valid.
   164  	return s.remote[0].resetToken == resetToken
   165  }
   166  
   167  // setPeerActiveConnIDLimit sets the active_connection_id_limit
   168  // transport parameter received from the peer.
   169  func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error {
   170  	s.peerActiveConnIDLimit = lim
   171  	return s.issueLocalIDs(c)
   172  }
   173  
   174  func (s *connIDState) issueLocalIDs(c *Conn) error {
   175  	toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit)
   176  	for i := range s.local {
   177  		if s.local[i].seq != -1 {
   178  			toIssue--
   179  		}
   180  	}
   181  	var newIDs [][]byte
   182  	for toIssue > 0 {
   183  		cid, err := c.newConnID(s.nextLocalSeq)
   184  		if err != nil {
   185  			return err
   186  		}
   187  		newIDs = append(newIDs, cid)
   188  		s.local = append(s.local, connID{
   189  			seq: s.nextLocalSeq,
   190  			cid: cid,
   191  		})
   192  		s.local[len(s.local)-1].send.setUnsent()
   193  		s.nextLocalSeq++
   194  		s.needSend = true
   195  		toIssue--
   196  	}
   197  	c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   198  		for _, cid := range newIDs {
   199  			conns.addConnID(c, cid)
   200  		}
   201  	})
   202  	return nil
   203  }
   204  
   205  // validateTransportParameters verifies the original_destination_connection_id and
   206  // initial_source_connection_id transport parameters match the expected values.
   207  func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p transportParameters) error {
   208  	// TODO: Consider returning more detailed errors, for debugging.
   209  	// Verify original_destination_connection_id matches
   210  	// the transient remote connection ID we chose (client)
   211  	// or is empty (server).
   212  	if !bytes.Equal(s.originalDstConnID, p.originalDstConnID) {
   213  		return localTransportError{
   214  			code:   errTransportParameter,
   215  			reason: "original_destination_connection_id mismatch",
   216  		}
   217  	}
   218  	s.originalDstConnID = nil // we have no further need for this
   219  	// Verify retry_source_connection_id matches the value from
   220  	// the server's Retry packet (when one was sent), or is empty.
   221  	if !bytes.Equal(p.retrySrcConnID, s.retrySrcConnID) {
   222  		return localTransportError{
   223  			code:   errTransportParameter,
   224  			reason: "retry_source_connection_id mismatch",
   225  		}
   226  	}
   227  	s.retrySrcConnID = nil // we have no further need for this
   228  	// Verify initial_source_connection_id matches the first remote connection ID.
   229  	if len(s.remote) == 0 || s.remote[0].seq != 0 {
   230  		return localTransportError{
   231  			code:   errInternal,
   232  			reason: "remote connection id missing",
   233  		}
   234  	}
   235  	if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) {
   236  		return localTransportError{
   237  			code:   errTransportParameter,
   238  			reason: "initial_source_connection_id mismatch",
   239  		}
   240  	}
   241  	if len(p.statelessResetToken) > 0 {
   242  		if c.side == serverSide {
   243  			return localTransportError{
   244  				code:   errTransportParameter,
   245  				reason: "client sent stateless_reset_token",
   246  			}
   247  		}
   248  		token := statelessResetToken(p.statelessResetToken)
   249  		s.remote[0].resetToken = token
   250  		c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   251  			conns.addResetToken(c, token)
   252  		})
   253  	}
   254  	return nil
   255  }
   256  
   257  // handlePacket updates the connection ID state during the handshake
   258  // (Initial and Handshake packets).
   259  func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) {
   260  	switch {
   261  	case ptype == packetTypeInitial && c.side == clientSide:
   262  		if len(s.remote) == 1 && s.remote[0].seq == -1 {
   263  			// We're a client connection processing the first Initial packet
   264  			// from the server. Replace the transient remote connection ID
   265  			// with the Source Connection ID from the packet.
   266  			s.remote[0] = remoteConnID{
   267  				connID: connID{
   268  					seq: 0,
   269  					cid: cloneBytes(srcConnID),
   270  				},
   271  			}
   272  		}
   273  	case ptype == packetTypeHandshake && c.side == serverSide:
   274  		if len(s.local) > 0 && s.local[0].seq == -1 {
   275  			// We're a server connection processing the first Handshake packet from
   276  			// the client. Discard the transient, client-chosen connection ID used
   277  			// for Initial packets; the client will never send it again.
   278  			cid := s.local[0].cid
   279  			c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   280  				conns.retireConnID(c, cid)
   281  			})
   282  			s.local = append(s.local[:0], s.local[1:]...)
   283  		}
   284  	}
   285  }
   286  
   287  func (s *connIDState) handleRetryPacket(srcConnID []byte) {
   288  	if len(s.remote) != 1 || s.remote[0].seq != -1 {
   289  		panic("BUG: handling retry with non-transient remote conn id")
   290  	}
   291  	s.retrySrcConnID = cloneBytes(srcConnID)
   292  	s.remote[0].cid = s.retrySrcConnID
   293  }
   294  
   295  func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, resetToken statelessResetToken) error {
   296  	if len(s.remote[0].cid) == 0 {
   297  		// "An endpoint that is sending packets with a zero-length
   298  		// Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID
   299  		// frame as a connection error of type PROTOCOL_VIOLATION."
   300  		// https://www.rfc-editor.org/rfc/rfc9000.html#section-19.15-6
   301  		return localTransportError{
   302  			code:   errProtocolViolation,
   303  			reason: "NEW_CONNECTION_ID from peer with zero-length DCID",
   304  		}
   305  	}
   306  
   307  	if seq < s.retireRemotePriorTo {
   308  		// This ID was already retired by a previous NEW_CONNECTION_ID frame.
   309  		// Nothing to do.
   310  		return nil
   311  	}
   312  
   313  	if retire > s.retireRemotePriorTo {
   314  		// Add newly-retired connection IDs to the set we need to send
   315  		// RETIRE_CONNECTION_ID frames for, and remove them from s.remote.
   316  		//
   317  		// (This might cause us to send a RETIRE_CONNECTION_ID for an ID we've
   318  		// never seen. That's fine.)
   319  		s.remoteRetiring.add(s.retireRemotePriorTo, retire)
   320  		s.retireRemotePriorTo = retire
   321  		s.needSend = true
   322  		s.remote = slices.DeleteFunc(s.remote, func(rcid remoteConnID) bool {
   323  			return rcid.seq < s.retireRemotePriorTo
   324  		})
   325  	}
   326  
   327  	have := false // do we already have this connection ID?
   328  	for i := range s.remote {
   329  		rcid := &s.remote[i]
   330  		if rcid.seq == seq {
   331  			if !bytes.Equal(rcid.cid, cid) {
   332  				return localTransportError{
   333  					code:   errProtocolViolation,
   334  					reason: "NEW_CONNECTION_ID does not match prior id",
   335  				}
   336  			}
   337  			have = true // yes, we've seen this sequence number
   338  			break
   339  		}
   340  	}
   341  
   342  	if !have {
   343  		// This is a new connection ID that we have not seen before.
   344  		//
   345  		// We could take steps to keep the list of remote connection IDs
   346  		// sorted by sequence number, but there's no particular need
   347  		// so we don't bother.
   348  		s.remote = append(s.remote, remoteConnID{
   349  			connID: connID{
   350  				seq: seq,
   351  				cid: cloneBytes(cid),
   352  			},
   353  			resetToken: resetToken,
   354  		})
   355  		c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   356  			conns.addResetToken(c, resetToken)
   357  		})
   358  	}
   359  
   360  	if len(s.remote) > activeConnIDLimit {
   361  		// Retired connection IDs (including newly-retired ones) do not count
   362  		// against the limit.
   363  		// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5
   364  		return localTransportError{
   365  			code:   errConnectionIDLimit,
   366  			reason: "active_connection_id_limit exceeded",
   367  		}
   368  	}
   369  
   370  	// "An endpoint SHOULD limit the number of connection IDs it has retired locally
   371  	// for which RETIRE_CONNECTION_ID frames have not yet been acknowledged."
   372  	// https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6
   373  	//
   374  	// Set a limit of three times the active_connection_id_limit for
   375  	// the total number of remote connection IDs we keep retirement state for.
   376  	if s.remoteRetiring.size()+s.remoteRetiringSent.size() > 3*activeConnIDLimit {
   377  		return localTransportError{
   378  			code:   errConnectionIDLimit,
   379  			reason: "too many unacknowledged retired connection ids",
   380  		}
   381  	}
   382  
   383  	return nil
   384  }
   385  
   386  func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error {
   387  	if seq >= s.nextLocalSeq {
   388  		return localTransportError{
   389  			code:   errProtocolViolation,
   390  			reason: "RETIRE_CONNECTION_ID for unissued sequence number",
   391  		}
   392  	}
   393  	for i := range s.local {
   394  		if s.local[i].seq == seq {
   395  			cid := s.local[i].cid
   396  			c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
   397  				conns.retireConnID(c, cid)
   398  			})
   399  			s.local = append(s.local[:i], s.local[i+1:]...)
   400  			break
   401  		}
   402  	}
   403  	s.issueLocalIDs(c)
   404  	return nil
   405  }
   406  
   407  func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fate packetFate) {
   408  	for i := range s.local {
   409  		if s.local[i].seq != seq {
   410  			continue
   411  		}
   412  		s.local[i].send.ackOrLoss(pnum, fate)
   413  		if fate != packetAcked {
   414  			s.needSend = true
   415  		}
   416  		return
   417  	}
   418  }
   419  
   420  func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) {
   421  	s.remoteRetiringSent.sub(seq, seq+1)
   422  	if fate == packetLost {
   423  		// RETIRE_CONNECTION_ID frame was lost, mark for retransmission.
   424  		s.remoteRetiring.add(seq, seq+1)
   425  		s.needSend = true
   426  	}
   427  }
   428  
   429  // appendFrames appends NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames
   430  // to the current packet.
   431  //
   432  // It returns true if no more frames need appending,
   433  // false if not everything fit in the current packet.
   434  func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool {
   435  	if !s.needSend && !pto {
   436  		// Fast path: We don't need to send anything.
   437  		return true
   438  	}
   439  	retireBefore := int64(0)
   440  	if s.local[0].seq != -1 {
   441  		retireBefore = s.local[0].seq
   442  	}
   443  	for i := range s.local {
   444  		if !s.local[i].send.shouldSendPTO(pto) {
   445  			continue
   446  		}
   447  		if !c.w.appendNewConnectionIDFrame(
   448  			s.local[i].seq,
   449  			retireBefore,
   450  			s.local[i].cid,
   451  			c.endpoint.resetGen.tokenForConnID(s.local[i].cid),
   452  		) {
   453  			return false
   454  		}
   455  		s.local[i].send.setSent(pnum)
   456  	}
   457  	if pto {
   458  		for _, r := range s.remoteRetiringSent {
   459  			for cid := r.start; cid < r.end; cid++ {
   460  				if !c.w.appendRetireConnectionIDFrame(cid) {
   461  					return false
   462  				}
   463  			}
   464  		}
   465  	}
   466  	for s.remoteRetiring.numRanges() > 0 {
   467  		cid := s.remoteRetiring.min()
   468  		if !c.w.appendRetireConnectionIDFrame(cid) {
   469  			return false
   470  		}
   471  		s.remoteRetiring.sub(cid, cid+1)
   472  		s.remoteRetiringSent.add(cid, cid+1)
   473  	}
   474  	s.needSend = false
   475  	return true
   476  }
   477  
   478  func cloneBytes(b []byte) []byte {
   479  	n := make([]byte, len(b))
   480  	copy(n, b)
   481  	return n
   482  }
   483  
   484  func (c *Conn) newConnID(seq int64) ([]byte, error) {
   485  	if c.testHooks != nil {
   486  		return c.testHooks.newConnID(seq)
   487  	}
   488  	return newRandomConnID(seq)
   489  }
   490  
   491  func newRandomConnID(_ int64) ([]byte, error) {
   492  	// It is not necessary for connection IDs to be cryptographically secure,
   493  	// but it doesn't hurt.
   494  	id := make([]byte, connIDLen)
   495  	if _, err := rand.Read(id); err != nil {
   496  		// TODO: Surface this error as a metric or log event or something.
   497  		// rand.Read really shouldn't ever fail, but if it does, we should
   498  		// have a way to inform the user.
   499  		return nil, err
   500  	}
   501  	return id, nil
   502  }
   503  

View as plain text