Source file src/vendor/golang.org/x/net/internal/http3/roundtrip.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http3
     6  
     7  import (
     8  	"errors"
     9  	"io"
    10  	"net/http"
    11  	"net/http/httptrace"
    12  	"net/textproto"
    13  	"strconv"
    14  	"sync"
    15  
    16  	"golang.org/x/net/http/httpguts"
    17  	"golang.org/x/net/internal/httpcommon"
    18  )
    19  
    20  type roundTripState struct {
    21  	cc *clientConn
    22  	st *stream
    23  
    24  	// Request body, provided by the caller.
    25  	onceCloseReqBody sync.Once
    26  	reqBody          io.ReadCloser
    27  
    28  	reqBodyWriter bodyWriter
    29  
    30  	// Response.Body, provided to the caller.
    31  	respBody io.ReadCloser
    32  
    33  	trace *httptrace.ClientTrace
    34  
    35  	errOnce sync.Once
    36  	err     error
    37  }
    38  
    39  // abort terminates the RoundTrip.
    40  // It returns the first fatal error encountered by the RoundTrip call.
    41  func (rt *roundTripState) abort(err error) error {
    42  	rt.errOnce.Do(func() {
    43  		rt.err = err
    44  		switch e := err.(type) {
    45  		case *connectionError:
    46  			rt.cc.abort(e)
    47  		case *streamError:
    48  			rt.st.stream.CloseRead()
    49  			rt.st.stream.Reset(uint64(e.code))
    50  		default:
    51  			rt.st.stream.CloseRead()
    52  			rt.st.stream.Reset(uint64(errH3NoError))
    53  		}
    54  	})
    55  	return rt.err
    56  }
    57  
    58  // closeReqBody closes the Request.Body, at most once.
    59  func (rt *roundTripState) closeReqBody() {
    60  	if rt.reqBody != nil {
    61  		rt.onceCloseReqBody.Do(func() {
    62  			rt.reqBody.Close()
    63  		})
    64  	}
    65  }
    66  
    67  // TODO: Set up the rest of the hooks that might be in rt.trace.
    68  func (rt *roundTripState) maybeCallGot1xxResponse(status int, h http.Header) error {
    69  	if rt.trace == nil || rt.trace.Got1xxResponse == nil {
    70  		return nil
    71  	}
    72  	return rt.trace.Got1xxResponse(status, textproto.MIMEHeader(h))
    73  }
    74  
    75  func (rt *roundTripState) maybeCallGot100Continue() {
    76  	if rt.trace == nil || rt.trace.Got100Continue == nil {
    77  		return
    78  	}
    79  	rt.trace.Got100Continue()
    80  }
    81  
    82  func (rt *roundTripState) maybeCallWait100Continue() {
    83  	if rt.trace == nil || rt.trace.Wait100Continue == nil {
    84  		return
    85  	}
    86  	rt.trace.Wait100Continue()
    87  }
    88  
    89  // RoundTrip sends a request on the connection.
    90  func (cc *clientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) {
    91  	// Each request gets its own QUIC stream.
    92  	st, err := newConnStream(req.Context(), cc.qconn, streamTypeRequest)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	rt := &roundTripState{
    97  		cc:      cc,
    98  		st:      st,
    99  		trace:   httptrace.ContextClientTrace(req.Context()),
   100  		reqBody: req.Body,
   101  	}
   102  	if rt.reqBody == nil {
   103  		rt.reqBody = http.NoBody
   104  	}
   105  	defer func() {
   106  		if err != nil {
   107  			err = rt.abort(err)
   108  		}
   109  	}()
   110  
   111  	// Cancel reads/writes on the stream when the request expires.
   112  	st.stream.SetReadContext(req.Context())
   113  	st.stream.SetWriteContext(req.Context())
   114  
   115  	headers := cc.enc.encode(func(yield func(itype indexType, name, value string)) {
   116  		_, err = httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{
   117  			Request: httpcommon.Request{
   118  				URL:                 req.URL,
   119  				Method:              req.Method,
   120  				Host:                req.Host,
   121  				Header:              req.Header,
   122  				Trailer:             req.Trailer,
   123  				ActualContentLength: actualContentLength(req),
   124  			},
   125  			AddGzipHeader:         false, // TODO: add when appropriate
   126  			PeerMaxHeaderListSize: 0,
   127  			DefaultUserAgent:      "Go-http-client/3",
   128  		}, func(name, value string) {
   129  			// Issue #71374: Consider supporting never-indexed fields.
   130  			yield(mayIndex, name, value)
   131  		})
   132  	})
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  
   137  	// Write the HEADERS frame.
   138  	st.writeVarint(int64(frameTypeHeaders))
   139  	st.writeVarint(int64(len(headers)))
   140  	st.Write(headers)
   141  	if err := st.Flush(); err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	var bodyAndTrailerWritten bool
   146  	is100ContinueReq := httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue")
   147  	if is100ContinueReq {
   148  		rt.maybeCallWait100Continue()
   149  	} else {
   150  		bodyAndTrailerWritten = true
   151  		go cc.writeBodyAndTrailer(rt, req)
   152  	}
   153  
   154  	// Read the response headers.
   155  	for {
   156  		ftype, err := st.readFrameHeader()
   157  		if err != nil {
   158  			return nil, err
   159  		}
   160  		switch ftype {
   161  		case frameTypeHeaders:
   162  			statusCode, h, err := cc.handleHeaders(st)
   163  			if err != nil {
   164  				return nil, err
   165  			}
   166  
   167  			// TODO: Handle 1xx responses.
   168  			if isInfoStatus(statusCode) {
   169  				if err := rt.maybeCallGot1xxResponse(statusCode, h); err != nil {
   170  					return nil, err
   171  				}
   172  				switch statusCode {
   173  				case 100:
   174  					rt.maybeCallGot100Continue()
   175  					if is100ContinueReq && !bodyAndTrailerWritten {
   176  						bodyAndTrailerWritten = true
   177  						go cc.writeBodyAndTrailer(rt, req)
   178  						continue
   179  					}
   180  					// If we did not send "Expect: 100-continue" request but
   181  					// received status 100 anyways, just continue per usual and
   182  					// let the caller decide what to do with the response.
   183  				default:
   184  					continue
   185  				}
   186  			}
   187  
   188  			// We have the response headers.
   189  			// Set up the response and return it to the caller.
   190  			contentLength, err := parseResponseContentLength(req.Method, statusCode, h)
   191  			if err != nil {
   192  				return nil, err
   193  			}
   194  
   195  			trailer := make(http.Header)
   196  			extractTrailerFromHeader(h, trailer)
   197  			delete(h, "Trailer")
   198  
   199  			if (contentLength != 0 && req.Method != http.MethodHead) || len(trailer) > 0 {
   200  				rt.respBody = &bodyReader{
   201  					st:      st,
   202  					remain:  contentLength,
   203  					trailer: trailer,
   204  				}
   205  			} else {
   206  				rt.respBody = http.NoBody
   207  			}
   208  			resp := &http.Response{
   209  				Proto:         "HTTP/3.0",
   210  				ProtoMajor:    3,
   211  				Header:        h,
   212  				StatusCode:    statusCode,
   213  				Status:        strconv.Itoa(statusCode) + " " + http.StatusText(statusCode),
   214  				ContentLength: contentLength,
   215  				Trailer:       trailer,
   216  				Body:          (*transportResponseBody)(rt),
   217  			}
   218  			// TODO: Automatic Content-Type: gzip decoding.
   219  			return resp, nil
   220  		case frameTypePushPromise:
   221  			if err := cc.handlePushPromise(st); err != nil {
   222  				return nil, err
   223  			}
   224  		default:
   225  			if err := st.discardUnknownFrame(ftype); err != nil {
   226  				return nil, err
   227  			}
   228  		}
   229  	}
   230  }
   231  
   232  // actualContentLength returns a sanitized version of req.ContentLength,
   233  // where 0 actually means zero (not unknown) and -1 means unknown.
   234  func actualContentLength(req *http.Request) int64 {
   235  	if req.Body == nil || req.Body == http.NoBody {
   236  		return 0
   237  	}
   238  	if req.ContentLength != 0 {
   239  		return req.ContentLength
   240  	}
   241  	return -1
   242  }
   243  
   244  // writeBodyAndTrailer handles writing the body and trailer for a given
   245  // request, if any. This function will close the write direction of the stream.
   246  func (cc *clientConn) writeBodyAndTrailer(rt *roundTripState, req *http.Request) {
   247  	defer rt.closeReqBody()
   248  
   249  	declaredTrailer := req.Trailer.Clone()
   250  
   251  	rt.reqBodyWriter.st = rt.st
   252  	rt.reqBodyWriter.remain = actualContentLength(req)
   253  	rt.reqBodyWriter.flush = true
   254  	rt.reqBodyWriter.name = "request"
   255  	rt.reqBodyWriter.trailer = req.Trailer
   256  	rt.reqBodyWriter.enc = &cc.enc
   257  
   258  	if _, err := io.Copy(&rt.reqBodyWriter, rt.reqBody); err != nil {
   259  		rt.abort(err)
   260  	}
   261  	// Get rid of any trailer that was not declared beforehand, before we
   262  	// close the request body which will cause the trailer headers to be
   263  	// written.
   264  	for name := range req.Trailer {
   265  		if _, ok := declaredTrailer[name]; !ok {
   266  			delete(req.Trailer, name)
   267  		}
   268  	}
   269  	if err := rt.reqBodyWriter.Close(); err != nil {
   270  		rt.abort(err)
   271  	}
   272  }
   273  
   274  // transportResponseBody is the Response.Body returned by RoundTrip.
   275  type transportResponseBody roundTripState
   276  
   277  // Read is Response.Body.Read.
   278  func (b *transportResponseBody) Read(p []byte) (n int, err error) {
   279  	return b.respBody.Read(p)
   280  }
   281  
   282  var errRespBodyClosed = errors.New("response body closed")
   283  
   284  // Close is Response.Body.Close.
   285  // Closing the response body is how the caller signals that they're done with a request.
   286  func (b *transportResponseBody) Close() error {
   287  	rt := (*roundTripState)(b)
   288  	// Close the request body, which should wake up copyRequestBody if it's
   289  	// currently blocked reading the body.
   290  	rt.closeReqBody()
   291  	// Close the request stream, since we're done with the request.
   292  	// Reset closes the sending half of the stream.
   293  	rt.st.stream.Reset(uint64(errH3NoError))
   294  	// respBody.Close is responsible for closing the receiving half.
   295  	err := rt.respBody.Close()
   296  	if err == nil {
   297  		err = errRespBodyClosed
   298  	}
   299  	err = rt.abort(err)
   300  	if err == errRespBodyClosed {
   301  		// No other errors occurred before closing Response.Body,
   302  		// so consider this a successful request.
   303  		return nil
   304  	}
   305  	return err
   306  }
   307  
   308  func parseResponseContentLength(method string, statusCode int, h http.Header) (int64, error) {
   309  	clens := h["Content-Length"]
   310  	if len(clens) == 0 {
   311  		return -1, nil
   312  	}
   313  
   314  	// We allow duplicate Content-Length headers,
   315  	// but only if they all have the same value.
   316  	for _, v := range clens[1:] {
   317  		if clens[0] != v {
   318  			return -1, &streamError{errH3MessageError, "mismatching Content-Length headers"}
   319  		}
   320  	}
   321  
   322  	// "A server MUST NOT send a Content-Length header field in any response
   323  	// with a status code of 1xx (Informational) or 204 (No Content).
   324  	// A server MUST NOT send a Content-Length header field in any 2xx (Successful)
   325  	// response to a CONNECT request [...]"
   326  	// https://www.rfc-editor.org/rfc/rfc9110#section-8.6-8
   327  	if (statusCode >= 100 && statusCode < 200) ||
   328  		statusCode == 204 ||
   329  		(method == "CONNECT" && statusCode >= 200 && statusCode < 300) {
   330  		// This is a protocol violation, but a fairly harmless one.
   331  		// Just ignore the header.
   332  		return -1, nil
   333  	}
   334  
   335  	contentLen, err := strconv.ParseUint(clens[0], 10, 63)
   336  	if err != nil {
   337  		return -1, &streamError{errH3MessageError, "invalid Content-Length header"}
   338  	}
   339  	return int64(contentLen), nil
   340  }
   341  
   342  func (cc *clientConn) handleHeaders(st *stream) (statusCode int, h http.Header, err error) {
   343  	haveStatus := false
   344  	cookie := ""
   345  	// Issue #71374: Consider tracking the never-indexed status of headers
   346  	// with the N bit set in their QPACK encoding.
   347  	err = cc.dec.decode(st, func(_ indexType, name, value string) error {
   348  		if !httpguts.ValidHeaderFieldValue(value) {
   349  			return &streamError{errH3MessageError, "invalid field value"}
   350  		}
   351  		switch {
   352  		case name == ":status":
   353  			if haveStatus {
   354  				return &streamError{errH3MessageError, "duplicate :status"}
   355  			}
   356  			haveStatus = true
   357  			statusCode, err = strconv.Atoi(value)
   358  			if err != nil {
   359  				return &streamError{errH3MessageError, "invalid :status"}
   360  			}
   361  		case name[0] == ':':
   362  			// "Endpoints MUST treat a request or response
   363  			// that contains undefined or invalid
   364  			// pseudo-header fields as malformed."
   365  			// https://www.rfc-editor.org/rfc/rfc9114.html#section-4.3-3
   366  			return &streamError{errH3MessageError, "undefined pseudo-header"}
   367  		case name == "cookie":
   368  			// "If a decompressed field section contains multiple cookie field lines,
   369  			// these MUST be concatenated into a single byte string [...]"
   370  			// using the two-byte delimiter of "; "''
   371  			// https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2.1-2
   372  			if cookie == "" {
   373  				cookie = value
   374  			} else {
   375  				cookie += "; " + value
   376  			}
   377  		default:
   378  			if !validWireHeaderFieldName(name) {
   379  				return &streamError{errH3MessageError, "invalid field name"}
   380  			}
   381  			if h == nil {
   382  				h = make(http.Header)
   383  			}
   384  			// TODO: Use a per-connection canonicalization cache as we do in HTTP/2.
   385  			// Maybe we could put this in the QPACK decoder and have it deliver
   386  			// pre-canonicalized headers to us here?
   387  			cname := httpcommon.CanonicalHeader(name)
   388  			// TODO: Consider using a single []string slice for all headers,
   389  			// as we do in the HTTP/1 and HTTP/2 cases.
   390  			// This is a bit tricky, since we don't know the number of headers
   391  			// at the start of decoding. Perhaps it's worth doing a two-pass decode,
   392  			// or perhaps we should just allocate header value slices in
   393  			// reasonably-sized chunks.
   394  			h[cname] = append(h[cname], value)
   395  		}
   396  		return nil
   397  	})
   398  	if !haveStatus {
   399  		// "[The :status] pseudo-header field MUST be included in all responses [...]"
   400  		// https://www.rfc-editor.org/rfc/rfc9114.html#section-4.3.2-1
   401  		err = errH3MessageError
   402  	}
   403  	if cookie != "" {
   404  		if h == nil {
   405  			h = make(http.Header)
   406  		}
   407  		h["Cookie"] = []string{cookie}
   408  	}
   409  	if err := st.endFrame(); err != nil {
   410  		return 0, nil, err
   411  	}
   412  	return statusCode, h, err
   413  }
   414  
   415  func (cc *clientConn) handlePushPromise(st *stream) error {
   416  	// "A client MUST treat receipt of a PUSH_PROMISE frame that contains a
   417  	// larger push ID than the client has advertised as a connection error of H3_ID_ERROR."
   418  	// https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.5-5
   419  	return &connectionError{
   420  		code:    errH3IDError,
   421  		message: "PUSH_PROMISE received when no MAX_PUSH_ID has been sent",
   422  	}
   423  }
   424  

View as plain text