Source file src/vendor/golang.org/x/net/internal/http3/server.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http3
     6  
     7  import (
     8  	"context"
     9  	"crypto/tls"
    10  	"fmt"
    11  	"io"
    12  	"maps"
    13  	"net/http"
    14  	"slices"
    15  	"strconv"
    16  	"strings"
    17  	"sync"
    18  	"time"
    19  
    20  	"golang.org/x/net/http/httpguts"
    21  	"golang.org/x/net/internal/httpcommon"
    22  	"golang.org/x/net/quic"
    23  )
    24  
    25  // A server is an HTTP/3 server.
    26  // The zero value for server is a valid server.
    27  type server struct {
    28  	// handler to invoke for requests, http.DefaultServeMux if nil.
    29  	handler http.Handler
    30  
    31  	config *quic.Config
    32  
    33  	listenQUIC func(addr string, config *quic.Config) (*quic.Endpoint, error)
    34  
    35  	initOnce sync.Once
    36  
    37  	serveCtx       context.Context
    38  	serveCtxCancel context.CancelFunc
    39  
    40  	// connClosed is used to signal that a connection has been unregistered
    41  	// from activeConns. That way, when shutting down gracefully, the server
    42  	// can avoid busy-waiting for activeConns to be empty.
    43  	connClosed  chan any
    44  	mu          sync.Mutex // Guards fields below.
    45  	activeConns map[*serverConn]struct{}
    46  }
    47  
    48  // netHTTPHandler is an interface that is implemented by
    49  // net/http.http3ServerHandler in std.
    50  //
    51  // It provides a way for information to be passed between x/net and net/http
    52  // that would otherwise be inaccessible, such as the TLS configs that users
    53  // have supplied to net/http servers.
    54  //
    55  // This allows us to integrate our HTTP/3 server implementation with the
    56  // net/http server when RegisterServer is called.
    57  type netHTTPHandler interface {
    58  	http.Handler
    59  	TLSConfig() *tls.Config
    60  	BaseContext() context.Context
    61  	Addr() string
    62  	ListenErrHook(err error)
    63  	ShutdownContext() context.Context
    64  }
    65  
    66  type ServerOpts struct {
    67  	// ListenQUIC determines how the server will open a QUIC endpoint.
    68  	// By default, quic.Listen("udp", addr, config) is used.
    69  	ListenQUIC func(addr string, config *quic.Config) (*quic.Endpoint, error)
    70  
    71  	// QUICConfig is the QUIC configuration used by the server.
    72  	// QUICConfig may be nil and should not be modified after calling
    73  	// RegisterServer.
    74  	// If QUICConfig.TLSConfig is nil, the TLSConfig of the net/http Server
    75  	// given to RegisterServer will be used.
    76  	QUICConfig *quic.Config
    77  }
    78  
    79  // RegisterServer adds HTTP/3 support to a net/http Server.
    80  //
    81  // RegisterServer must be called before s begins serving, and only affects
    82  // s.ListenAndServeTLS.
    83  func RegisterServer(s *http.Server, opts ServerOpts) {
    84  	if s.TLSNextProto == nil {
    85  		s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
    86  	}
    87  	s.TLSNextProto["http/3"] = func(s *http.Server, c *tls.Conn, h http.Handler) {
    88  		stdHandler, ok := h.(netHTTPHandler)
    89  		if !ok {
    90  			panic("RegisterServer was given a server that does not implement netHTTPHandler")
    91  		}
    92  		if opts.QUICConfig == nil {
    93  			opts.QUICConfig = &quic.Config{}
    94  		}
    95  		if opts.QUICConfig.TLSConfig == nil {
    96  			opts.QUICConfig.TLSConfig = stdHandler.TLSConfig()
    97  		}
    98  		s3 := &server{
    99  			config:     opts.QUICConfig,
   100  			listenQUIC: opts.ListenQUIC,
   101  			handler:    stdHandler,
   102  			serveCtx:   stdHandler.BaseContext(),
   103  		}
   104  		s3.init()
   105  		s.RegisterOnShutdown(func() {
   106  			s3.shutdown(stdHandler.ShutdownContext())
   107  		})
   108  		stdHandler.ListenErrHook(s3.listenAndServe(stdHandler.Addr()))
   109  	}
   110  }
   111  
   112  func (s *server) init() {
   113  	s.initOnce.Do(func() {
   114  		s.config = initConfig(s.config)
   115  		if s.handler == nil {
   116  			s.handler = http.DefaultServeMux
   117  		}
   118  		if s.serveCtx == nil {
   119  			s.serveCtx = context.Background()
   120  		}
   121  		if s.listenQUIC == nil {
   122  			s.listenQUIC = func(addr string, config *quic.Config) (*quic.Endpoint, error) {
   123  				return quic.Listen("udp", addr, config)
   124  			}
   125  		}
   126  		s.serveCtx, s.serveCtxCancel = context.WithCancel(s.serveCtx)
   127  		s.activeConns = make(map[*serverConn]struct{})
   128  		s.connClosed = make(chan any, 1)
   129  	})
   130  }
   131  
   132  // listenAndServe listens on the UDP network address addr
   133  // and then calls Serve to handle requests on incoming connections.
   134  func (s *server) listenAndServe(addr string) error {
   135  	s.init()
   136  	e, err := s.listenQUIC(addr, s.config)
   137  	if err != nil {
   138  		return err
   139  	}
   140  	go s.serve(e)
   141  	return nil
   142  }
   143  
   144  // serve accepts incoming connections on the QUIC endpoint e,
   145  // and handles requests from those connections.
   146  func (s *server) serve(e *quic.Endpoint) error {
   147  	s.init()
   148  	defer e.Close(canceledCtx)
   149  	for {
   150  		qconn, err := e.Accept(s.serveCtx)
   151  		if err != nil {
   152  			return err
   153  		}
   154  		go s.newServerConn(qconn, s.handler)
   155  	}
   156  }
   157  
   158  // shutdown attempts a graceful shutdown for the server.
   159  func (s *server) shutdown(ctx context.Context) {
   160  	// Set a reasonable default in case ctx is nil.
   161  	if ctx == nil {
   162  		var cancel context.CancelFunc
   163  		ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   164  		defer cancel()
   165  	}
   166  
   167  	// Send GOAWAY frames to all active connections to give a chance for them
   168  	// to gracefully terminate.
   169  	s.mu.Lock()
   170  	for sc := range s.activeConns {
   171  		// TODO: Modify x/net/quic stream API so that write errors from context
   172  		// deadline are sticky.
   173  		go sc.sendGoaway()
   174  	}
   175  	s.mu.Unlock()
   176  
   177  	// Complete shutdown as soon as there are no more active connections or ctx
   178  	// is done, whichever comes first.
   179  	defer func() {
   180  		s.mu.Lock()
   181  		defer s.mu.Unlock()
   182  		s.serveCtxCancel()
   183  		for sc := range s.activeConns {
   184  			sc.abort(&connectionError{
   185  				code:    errH3NoError,
   186  				message: "server is shutting down",
   187  			})
   188  		}
   189  	}()
   190  	noMoreConns := func() bool {
   191  		s.mu.Lock()
   192  		defer s.mu.Unlock()
   193  		return len(s.activeConns) == 0
   194  	}
   195  	for {
   196  		if noMoreConns() {
   197  			return
   198  		}
   199  		select {
   200  		case <-ctx.Done():
   201  			return
   202  		case <-s.connClosed:
   203  		}
   204  	}
   205  }
   206  
   207  func (s *server) registerConn(sc *serverConn) {
   208  	s.mu.Lock()
   209  	defer s.mu.Unlock()
   210  	s.activeConns[sc] = struct{}{}
   211  }
   212  
   213  func (s *server) unregisterConn(sc *serverConn) {
   214  	s.mu.Lock()
   215  	delete(s.activeConns, sc)
   216  	s.mu.Unlock()
   217  	select {
   218  	case s.connClosed <- struct{}{}:
   219  	default:
   220  		// Channel already full. No need to send more values since we are just
   221  		// using this channel as a simpler sync.Cond.
   222  	}
   223  }
   224  
   225  type serverConn struct {
   226  	qconn *quic.Conn
   227  
   228  	genericConn // for handleUnidirectionalStream
   229  	enc         qpackEncoder
   230  	dec         qpackDecoder
   231  	handler     http.Handler
   232  
   233  	// For handling shutdown.
   234  	controlStream      *stream
   235  	mu                 sync.Mutex // Guards everything below.
   236  	maxRequestStreamID int64
   237  	goawaySent         bool
   238  }
   239  
   240  func (s *server) newServerConn(qconn *quic.Conn, handler http.Handler) {
   241  	sc := &serverConn{
   242  		qconn:   qconn,
   243  		handler: handler,
   244  	}
   245  	s.registerConn(sc)
   246  	defer s.unregisterConn(sc)
   247  	sc.enc.init()
   248  
   249  	// Create control stream and send SETTINGS frame.
   250  	// TODO: Time out on creating stream.
   251  	var err error
   252  	sc.controlStream, err = newConnStream(context.Background(), sc.qconn, streamTypeControl)
   253  	if err != nil {
   254  		return
   255  	}
   256  	sc.controlStream.writeSettings()
   257  	sc.controlStream.Flush()
   258  
   259  	sc.acceptStreams(sc.qconn, sc)
   260  }
   261  
   262  func (sc *serverConn) handleControlStream(st *stream) error {
   263  	// "A SETTINGS frame MUST be sent as the first frame of each control stream [...]"
   264  	// https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4-2
   265  	if err := st.readSettings(func(settingsType, settingsValue int64) error {
   266  		switch settingsType {
   267  		case settingsMaxFieldSectionSize:
   268  			_ = settingsValue // TODO
   269  		case settingsQPACKMaxTableCapacity:
   270  			_ = settingsValue // TODO
   271  		case settingsQPACKBlockedStreams:
   272  			_ = settingsValue // TODO
   273  		default:
   274  			// Unknown settings types are ignored.
   275  		}
   276  		return nil
   277  	}); err != nil {
   278  		return err
   279  	}
   280  
   281  	for {
   282  		ftype, err := st.readFrameHeader()
   283  		if err != nil {
   284  			return err
   285  		}
   286  		switch ftype {
   287  		case frameTypeCancelPush:
   288  			// "If a server receives a CANCEL_PUSH frame for a push ID
   289  			// that has not yet been mentioned by a PUSH_PROMISE frame,
   290  			// this MUST be treated as a connection error of type H3_ID_ERROR."
   291  			// https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.3-8
   292  			return &connectionError{
   293  				code:    errH3IDError,
   294  				message: "CANCEL_PUSH for unsent push ID",
   295  			}
   296  		case frameTypeGoaway:
   297  			return errH3NoError
   298  		default:
   299  			// Unknown frames are ignored.
   300  			if err := st.discardUnknownFrame(ftype); err != nil {
   301  				return err
   302  			}
   303  		}
   304  	}
   305  }
   306  
   307  func (sc *serverConn) handleEncoderStream(*stream) error {
   308  	// TODO
   309  	return nil
   310  }
   311  
   312  func (sc *serverConn) handleDecoderStream(*stream) error {
   313  	// TODO
   314  	return nil
   315  }
   316  
   317  func (sc *serverConn) handlePushStream(*stream) error {
   318  	// "[...] if a server receives a client-initiated push stream,
   319  	// this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR."
   320  	// https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3
   321  	return &connectionError{
   322  		code:    errH3StreamCreationError,
   323  		message: "client created push stream",
   324  	}
   325  }
   326  
   327  // hasDisallowedConnectionHeader reports whether h contains connnection headers
   328  // that are not allowed in HTTP/3:
   329  //
   330  // "An endpoint MUST NOT generate an HTTP/3 field section containing
   331  // connection-specific fields; any message containing connection-specific
   332  // fields MUST be treated as malformed."
   333  //
   334  // "The only exception to this is the TE header field, which MAY be present in
   335  // an HTTP/3 request header; when it is, it MUST NOT contain any value other
   336  // than "trailers"."
   337  func hasDisallowedConnectionHeader(h http.Header) bool {
   338  	neverAllowed := []string{
   339  		"Connection",
   340  		"Keep-Alive",
   341  		"Proxy-Connection",
   342  		"Transfer-Encoding",
   343  		"Upgrade",
   344  	}
   345  	for _, k := range neverAllowed {
   346  		if _, ok := h[k]; ok {
   347  			return true
   348  		}
   349  	}
   350  	if te, ok := h["Te"]; ok && (len(te) != 1 || te[0] != "trailers") {
   351  		return true
   352  	}
   353  	return false
   354  }
   355  
   356  type pseudoHeader struct {
   357  	method    string
   358  	scheme    string
   359  	path      string
   360  	authority string
   361  }
   362  
   363  func (sc *serverConn) parseHeader(st *stream) (http.Header, pseudoHeader, error) {
   364  	ftype, err := st.readFrameHeader()
   365  	if err != nil {
   366  		return nil, pseudoHeader{}, err
   367  	}
   368  	if ftype != frameTypeHeaders {
   369  		return nil, pseudoHeader{}, &streamError{errH3MessageError, "received other frames when expecting HEADERS"}
   370  	}
   371  	header := make(http.Header)
   372  	var pHeader pseudoHeader
   373  	var dec qpackDecoder
   374  	var hasMethod, hasScheme, hasPath, hasAuthority bool
   375  	if err := dec.decode(st, func(_ indexType, name, value string) error {
   376  		if !httpguts.ValidHeaderFieldValue(value) {
   377  			return &streamError{errH3MessageError, "invalid field value"}
   378  		}
   379  		switch name {
   380  		case ":method":
   381  			if hasMethod {
   382  				return &streamError{errH3MessageError, "duplicate :method"}
   383  			}
   384  			hasMethod = true
   385  			pHeader.method = value
   386  		case ":scheme":
   387  			if hasScheme {
   388  				return &streamError{errH3MessageError, "duplicate :scheme"}
   389  			}
   390  			hasScheme = true
   391  			pHeader.scheme = value
   392  		case ":path":
   393  			if hasPath {
   394  				return &streamError{errH3MessageError, "duplicate :path"}
   395  			}
   396  			hasPath = true
   397  			pHeader.path = value
   398  		case ":authority":
   399  			if hasAuthority {
   400  				return &streamError{errH3MessageError, "duplicate :authority"}
   401  			}
   402  			hasAuthority = true
   403  			pHeader.authority = value
   404  		default:
   405  			if !validWireHeaderFieldName(name) {
   406  				return &streamError{errH3MessageError, "invalid field name"}
   407  			}
   408  			header.Add(name, value)
   409  		}
   410  		return nil
   411  	}); err != nil {
   412  		return nil, pseudoHeader{}, err
   413  	}
   414  	if err := st.endFrame(); err != nil {
   415  		return nil, pseudoHeader{}, err
   416  	}
   417  	if hasDisallowedConnectionHeader(header) {
   418  		return nil, pseudoHeader{}, &streamError{errH3MessageError, "invalid connection-related header"}
   419  	}
   420  
   421  	// "All HTTP/3 requests MUST include exactly one value for the :method,
   422  	// :scheme, and :path pseudo-header fields, unless the request is a CONNECT
   423  	// request"
   424  	//
   425  	// "A CONNECT request MUST be constructed as follows:
   426  	// - The :method pseudo-header field is set to "CONNECT"
   427  	// - The :scheme and :path pseudo-header fields are omitted
   428  	// - The :authority pseudo-header field contains the host and port to connect to"
   429  	if !hasMethod {
   430  		return nil, pseudoHeader{}, &streamError{errH3MessageError, "missing :method"}
   431  	}
   432  	if pHeader.method != "CONNECT" && (!hasScheme || !hasPath) {
   433  		return nil, pseudoHeader{}, &streamError{errH3MessageError, "missing :scheme or :path for non-CONNECT requests"}
   434  	}
   435  	if pHeader.method == "CONNECT" && (hasScheme || hasPath || !hasAuthority) {
   436  		return nil, pseudoHeader{}, &streamError{
   437  			errH3MessageError, "CONNECT request must only have :method and :authority pseudo-headers",
   438  		}
   439  	}
   440  	return header, pHeader, nil
   441  }
   442  
   443  func (sc *serverConn) sendGoaway() {
   444  	sc.mu.Lock()
   445  	if sc.goawaySent || sc.controlStream == nil {
   446  		sc.mu.Unlock()
   447  		return
   448  	}
   449  	sc.goawaySent = true
   450  	sc.mu.Unlock()
   451  
   452  	// No lock in this section in case writing to stream blocks. This is safe
   453  	// since sc.maxRequestStreamID is only updated when sc.goawaySent is false.
   454  	sc.controlStream.writeVarint(int64(frameTypeGoaway))
   455  	sc.controlStream.writeVarint(int64(sizeVarint(uint64(sc.maxRequestStreamID))))
   456  	sc.controlStream.writeVarint(sc.maxRequestStreamID)
   457  	sc.controlStream.Flush()
   458  }
   459  
   460  // requestShouldGoAway returns true if st has a stream ID that is equal or
   461  // greater than the ID we have sent in a GOAWAY frame, if any.
   462  func (sc *serverConn) requestShouldGoaway(st *stream) bool {
   463  	sc.mu.Lock()
   464  	defer sc.mu.Unlock()
   465  	if sc.goawaySent {
   466  		return st.stream.ID() >= sc.maxRequestStreamID
   467  	} else {
   468  		sc.maxRequestStreamID = max(sc.maxRequestStreamID, st.stream.ID())
   469  		return false
   470  	}
   471  }
   472  
   473  func (sc *serverConn) handleRequestStream(st *stream) error {
   474  	if sc.requestShouldGoaway(st) {
   475  		return &streamError{
   476  			code:    errH3RequestRejected,
   477  			message: "GOAWAY request with equal or lower ID than the stream has been sent",
   478  		}
   479  	}
   480  	header, pHeader, err := sc.parseHeader(st)
   481  	if err != nil {
   482  		return err
   483  	}
   484  
   485  	reqInfo := httpcommon.NewServerRequest(httpcommon.ServerRequestParam{
   486  		Method:    pHeader.method,
   487  		Scheme:    pHeader.scheme,
   488  		Authority: pHeader.authority,
   489  		Path:      pHeader.path,
   490  		Header:    header,
   491  	})
   492  	if reqInfo.InvalidReason != "" {
   493  		return &streamError{
   494  			code:    errH3MessageError,
   495  			message: reqInfo.InvalidReason,
   496  		}
   497  	}
   498  
   499  	var body io.ReadCloser
   500  	contentLength := int64(-1)
   501  	if n, err := strconv.Atoi(header.Get("Content-Length")); err == nil {
   502  		contentLength = int64(n)
   503  	}
   504  	if contentLength != 0 || len(reqInfo.Trailer) != 0 {
   505  		body = &bodyReader{
   506  			st:      st,
   507  			remain:  contentLength,
   508  			trailer: reqInfo.Trailer,
   509  		}
   510  	} else {
   511  		body = http.NoBody
   512  	}
   513  
   514  	req := &http.Request{
   515  		Proto:         "HTTP/3.0",
   516  		Method:        pHeader.method,
   517  		Host:          pHeader.authority,
   518  		URL:           reqInfo.URL,
   519  		RequestURI:    reqInfo.RequestURI,
   520  		Trailer:       reqInfo.Trailer,
   521  		ProtoMajor:    3,
   522  		RemoteAddr:    sc.qconn.RemoteAddr().String(),
   523  		Body:          body,
   524  		Header:        header,
   525  		ContentLength: contentLength,
   526  	}
   527  	defer req.Body.Close()
   528  
   529  	rw := &responseWriter{
   530  		st:             st,
   531  		headers:        make(http.Header),
   532  		trailer:        make(http.Header),
   533  		bb:             make(bodyBuffer, 0, defaultBodyBufferCap),
   534  		cannotHaveBody: req.Method == "HEAD",
   535  		bw: &bodyWriter{
   536  			st:     st,
   537  			remain: -1,
   538  			flush:  false,
   539  			name:   "response",
   540  			enc:    &sc.enc,
   541  		},
   542  	}
   543  	defer rw.close()
   544  	if reqInfo.NeedsContinue {
   545  		req.Body.(*bodyReader).send100Continue = func() {
   546  			rw.WriteHeader(100)
   547  		}
   548  	}
   549  
   550  	// TODO: handle panic coming from the HTTP handler.
   551  	sc.handler.ServeHTTP(rw, req)
   552  	return nil
   553  }
   554  
   555  // abort closes the connection with an error.
   556  func (sc *serverConn) abort(err error) {
   557  	if e, ok := err.(*connectionError); ok {
   558  		sc.qconn.Abort(&quic.ApplicationError{
   559  			Code:   uint64(e.code),
   560  			Reason: e.message,
   561  		})
   562  	} else {
   563  		sc.qconn.Abort(err)
   564  	}
   565  }
   566  
   567  // responseCanHaveBody reports whether a given response status code permits a
   568  // body. See RFC 7230, section 3.3.
   569  func responseCanHaveBody(status int) bool {
   570  	switch {
   571  	case status >= 100 && status <= 199:
   572  		return false
   573  	case status == 204:
   574  		return false
   575  	case status == 304:
   576  		return false
   577  	}
   578  	return true
   579  }
   580  
   581  type responseWriter struct {
   582  	st             *stream
   583  	bw             *bodyWriter
   584  	mu             sync.Mutex
   585  	headers        http.Header
   586  	trailer        http.Header
   587  	bb             bodyBuffer
   588  	wroteHeader    bool // Non-1xx header has been (logically) written.
   589  	statusCode     int  // Status of the response that will be sent in HEADERS frame.
   590  	statusCodeSet  bool // Status of the response has been set via a call to WriteHeader.
   591  	cannotHaveBody bool // Response should not have a body (e.g. response to a HEAD request).
   592  	bodyLenLeft    int  // How much of the content body is left to be sent, set via "Content-Length" header. -1 if unknown.
   593  }
   594  
   595  func (rw *responseWriter) Header() http.Header {
   596  	return rw.headers
   597  }
   598  
   599  // prepareTrailerForWriteLocked populates any pre-declared trailer header with
   600  // its value, and passes it to bodyWriter so it can be written after body EOF.
   601  // Caller must hold rw.mu.
   602  func (rw *responseWriter) prepareTrailerForWriteLocked() {
   603  	for name := range rw.trailer {
   604  		if val, ok := rw.headers[name]; ok {
   605  			rw.trailer[name] = val
   606  		} else {
   607  			delete(rw.trailer, name)
   608  		}
   609  	}
   610  	if len(rw.trailer) > 0 {
   611  		rw.bw.trailer = rw.trailer
   612  	}
   613  }
   614  
   615  // writeHeaderLockedOnce writes the final response header. If rw.wroteHeader is
   616  // true, calling this method is a no-op. Sending informational status headers
   617  // should be done using writeInfoHeaderLocked, rather than this method.
   618  // Caller must hold rw.mu.
   619  func (rw *responseWriter) writeHeaderLockedOnce() {
   620  	if rw.wroteHeader {
   621  		return
   622  	}
   623  	if !responseCanHaveBody(rw.statusCode) {
   624  		rw.cannotHaveBody = true
   625  	}
   626  	// If there is any Trailer declared in headers, save them so we know which
   627  	// trailers have been pre-declared. Also, write back the extracted value,
   628  	// which is canonicalized, to rw.Header for consistency.
   629  	if _, ok := rw.headers["Trailer"]; ok {
   630  		extractTrailerFromHeader(rw.headers, rw.trailer)
   631  		rw.headers.Set("Trailer", strings.Join(slices.Sorted(maps.Keys(rw.trailer)), ", "))
   632  	}
   633  
   634  	rw.bb.inferHeader(rw.headers, rw.statusCode)
   635  	encHeaders := rw.bw.enc.encode(func(f func(itype indexType, name, value string)) {
   636  		f(mayIndex, ":status", strconv.Itoa(rw.statusCode))
   637  		for name, values := range rw.headers {
   638  			if !httpguts.ValidHeaderFieldName(name) {
   639  				continue
   640  			}
   641  			for _, val := range values {
   642  				if !httpguts.ValidHeaderFieldValue(val) {
   643  					continue
   644  				}
   645  				// Issue #71374: Consider supporting never-indexed fields.
   646  				f(mayIndex, name, val)
   647  			}
   648  		}
   649  	})
   650  
   651  	rw.st.writeVarint(int64(frameTypeHeaders))
   652  	rw.st.writeVarint(int64(len(encHeaders)))
   653  	rw.st.Write(encHeaders)
   654  	rw.wroteHeader = true
   655  }
   656  
   657  // writeHeaderLocked writes informational status headers (i.e. status 1XX).
   658  // If a non-informational status header has been written via
   659  // writeHeaderLockedOnce, this method is a no-op.
   660  // Caller must hold rw.mu.
   661  func (rw *responseWriter) writeHeaderLocked(statusCode int) {
   662  	if rw.wroteHeader {
   663  		return
   664  	}
   665  	encHeaders := rw.bw.enc.encode(func(f func(itype indexType, name, value string)) {
   666  		f(mayIndex, ":status", strconv.Itoa(statusCode))
   667  		for name, values := range rw.headers {
   668  			if name == "Content-Length" || name == "Transfer-Encoding" {
   669  				continue
   670  			}
   671  			if !httpguts.ValidHeaderFieldName(name) {
   672  				continue
   673  			}
   674  			for _, val := range values {
   675  				if !httpguts.ValidHeaderFieldValue(val) {
   676  					continue
   677  				}
   678  				// Issue #71374: Consider supporting never-indexed fields.
   679  				f(mayIndex, name, val)
   680  			}
   681  		}
   682  	})
   683  	rw.st.writeVarint(int64(frameTypeHeaders))
   684  	rw.st.writeVarint(int64(len(encHeaders)))
   685  	rw.st.Write(encHeaders)
   686  }
   687  
   688  func isInfoStatus(status int) bool {
   689  	return status >= 100 && status < 200
   690  }
   691  
   692  // checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode.
   693  func checkWriteHeaderCode(code int) {
   694  	// Issue 22880: require valid WriteHeader status codes.
   695  	// For now we only enforce that it's three digits.
   696  	// In the future we might block things over 599 (600 and above aren't defined
   697  	// at http://httpwg.org/specs/rfc7231.html#status.codes).
   698  	// But for now any three digits.
   699  	//
   700  	// We used to send "HTTP/1.1 000 0" on the wire in responses but there's
   701  	// no equivalent bogus thing we can realistically send in HTTP/3,
   702  	// so we'll consistently panic instead and help people find their bugs
   703  	// early. (We can't return an error from WriteHeader even if we wanted to.)
   704  	if code < 100 || code > 999 {
   705  		panic(fmt.Sprintf("invalid WriteHeader code %v", code))
   706  	}
   707  }
   708  
   709  func (rw *responseWriter) WriteHeader(statusCode int) {
   710  	// TODO: handle sending informational status headers (e.g. 103).
   711  	rw.mu.Lock()
   712  	defer rw.mu.Unlock()
   713  	if rw.statusCodeSet {
   714  		return
   715  	}
   716  	checkWriteHeaderCode(statusCode)
   717  
   718  	// Informational headers can be sent multiple times, and should be flushed
   719  	// immediately.
   720  	if isInfoStatus(statusCode) {
   721  		rw.writeHeaderLocked(statusCode)
   722  		rw.st.Flush()
   723  		return
   724  	}
   725  
   726  	// Non-informational headers should only be set once, and should be
   727  	// buffered.
   728  	rw.statusCodeSet = true
   729  	rw.statusCode = statusCode
   730  	if n, err := strconv.Atoi(rw.Header().Get("Content-Length")); err == nil {
   731  		rw.bodyLenLeft = n
   732  	} else {
   733  		rw.bodyLenLeft = -1 // Unknown.
   734  	}
   735  }
   736  
   737  // trimWriteLocked trims a byte slice, b, such that the length of b will not
   738  // exceed rw.bodyLenLeft. This method will update rw.bodyLenLeft when trimming
   739  // b, and will also return whether b was trimmed or not.
   740  // Caller must hold rw.mu.
   741  func (rw *responseWriter) trimWriteLocked(b []byte) ([]byte, bool) {
   742  	if rw.bodyLenLeft < 0 {
   743  		return b, false
   744  	}
   745  	n := min(len(b), rw.bodyLenLeft)
   746  	rw.bodyLenLeft -= n
   747  	return b[:n], n != len(b)
   748  }
   749  
   750  func (rw *responseWriter) Write(b []byte) (n int, err error) {
   751  	// Calling Write implicitly calls WriteHeader(200) if WriteHeader has not
   752  	// been called before.
   753  	rw.WriteHeader(http.StatusOK)
   754  	rw.mu.Lock()
   755  	defer rw.mu.Unlock()
   756  
   757  	if rw.statusCode == http.StatusNotModified {
   758  		return 0, http.ErrBodyNotAllowed
   759  	}
   760  
   761  	b, trimmed := rw.trimWriteLocked(b)
   762  	if trimmed {
   763  		defer func() {
   764  			err = http.ErrContentLength
   765  		}()
   766  	}
   767  
   768  	// If b fits entirely in our body buffer, save it to the buffer and return
   769  	// early so we can coalesce small writes.
   770  	// As a special case, we always want to save b to the buffer even when b is
   771  	// big if we had yet to write our header, so we can infer headers like
   772  	// "Content-Type" with as much information as possible.
   773  	initialBLen := len(b)
   774  	initialBufLen := len(rw.bb)
   775  	if !rw.wroteHeader || len(b) <= cap(rw.bb)-len(rw.bb) {
   776  		b = rw.bb.write(b)
   777  		if len(b) == 0 {
   778  			return initialBLen, nil
   779  		}
   780  	}
   781  
   782  	// Reaching this point means that our buffer has been sufficiently filled.
   783  	// Therefore, we now want to:
   784  	// 1. Infer and write response headers based on our body buffer, if not
   785  	// done yet.
   786  	// 2. Write our body buffer and the rest of b (if any).
   787  	// 3. Reset the current body buffer so it can be used again.
   788  	rw.writeHeaderLockedOnce()
   789  	if rw.cannotHaveBody {
   790  		return initialBLen, nil
   791  	}
   792  	if n, err := rw.bw.write(rw.bb, b); err != nil {
   793  		return max(0, n-initialBufLen), err
   794  	}
   795  	rw.bb.discard()
   796  	return initialBLen, nil
   797  }
   798  
   799  func (rw *responseWriter) Flush() {
   800  	// Calling Flush implicitly calls WriteHeader(200) if WriteHeader has not
   801  	// been called before.
   802  	rw.WriteHeader(http.StatusOK)
   803  	rw.mu.Lock()
   804  	defer rw.mu.Unlock()
   805  	rw.writeHeaderLockedOnce()
   806  	if !rw.cannotHaveBody {
   807  		rw.bw.Write(rw.bb)
   808  		rw.bb.discard()
   809  	}
   810  	rw.st.Flush()
   811  }
   812  
   813  func (rw *responseWriter) close() error {
   814  	rw.Flush()
   815  	rw.mu.Lock()
   816  	defer rw.mu.Unlock()
   817  	rw.prepareTrailerForWriteLocked()
   818  	if err := rw.bw.Close(); err != nil {
   819  		return err
   820  	}
   821  	return rw.st.stream.Close()
   822  }
   823  
   824  // defaultBodyBufferCap is the default number of bytes of body that we are
   825  // willing to save in a buffer for the sake of inferring headers and coalescing
   826  // small writes. 512 was chosen to be consistent with how much
   827  // http.DetectContentType is willing to read.
   828  const defaultBodyBufferCap = 512
   829  
   830  // bodyBuffer is a buffer used to store body content of a response.
   831  type bodyBuffer []byte
   832  
   833  // write writes b to the buffer. It returns a new slice of b, which contains
   834  // any remaining data that could not be written to the buffer, if any.
   835  func (bb *bodyBuffer) write(b []byte) []byte {
   836  	n := min(len(b), cap(*bb)-len(*bb))
   837  	*bb = append(*bb, b[:n]...)
   838  	return b[n:]
   839  }
   840  
   841  // discard resets the buffer so it can be used again.
   842  func (bb *bodyBuffer) discard() {
   843  	*bb = (*bb)[:0]
   844  }
   845  
   846  // inferHeader populates h with the header values that we can infer from our
   847  // current buffer content, if not already explicitly set. This method should be
   848  // called only once with as much body content as possible in the buffer, before
   849  // a HEADERS frame is sent, and before discard has been called. Doing so
   850  // properly is the responsibility of the caller.
   851  func (bb *bodyBuffer) inferHeader(h http.Header, status int) {
   852  	if _, ok := h["Date"]; !ok {
   853  		h.Set("Date", time.Now().UTC().Format(http.TimeFormat))
   854  	}
   855  	// If the Content-Encoding is non-blank, we shouldn't
   856  	// sniff the body. See Issue golang.org/issue/31753.
   857  	_, hasCE := h["Content-Encoding"]
   858  	_, hasCT := h["Content-Type"]
   859  	if !hasCE && !hasCT && responseCanHaveBody(status) && len(*bb) > 0 {
   860  		h.Set("Content-Type", http.DetectContentType(*bb))
   861  	}
   862  	// We can technically infer Content-Length too here, as long as the entire
   863  	// response body fits within hi.buf and does not require flushing. However,
   864  	// we have chosen not to do so for now as Content-Length is not very
   865  	// important for HTTP/3, and such inconsistent behavior might be confusing.
   866  }
   867  

View as plain text