1
2
3
4
5 package http3
6
7 import (
8 "context"
9 "crypto/tls"
10 "fmt"
11 "io"
12 "maps"
13 "net/http"
14 "slices"
15 "strconv"
16 "strings"
17 "sync"
18 "time"
19
20 "golang.org/x/net/http/httpguts"
21 "golang.org/x/net/internal/httpcommon"
22 "golang.org/x/net/quic"
23 )
24
25
26
27 type server struct {
28
29 handler http.Handler
30
31 config *quic.Config
32
33 listenQUIC func(addr string, config *quic.Config) (*quic.Endpoint, error)
34
35 initOnce sync.Once
36
37 serveCtx context.Context
38 serveCtxCancel context.CancelFunc
39
40
41
42
43 connClosed chan any
44 mu sync.Mutex
45 activeConns map[*serverConn]struct{}
46 }
47
48
49
50
51
52
53
54
55
56
57 type netHTTPHandler interface {
58 http.Handler
59 TLSConfig() *tls.Config
60 BaseContext() context.Context
61 Addr() string
62 ListenErrHook(err error)
63 ShutdownContext() context.Context
64 }
65
66 type ServerOpts struct {
67
68
69 ListenQUIC func(addr string, config *quic.Config) (*quic.Endpoint, error)
70
71
72
73
74
75
76 QUICConfig *quic.Config
77 }
78
79
80
81
82
83 func RegisterServer(s *http.Server, opts ServerOpts) {
84 if s.TLSNextProto == nil {
85 s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
86 }
87 s.TLSNextProto["http/3"] = func(s *http.Server, c *tls.Conn, h http.Handler) {
88 stdHandler, ok := h.(netHTTPHandler)
89 if !ok {
90 panic("RegisterServer was given a server that does not implement netHTTPHandler")
91 }
92 if opts.QUICConfig == nil {
93 opts.QUICConfig = &quic.Config{}
94 }
95 if opts.QUICConfig.TLSConfig == nil {
96 opts.QUICConfig.TLSConfig = stdHandler.TLSConfig()
97 }
98 s3 := &server{
99 config: opts.QUICConfig,
100 listenQUIC: opts.ListenQUIC,
101 handler: stdHandler,
102 serveCtx: stdHandler.BaseContext(),
103 }
104 s3.init()
105 s.RegisterOnShutdown(func() {
106 s3.shutdown(stdHandler.ShutdownContext())
107 })
108 stdHandler.ListenErrHook(s3.listenAndServe(stdHandler.Addr()))
109 }
110 }
111
112 func (s *server) init() {
113 s.initOnce.Do(func() {
114 s.config = initConfig(s.config)
115 if s.handler == nil {
116 s.handler = http.DefaultServeMux
117 }
118 if s.serveCtx == nil {
119 s.serveCtx = context.Background()
120 }
121 if s.listenQUIC == nil {
122 s.listenQUIC = func(addr string, config *quic.Config) (*quic.Endpoint, error) {
123 return quic.Listen("udp", addr, config)
124 }
125 }
126 s.serveCtx, s.serveCtxCancel = context.WithCancel(s.serveCtx)
127 s.activeConns = make(map[*serverConn]struct{})
128 s.connClosed = make(chan any, 1)
129 })
130 }
131
132
133
134 func (s *server) listenAndServe(addr string) error {
135 s.init()
136 e, err := s.listenQUIC(addr, s.config)
137 if err != nil {
138 return err
139 }
140 go s.serve(e)
141 return nil
142 }
143
144
145
146 func (s *server) serve(e *quic.Endpoint) error {
147 s.init()
148 defer e.Close(canceledCtx)
149 for {
150 qconn, err := e.Accept(s.serveCtx)
151 if err != nil {
152 return err
153 }
154 go s.newServerConn(qconn, s.handler)
155 }
156 }
157
158
159 func (s *server) shutdown(ctx context.Context) {
160
161 if ctx == nil {
162 var cancel context.CancelFunc
163 ctx, cancel = context.WithTimeout(context.Background(), time.Second)
164 defer cancel()
165 }
166
167
168
169 s.mu.Lock()
170 for sc := range s.activeConns {
171
172
173 go sc.sendGoaway()
174 }
175 s.mu.Unlock()
176
177
178
179 defer func() {
180 s.mu.Lock()
181 defer s.mu.Unlock()
182 s.serveCtxCancel()
183 for sc := range s.activeConns {
184 sc.abort(&connectionError{
185 code: errH3NoError,
186 message: "server is shutting down",
187 })
188 }
189 }()
190 noMoreConns := func() bool {
191 s.mu.Lock()
192 defer s.mu.Unlock()
193 return len(s.activeConns) == 0
194 }
195 for {
196 if noMoreConns() {
197 return
198 }
199 select {
200 case <-ctx.Done():
201 return
202 case <-s.connClosed:
203 }
204 }
205 }
206
207 func (s *server) registerConn(sc *serverConn) {
208 s.mu.Lock()
209 defer s.mu.Unlock()
210 s.activeConns[sc] = struct{}{}
211 }
212
213 func (s *server) unregisterConn(sc *serverConn) {
214 s.mu.Lock()
215 delete(s.activeConns, sc)
216 s.mu.Unlock()
217 select {
218 case s.connClosed <- struct{}{}:
219 default:
220
221
222 }
223 }
224
225 type serverConn struct {
226 qconn *quic.Conn
227
228 genericConn
229 enc qpackEncoder
230 dec qpackDecoder
231 handler http.Handler
232
233
234 controlStream *stream
235 mu sync.Mutex
236 maxRequestStreamID int64
237 goawaySent bool
238 }
239
240 func (s *server) newServerConn(qconn *quic.Conn, handler http.Handler) {
241 sc := &serverConn{
242 qconn: qconn,
243 handler: handler,
244 }
245 s.registerConn(sc)
246 defer s.unregisterConn(sc)
247 sc.enc.init()
248
249
250
251 var err error
252 sc.controlStream, err = newConnStream(context.Background(), sc.qconn, streamTypeControl)
253 if err != nil {
254 return
255 }
256 sc.controlStream.writeSettings()
257 sc.controlStream.Flush()
258
259 sc.acceptStreams(sc.qconn, sc)
260 }
261
262 func (sc *serverConn) handleControlStream(st *stream) error {
263
264
265 if err := st.readSettings(func(settingsType, settingsValue int64) error {
266 switch settingsType {
267 case settingsMaxFieldSectionSize:
268 _ = settingsValue
269 case settingsQPACKMaxTableCapacity:
270 _ = settingsValue
271 case settingsQPACKBlockedStreams:
272 _ = settingsValue
273 default:
274
275 }
276 return nil
277 }); err != nil {
278 return err
279 }
280
281 for {
282 ftype, err := st.readFrameHeader()
283 if err != nil {
284 return err
285 }
286 switch ftype {
287 case frameTypeCancelPush:
288
289
290
291
292 return &connectionError{
293 code: errH3IDError,
294 message: "CANCEL_PUSH for unsent push ID",
295 }
296 case frameTypeGoaway:
297 return errH3NoError
298 default:
299
300 if err := st.discardUnknownFrame(ftype); err != nil {
301 return err
302 }
303 }
304 }
305 }
306
307 func (sc *serverConn) handleEncoderStream(*stream) error {
308
309 return nil
310 }
311
312 func (sc *serverConn) handleDecoderStream(*stream) error {
313
314 return nil
315 }
316
317 func (sc *serverConn) handlePushStream(*stream) error {
318
319
320
321 return &connectionError{
322 code: errH3StreamCreationError,
323 message: "client created push stream",
324 }
325 }
326
327
328
329
330
331
332
333
334
335
336
337 func hasDisallowedConnectionHeader(h http.Header) bool {
338 neverAllowed := []string{
339 "Connection",
340 "Keep-Alive",
341 "Proxy-Connection",
342 "Transfer-Encoding",
343 "Upgrade",
344 }
345 for _, k := range neverAllowed {
346 if _, ok := h[k]; ok {
347 return true
348 }
349 }
350 if te, ok := h["Te"]; ok && (len(te) != 1 || te[0] != "trailers") {
351 return true
352 }
353 return false
354 }
355
356 type pseudoHeader struct {
357 method string
358 scheme string
359 path string
360 authority string
361 }
362
363 func (sc *serverConn) parseHeader(st *stream) (http.Header, pseudoHeader, error) {
364 ftype, err := st.readFrameHeader()
365 if err != nil {
366 return nil, pseudoHeader{}, err
367 }
368 if ftype != frameTypeHeaders {
369 return nil, pseudoHeader{}, &streamError{errH3MessageError, "received other frames when expecting HEADERS"}
370 }
371 header := make(http.Header)
372 var pHeader pseudoHeader
373 var dec qpackDecoder
374 var hasMethod, hasScheme, hasPath, hasAuthority bool
375 if err := dec.decode(st, func(_ indexType, name, value string) error {
376 if !httpguts.ValidHeaderFieldValue(value) {
377 return &streamError{errH3MessageError, "invalid field value"}
378 }
379 switch name {
380 case ":method":
381 if hasMethod {
382 return &streamError{errH3MessageError, "duplicate :method"}
383 }
384 hasMethod = true
385 pHeader.method = value
386 case ":scheme":
387 if hasScheme {
388 return &streamError{errH3MessageError, "duplicate :scheme"}
389 }
390 hasScheme = true
391 pHeader.scheme = value
392 case ":path":
393 if hasPath {
394 return &streamError{errH3MessageError, "duplicate :path"}
395 }
396 hasPath = true
397 pHeader.path = value
398 case ":authority":
399 if hasAuthority {
400 return &streamError{errH3MessageError, "duplicate :authority"}
401 }
402 hasAuthority = true
403 pHeader.authority = value
404 default:
405 if !validWireHeaderFieldName(name) {
406 return &streamError{errH3MessageError, "invalid field name"}
407 }
408 header.Add(name, value)
409 }
410 return nil
411 }); err != nil {
412 return nil, pseudoHeader{}, err
413 }
414 if err := st.endFrame(); err != nil {
415 return nil, pseudoHeader{}, err
416 }
417 if hasDisallowedConnectionHeader(header) {
418 return nil, pseudoHeader{}, &streamError{errH3MessageError, "invalid connection-related header"}
419 }
420
421
422
423
424
425
426
427
428
429 if !hasMethod {
430 return nil, pseudoHeader{}, &streamError{errH3MessageError, "missing :method"}
431 }
432 if pHeader.method != "CONNECT" && (!hasScheme || !hasPath) {
433 return nil, pseudoHeader{}, &streamError{errH3MessageError, "missing :scheme or :path for non-CONNECT requests"}
434 }
435 if pHeader.method == "CONNECT" && (hasScheme || hasPath || !hasAuthority) {
436 return nil, pseudoHeader{}, &streamError{
437 errH3MessageError, "CONNECT request must only have :method and :authority pseudo-headers",
438 }
439 }
440 return header, pHeader, nil
441 }
442
443 func (sc *serverConn) sendGoaway() {
444 sc.mu.Lock()
445 if sc.goawaySent || sc.controlStream == nil {
446 sc.mu.Unlock()
447 return
448 }
449 sc.goawaySent = true
450 sc.mu.Unlock()
451
452
453
454 sc.controlStream.writeVarint(int64(frameTypeGoaway))
455 sc.controlStream.writeVarint(int64(sizeVarint(uint64(sc.maxRequestStreamID))))
456 sc.controlStream.writeVarint(sc.maxRequestStreamID)
457 sc.controlStream.Flush()
458 }
459
460
461
462 func (sc *serverConn) requestShouldGoaway(st *stream) bool {
463 sc.mu.Lock()
464 defer sc.mu.Unlock()
465 if sc.goawaySent {
466 return st.stream.ID() >= sc.maxRequestStreamID
467 } else {
468 sc.maxRequestStreamID = max(sc.maxRequestStreamID, st.stream.ID())
469 return false
470 }
471 }
472
473 func (sc *serverConn) handleRequestStream(st *stream) error {
474 if sc.requestShouldGoaway(st) {
475 return &streamError{
476 code: errH3RequestRejected,
477 message: "GOAWAY request with equal or lower ID than the stream has been sent",
478 }
479 }
480 header, pHeader, err := sc.parseHeader(st)
481 if err != nil {
482 return err
483 }
484
485 reqInfo := httpcommon.NewServerRequest(httpcommon.ServerRequestParam{
486 Method: pHeader.method,
487 Scheme: pHeader.scheme,
488 Authority: pHeader.authority,
489 Path: pHeader.path,
490 Header: header,
491 })
492 if reqInfo.InvalidReason != "" {
493 return &streamError{
494 code: errH3MessageError,
495 message: reqInfo.InvalidReason,
496 }
497 }
498
499 var body io.ReadCloser
500 contentLength := int64(-1)
501 if n, err := strconv.Atoi(header.Get("Content-Length")); err == nil {
502 contentLength = int64(n)
503 }
504 if contentLength != 0 || len(reqInfo.Trailer) != 0 {
505 body = &bodyReader{
506 st: st,
507 remain: contentLength,
508 trailer: reqInfo.Trailer,
509 }
510 } else {
511 body = http.NoBody
512 }
513
514 req := &http.Request{
515 Proto: "HTTP/3.0",
516 Method: pHeader.method,
517 Host: pHeader.authority,
518 URL: reqInfo.URL,
519 RequestURI: reqInfo.RequestURI,
520 Trailer: reqInfo.Trailer,
521 ProtoMajor: 3,
522 RemoteAddr: sc.qconn.RemoteAddr().String(),
523 Body: body,
524 Header: header,
525 ContentLength: contentLength,
526 }
527 defer req.Body.Close()
528
529 rw := &responseWriter{
530 st: st,
531 headers: make(http.Header),
532 trailer: make(http.Header),
533 bb: make(bodyBuffer, 0, defaultBodyBufferCap),
534 cannotHaveBody: req.Method == "HEAD",
535 bw: &bodyWriter{
536 st: st,
537 remain: -1,
538 flush: false,
539 name: "response",
540 enc: &sc.enc,
541 },
542 }
543 defer rw.close()
544 if reqInfo.NeedsContinue {
545 req.Body.(*bodyReader).send100Continue = func() {
546 rw.WriteHeader(100)
547 }
548 }
549
550
551 sc.handler.ServeHTTP(rw, req)
552 return nil
553 }
554
555
556 func (sc *serverConn) abort(err error) {
557 if e, ok := err.(*connectionError); ok {
558 sc.qconn.Abort(&quic.ApplicationError{
559 Code: uint64(e.code),
560 Reason: e.message,
561 })
562 } else {
563 sc.qconn.Abort(err)
564 }
565 }
566
567
568
569 func responseCanHaveBody(status int) bool {
570 switch {
571 case status >= 100 && status <= 199:
572 return false
573 case status == 204:
574 return false
575 case status == 304:
576 return false
577 }
578 return true
579 }
580
581 type responseWriter struct {
582 st *stream
583 bw *bodyWriter
584 mu sync.Mutex
585 headers http.Header
586 trailer http.Header
587 bb bodyBuffer
588 wroteHeader bool
589 statusCode int
590 statusCodeSet bool
591 cannotHaveBody bool
592 bodyLenLeft int
593 }
594
595 func (rw *responseWriter) Header() http.Header {
596 return rw.headers
597 }
598
599
600
601
602 func (rw *responseWriter) prepareTrailerForWriteLocked() {
603 for name := range rw.trailer {
604 if val, ok := rw.headers[name]; ok {
605 rw.trailer[name] = val
606 } else {
607 delete(rw.trailer, name)
608 }
609 }
610 if len(rw.trailer) > 0 {
611 rw.bw.trailer = rw.trailer
612 }
613 }
614
615
616
617
618
619 func (rw *responseWriter) writeHeaderLockedOnce() {
620 if rw.wroteHeader {
621 return
622 }
623 if !responseCanHaveBody(rw.statusCode) {
624 rw.cannotHaveBody = true
625 }
626
627
628
629 if _, ok := rw.headers["Trailer"]; ok {
630 extractTrailerFromHeader(rw.headers, rw.trailer)
631 rw.headers.Set("Trailer", strings.Join(slices.Sorted(maps.Keys(rw.trailer)), ", "))
632 }
633
634 rw.bb.inferHeader(rw.headers, rw.statusCode)
635 encHeaders := rw.bw.enc.encode(func(f func(itype indexType, name, value string)) {
636 f(mayIndex, ":status", strconv.Itoa(rw.statusCode))
637 for name, values := range rw.headers {
638 if !httpguts.ValidHeaderFieldName(name) {
639 continue
640 }
641 for _, val := range values {
642 if !httpguts.ValidHeaderFieldValue(val) {
643 continue
644 }
645
646 f(mayIndex, name, val)
647 }
648 }
649 })
650
651 rw.st.writeVarint(int64(frameTypeHeaders))
652 rw.st.writeVarint(int64(len(encHeaders)))
653 rw.st.Write(encHeaders)
654 rw.wroteHeader = true
655 }
656
657
658
659
660
661 func (rw *responseWriter) writeHeaderLocked(statusCode int) {
662 if rw.wroteHeader {
663 return
664 }
665 encHeaders := rw.bw.enc.encode(func(f func(itype indexType, name, value string)) {
666 f(mayIndex, ":status", strconv.Itoa(statusCode))
667 for name, values := range rw.headers {
668 if name == "Content-Length" || name == "Transfer-Encoding" {
669 continue
670 }
671 if !httpguts.ValidHeaderFieldName(name) {
672 continue
673 }
674 for _, val := range values {
675 if !httpguts.ValidHeaderFieldValue(val) {
676 continue
677 }
678
679 f(mayIndex, name, val)
680 }
681 }
682 })
683 rw.st.writeVarint(int64(frameTypeHeaders))
684 rw.st.writeVarint(int64(len(encHeaders)))
685 rw.st.Write(encHeaders)
686 }
687
688 func isInfoStatus(status int) bool {
689 return status >= 100 && status < 200
690 }
691
692
693 func checkWriteHeaderCode(code int) {
694
695
696
697
698
699
700
701
702
703
704 if code < 100 || code > 999 {
705 panic(fmt.Sprintf("invalid WriteHeader code %v", code))
706 }
707 }
708
709 func (rw *responseWriter) WriteHeader(statusCode int) {
710
711 rw.mu.Lock()
712 defer rw.mu.Unlock()
713 if rw.statusCodeSet {
714 return
715 }
716 checkWriteHeaderCode(statusCode)
717
718
719
720 if isInfoStatus(statusCode) {
721 rw.writeHeaderLocked(statusCode)
722 rw.st.Flush()
723 return
724 }
725
726
727
728 rw.statusCodeSet = true
729 rw.statusCode = statusCode
730 if n, err := strconv.Atoi(rw.Header().Get("Content-Length")); err == nil {
731 rw.bodyLenLeft = n
732 } else {
733 rw.bodyLenLeft = -1
734 }
735 }
736
737
738
739
740
741 func (rw *responseWriter) trimWriteLocked(b []byte) ([]byte, bool) {
742 if rw.bodyLenLeft < 0 {
743 return b, false
744 }
745 n := min(len(b), rw.bodyLenLeft)
746 rw.bodyLenLeft -= n
747 return b[:n], n != len(b)
748 }
749
750 func (rw *responseWriter) Write(b []byte) (n int, err error) {
751
752
753 rw.WriteHeader(http.StatusOK)
754 rw.mu.Lock()
755 defer rw.mu.Unlock()
756
757 if rw.statusCode == http.StatusNotModified {
758 return 0, http.ErrBodyNotAllowed
759 }
760
761 b, trimmed := rw.trimWriteLocked(b)
762 if trimmed {
763 defer func() {
764 err = http.ErrContentLength
765 }()
766 }
767
768
769
770
771
772
773 initialBLen := len(b)
774 initialBufLen := len(rw.bb)
775 if !rw.wroteHeader || len(b) <= cap(rw.bb)-len(rw.bb) {
776 b = rw.bb.write(b)
777 if len(b) == 0 {
778 return initialBLen, nil
779 }
780 }
781
782
783
784
785
786
787
788 rw.writeHeaderLockedOnce()
789 if rw.cannotHaveBody {
790 return initialBLen, nil
791 }
792 if n, err := rw.bw.write(rw.bb, b); err != nil {
793 return max(0, n-initialBufLen), err
794 }
795 rw.bb.discard()
796 return initialBLen, nil
797 }
798
799 func (rw *responseWriter) Flush() {
800
801
802 rw.WriteHeader(http.StatusOK)
803 rw.mu.Lock()
804 defer rw.mu.Unlock()
805 rw.writeHeaderLockedOnce()
806 if !rw.cannotHaveBody {
807 rw.bw.Write(rw.bb)
808 rw.bb.discard()
809 }
810 rw.st.Flush()
811 }
812
813 func (rw *responseWriter) close() error {
814 rw.Flush()
815 rw.mu.Lock()
816 defer rw.mu.Unlock()
817 rw.prepareTrailerForWriteLocked()
818 if err := rw.bw.Close(); err != nil {
819 return err
820 }
821 return rw.st.stream.Close()
822 }
823
824
825
826
827
828 const defaultBodyBufferCap = 512
829
830
831 type bodyBuffer []byte
832
833
834
835 func (bb *bodyBuffer) write(b []byte) []byte {
836 n := min(len(b), cap(*bb)-len(*bb))
837 *bb = append(*bb, b[:n]...)
838 return b[n:]
839 }
840
841
842 func (bb *bodyBuffer) discard() {
843 *bb = (*bb)[:0]
844 }
845
846
847
848
849
850
851 func (bb *bodyBuffer) inferHeader(h http.Header, status int) {
852 if _, ok := h["Date"]; !ok {
853 h.Set("Date", time.Now().UTC().Format(http.TimeFormat))
854 }
855
856
857 _, hasCE := h["Content-Encoding"]
858 _, hasCT := h["Content-Type"]
859 if !hasCE && !hasCT && responseCanHaveBody(status) && len(*bb) > 0 {
860 h.Set("Content-Type", http.DetectContentType(*bb))
861 }
862
863
864
865
866 }
867
View as plain text