Source file src/vendor/golang.org/x/net/internal/http3/stream.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http3
     6  
     7  import (
     8  	"context"
     9  	"io"
    10  
    11  	"golang.org/x/net/quic"
    12  )
    13  
    14  // A stream wraps a QUIC stream, providing methods to read/write various values.
    15  type stream struct {
    16  	stream *quic.Stream
    17  
    18  	// lim is the current read limit.
    19  	// Reading a frame header sets the limit to the end of the frame.
    20  	// Reading past the limit or reading less than the limit and ending the frame
    21  	// results in an error.
    22  	// -1 indicates no limit.
    23  	lim int64
    24  }
    25  
    26  // newConnStream creates a new stream on a connection.
    27  // It writes the stream header for unidirectional streams.
    28  //
    29  // The stream returned by newStream is not flushed,
    30  // and will not be sent to the peer until the caller calls
    31  // Flush or writes enough data to the stream.
    32  func newConnStream(ctx context.Context, qconn *quic.Conn, stype streamType) (*stream, error) {
    33  	var qs *quic.Stream
    34  	var err error
    35  	if stype == streamTypeRequest {
    36  		// Request streams are bidirectional.
    37  		qs, err = qconn.NewStream(ctx)
    38  	} else {
    39  		// All other streams are unidirectional.
    40  		qs, err = qconn.NewSendOnlyStream(ctx)
    41  	}
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  	st := &stream{
    46  		stream: qs,
    47  		lim:    -1, // no limit
    48  	}
    49  	if stype != streamTypeRequest {
    50  		// Unidirectional stream header.
    51  		st.writeVarint(int64(stype))
    52  	}
    53  	return st, err
    54  }
    55  
    56  func newStream(qs *quic.Stream) *stream {
    57  	return &stream{
    58  		stream: qs,
    59  		lim:    -1, // no limit
    60  	}
    61  }
    62  
    63  // readFrameHeader reads the type and length fields of an HTTP/3 frame.
    64  // It sets the read limit to the end of the frame.
    65  //
    66  // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.1
    67  func (st *stream) readFrameHeader() (ftype frameType, err error) {
    68  	if st.lim >= 0 {
    69  		// We shouldn't call readFrameHeader before ending the previous frame.
    70  		return 0, errH3FrameError
    71  	}
    72  	ftype, err = readVarint[frameType](st)
    73  	if err != nil {
    74  		return 0, err
    75  	}
    76  	size, err := st.readVarint()
    77  	if err != nil {
    78  		return 0, err
    79  	}
    80  	st.lim = size
    81  	return ftype, nil
    82  }
    83  
    84  // endFrame is called after reading a frame to reset the read limit.
    85  // It returns an error if the entire contents of a frame have not been read.
    86  func (st *stream) endFrame() error {
    87  	if st.lim != 0 {
    88  		return &connectionError{
    89  			code:    errH3FrameError,
    90  			message: "invalid HTTP/3 frame",
    91  		}
    92  	}
    93  	st.lim = -1
    94  	return nil
    95  }
    96  
    97  // readFrameData returns the remaining data in the current frame.
    98  func (st *stream) readFrameData() ([]byte, error) {
    99  	if st.lim < 0 {
   100  		return nil, errH3FrameError
   101  	}
   102  	// TODO: Pool buffers to avoid allocation here.
   103  	b := make([]byte, st.lim)
   104  	_, err := io.ReadFull(st, b)
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	return b, nil
   109  }
   110  
   111  // ReadByte reads one byte from the stream.
   112  func (st *stream) ReadByte() (b byte, err error) {
   113  	if err := st.recordBytesRead(1); err != nil {
   114  		return 0, err
   115  	}
   116  	b, err = st.stream.ReadByte()
   117  	if err != nil {
   118  		if err == io.EOF && st.lim < 0 {
   119  			return 0, io.EOF
   120  		}
   121  		return 0, errH3FrameError
   122  	}
   123  	return b, nil
   124  }
   125  
   126  // Read reads from the stream.
   127  func (st *stream) Read(b []byte) (int, error) {
   128  	n, err := st.stream.Read(b)
   129  	if e2 := st.recordBytesRead(n); e2 != nil {
   130  		return 0, e2
   131  	}
   132  	if err == io.EOF {
   133  		if st.lim == 0 {
   134  			// EOF at end of frame, ignore.
   135  			return n, nil
   136  		} else if st.lim > 0 {
   137  			// EOF inside frame, error.
   138  			return 0, errH3FrameError
   139  		} else {
   140  			// EOF outside of frame, surface to caller.
   141  			return n, io.EOF
   142  		}
   143  	}
   144  	if err != nil {
   145  		return 0, errH3FrameError
   146  	}
   147  	return n, nil
   148  }
   149  
   150  // discardUnknownFrame discards an unknown frame.
   151  //
   152  // HTTP/3 requires that unknown frames be ignored on all streams.
   153  // However, a known frame appearing in an unexpected place is a fatal error,
   154  // so this returns an error if the frame is one we know.
   155  func (st *stream) discardUnknownFrame(ftype frameType) error {
   156  	switch ftype {
   157  	case frameTypeData,
   158  		frameTypeHeaders,
   159  		frameTypeCancelPush,
   160  		frameTypeSettings,
   161  		frameTypePushPromise,
   162  		frameTypeGoaway,
   163  		frameTypeMaxPushID:
   164  		return &connectionError{
   165  			code:    errH3FrameUnexpected,
   166  			message: "unexpected " + ftype.String() + " frame",
   167  		}
   168  	}
   169  	return st.discardFrame()
   170  }
   171  
   172  // discardFrame discards any remaining data in the current frame and resets the read limit.
   173  func (st *stream) discardFrame() error {
   174  	// TODO: Consider adding a *quic.Stream method to discard some amount of data.
   175  	for range st.lim {
   176  		_, err := st.stream.ReadByte()
   177  		if err != nil {
   178  			return &streamError{errH3FrameError, err.Error()}
   179  		}
   180  	}
   181  	st.lim = -1
   182  	return nil
   183  }
   184  
   185  // Write writes to the stream.
   186  func (st *stream) Write(b []byte) (int, error) { return st.stream.Write(b) }
   187  
   188  // Flush commits data written to the stream.
   189  func (st *stream) Flush() error { return st.stream.Flush() }
   190  
   191  // readVarint reads a QUIC variable-length integer from the stream.
   192  func (st *stream) readVarint() (v int64, err error) {
   193  	b, err := st.stream.ReadByte()
   194  	if err != nil {
   195  		return 0, err
   196  	}
   197  	v = int64(b & 0x3f)
   198  	n := 1 << (b >> 6)
   199  	for i := 1; i < n; i++ {
   200  		b, err := st.stream.ReadByte()
   201  		if err != nil {
   202  			return 0, errH3FrameError
   203  		}
   204  		v = (v << 8) | int64(b)
   205  	}
   206  	if err := st.recordBytesRead(n); err != nil {
   207  		return 0, err
   208  	}
   209  	return v, nil
   210  }
   211  
   212  // readVarint reads a varint of a particular type.
   213  func readVarint[T ~int64 | ~uint64](st *stream) (T, error) {
   214  	v, err := st.readVarint()
   215  	return T(v), err
   216  }
   217  
   218  // writeVarint writes a QUIC variable-length integer to the stream.
   219  func (st *stream) writeVarint(v int64) {
   220  	switch {
   221  	case v <= (1<<6)-1:
   222  		st.stream.WriteByte(byte(v))
   223  	case v <= (1<<14)-1:
   224  		st.stream.WriteByte((1 << 6) | byte(v>>8))
   225  		st.stream.WriteByte(byte(v))
   226  	case v <= (1<<30)-1:
   227  		st.stream.WriteByte((2 << 6) | byte(v>>24))
   228  		st.stream.WriteByte(byte(v >> 16))
   229  		st.stream.WriteByte(byte(v >> 8))
   230  		st.stream.WriteByte(byte(v))
   231  	case v <= (1<<62)-1:
   232  		st.stream.WriteByte((3 << 6) | byte(v>>56))
   233  		st.stream.WriteByte(byte(v >> 48))
   234  		st.stream.WriteByte(byte(v >> 40))
   235  		st.stream.WriteByte(byte(v >> 32))
   236  		st.stream.WriteByte(byte(v >> 24))
   237  		st.stream.WriteByte(byte(v >> 16))
   238  		st.stream.WriteByte(byte(v >> 8))
   239  		st.stream.WriteByte(byte(v))
   240  	default:
   241  		panic("varint too large")
   242  	}
   243  }
   244  
   245  // recordBytesRead records that n bytes have been read.
   246  // It returns an error if the read passes the current limit.
   247  func (st *stream) recordBytesRead(n int) error {
   248  	if st.lim < 0 {
   249  		return nil
   250  	}
   251  	st.lim -= int64(n)
   252  	if st.lim < 0 {
   253  		st.stream = nil // panic if we try to read again
   254  		return &connectionError{
   255  			code:    errH3FrameError,
   256  			message: "invalid HTTP/3 frame",
   257  		}
   258  	}
   259  	return nil
   260  }
   261  

View as plain text