1
2
3
4
5 package http2_test
6
7 import (
8 "bytes"
9 "compress/gzip"
10 "compress/zlib"
11 "context"
12 "crypto/tls"
13 "crypto/x509"
14 "errors"
15 "flag"
16 "fmt"
17 "io"
18 "log"
19 "math"
20 "net"
21 "net/http"
22 "net/http/httptest"
23 "os"
24 "reflect"
25 "runtime"
26 "slices"
27 "strconv"
28 "strings"
29 "sync"
30 "testing"
31 "testing/synctest"
32 "time"
33 _ "unsafe"
34
35 "net/http/internal/http2"
36 . "net/http/internal/http2"
37 "net/http/internal/testcert"
38
39 "golang.org/x/net/http2/hpack"
40 )
41
42 var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered")
43
44 func stderrv() io.Writer {
45 if *stderrVerbose {
46 return os.Stderr
47 }
48
49 return io.Discard
50 }
51
52 type safeBuffer struct {
53 b bytes.Buffer
54 m sync.Mutex
55 }
56
57 func (sb *safeBuffer) Write(d []byte) (int, error) {
58 sb.m.Lock()
59 defer sb.m.Unlock()
60 return sb.b.Write(d)
61 }
62
63 func (sb *safeBuffer) Bytes() []byte {
64 sb.m.Lock()
65 defer sb.m.Unlock()
66 return sb.b.Bytes()
67 }
68
69 func (sb *safeBuffer) Len() int {
70 sb.m.Lock()
71 defer sb.m.Unlock()
72 return sb.b.Len()
73 }
74
75 type serverTester struct {
76 cc net.Conn
77 t testing.TB
78 h1server *http.Server
79 h2server *Server
80 serverLogBuf safeBuffer
81 logFilter []string
82 scMu sync.Mutex
83 sc *ServerConn
84 wrotePreface bool
85 testConnFramer
86
87 callsMu sync.Mutex
88 calls []*serverHandlerCall
89
90
91
92
93
94 frameReadLogMu sync.Mutex
95 frameReadLogBuf bytes.Buffer
96 frameWriteLogMu sync.Mutex
97 frameWriteLogBuf bytes.Buffer
98
99
100 headerBuf bytes.Buffer
101 hpackEnc *hpack.Encoder
102 }
103
104 type twriter struct {
105 t testing.TB
106 st *serverTester
107 }
108
109 func (w twriter) Write(p []byte) (n int, err error) {
110 if w.st != nil {
111 ps := string(p)
112 for _, phrase := range w.st.logFilter {
113 if strings.Contains(ps, phrase) {
114 return len(p), nil
115 }
116 }
117 }
118 w.t.Logf("%s", p)
119 return len(p), nil
120 }
121
122 func newTestServer(t testing.TB, handler http.HandlerFunc, opts ...any) *httptest.Server {
123 t.Helper()
124 if handler == nil {
125 handler = func(w http.ResponseWriter, req *http.Request) {}
126 }
127 ts := httptest.NewUnstartedServer(handler)
128 ts.EnableHTTP2 = true
129 ts.Config.ErrorLog = log.New(twriter{t: t}, "", log.LstdFlags)
130 ts.Config.Protocols = protocols("h2")
131 for _, opt := range opts {
132 switch v := opt.(type) {
133 case func(*httptest.Server):
134 v(ts)
135 case func(*http.Server):
136 v(ts.Config)
137 case func(*http.HTTP2Config):
138 if ts.Config.HTTP2 == nil {
139 ts.Config.HTTP2 = &http.HTTP2Config{}
140 }
141 v(ts.Config.HTTP2)
142 default:
143 t.Fatalf("unknown newTestServer option type %T", v)
144 }
145 }
146
147 if ts.Config.Protocols.HTTP2() {
148 ts.TLS = testServerTLSConfig
149 if ts.Config.TLSConfig != nil {
150 ts.TLS = ts.Config.TLSConfig
151 }
152 ts.StartTLS()
153 } else if ts.Config.Protocols.UnencryptedHTTP2() {
154 ts.EnableHTTP2 = false
155 ts.Start()
156 } else {
157 t.Fatalf("Protocols contains neither HTTP2 nor UnencryptedHTTP2")
158 }
159
160 t.Cleanup(func() {
161 ts.CloseClientConnections()
162 ts.Close()
163 })
164
165 return ts
166 }
167
168 type serverTesterOpt string
169
170 var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames")
171
172 var optQuiet = func(server *http.Server) {
173 server.ErrorLog = log.New(io.Discard, "", 0)
174 }
175
176 func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...any) *serverTester {
177 t.Helper()
178
179 h1server := &http.Server{}
180 var tlsState *tls.ConnectionState
181 for _, opt := range opts {
182 switch v := opt.(type) {
183 case func(*http.Server):
184 v(h1server)
185 case func(*http.HTTP2Config):
186 if h1server.HTTP2 == nil {
187 h1server.HTTP2 = &http.HTTP2Config{}
188 }
189 v(h1server.HTTP2)
190 case func(*tls.ConnectionState):
191 if tlsState == nil {
192 tlsState = &tls.ConnectionState{
193 Version: tls.VersionTLS13,
194 ServerName: "go.dev",
195 CipherSuite: tls.TLS_AES_128_GCM_SHA256,
196 }
197 }
198 v(tlsState)
199 default:
200 t.Fatalf("unknown newServerTester option type %T", v)
201 }
202 }
203
204 tlsConfig := h1server.TLSConfig
205 if tlsConfig == nil {
206 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
207 if err != nil {
208 t.Fatal(err)
209 }
210 tlsConfig = &tls.Config{
211 Certificates: []tls.Certificate{cert},
212 InsecureSkipVerify: true,
213 NextProtos: []string{"h2"},
214 }
215 h1server.TLSConfig = tlsConfig
216 }
217
218 var cli, srv net.Conn
219
220 cliPipe, srvPipe := synctestNetPipe()
221
222 if h1server.Protocols != nil && h1server.Protocols.UnencryptedHTTP2() {
223 cli, srv = cliPipe, srvPipe
224 } else {
225 cli = tls.Client(cliPipe, &tls.Config{
226 InsecureSkipVerify: true,
227 NextProtos: []string{"h2"},
228 })
229 srv = tls.Server(srvPipe, tlsConfig)
230 }
231
232 st := &serverTester{
233 t: t,
234 cc: cli,
235 h1server: h1server,
236 }
237 st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
238 if h1server.ErrorLog == nil {
239 h1server.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
240 }
241
242 if handler == nil {
243 handler = serverTesterHandler{st}.ServeHTTP
244 }
245 h1server.Handler = handler
246
247 t.Cleanup(func() {
248 st.Close()
249 time.Sleep(GoAwayTimeout)
250 })
251
252 connc := make(chan *ServerConn)
253 h1server.ConnContext = func(ctx context.Context, conn net.Conn) context.Context {
254 ctx = context.WithValue(ctx, NewConnContextKey, func(sc *ServerConn) {
255 connc <- sc
256 })
257 if tlsState != nil {
258 ctx = context.WithValue(ctx, ConnectionStateContextKey, func() tls.ConnectionState {
259 return *tlsState
260 })
261 }
262 return ctx
263 }
264 go func() {
265 li := newOneConnListener(srv)
266 t.Cleanup(func() {
267 li.Close()
268 })
269 h1server.Serve(li)
270 }()
271 if cliTLS, ok := cli.(*tls.Conn); ok {
272 if err := cliTLS.Handshake(); err != nil {
273 t.Fatalf("client TLS handshake: %v", err)
274 }
275 cliTLS.SetReadDeadline(time.Now())
276 } else {
277
278
279 st.writePreface()
280 st.wrotePreface = true
281 cliPipe.SetReadDeadline(time.Now())
282 }
283 st.sc = <-connc
284
285 st.fr = NewFramer(st.cc, st.cc)
286 st.testConnFramer = testConnFramer{
287 t: t,
288 fr: NewFramer(cli, cli),
289 dec: hpack.NewDecoder(InitialHeaderTableSize, nil),
290 }
291 synctest.Wait()
292 return st
293 }
294
295 type netConnWithConnectionState struct {
296 net.Conn
297 state tls.ConnectionState
298 }
299
300 func (c *netConnWithConnectionState) ConnectionState() tls.ConnectionState {
301 return c.state
302 }
303
304 func (c *netConnWithConnectionState) HandshakeContext() tls.ConnectionState {
305 return c.state
306 }
307
308 type serverTesterHandler struct {
309 st *serverTester
310 }
311
312 func (h serverTesterHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
313 call := &serverHandlerCall{
314 w: w,
315 req: req,
316 ch: make(chan func()),
317 }
318 h.st.t.Cleanup(call.exit)
319 h.st.callsMu.Lock()
320 h.st.calls = append(h.st.calls, call)
321 h.st.callsMu.Unlock()
322 for f := range call.ch {
323 f()
324 }
325 }
326
327
328 type serverHandlerCall struct {
329 w http.ResponseWriter
330 req *http.Request
331 closeOnce sync.Once
332 ch chan func()
333 }
334
335
336 func (call *serverHandlerCall) do(f func(http.ResponseWriter, *http.Request)) {
337 donec := make(chan struct{})
338 call.ch <- func() {
339 defer close(donec)
340 f(call.w, call.req)
341 }
342 <-donec
343 }
344
345
346 func (call *serverHandlerCall) exit() {
347 call.closeOnce.Do(func() {
348 close(call.ch)
349 })
350 }
351
352
353 func (st *serverTester) sync() {
354 synctest.Wait()
355 }
356
357
358 func (st *serverTester) advance(d time.Duration) {
359 time.Sleep(d)
360 synctest.Wait()
361 }
362
363 func (st *serverTester) authority() string {
364 return "dummy.tld"
365 }
366
367 func (st *serverTester) addLogFilter(phrase string) {
368 st.logFilter = append(st.logFilter, phrase)
369 }
370
371 func (st *serverTester) nextHandlerCall() *serverHandlerCall {
372 st.t.Helper()
373 synctest.Wait()
374 st.callsMu.Lock()
375 defer st.callsMu.Unlock()
376 if len(st.calls) == 0 {
377 st.t.Fatal("expected server handler call, got none")
378 }
379 call := st.calls[0]
380 st.calls = st.calls[1:]
381 return call
382 }
383
384 func (st *serverTester) streamExists(id uint32) bool {
385 return st.sc.TestStreamExists(id)
386 }
387
388 func (st *serverTester) streamState(id uint32) StreamState {
389 return st.sc.TestStreamState(id)
390 }
391
392 func (st *serverTester) Close() {
393 if st.t.Failed() {
394 st.frameReadLogMu.Lock()
395 if st.frameReadLogBuf.Len() > 0 {
396 st.t.Logf("Framer read log:\n%s", st.frameReadLogBuf.String())
397 }
398 st.frameReadLogMu.Unlock()
399
400 st.frameWriteLogMu.Lock()
401 if st.frameWriteLogBuf.Len() > 0 {
402 st.t.Logf("Framer write log:\n%s", st.frameWriteLogBuf.String())
403 }
404 st.frameWriteLogMu.Unlock()
405
406
407
408
409
410 if st.cc != nil {
411 st.cc.Close()
412 }
413 }
414 if st.cc != nil {
415 st.cc.Close()
416 }
417 log.SetOutput(os.Stderr)
418 }
419
420
421
422 func (st *serverTester) greet() {
423 st.t.Helper()
424 st.greetAndCheckSettings(func(Setting) error { return nil })
425 }
426
427 func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error) {
428 st.t.Helper()
429 st.writePreface()
430 st.writeSettings()
431 st.sync()
432 readFrame[*SettingsFrame](st.t, st).ForeachSetting(checkSetting)
433 st.writeSettingsAck()
434
435
436 var gotSettingsAck bool
437 var gotWindowUpdate bool
438
439 for range 2 {
440 f := st.readFrame()
441 if f == nil {
442 st.t.Fatal("wanted a settings ACK and window update, got none")
443 }
444 switch f := f.(type) {
445 case *SettingsFrame:
446 if !f.Header().Flags.Has(FlagSettingsAck) {
447 st.t.Fatal("Settings Frame didn't have ACK set")
448 }
449 gotSettingsAck = true
450
451 case *WindowUpdateFrame:
452 if f.FrameHeader.StreamID != 0 {
453 st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
454 }
455 gotWindowUpdate = true
456
457 default:
458 st.t.Fatalf("Wanting a settings ACK or window update, received a %T", f)
459 }
460 }
461
462 if !gotSettingsAck {
463 st.t.Fatalf("Didn't get a settings ACK")
464 }
465 if !gotWindowUpdate {
466 st.t.Fatalf("Didn't get a window update")
467 }
468 }
469
470 func (st *serverTester) writePreface() {
471 if st.wrotePreface {
472 return
473 }
474 n, err := st.cc.Write([]byte(ClientPreface))
475 if err != nil {
476 st.t.Fatalf("Error writing client preface: %v", err)
477 }
478 if n != len(ClientPreface) {
479 st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(ClientPreface))
480 }
481 }
482
483 func (st *serverTester) encodeHeaderField(k, v string) {
484 err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
485 if err != nil {
486 st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
487 }
488 }
489
490
491
492 func (st *serverTester) encodeHeaderRaw(headers ...string) []byte {
493 if len(headers)%2 == 1 {
494 panic("odd number of kv args")
495 }
496 st.headerBuf.Reset()
497 for len(headers) > 0 {
498 k, v := headers[0], headers[1]
499 st.encodeHeaderField(k, v)
500 headers = headers[2:]
501 }
502 return st.headerBuf.Bytes()
503 }
504
505
506
507
508
509
510 func (st *serverTester) encodeHeader(headers ...string) []byte {
511 if len(headers)%2 == 1 {
512 panic("odd number of kv args")
513 }
514
515 st.headerBuf.Reset()
516 defaultAuthority := st.authority()
517
518 if len(headers) == 0 {
519
520
521 st.encodeHeaderField(":method", "GET")
522 st.encodeHeaderField(":scheme", "https")
523 st.encodeHeaderField(":authority", defaultAuthority)
524 st.encodeHeaderField(":path", "/")
525 return st.headerBuf.Bytes()
526 }
527
528 if len(headers) == 2 && headers[0] == ":method" {
529
530 st.encodeHeaderField(":method", headers[1])
531 st.encodeHeaderField(":scheme", "https")
532 st.encodeHeaderField(":authority", defaultAuthority)
533 st.encodeHeaderField(":path", "/")
534 return st.headerBuf.Bytes()
535 }
536
537 pseudoCount := map[string]int{}
538 keys := []string{":method", ":scheme", ":authority", ":path"}
539 vals := map[string][]string{
540 ":method": {"GET"},
541 ":scheme": {"https"},
542 ":authority": {defaultAuthority},
543 ":path": {"/"},
544 }
545 for len(headers) > 0 {
546 k, v := headers[0], headers[1]
547 headers = headers[2:]
548 if _, ok := vals[k]; !ok {
549 keys = append(keys, k)
550 }
551 if strings.HasPrefix(k, ":") {
552 pseudoCount[k]++
553 if pseudoCount[k] == 1 {
554 vals[k] = []string{v}
555 } else {
556
557 vals[k] = append(vals[k], v)
558 }
559 } else {
560 vals[k] = append(vals[k], v)
561 }
562 }
563 for _, k := range keys {
564 for _, v := range vals[k] {
565 st.encodeHeaderField(k, v)
566 }
567 }
568 return st.headerBuf.Bytes()
569 }
570
571
572 func (st *serverTester) bodylessReq1(headers ...string) {
573 st.writeHeaders(HeadersFrameParam{
574 StreamID: 1,
575 BlockFragment: st.encodeHeader(headers...),
576 EndStream: true,
577 EndHeaders: true,
578 })
579 }
580
581 func (st *serverTester) wantConnFlowControlConsumed(consumed int32) {
582 if got, want := st.sc.TestFlowControlConsumed(), consumed; got != want {
583 st.t.Errorf("connection flow control consumed: %v, want %v", got, want)
584 }
585 }
586
587 func TestServer(t *testing.T) { synctestTest(t, testServer) }
588 func testServer(t testing.TB) {
589 gotReq := make(chan bool, 1)
590 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
591 w.Header().Set("Foo", "Bar")
592 gotReq <- true
593 })
594 defer st.Close()
595
596 st.greet()
597 st.writeHeaders(HeadersFrameParam{
598 StreamID: 1,
599 BlockFragment: st.encodeHeader(),
600 EndStream: true,
601 EndHeaders: true,
602 })
603
604 <-gotReq
605 }
606
607 func TestServer_Request_Get(t *testing.T) { synctestTest(t, testServer_Request_Get) }
608 func testServer_Request_Get(t testing.TB) {
609 testServerRequest(t, func(st *serverTester) {
610 st.writeHeaders(HeadersFrameParam{
611 StreamID: 1,
612 BlockFragment: st.encodeHeader("foo-bar", "some-value"),
613 EndStream: true,
614 EndHeaders: true,
615 })
616 }, func(r *http.Request) {
617 if r.Method != "GET" {
618 t.Errorf("Method = %q; want GET", r.Method)
619 }
620 if r.URL.Path != "/" {
621 t.Errorf("URL.Path = %q; want /", r.URL.Path)
622 }
623 if r.ContentLength != 0 {
624 t.Errorf("ContentLength = %v; want 0", r.ContentLength)
625 }
626 if r.Close {
627 t.Error("Close = true; want false")
628 }
629 if !strings.Contains(r.RemoteAddr, ":") {
630 t.Errorf("RemoteAddr = %q; want something with a colon", r.RemoteAddr)
631 }
632 if r.Proto != "HTTP/2.0" || r.ProtoMajor != 2 || r.ProtoMinor != 0 {
633 t.Errorf("Proto = %q Major=%v,Minor=%v; want HTTP/2.0", r.Proto, r.ProtoMajor, r.ProtoMinor)
634 }
635 wantHeader := http.Header{
636 "Foo-Bar": []string{"some-value"},
637 }
638 if !reflect.DeepEqual(r.Header, wantHeader) {
639 t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
640 }
641 if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
642 t.Errorf("Read = %d, %v; want 0, EOF", n, err)
643 }
644 })
645 }
646
647 func TestServer_Request_Get_PathSlashes(t *testing.T) {
648 synctestTest(t, testServer_Request_Get_PathSlashes)
649 }
650 func testServer_Request_Get_PathSlashes(t testing.TB) {
651 testServerRequest(t, func(st *serverTester) {
652 st.writeHeaders(HeadersFrameParam{
653 StreamID: 1,
654 BlockFragment: st.encodeHeader(":path", "/%2f/"),
655 EndStream: true,
656 EndHeaders: true,
657 })
658 }, func(r *http.Request) {
659 if r.RequestURI != "/%2f/" {
660 t.Errorf("RequestURI = %q; want /%%2f/", r.RequestURI)
661 }
662 if r.URL.Path != "///" {
663 t.Errorf("URL.Path = %q; want ///", r.URL.Path)
664 }
665 })
666 }
667
668
669
670
671
672 func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) {
673 synctestTest(t, testServer_Request_Post_NoContentLength_EndStream)
674 }
675 func testServer_Request_Post_NoContentLength_EndStream(t testing.TB) {
676 testServerRequest(t, func(st *serverTester) {
677 st.writeHeaders(HeadersFrameParam{
678 StreamID: 1,
679 BlockFragment: st.encodeHeader(":method", "POST"),
680 EndStream: true,
681 EndHeaders: true,
682 })
683 }, func(r *http.Request) {
684 if r.Method != "POST" {
685 t.Errorf("Method = %q; want POST", r.Method)
686 }
687 if r.ContentLength != 0 {
688 t.Errorf("ContentLength = %v; want 0", r.ContentLength)
689 }
690 if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
691 t.Errorf("Read = %d, %v; want 0, EOF", n, err)
692 }
693 })
694 }
695
696 func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) {
697 synctestTest(t, testServer_Request_Post_Body_ImmediateEOF)
698 }
699 func testServer_Request_Post_Body_ImmediateEOF(t testing.TB) {
700 testBodyContents(t, -1, "", func(st *serverTester) {
701 st.writeHeaders(HeadersFrameParam{
702 StreamID: 1,
703 BlockFragment: st.encodeHeader(":method", "POST"),
704 EndStream: false,
705 EndHeaders: true,
706 })
707 st.writeData(1, true, nil)
708 })
709 }
710
711 func TestServer_Request_Post_Body_OneData(t *testing.T) {
712 synctestTest(t, testServer_Request_Post_Body_OneData)
713 }
714 func testServer_Request_Post_Body_OneData(t testing.TB) {
715 const content = "Some content"
716 testBodyContents(t, -1, content, func(st *serverTester) {
717 st.writeHeaders(HeadersFrameParam{
718 StreamID: 1,
719 BlockFragment: st.encodeHeader(":method", "POST"),
720 EndStream: false,
721 EndHeaders: true,
722 })
723 st.writeData(1, true, []byte(content))
724 })
725 }
726
727 func TestServer_Request_Post_Body_TwoData(t *testing.T) {
728 synctestTest(t, testServer_Request_Post_Body_TwoData)
729 }
730 func testServer_Request_Post_Body_TwoData(t testing.TB) {
731 const content = "Some content"
732 testBodyContents(t, -1, content, func(st *serverTester) {
733 st.writeHeaders(HeadersFrameParam{
734 StreamID: 1,
735 BlockFragment: st.encodeHeader(":method", "POST"),
736 EndStream: false,
737 EndHeaders: true,
738 })
739 st.writeData(1, false, []byte(content[:5]))
740 st.writeData(1, true, []byte(content[5:]))
741 })
742 }
743
744 func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) {
745 synctestTest(t, testServer_Request_Post_Body_ContentLength_Correct)
746 }
747 func testServer_Request_Post_Body_ContentLength_Correct(t testing.TB) {
748 const content = "Some content"
749 testBodyContents(t, int64(len(content)), content, func(st *serverTester) {
750 st.writeHeaders(HeadersFrameParam{
751 StreamID: 1,
752 BlockFragment: st.encodeHeader(
753 ":method", "POST",
754 "content-length", strconv.Itoa(len(content)),
755 ),
756 EndStream: false,
757 EndHeaders: true,
758 })
759 st.writeData(1, true, []byte(content))
760 })
761 }
762
763 func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) {
764 synctestTest(t, testServer_Request_Post_Body_ContentLength_TooLarge)
765 }
766 func testServer_Request_Post_Body_ContentLength_TooLarge(t testing.TB) {
767 testBodyContentsFail(t, 3, "request declared a Content-Length of 3 but only wrote 2 bytes",
768 func(st *serverTester) {
769 st.writeHeaders(HeadersFrameParam{
770 StreamID: 1,
771 BlockFragment: st.encodeHeader(
772 ":method", "POST",
773 "content-length", "3",
774 ),
775 EndStream: false,
776 EndHeaders: true,
777 })
778 st.writeData(1, true, []byte("12"))
779 })
780 }
781
782 func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) {
783 synctestTest(t, testServer_Request_Post_Body_ContentLength_TooSmall)
784 }
785 func testServer_Request_Post_Body_ContentLength_TooSmall(t testing.TB) {
786 testBodyContentsFail(t, 4, "sender tried to send more than declared Content-Length of 4 bytes",
787 func(st *serverTester) {
788 st.writeHeaders(HeadersFrameParam{
789 StreamID: 1,
790 BlockFragment: st.encodeHeader(
791 ":method", "POST",
792 "content-length", "4",
793 ),
794 EndStream: false,
795 EndHeaders: true,
796 })
797 st.writeData(1, true, []byte("12345"))
798
799
800 st.wantRSTStream(1, ErrCodeProtocol)
801 st.wantConnFlowControlConsumed(0)
802 })
803 }
804
805 func testBodyContents(t testing.TB, wantContentLength int64, wantBody string, write func(st *serverTester)) {
806 testServerRequest(t, write, func(r *http.Request) {
807 if r.Method != "POST" {
808 t.Errorf("Method = %q; want POST", r.Method)
809 }
810 if r.ContentLength != wantContentLength {
811 t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
812 }
813 all, err := io.ReadAll(r.Body)
814 if err != nil {
815 t.Fatal(err)
816 }
817 if string(all) != wantBody {
818 t.Errorf("Read = %q; want %q", all, wantBody)
819 }
820 if err := r.Body.Close(); err != nil {
821 t.Fatalf("Close: %v", err)
822 }
823 })
824 }
825
826 func testBodyContentsFail(t testing.TB, wantContentLength int64, wantReadError string, write func(st *serverTester)) {
827 testServerRequest(t, write, func(r *http.Request) {
828 if r.Method != "POST" {
829 t.Errorf("Method = %q; want POST", r.Method)
830 }
831 if r.ContentLength != wantContentLength {
832 t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
833 }
834 all, err := io.ReadAll(r.Body)
835 if err == nil {
836 t.Fatalf("expected an error (%q) reading from the body. Successfully read %q instead.",
837 wantReadError, all)
838 }
839 if !strings.Contains(err.Error(), wantReadError) {
840 t.Fatalf("Body.Read = %v; want substring %q", err, wantReadError)
841 }
842 if err := r.Body.Close(); err != nil {
843 t.Fatalf("Close: %v", err)
844 }
845 })
846 }
847
848
849 func TestServer_Request_Get_Host(t *testing.T) { synctestTest(t, testServer_Request_Get_Host) }
850 func testServer_Request_Get_Host(t testing.TB) {
851 const host = "example.com"
852 testServerRequest(t, func(st *serverTester) {
853 st.writeHeaders(HeadersFrameParam{
854 StreamID: 1,
855 BlockFragment: st.encodeHeader(":authority", "", "host", host),
856 EndStream: true,
857 EndHeaders: true,
858 })
859 }, func(r *http.Request) {
860 if r.Host != host {
861 t.Errorf("Host = %q; want %q", r.Host, host)
862 }
863 })
864 }
865
866
867 func TestServer_Request_Get_Authority(t *testing.T) {
868 synctestTest(t, testServer_Request_Get_Authority)
869 }
870 func testServer_Request_Get_Authority(t testing.TB) {
871 const host = "example.com"
872 testServerRequest(t, func(st *serverTester) {
873 st.writeHeaders(HeadersFrameParam{
874 StreamID: 1,
875 BlockFragment: st.encodeHeader(":authority", host),
876 EndStream: true,
877 EndHeaders: true,
878 })
879 }, func(r *http.Request) {
880 if r.Host != host {
881 t.Errorf("Host = %q; want %q", r.Host, host)
882 }
883 })
884 }
885
886 func TestServer_Request_WithContinuation(t *testing.T) {
887 synctestTest(t, testServer_Request_WithContinuation)
888 }
889 func testServer_Request_WithContinuation(t testing.TB) {
890 wantHeader := http.Header{
891 "Foo-One": []string{"value-one"},
892 "Foo-Two": []string{"value-two"},
893 "Foo-Three": []string{"value-three"},
894 }
895 testServerRequest(t, func(st *serverTester) {
896 fullHeaders := st.encodeHeader(
897 "foo-one", "value-one",
898 "foo-two", "value-two",
899 "foo-three", "value-three",
900 )
901 remain := fullHeaders
902 chunks := 0
903 for len(remain) > 0 {
904 const maxChunkSize = 5
905 chunk := remain
906 if len(chunk) > maxChunkSize {
907 chunk = chunk[:maxChunkSize]
908 }
909 remain = remain[len(chunk):]
910
911 if chunks == 0 {
912 st.writeHeaders(HeadersFrameParam{
913 StreamID: 1,
914 BlockFragment: chunk,
915 EndStream: true,
916 EndHeaders: false,
917 })
918 } else {
919 err := st.fr.WriteContinuation(1, len(remain) == 0, chunk)
920 if err != nil {
921 t.Fatal(err)
922 }
923 }
924 chunks++
925 }
926 if chunks < 2 {
927 t.Fatal("too few chunks")
928 }
929 }, func(r *http.Request) {
930 if !reflect.DeepEqual(r.Header, wantHeader) {
931 t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
932 }
933 })
934 }
935
936
937 func TestServer_Request_CookieConcat(t *testing.T) { synctestTest(t, testServer_Request_CookieConcat) }
938 func testServer_Request_CookieConcat(t testing.TB) {
939 const host = "example.com"
940 testServerRequest(t, func(st *serverTester) {
941 st.bodylessReq1(
942 ":authority", host,
943 "cookie", "a=b",
944 "cookie", "c=d",
945 "cookie", "e=f",
946 )
947 }, func(r *http.Request) {
948 const want = "a=b; c=d; e=f"
949 if got := r.Header.Get("Cookie"); got != want {
950 t.Errorf("Cookie = %q; want %q", got, want)
951 }
952 })
953 }
954
955 func TestServer_Request_Reject_CapitalHeader(t *testing.T) {
956 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("UPPER", "v") })
957 }
958
959 func TestServer_Request_Reject_HeaderFieldNameColon(t *testing.T) {
960 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has:colon", "v") })
961 }
962
963 func TestServer_Request_Reject_HeaderFieldNameNULL(t *testing.T) {
964 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has\x00null", "v") })
965 }
966
967 func TestServer_Request_Reject_HeaderFieldNameEmpty(t *testing.T) {
968 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("", "v") })
969 }
970
971 func TestServer_Request_Reject_HeaderFieldValueNewline(t *testing.T) {
972 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\nnewline") })
973 }
974
975 func TestServer_Request_Reject_HeaderFieldValueCR(t *testing.T) {
976 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\rcarriage") })
977 }
978
979 func TestServer_Request_Reject_HeaderFieldValueDEL(t *testing.T) {
980 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\x7fdel") })
981 }
982
983 func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) {
984 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "") })
985 }
986
987 func TestServer_Request_Reject_Pseudo_ExactlyOne(t *testing.T) {
988
989
990 testRejectRequest(t, func(st *serverTester) {
991 st.addLogFilter("duplicate pseudo-header")
992 st.bodylessReq1(":method", "GET", ":method", "POST")
993 })
994 }
995
996 func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) {
997
998
999
1000
1001
1002
1003 testRejectRequest(t, func(st *serverTester) {
1004 st.addLogFilter("pseudo-header after regular header")
1005 var buf bytes.Buffer
1006 enc := hpack.NewEncoder(&buf)
1007 enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
1008 enc.WriteField(hpack.HeaderField{Name: "regular", Value: "foobar"})
1009 enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"})
1010 enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
1011 st.writeHeaders(HeadersFrameParam{
1012 StreamID: 1,
1013 BlockFragment: buf.Bytes(),
1014 EndStream: true,
1015 EndHeaders: true,
1016 })
1017 })
1018 }
1019
1020 func TestServer_Request_Reject_Pseudo_Missing_path(t *testing.T) {
1021 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":path", "") })
1022 }
1023
1024 func TestServer_Request_Reject_Pseudo_Missing_scheme(t *testing.T) {
1025 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "") })
1026 }
1027
1028 func TestServer_Request_Reject_Pseudo_scheme_invalid(t *testing.T) {
1029 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "bogus") })
1030 }
1031
1032 func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
1033 testRejectRequest(t, func(st *serverTester) {
1034 st.addLogFilter(`invalid pseudo-header ":unknown_thing"`)
1035 st.bodylessReq1(":unknown_thing", "")
1036 })
1037 }
1038
1039 func TestServer_Request_Reject_Authority_Userinfo(t *testing.T) {
1040
1041
1042
1043 testRejectRequest(t, func(st *serverTester) {
1044 var buf bytes.Buffer
1045 enc := hpack.NewEncoder(&buf)
1046 enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "userinfo@example.tld"})
1047 enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
1048 enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"})
1049 enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
1050 st.writeHeaders(HeadersFrameParam{
1051 StreamID: 1,
1052 BlockFragment: buf.Bytes(),
1053 EndStream: true,
1054 EndHeaders: true,
1055 })
1056 })
1057 }
1058
1059 func testRejectRequest(t *testing.T, send func(*serverTester)) {
1060 synctestTest(t, func(t testing.TB) {
1061 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1062 t.Error("server request made it to handler; should've been rejected")
1063 })
1064 defer st.Close()
1065
1066 st.greet()
1067 send(st)
1068 st.wantRSTStream(1, ErrCodeProtocol)
1069 })
1070 }
1071
1072 func newServerTesterForError(t testing.TB) *serverTester {
1073 t.Helper()
1074 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1075 t.Error("server request made it to handler; should've been rejected")
1076 }, optQuiet)
1077 st.greet()
1078 return st
1079 }
1080
1081
1082
1083
1084 func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) {
1085 synctestTest(t, testRejectFrameOnIdle_WindowUpdate)
1086 }
1087 func testRejectFrameOnIdle_WindowUpdate(t testing.TB) {
1088 st := newServerTesterForError(t)
1089 st.fr.WriteWindowUpdate(123, 456)
1090 st.wantGoAway(123, ErrCodeProtocol)
1091 }
1092 func TestRejectFrameOnIdle_Data(t *testing.T) { synctestTest(t, testRejectFrameOnIdle_Data) }
1093 func testRejectFrameOnIdle_Data(t testing.TB) {
1094 st := newServerTesterForError(t)
1095 st.fr.WriteData(123, true, nil)
1096 st.wantGoAway(123, ErrCodeProtocol)
1097 }
1098 func TestRejectFrameOnIdle_RSTStream(t *testing.T) { synctestTest(t, testRejectFrameOnIdle_RSTStream) }
1099 func testRejectFrameOnIdle_RSTStream(t testing.TB) {
1100 st := newServerTesterForError(t)
1101 st.fr.WriteRSTStream(123, ErrCodeCancel)
1102 st.wantGoAway(123, ErrCodeProtocol)
1103 }
1104
1105 func TestServer_Request_Connect(t *testing.T) { synctestTest(t, testServer_Request_Connect) }
1106 func testServer_Request_Connect(t testing.TB) {
1107 testServerRequest(t, func(st *serverTester) {
1108 st.writeHeaders(HeadersFrameParam{
1109 StreamID: 1,
1110 BlockFragment: st.encodeHeaderRaw(
1111 ":method", "CONNECT",
1112 ":authority", "example.com:123",
1113 ),
1114 EndStream: true,
1115 EndHeaders: true,
1116 })
1117 }, func(r *http.Request) {
1118 if g, w := r.Method, "CONNECT"; g != w {
1119 t.Errorf("Method = %q; want %q", g, w)
1120 }
1121 if g, w := r.RequestURI, "example.com:123"; g != w {
1122 t.Errorf("RequestURI = %q; want %q", g, w)
1123 }
1124 if g, w := r.URL.Host, "example.com:123"; g != w {
1125 t.Errorf("URL.Host = %q; want %q", g, w)
1126 }
1127 })
1128 }
1129
1130 func TestServer_Request_Connect_InvalidPath(t *testing.T) {
1131 synctestTest(t, testServer_Request_Connect_InvalidPath)
1132 }
1133 func testServer_Request_Connect_InvalidPath(t testing.TB) {
1134 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1135 st.writeHeaders(HeadersFrameParam{
1136 StreamID: 1,
1137 BlockFragment: st.encodeHeaderRaw(
1138 ":method", "CONNECT",
1139 ":authority", "example.com:123",
1140 ":path", "/bogus",
1141 ),
1142 EndStream: true,
1143 EndHeaders: true,
1144 })
1145 })
1146 }
1147
1148 func TestServer_Request_Connect_InvalidScheme(t *testing.T) {
1149 synctestTest(t, testServer_Request_Connect_InvalidScheme)
1150 }
1151 func testServer_Request_Connect_InvalidScheme(t testing.TB) {
1152 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1153 st.writeHeaders(HeadersFrameParam{
1154 StreamID: 1,
1155 BlockFragment: st.encodeHeaderRaw(
1156 ":method", "CONNECT",
1157 ":authority", "example.com:123",
1158 ":scheme", "https",
1159 ),
1160 EndStream: true,
1161 EndHeaders: true,
1162 })
1163 })
1164 }
1165
1166 func TestServer_Ping(t *testing.T) { synctestTest(t, testServer_Ping) }
1167 func testServer_Ping(t testing.TB) {
1168 st := newServerTester(t, nil)
1169 defer st.Close()
1170 st.greet()
1171
1172
1173 ackPingData := [8]byte{1, 2, 4, 8, 16, 32, 64, 128}
1174 if err := st.fr.WritePing(true, ackPingData); err != nil {
1175 t.Fatal(err)
1176 }
1177
1178
1179 pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
1180 if err := st.fr.WritePing(false, pingData); err != nil {
1181 t.Fatal(err)
1182 }
1183
1184 pf := readFrame[*PingFrame](t, st)
1185 if !pf.Flags.Has(FlagPingAck) {
1186 t.Error("response ping doesn't have ACK set")
1187 }
1188 if pf.Data != pingData {
1189 t.Errorf("response ping has data %q; want %q", pf.Data, pingData)
1190 }
1191 }
1192
1193 type filterListener struct {
1194 net.Listener
1195 accept func(conn net.Conn) (net.Conn, error)
1196 }
1197
1198 func (l *filterListener) Accept() (net.Conn, error) {
1199 c, err := l.Listener.Accept()
1200 if err != nil {
1201 return nil, err
1202 }
1203 return l.accept(c)
1204 }
1205
1206 func TestServer_MaxQueuedControlFrames(t *testing.T) {
1207 synctestTest(t, testServer_MaxQueuedControlFrames)
1208 }
1209 func testServer_MaxQueuedControlFrames(t testing.TB) {
1210
1211 DisableGoroutineTracking(t)
1212
1213 st := newServerTester(t, nil)
1214 st.greet()
1215
1216 st.cc.(*tls.Conn).NetConn().(*synctestNetConn).SetReadBufferSize(0)
1217
1218
1219
1220 const extraPings = 2
1221 for range MaxQueuedControlFrames + extraPings {
1222 pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
1223 st.fr.WritePing(false, pingData)
1224 }
1225 synctest.Wait()
1226
1227
1228
1229 st.cc.(*tls.Conn).NetConn().(*synctestNetConn).SetReadBufferSize(math.MaxInt)
1230
1231 st.advance(GoAwayTimeout)
1232
1233 for range 10 {
1234 if st.readFrame() == nil {
1235 break
1236 }
1237 }
1238 st.wantClosed()
1239 }
1240
1241 func TestServer_RejectsLargeFrames(t *testing.T) { synctestTest(t, testServer_RejectsLargeFrames) }
1242 func testServer_RejectsLargeFrames(t testing.TB) {
1243 if runtime.GOOS == "windows" || runtime.GOOS == "plan9" || runtime.GOOS == "zos" {
1244 t.Skip("see golang.org/issue/13434, golang.org/issue/37321")
1245 }
1246 st := newServerTester(t, nil)
1247 defer st.Close()
1248 st.greet()
1249
1250
1251
1252
1253 st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, DefaultMaxReadFrameSize+1))
1254
1255 st.wantGoAway(0, ErrCodeFrameSize)
1256 st.advance(GoAwayTimeout)
1257 st.wantClosed()
1258 }
1259
1260 func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
1261 synctestTest(t, testServer_Handler_Sends_WindowUpdate)
1262 }
1263 func testServer_Handler_Sends_WindowUpdate(t testing.TB) {
1264
1265
1266
1267
1268 const windowSize = 65535 * 2
1269 st := newServerTester(t, nil, func(h2 *http.HTTP2Config) {
1270 h2.MaxReceiveBufferPerConnection = windowSize
1271 h2.MaxReceiveBufferPerStream = windowSize
1272 })
1273 defer st.Close()
1274
1275 st.greet()
1276 st.writeHeaders(HeadersFrameParam{
1277 StreamID: 1,
1278 BlockFragment: st.encodeHeader(":method", "POST"),
1279 EndStream: false,
1280 EndHeaders: true,
1281 })
1282 call := st.nextHandlerCall()
1283
1284
1285
1286
1287 data := make([]byte, windowSize)
1288 st.writeData(1, false, data[:1024])
1289 call.do(readBodyHandler(t, string(data[:1024])))
1290
1291
1292
1293 st.writeData(1, false, data[1024:])
1294 st.wantWindowUpdate(0, 1024)
1295 st.wantWindowUpdate(1, 1024)
1296
1297
1298 call.do(readBodyHandler(t, string(data[1024:])))
1299 st.wantWindowUpdate(0, windowSize-1024)
1300 st.wantWindowUpdate(1, windowSize-1024)
1301 }
1302
1303
1304
1305 func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) {
1306 synctestTest(t, testServer_Handler_Sends_WindowUpdate_Padding)
1307 }
1308 func testServer_Handler_Sends_WindowUpdate_Padding(t testing.TB) {
1309 const windowSize = 65535 * 2
1310 st := newServerTester(t, nil, func(h2 *http.HTTP2Config) {
1311 h2.MaxReceiveBufferPerConnection = windowSize
1312 h2.MaxReceiveBufferPerStream = windowSize
1313 })
1314 defer st.Close()
1315
1316 st.greet()
1317 st.writeHeaders(HeadersFrameParam{
1318 StreamID: 1,
1319 BlockFragment: st.encodeHeader(":method", "POST"),
1320 EndStream: false,
1321 EndHeaders: true,
1322 })
1323 call := st.nextHandlerCall()
1324
1325
1326
1327
1328 data := make([]byte, windowSize/2)
1329 pad := make([]byte, 4)
1330 st.writeDataPadded(1, false, data, pad)
1331
1332
1333
1334
1335 call.do(readBodyHandler(t, string(data)))
1336 st.wantWindowUpdate(0, uint32(len(data)+1+len(pad)))
1337 st.wantWindowUpdate(1, uint32(len(data)+1+len(pad)))
1338 }
1339
1340 func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) {
1341 synctestTest(t, testServer_Send_GoAway_After_Bogus_WindowUpdate)
1342 }
1343 func testServer_Send_GoAway_After_Bogus_WindowUpdate(t testing.TB) {
1344 st := newServerTester(t, nil)
1345 defer st.Close()
1346 st.greet()
1347 if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
1348 t.Fatal(err)
1349 }
1350 st.wantGoAway(0, ErrCodeFlowControl)
1351 }
1352
1353 func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) {
1354 synctestTest(t, testServer_Send_RstStream_After_Bogus_WindowUpdate)
1355 }
1356 func testServer_Send_RstStream_After_Bogus_WindowUpdate(t testing.TB) {
1357 inHandler := make(chan bool)
1358 blockHandler := make(chan bool)
1359 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1360 inHandler <- true
1361 <-blockHandler
1362 })
1363 defer st.Close()
1364 defer close(blockHandler)
1365 st.greet()
1366 st.writeHeaders(HeadersFrameParam{
1367 StreamID: 1,
1368 BlockFragment: st.encodeHeader(":method", "POST"),
1369 EndStream: false,
1370 EndHeaders: true,
1371 })
1372 <-inHandler
1373
1374 if err := st.fr.WriteWindowUpdate(1, 1<<31-1); err != nil {
1375 t.Fatal(err)
1376 }
1377 st.wantRSTStream(1, ErrCodeFlowControl)
1378 }
1379
1380
1381
1382
1383 func testServerPostUnblock(t testing.TB,
1384 handler func(http.ResponseWriter, *http.Request) error,
1385 fn func(*serverTester),
1386 checkErr func(error),
1387 otherHeaders ...string) {
1388 inHandler := make(chan bool)
1389 errc := make(chan error, 1)
1390 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1391 inHandler <- true
1392 errc <- handler(w, r)
1393 })
1394 defer st.Close()
1395 st.greet()
1396 st.writeHeaders(HeadersFrameParam{
1397 StreamID: 1,
1398 BlockFragment: st.encodeHeader(append([]string{":method", "POST"}, otherHeaders...)...),
1399 EndStream: false,
1400 EndHeaders: true,
1401 })
1402 <-inHandler
1403 fn(st)
1404 err := <-errc
1405 if checkErr != nil {
1406 checkErr(err)
1407 }
1408 }
1409
1410 func TestServer_RSTStream_Unblocks_Read(t *testing.T) {
1411 synctestTest(t, testServer_RSTStream_Unblocks_Read)
1412 }
1413 func testServer_RSTStream_Unblocks_Read(t testing.TB) {
1414 testServerPostUnblock(t,
1415 func(w http.ResponseWriter, r *http.Request) (err error) {
1416 _, err = r.Body.Read(make([]byte, 1))
1417 return
1418 },
1419 func(st *serverTester) {
1420 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1421 t.Fatal(err)
1422 }
1423 },
1424 func(err error) {
1425 want := StreamError{StreamID: 0x1, Code: 0x8}
1426 if !reflect.DeepEqual(err, want) {
1427 t.Errorf("Read error = %v; want %v", err, want)
1428 }
1429 },
1430 )
1431 }
1432
1433 func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1434
1435
1436 n := 50
1437 if testing.Short() {
1438 n = 5
1439 }
1440 for i := 0; i < n; i++ {
1441 synctestTest(t, testServer_RSTStream_Unblocks_Header_Write)
1442 }
1443 }
1444
1445 func testServer_RSTStream_Unblocks_Header_Write(t testing.TB) {
1446 inHandler := make(chan bool, 1)
1447 unblockHandler := make(chan bool, 1)
1448 headerWritten := make(chan bool, 1)
1449 wroteRST := make(chan bool, 1)
1450
1451 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1452 inHandler <- true
1453 <-wroteRST
1454 w.Header().Set("foo", "bar")
1455 w.WriteHeader(200)
1456 w.(http.Flusher).Flush()
1457 headerWritten <- true
1458 <-unblockHandler
1459 })
1460 defer st.Close()
1461
1462 st.greet()
1463 st.writeHeaders(HeadersFrameParam{
1464 StreamID: 1,
1465 BlockFragment: st.encodeHeader(":method", "POST"),
1466 EndStream: false,
1467 EndHeaders: true,
1468 })
1469 <-inHandler
1470 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1471 t.Fatal(err)
1472 }
1473 wroteRST <- true
1474 synctest.Wait()
1475 <-headerWritten
1476 unblockHandler <- true
1477 }
1478
1479 func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
1480 synctestTest(t, testServer_DeadConn_Unblocks_Read)
1481 }
1482 func testServer_DeadConn_Unblocks_Read(t testing.TB) {
1483 testServerPostUnblock(t,
1484 func(w http.ResponseWriter, r *http.Request) (err error) {
1485 _, err = r.Body.Read(make([]byte, 1))
1486 return
1487 },
1488 func(st *serverTester) { st.cc.Close() },
1489 func(err error) {
1490 if err == nil {
1491 t.Error("unexpected nil error from Request.Body.Read")
1492 }
1493 },
1494 )
1495 }
1496
1497 var blockUntilClosed = func(w http.ResponseWriter, r *http.Request) error {
1498 <-w.(http.CloseNotifier).CloseNotify()
1499 return nil
1500 }
1501
1502 func TestServer_CloseNotify_After_RSTStream(t *testing.T) {
1503 synctestTest(t, testServer_CloseNotify_After_RSTStream)
1504 }
1505 func testServer_CloseNotify_After_RSTStream(t testing.TB) {
1506 testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1507 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1508 t.Fatal(err)
1509 }
1510 }, nil)
1511 }
1512
1513 func TestServer_CloseNotify_After_ConnClose(t *testing.T) {
1514 synctestTest(t, testServer_CloseNotify_After_ConnClose)
1515 }
1516 func testServer_CloseNotify_After_ConnClose(t testing.TB) {
1517 testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { st.cc.Close() }, nil)
1518 }
1519
1520
1521
1522
1523 func TestServer_CloseNotify_After_StreamError(t *testing.T) {
1524 synctestTest(t, testServer_CloseNotify_After_StreamError)
1525 }
1526 func testServer_CloseNotify_After_StreamError(t testing.TB) {
1527 testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1528
1529 st.writeData(1, true, []byte("1234"))
1530 }, nil, "content-length", "3")
1531 }
1532
1533 func TestServer_StateTransitions(t *testing.T) { synctestTest(t, testServer_StateTransitions) }
1534 func testServer_StateTransitions(t testing.TB) {
1535 var st *serverTester
1536 inHandler := make(chan bool)
1537 writeData := make(chan bool)
1538 leaveHandler := make(chan bool)
1539 st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1540 inHandler <- true
1541 if !st.streamExists(1) {
1542 t.Errorf("stream 1 does not exist in handler")
1543 }
1544 if got, want := st.streamState(1), StateOpen; got != want {
1545 t.Errorf("in handler, state is %v; want %v", got, want)
1546 }
1547 writeData <- true
1548 if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
1549 t.Errorf("body read = %d, %v; want 0, EOF", n, err)
1550 }
1551 if got, want := st.streamState(1), StateHalfClosedRemote; got != want {
1552 t.Errorf("in handler, state is %v; want %v", got, want)
1553 }
1554
1555 <-leaveHandler
1556 })
1557 st.greet()
1558 if st.streamExists(1) {
1559 t.Fatal("stream 1 should be empty")
1560 }
1561 if got := st.streamState(1); got != StateIdle {
1562 t.Fatalf("stream 1 should be idle; got %v", got)
1563 }
1564
1565 st.writeHeaders(HeadersFrameParam{
1566 StreamID: 1,
1567 BlockFragment: st.encodeHeader(":method", "POST"),
1568 EndStream: false,
1569 EndHeaders: true,
1570 })
1571 <-inHandler
1572 <-writeData
1573 st.writeData(1, true, nil)
1574
1575 leaveHandler <- true
1576 st.wantHeaders(wantHeader{
1577 streamID: 1,
1578 endStream: true,
1579 })
1580
1581 if got, want := st.streamState(1), StateClosed; got != want {
1582 t.Errorf("at end, state is %v; want %v", got, want)
1583 }
1584 if st.streamExists(1) {
1585 t.Fatal("at end, stream 1 should be gone")
1586 }
1587 }
1588
1589
1590 func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
1591 synctestTest(t, testServer_Rejects_HeadersNoEnd_Then_Headers)
1592 }
1593 func testServer_Rejects_HeadersNoEnd_Then_Headers(t testing.TB) {
1594 st := newServerTesterForError(t)
1595 st.writeHeaders(HeadersFrameParam{
1596 StreamID: 1,
1597 BlockFragment: st.encodeHeader(),
1598 EndStream: true,
1599 EndHeaders: false,
1600 })
1601 st.writeHeaders(HeadersFrameParam{
1602 StreamID: 3,
1603 BlockFragment: st.encodeHeader(),
1604 EndStream: true,
1605 EndHeaders: true,
1606 })
1607 st.wantGoAway(0, ErrCodeProtocol)
1608 }
1609
1610
1611 func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
1612 synctestTest(t, testServer_Rejects_HeadersNoEnd_Then_Ping)
1613 }
1614 func testServer_Rejects_HeadersNoEnd_Then_Ping(t testing.TB) {
1615 st := newServerTesterForError(t)
1616 st.writeHeaders(HeadersFrameParam{
1617 StreamID: 1,
1618 BlockFragment: st.encodeHeader(),
1619 EndStream: true,
1620 EndHeaders: false,
1621 })
1622 if err := st.fr.WritePing(false, [8]byte{}); err != nil {
1623 t.Fatal(err)
1624 }
1625 st.wantGoAway(0, ErrCodeProtocol)
1626 }
1627
1628
1629 func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
1630 synctestTest(t, testServer_Rejects_HeadersEnd_Then_Continuation)
1631 }
1632 func testServer_Rejects_HeadersEnd_Then_Continuation(t testing.TB) {
1633 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optQuiet)
1634 st.greet()
1635 st.writeHeaders(HeadersFrameParam{
1636 StreamID: 1,
1637 BlockFragment: st.encodeHeader(),
1638 EndStream: true,
1639 EndHeaders: true,
1640 })
1641 st.wantHeaders(wantHeader{
1642 streamID: 1,
1643 endStream: true,
1644 })
1645 if err := st.fr.WriteContinuation(1, true, EncodeHeaderRaw(t, "foo", "bar")); err != nil {
1646 t.Fatal(err)
1647 }
1648 st.wantGoAway(1, ErrCodeProtocol)
1649 }
1650
1651
1652 func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
1653 synctestTest(t, testServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream)
1654 }
1655 func testServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t testing.TB) {
1656 st := newServerTesterForError(t)
1657 st.writeHeaders(HeadersFrameParam{
1658 StreamID: 1,
1659 BlockFragment: st.encodeHeader(),
1660 EndStream: true,
1661 EndHeaders: false,
1662 })
1663 if err := st.fr.WriteContinuation(3, true, EncodeHeaderRaw(t, "foo", "bar")); err != nil {
1664 t.Fatal(err)
1665 }
1666 st.wantGoAway(0, ErrCodeProtocol)
1667 }
1668
1669
1670 func TestServer_Rejects_Headers0(t *testing.T) { synctestTest(t, testServer_Rejects_Headers0) }
1671 func testServer_Rejects_Headers0(t testing.TB) {
1672 st := newServerTesterForError(t)
1673 st.fr.AllowIllegalWrites = true
1674 st.writeHeaders(HeadersFrameParam{
1675 StreamID: 0,
1676 BlockFragment: st.encodeHeader(),
1677 EndStream: true,
1678 EndHeaders: true,
1679 })
1680 st.wantGoAway(0, ErrCodeProtocol)
1681 }
1682
1683
1684 func TestServer_Rejects_Continuation0(t *testing.T) {
1685 synctestTest(t, testServer_Rejects_Continuation0)
1686 }
1687 func testServer_Rejects_Continuation0(t testing.TB) {
1688 st := newServerTesterForError(t)
1689 st.fr.AllowIllegalWrites = true
1690 if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil {
1691 t.Fatal(err)
1692 }
1693 st.wantGoAway(0, ErrCodeProtocol)
1694 }
1695
1696
1697 func TestServer_Rejects_Priority0(t *testing.T) { synctestTest(t, testServer_Rejects_Priority0) }
1698 func testServer_Rejects_Priority0(t testing.TB) {
1699 st := newServerTesterForError(t)
1700 st.fr.AllowIllegalWrites = true
1701 st.writePriority(0, PriorityParam{StreamDep: 1})
1702 st.wantGoAway(0, ErrCodeProtocol)
1703 }
1704
1705
1706
1707 func TestServer_Rejects_PriorityUpdate0(t *testing.T) {
1708 synctestTest(t, testServer_Rejects_PriorityUpdate0)
1709 }
1710 func testServer_Rejects_PriorityUpdate0(t testing.TB) {
1711 st := newServerTesterForError(t)
1712 st.fr.AllowIllegalWrites = true
1713 st.writePriorityUpdate(0, "")
1714 st.wantGoAway(0, ErrCodeProtocol)
1715 }
1716
1717
1718 func TestServer_Rejects_PriorityUpdateUnparsable(t *testing.T) {
1719 synctestTest(t, testServer_Rejects_PriorityUnparsable)
1720 }
1721 func testServer_Rejects_PriorityUnparsable(t testing.TB) {
1722 st := newServerTester(t, nil)
1723 defer st.Close()
1724 st.greet()
1725 st.writePriorityUpdate(1, "Invalid dictionary: ((((")
1726 st.wantRSTStream(1, ErrCodeProtocol)
1727 }
1728
1729
1730 func TestServer_Rejects_HeadersSelfDependence(t *testing.T) {
1731 synctestTest(t, testServer_Rejects_HeadersSelfDependence)
1732 }
1733 func testServer_Rejects_HeadersSelfDependence(t testing.TB) {
1734 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1735 st.fr.AllowIllegalWrites = true
1736 st.writeHeaders(HeadersFrameParam{
1737 StreamID: 1,
1738 BlockFragment: st.encodeHeader(),
1739 EndStream: true,
1740 EndHeaders: true,
1741 Priority: PriorityParam{StreamDep: 1},
1742 })
1743 })
1744 }
1745
1746
1747 func TestServer_Rejects_PrioritySelfDependence(t *testing.T) {
1748 synctestTest(t, testServer_Rejects_PrioritySelfDependence)
1749 }
1750 func testServer_Rejects_PrioritySelfDependence(t testing.TB) {
1751 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1752 st.fr.AllowIllegalWrites = true
1753 st.writePriority(1, PriorityParam{StreamDep: 1})
1754 })
1755 }
1756
1757 func TestServer_Rejects_PushPromise(t *testing.T) { synctestTest(t, testServer_Rejects_PushPromise) }
1758 func testServer_Rejects_PushPromise(t testing.TB) {
1759 st := newServerTesterForError(t)
1760 pp := PushPromiseParam{
1761 StreamID: 1,
1762 PromiseID: 3,
1763 }
1764 if err := st.fr.WritePushPromise(pp); err != nil {
1765 t.Fatal(err)
1766 }
1767 st.wantGoAway(1, ErrCodeProtocol)
1768 }
1769
1770
1771
1772 func testServerRejectsStream(t testing.TB, code ErrCode, writeReq func(*serverTester)) {
1773 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1774 defer st.Close()
1775 st.greet()
1776 writeReq(st)
1777 st.wantRSTStream(1, code)
1778 }
1779
1780
1781
1782
1783 func testServerRequest(t testing.TB, writeReq func(*serverTester), checkReq func(*http.Request)) {
1784 gotReq := make(chan bool, 1)
1785 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1786 if r.Body == nil {
1787 t.Fatal("nil Body")
1788 }
1789 checkReq(r)
1790 gotReq <- true
1791 })
1792 defer st.Close()
1793
1794 st.greet()
1795 writeReq(st)
1796 <-gotReq
1797 }
1798
1799 func getSlash(st *serverTester) { st.bodylessReq1() }
1800
1801 func TestServer_Response_NoData(t *testing.T) { synctestTest(t, testServer_Response_NoData) }
1802 func testServer_Response_NoData(t testing.TB) {
1803 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1804
1805 return nil
1806 }, func(st *serverTester) {
1807 getSlash(st)
1808 st.wantHeaders(wantHeader{
1809 streamID: 1,
1810 endStream: true,
1811 })
1812 })
1813 }
1814
1815 func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
1816 synctestTest(t, testServer_Response_NoData_Header_FooBar)
1817 }
1818 func testServer_Response_NoData_Header_FooBar(t testing.TB) {
1819 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1820 w.Header().Set("Foo-Bar", "some-value")
1821 return nil
1822 }, func(st *serverTester) {
1823 getSlash(st)
1824 st.wantHeaders(wantHeader{
1825 streamID: 1,
1826 endStream: true,
1827 header: http.Header{
1828 ":status": []string{"200"},
1829 "foo-bar": []string{"some-value"},
1830 "content-length": []string{"0"},
1831 },
1832 })
1833 })
1834 }
1835
1836
1837
1838 func TestServerIgnoresContentLengthSignWhenWritingChunks(t *testing.T) {
1839 synctestTest(t, testServerIgnoresContentLengthSignWhenWritingChunks)
1840 }
1841 func testServerIgnoresContentLengthSignWhenWritingChunks(t testing.TB) {
1842 tests := []struct {
1843 name string
1844 cl string
1845 wantCL string
1846 }{
1847 {
1848 name: "proper content-length",
1849 cl: "3",
1850 wantCL: "3",
1851 },
1852 {
1853 name: "ignore cl with plus sign",
1854 cl: "+3",
1855 wantCL: "0",
1856 },
1857 {
1858 name: "ignore cl with minus sign",
1859 cl: "-3",
1860 wantCL: "0",
1861 },
1862 {
1863 name: "max int64, for safe uint64->int64 conversion",
1864 cl: "9223372036854775807",
1865 wantCL: "9223372036854775807",
1866 },
1867 {
1868 name: "overflows int64, so ignored",
1869 cl: "9223372036854775808",
1870 wantCL: "0",
1871 },
1872 }
1873
1874 for _, tt := range tests {
1875 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1876 w.Header().Set("content-length", tt.cl)
1877 return nil
1878 }, func(st *serverTester) {
1879 getSlash(st)
1880 st.wantHeaders(wantHeader{
1881 streamID: 1,
1882 endStream: true,
1883 header: http.Header{
1884 ":status": []string{"200"},
1885 "content-length": []string{tt.wantCL},
1886 },
1887 })
1888 })
1889 }
1890 }
1891
1892
1893
1894 func TestServerRejectsContentLengthWithSignNewRequests(t *testing.T) {
1895 tests := []struct {
1896 name string
1897 cl string
1898 wantCL int64
1899 }{
1900 {
1901 name: "proper content-length",
1902 cl: "3",
1903 wantCL: 3,
1904 },
1905 {
1906 name: "ignore cl with plus sign",
1907 cl: "+3",
1908 wantCL: 0,
1909 },
1910 {
1911 name: "ignore cl with minus sign",
1912 cl: "-3",
1913 wantCL: 0,
1914 },
1915 {
1916 name: "max int64, for safe uint64->int64 conversion",
1917 cl: "9223372036854775807",
1918 wantCL: 9223372036854775807,
1919 },
1920 {
1921 name: "overflows int64, so ignored",
1922 cl: "9223372036854775808",
1923 wantCL: 0,
1924 },
1925 }
1926
1927 for _, tt := range tests {
1928 synctestSubtest(t, tt.name, func(t testing.TB) {
1929 writeReq := func(st *serverTester) {
1930 st.writeHeaders(HeadersFrameParam{
1931 StreamID: 1,
1932 BlockFragment: st.encodeHeader("content-length", tt.cl),
1933 EndStream: false,
1934 EndHeaders: true,
1935 })
1936 st.writeData(1, false, []byte(""))
1937 }
1938 checkReq := func(r *http.Request) {
1939 if r.ContentLength != tt.wantCL {
1940 t.Fatalf("Got: %d\nWant: %d", r.ContentLength, tt.wantCL)
1941 }
1942 }
1943 testServerRequest(t, writeReq, checkReq)
1944 })
1945 }
1946 }
1947
1948 func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) {
1949 synctestTest(t, testServer_Response_Data_Sniff_DoesntOverride)
1950 }
1951 func testServer_Response_Data_Sniff_DoesntOverride(t testing.TB) {
1952 const msg = "<html>this is HTML."
1953 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1954 w.Header().Set("Content-Type", "foo/bar")
1955 io.WriteString(w, msg)
1956 return nil
1957 }, func(st *serverTester) {
1958 getSlash(st)
1959 st.wantHeaders(wantHeader{
1960 streamID: 1,
1961 endStream: false,
1962 header: http.Header{
1963 ":status": []string{"200"},
1964 "content-type": []string{"foo/bar"},
1965 "content-length": []string{strconv.Itoa(len(msg))},
1966 },
1967 })
1968 st.wantData(wantData{
1969 streamID: 1,
1970 endStream: true,
1971 data: []byte(msg),
1972 })
1973 })
1974 }
1975
1976 func TestServer_Response_TransferEncoding_chunked(t *testing.T) {
1977 synctestTest(t, testServer_Response_TransferEncoding_chunked)
1978 }
1979 func testServer_Response_TransferEncoding_chunked(t testing.TB) {
1980 const msg = "hi"
1981 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1982 w.Header().Set("Transfer-Encoding", "chunked")
1983 io.WriteString(w, msg)
1984 return nil
1985 }, func(st *serverTester) {
1986 getSlash(st)
1987 st.wantHeaders(wantHeader{
1988 streamID: 1,
1989 endStream: false,
1990 header: http.Header{
1991 ":status": []string{"200"},
1992 "content-type": []string{"text/plain; charset=utf-8"},
1993 "content-length": []string{strconv.Itoa(len(msg))},
1994 },
1995 })
1996 })
1997 }
1998
1999
2000 func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) {
2001 synctestTest(t, testServer_Response_Data_IgnoreHeaderAfterWrite_After)
2002 }
2003 func testServer_Response_Data_IgnoreHeaderAfterWrite_After(t testing.TB) {
2004 const msg = "<html>this is HTML."
2005 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2006 io.WriteString(w, msg)
2007 w.Header().Set("foo", "should be ignored")
2008 return nil
2009 }, func(st *serverTester) {
2010 getSlash(st)
2011 st.wantHeaders(wantHeader{
2012 streamID: 1,
2013 endStream: false,
2014 header: http.Header{
2015 ":status": []string{"200"},
2016 "content-type": []string{"text/html; charset=utf-8"},
2017 "content-length": []string{strconv.Itoa(len(msg))},
2018 },
2019 })
2020 })
2021 }
2022
2023
2024 func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) {
2025 synctestTest(t, testServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite)
2026 }
2027 func testServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t testing.TB) {
2028 const msg = "<html>this is HTML."
2029 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2030 w.Header().Set("foo", "proper value")
2031 io.WriteString(w, msg)
2032 w.Header().Set("foo", "should be ignored")
2033 return nil
2034 }, func(st *serverTester) {
2035 getSlash(st)
2036 st.wantHeaders(wantHeader{
2037 streamID: 1,
2038 endStream: false,
2039 header: http.Header{
2040 ":status": []string{"200"},
2041 "foo": []string{"proper value"},
2042 "content-type": []string{"text/html; charset=utf-8"},
2043 "content-length": []string{strconv.Itoa(len(msg))},
2044 },
2045 })
2046 })
2047 }
2048
2049 func TestServer_Response_Data_SniffLenType(t *testing.T) {
2050 synctestTest(t, testServer_Response_Data_SniffLenType)
2051 }
2052 func testServer_Response_Data_SniffLenType(t testing.TB) {
2053 const msg = "<html>this is HTML."
2054 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2055 io.WriteString(w, msg)
2056 return nil
2057 }, func(st *serverTester) {
2058 getSlash(st)
2059 st.wantHeaders(wantHeader{
2060 streamID: 1,
2061 endStream: false,
2062 header: http.Header{
2063 ":status": []string{"200"},
2064 "content-type": []string{"text/html; charset=utf-8"},
2065 "content-length": []string{strconv.Itoa(len(msg))},
2066 },
2067 })
2068 st.wantData(wantData{
2069 streamID: 1,
2070 endStream: true,
2071 data: []byte(msg),
2072 })
2073 })
2074 }
2075
2076 func TestServer_Response_Header_Flush_MidWrite(t *testing.T) {
2077 synctestTest(t, testServer_Response_Header_Flush_MidWrite)
2078 }
2079 func testServer_Response_Header_Flush_MidWrite(t testing.TB) {
2080 const msg = "<html>this is HTML"
2081 const msg2 = ", and this is the next chunk"
2082 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2083 io.WriteString(w, msg)
2084 w.(http.Flusher).Flush()
2085 io.WriteString(w, msg2)
2086 return nil
2087 }, func(st *serverTester) {
2088 getSlash(st)
2089 st.wantHeaders(wantHeader{
2090 streamID: 1,
2091 endStream: false,
2092 header: http.Header{
2093 ":status": []string{"200"},
2094 "content-type": []string{"text/html; charset=utf-8"},
2095
2096 },
2097 })
2098 st.wantData(wantData{
2099 streamID: 1,
2100 endStream: false,
2101 data: []byte(msg),
2102 })
2103 st.wantData(wantData{
2104 streamID: 1,
2105 endStream: true,
2106 data: []byte(msg2),
2107 })
2108 })
2109 }
2110
2111 func TestServer_Response_LargeWrite(t *testing.T) { synctestTest(t, testServer_Response_LargeWrite) }
2112 func testServer_Response_LargeWrite(t testing.TB) {
2113 const size = 1 << 20
2114 const maxFrameSize = 16 << 10
2115 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2116 n, err := w.Write(bytes.Repeat([]byte("a"), size))
2117 if err != nil {
2118 return fmt.Errorf("Write error: %v", err)
2119 }
2120 if n != size {
2121 return fmt.Errorf("wrong size %d from Write", n)
2122 }
2123 return nil
2124 }, func(st *serverTester) {
2125 if err := st.fr.WriteSettings(
2126 Setting{SettingInitialWindowSize, 0},
2127 Setting{SettingMaxFrameSize, maxFrameSize},
2128 ); err != nil {
2129 t.Fatal(err)
2130 }
2131 st.wantSettingsAck()
2132
2133 getSlash(st)
2134
2135
2136 if err := st.fr.WriteWindowUpdate(1, size); err != nil {
2137 t.Fatal(err)
2138 }
2139
2140
2141 if err := st.fr.WriteWindowUpdate(0, size); err != nil {
2142 t.Fatal(err)
2143 }
2144 st.wantHeaders(wantHeader{
2145 streamID: 1,
2146 endStream: false,
2147 header: http.Header{
2148 ":status": []string{"200"},
2149 "content-type": []string{"text/plain; charset=utf-8"},
2150
2151 },
2152 })
2153 var bytes, frames int
2154 for {
2155 df := readFrame[*DataFrame](t, st)
2156 bytes += len(df.Data())
2157 frames++
2158 for _, b := range df.Data() {
2159 if b != 'a' {
2160 t.Fatal("non-'a' byte seen in DATA")
2161 }
2162 }
2163 if df.StreamEnded() {
2164 break
2165 }
2166 }
2167 if bytes != size {
2168 t.Errorf("Got %d bytes; want %d", bytes, size)
2169 }
2170 if want := int(size / maxFrameSize); frames < want || frames > want*2 {
2171 t.Errorf("Got %d frames; want %d", frames, size)
2172 }
2173 })
2174 }
2175
2176
2177 func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) {
2178 synctestTest(t, testServer_Response_LargeWrite_FlowControlled)
2179 }
2180 func testServer_Response_LargeWrite_FlowControlled(t testing.TB) {
2181
2182
2183 reads := []int{123, 1, 13, 127}
2184 size := 0
2185 for _, n := range reads {
2186 size += n
2187 }
2188
2189 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2190 w.(http.Flusher).Flush()
2191 n, err := w.Write(bytes.Repeat([]byte("a"), size))
2192 if err != nil {
2193 return fmt.Errorf("Write error: %v", err)
2194 }
2195 if n != size {
2196 return fmt.Errorf("wrong size %d from Write", n)
2197 }
2198 return nil
2199 }, func(st *serverTester) {
2200
2201
2202 if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, uint32(reads[0])}); err != nil {
2203 t.Fatal(err)
2204 }
2205 st.wantSettingsAck()
2206
2207 getSlash(st)
2208
2209 st.wantHeaders(wantHeader{
2210 streamID: 1,
2211 endStream: false,
2212 })
2213
2214 st.wantData(wantData{
2215 streamID: 1,
2216 endStream: false,
2217 size: reads[0],
2218 })
2219
2220 for i, quota := range reads[1:] {
2221 if err := st.fr.WriteWindowUpdate(1, uint32(quota)); err != nil {
2222 t.Fatal(err)
2223 }
2224 st.wantData(wantData{
2225 streamID: 1,
2226 endStream: i == len(reads[1:])-1,
2227 size: quota,
2228 })
2229 }
2230 })
2231 }
2232
2233
2234 func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) {
2235 synctestTest(t, testServer_Response_RST_Unblocks_LargeWrite)
2236 }
2237 func testServer_Response_RST_Unblocks_LargeWrite(t testing.TB) {
2238 const size = 1 << 20
2239 const maxFrameSize = 16 << 10
2240 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2241 w.(http.Flusher).Flush()
2242 _, err := w.Write(bytes.Repeat([]byte("a"), size))
2243 if err == nil {
2244 return errors.New("unexpected nil error from Write in handler")
2245 }
2246 return nil
2247 }, func(st *serverTester) {
2248 if err := st.fr.WriteSettings(
2249 Setting{SettingInitialWindowSize, 0},
2250 Setting{SettingMaxFrameSize, maxFrameSize},
2251 ); err != nil {
2252 t.Fatal(err)
2253 }
2254 st.wantSettingsAck()
2255
2256 getSlash(st)
2257
2258 st.wantHeaders(wantHeader{
2259 streamID: 1,
2260 endStream: false,
2261 })
2262
2263 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
2264 t.Fatal(err)
2265 }
2266 })
2267 }
2268
2269 func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) {
2270 synctestTest(t, testServer_Response_Empty_Data_Not_FlowControlled)
2271 }
2272 func testServer_Response_Empty_Data_Not_FlowControlled(t testing.TB) {
2273 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2274 w.(http.Flusher).Flush()
2275
2276 return nil
2277 }, func(st *serverTester) {
2278
2279 if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, 0}); err != nil {
2280 t.Fatal(err)
2281 }
2282 st.wantSettingsAck()
2283
2284 getSlash(st)
2285
2286 st.wantHeaders(wantHeader{
2287 streamID: 1,
2288 endStream: false,
2289 })
2290
2291 st.wantData(wantData{
2292 streamID: 1,
2293 endStream: true,
2294 size: 0,
2295 })
2296 })
2297 }
2298
2299 func TestServer_Response_Automatic100Continue(t *testing.T) {
2300 synctestTest(t, testServer_Response_Automatic100Continue)
2301 }
2302 func testServer_Response_Automatic100Continue(t testing.TB) {
2303 const msg = "foo"
2304 const reply = "bar"
2305 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2306 if v := r.Header.Get("Expect"); v != "" {
2307 t.Errorf("Expect header = %q; want empty", v)
2308 }
2309 buf := make([]byte, len(msg))
2310
2311 if n, err := io.ReadFull(r.Body, buf); err != nil || n != len(msg) || string(buf) != msg {
2312 return fmt.Errorf("ReadFull = %q, %v; want %q, nil", buf[:n], err, msg)
2313 }
2314 _, err := io.WriteString(w, reply)
2315 return err
2316 }, func(st *serverTester) {
2317 st.writeHeaders(HeadersFrameParam{
2318 StreamID: 1,
2319 BlockFragment: st.encodeHeader(":method", "POST", "expect", "100-Continue"),
2320 EndStream: false,
2321 EndHeaders: true,
2322 })
2323 st.wantHeaders(wantHeader{
2324 streamID: 1,
2325 endStream: false,
2326 header: http.Header{
2327 ":status": []string{"100"},
2328 },
2329 })
2330
2331
2332
2333 st.writeData(1, true, []byte(msg))
2334
2335 st.wantHeaders(wantHeader{
2336 streamID: 1,
2337 endStream: false,
2338 header: http.Header{
2339 ":status": []string{"200"},
2340 "content-type": []string{"text/plain; charset=utf-8"},
2341 "content-length": []string{strconv.Itoa(len(reply))},
2342 },
2343 })
2344
2345 st.wantData(wantData{
2346 streamID: 1,
2347 endStream: true,
2348 data: []byte(reply),
2349 })
2350 })
2351 }
2352
2353 func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
2354 synctestTest(t, testServer_HandlerWriteErrorOnDisconnect)
2355 }
2356 func testServer_HandlerWriteErrorOnDisconnect(t testing.TB) {
2357 errc := make(chan error, 1)
2358 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2359 p := []byte("some data.\n")
2360 for {
2361 _, err := w.Write(p)
2362 if err != nil {
2363 errc <- err
2364 return nil
2365 }
2366 }
2367 }, func(st *serverTester) {
2368 st.writeHeaders(HeadersFrameParam{
2369 StreamID: 1,
2370 BlockFragment: st.encodeHeader(),
2371 EndStream: false,
2372 EndHeaders: true,
2373 })
2374 st.wantHeaders(wantHeader{
2375 streamID: 1,
2376 endStream: false,
2377 })
2378
2379 st.cc.Close()
2380 _ = <-errc
2381 })
2382 }
2383
2384 func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
2385 synctestTest(t, testServer_Rejects_Too_Many_Streams)
2386 }
2387 func testServer_Rejects_Too_Many_Streams(t testing.TB) {
2388 st := newServerTester(t, nil)
2389 st.greet()
2390 nextStreamID := uint32(1)
2391 streamID := func() uint32 {
2392 defer func() { nextStreamID += 2 }()
2393 return nextStreamID
2394 }
2395 sendReq := func(id uint32) {
2396 st.writeHeaders(HeadersFrameParam{
2397 StreamID: id,
2398 BlockFragment: st.encodeHeader(
2399 ":path", fmt.Sprintf("/%v", id),
2400 ),
2401 EndStream: true,
2402 EndHeaders: true,
2403 })
2404 }
2405 var calls []*serverHandlerCall
2406 for range DefaultMaxStreams {
2407 sendReq(streamID())
2408 calls = append(calls, st.nextHandlerCall())
2409 }
2410
2411
2412
2413
2414 rejectID := streamID()
2415 headerBlock := st.encodeHeader(":path", fmt.Sprintf("/%v", rejectID))
2416 frag1, frag2 := headerBlock[:3], headerBlock[3:]
2417 st.writeHeaders(HeadersFrameParam{
2418 StreamID: rejectID,
2419 BlockFragment: frag1,
2420 EndStream: true,
2421 EndHeaders: false,
2422 })
2423 if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil {
2424 t.Fatal(err)
2425 }
2426 st.sync()
2427 st.wantRSTStream(rejectID, ErrCodeProtocol)
2428
2429
2430 calls[0].exit()
2431 st.sync()
2432 st.wantHeaders(wantHeader{
2433 streamID: 1,
2434 endStream: true,
2435 })
2436
2437
2438 goodID := streamID()
2439 sendReq(goodID)
2440 call := st.nextHandlerCall()
2441 if got, want := call.req.URL.Path, fmt.Sprintf("/%d", goodID); got != want {
2442 t.Errorf("Got request for %q, want %q", got, want)
2443 }
2444 }
2445
2446
2447 func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) {
2448 synctestTest(t, testServer_Response_ManyHeaders_With_Continuation)
2449 }
2450 func testServer_Response_ManyHeaders_With_Continuation(t testing.TB) {
2451 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2452 h := w.Header()
2453 for i := range 5000 {
2454 h.Set(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("x-value-%d", i))
2455 }
2456 return nil
2457 }, func(st *serverTester) {
2458 getSlash(st)
2459 hf := readFrame[*HeadersFrame](t, st)
2460 if hf.HeadersEnded() {
2461 t.Fatal("got unwanted END_HEADERS flag")
2462 }
2463 n := 0
2464 for {
2465 n++
2466 cf := readFrame[*ContinuationFrame](t, st)
2467 if cf.HeadersEnded() {
2468 break
2469 }
2470 }
2471 if n < 5 {
2472 t.Errorf("Only got %d CONTINUATION frames; expected 5+ (currently 6)", n)
2473 }
2474 })
2475 }
2476
2477
2478
2479
2480
2481
2482
2483
2484 func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) {
2485 synctestTest(t, testServer_NoCrash_HandlerClose_Then_ClientClose)
2486 }
2487 func testServer_NoCrash_HandlerClose_Then_ClientClose(t testing.TB) {
2488 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2489
2490 return nil
2491 }, func(st *serverTester) {
2492 st.writeHeaders(HeadersFrameParam{
2493 StreamID: 1,
2494 BlockFragment: st.encodeHeader(),
2495 EndStream: false,
2496 EndHeaders: true,
2497 })
2498 st.wantHeaders(wantHeader{
2499 streamID: 1,
2500 endStream: true,
2501 })
2502
2503
2504
2505 st.wantRSTStream(1, ErrCodeNo)
2506
2507
2508
2509
2510
2511
2512 st.writeData(1, true, []byte("foo"))
2513
2514
2515
2516
2517
2518 st.wantRSTStream(1, ErrCodeStreamClosed)
2519
2520
2521
2522 st.wantConnFlowControlConsumed(0)
2523
2524
2525
2526 var (
2527 panMu sync.Mutex
2528 panicVal any
2529 )
2530
2531 SetTestHookOnPanic(t, func(sc *ServerConn, pv any) bool {
2532 panMu.Lock()
2533 panicVal = pv
2534 panMu.Unlock()
2535 return true
2536 })
2537
2538
2539 st.cc.Close()
2540 synctest.Wait()
2541
2542 panMu.Lock()
2543 got := panicVal
2544 panMu.Unlock()
2545 if got != nil {
2546 t.Errorf("Got panic: %v", got)
2547 }
2548 })
2549 }
2550
2551 func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) }
2552 func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) }
2553
2554 func testRejectTLS(t *testing.T, version uint16) {
2555 synctestTest(t, func(t testing.TB) {
2556 st := newServerTester(t, nil, func(state *tls.ConnectionState) {
2557
2558
2559
2560 state.Version = version
2561 })
2562 defer st.Close()
2563 st.wantGoAway(0, ErrCodeInadequateSecurity)
2564 })
2565 }
2566
2567 func TestServer_Rejects_TLSBadCipher(t *testing.T) { synctestTest(t, testServer_Rejects_TLSBadCipher) }
2568 func testServer_Rejects_TLSBadCipher(t testing.TB) {
2569 st := newServerTester(t, nil, func(state *tls.ConnectionState) {
2570 state.Version = tls.VersionTLS12
2571 state.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA
2572 })
2573 defer st.Close()
2574 st.wantGoAway(0, ErrCodeInadequateSecurity)
2575 }
2576
2577 func TestServer_Advertises_Common_Cipher(t *testing.T) {
2578 synctestTest(t, testServer_Advertises_Common_Cipher)
2579 }
2580 func testServer_Advertises_Common_Cipher(t testing.TB) {
2581 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2582 }, func(srv *http.Server) {
2583
2584
2585 srv.TLSConfig = nil
2586 })
2587
2588
2589 const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
2590 tlsConfig := tlsConfigInsecure.Clone()
2591 tlsConfig.MaxVersion = tls.VersionTLS12
2592 tlsConfig.CipherSuites = []uint16{requiredSuite}
2593 tr := &http.Transport{
2594 TLSClientConfig: tlsConfig,
2595 Protocols: protocols("h2"),
2596 }
2597 defer tr.CloseIdleConnections()
2598
2599 req, err := http.NewRequest("GET", ts.URL, nil)
2600 if err != nil {
2601 t.Fatal(err)
2602 }
2603 res, err := tr.RoundTrip(req)
2604 if err != nil {
2605 t.Fatal(err)
2606 }
2607 res.Body.Close()
2608 }
2609
2610
2611
2612 func testServerResponse(t testing.TB,
2613 handler func(http.ResponseWriter, *http.Request) error,
2614 client func(*serverTester),
2615 ) {
2616 errc := make(chan error, 1)
2617 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2618 if r.Body == nil {
2619 t.Fatal("nil Body")
2620 }
2621 err := handler(w, r)
2622 select {
2623 case errc <- err:
2624 default:
2625 t.Errorf("unexpected duplicate request")
2626 }
2627 })
2628 defer st.Close()
2629
2630 st.greet()
2631 client(st)
2632
2633 if err := <-errc; err != nil {
2634 t.Fatalf("Error in handler: %v", err)
2635 }
2636 }
2637
2638
2639
2640
2641 func readBodyHandler(t testing.TB, want string) func(w http.ResponseWriter, r *http.Request) {
2642 return func(w http.ResponseWriter, r *http.Request) {
2643 buf := make([]byte, len(want))
2644 _, err := io.ReadFull(r.Body, buf)
2645 if err != nil {
2646 t.Error(err)
2647 return
2648 }
2649 if string(buf) != want {
2650 t.Errorf("read %q; want %q", buf, want)
2651 }
2652 }
2653 }
2654
2655 func TestServer_MaxDecoderHeaderTableSize(t *testing.T) {
2656 synctestTest(t, testServer_MaxDecoderHeaderTableSize)
2657 }
2658 func testServer_MaxDecoderHeaderTableSize(t testing.TB) {
2659 wantHeaderTableSize := uint32(InitialHeaderTableSize * 2)
2660 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(h2 *http.HTTP2Config) {
2661 h2.MaxDecoderHeaderTableSize = int(wantHeaderTableSize)
2662 })
2663 defer st.Close()
2664
2665 var advHeaderTableSize *uint32
2666 st.greetAndCheckSettings(func(s Setting) error {
2667 switch s.ID {
2668 case SettingHeaderTableSize:
2669 advHeaderTableSize = &s.Val
2670 }
2671 return nil
2672 })
2673
2674 if advHeaderTableSize == nil {
2675 t.Errorf("server didn't advertise a header table size")
2676 } else if got, want := *advHeaderTableSize, wantHeaderTableSize; got != want {
2677 t.Errorf("server advertised a header table size of %d, want %d", got, want)
2678 }
2679 }
2680
2681 func TestServer_MaxEncoderHeaderTableSize(t *testing.T) {
2682 synctestTest(t, testServer_MaxEncoderHeaderTableSize)
2683 }
2684 func testServer_MaxEncoderHeaderTableSize(t testing.TB) {
2685 wantHeaderTableSize := uint32(InitialHeaderTableSize / 2)
2686 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(h2 *http.HTTP2Config) {
2687 h2.MaxEncoderHeaderTableSize = int(wantHeaderTableSize)
2688 })
2689 defer st.Close()
2690
2691 st.greet()
2692
2693 if got, want := st.sc.TestHPACKEncoder().MaxDynamicTableSize(), wantHeaderTableSize; got != want {
2694 t.Errorf("server encoder is using a header table size of %d, want %d", got, want)
2695 }
2696 }
2697
2698
2699 func TestServerDoS_MaxHeaderListSize(t *testing.T) { synctestTest(t, testServerDoS_MaxHeaderListSize) }
2700 func testServerDoS_MaxHeaderListSize(t testing.TB) {
2701 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
2702 defer st.Close()
2703
2704
2705 frameSize := DefaultMaxReadFrameSize
2706 var advHeaderListSize *uint32
2707 st.greetAndCheckSettings(func(s Setting) error {
2708 switch s.ID {
2709 case SettingMaxFrameSize:
2710 if s.Val < MinMaxFrameSize {
2711 frameSize = MinMaxFrameSize
2712 } else if s.Val > MaxFrameSize {
2713 frameSize = MaxFrameSize
2714 } else {
2715 frameSize = int(s.Val)
2716 }
2717 case SettingMaxHeaderListSize:
2718 advHeaderListSize = &s.Val
2719 }
2720 return nil
2721 })
2722
2723 if advHeaderListSize == nil {
2724 t.Errorf("server didn't advertise a max header list size")
2725 } else if *advHeaderListSize == 0 {
2726 t.Errorf("server advertised a max header list size of 0")
2727 }
2728
2729 st.encodeHeaderField(":method", "GET")
2730 st.encodeHeaderField(":path", "/")
2731 st.encodeHeaderField(":scheme", "https")
2732 cookie := strings.Repeat("*", 4058)
2733 st.encodeHeaderField("cookie", cookie)
2734 st.writeHeaders(HeadersFrameParam{
2735 StreamID: 1,
2736 BlockFragment: st.headerBuf.Bytes(),
2737 EndStream: true,
2738 EndHeaders: false,
2739 })
2740
2741
2742
2743 st.headerBuf.Reset()
2744 st.encodeHeaderField("cookie", cookie)
2745
2746
2747 const size = 1 << 20
2748 b := bytes.Repeat(st.headerBuf.Bytes(), size/st.headerBuf.Len())
2749 for len(b) > 0 {
2750 chunk := b
2751 if len(chunk) > frameSize {
2752 chunk = chunk[:frameSize]
2753 }
2754 b = b[len(chunk):]
2755 st.fr.WriteContinuation(1, len(b) == 0, chunk)
2756 }
2757
2758 st.wantHeaders(wantHeader{
2759 streamID: 1,
2760 endStream: false,
2761 header: http.Header{
2762 ":status": []string{"431"},
2763 "content-type": []string{"text/html; charset=utf-8"},
2764 "content-length": []string{"63"},
2765 },
2766 })
2767 }
2768
2769 func TestServer_Response_Stream_With_Missing_Trailer(t *testing.T) {
2770 synctestTest(t, testServer_Response_Stream_With_Missing_Trailer)
2771 }
2772 func testServer_Response_Stream_With_Missing_Trailer(t testing.TB) {
2773 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2774 w.Header().Set("Trailer", "test-trailer")
2775 return nil
2776 }, func(st *serverTester) {
2777 getSlash(st)
2778 st.wantHeaders(wantHeader{
2779 streamID: 1,
2780 endStream: false,
2781 })
2782 st.wantData(wantData{
2783 streamID: 1,
2784 endStream: true,
2785 size: 0,
2786 })
2787 })
2788 }
2789
2790 func TestCompressionErrorOnWrite(t *testing.T) { synctestTest(t, testCompressionErrorOnWrite) }
2791 func testCompressionErrorOnWrite(t testing.TB) {
2792 const maxStrLen = 8 << 10
2793 var serverConfig *http.Server
2794 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2795
2796 }, func(s *http.Server) {
2797 serverConfig = s
2798 serverConfig.MaxHeaderBytes = maxStrLen
2799 })
2800 st.addLogFilter("connection error: COMPRESSION_ERROR")
2801 defer st.Close()
2802 st.greet()
2803
2804 maxAllowed := st.sc.TestFramerMaxHeaderStringLen()
2805
2806
2807
2808
2809
2810
2811 serverConfig.MaxHeaderBytes = 1 << 20
2812
2813
2814
2815
2816
2817 hbf := st.encodeHeader("foo", strings.Repeat("a", maxAllowed))
2818
2819 st.writeHeaders(HeadersFrameParam{
2820 StreamID: 1,
2821 BlockFragment: hbf,
2822 EndStream: true,
2823 EndHeaders: true,
2824 })
2825 st.wantHeaders(wantHeader{
2826 streamID: 1,
2827 endStream: false,
2828 header: http.Header{
2829 ":status": []string{"431"},
2830 "content-type": []string{"text/html; charset=utf-8"},
2831 "content-length": []string{"63"},
2832 },
2833 })
2834 df := readFrame[*DataFrame](t, st)
2835 if !strings.Contains(string(df.Data()), "HTTP Error 431") {
2836 t.Errorf("Unexpected data body: %q", df.Data())
2837 }
2838 if !df.StreamEnded() {
2839 t.Fatalf("expect data stream end")
2840 }
2841
2842
2843 hbf = st.encodeHeader("bar", strings.Repeat("b", maxAllowed+1))
2844 st.writeHeaders(HeadersFrameParam{
2845 StreamID: 3,
2846 BlockFragment: hbf,
2847 EndStream: true,
2848 EndHeaders: true,
2849 })
2850 st.wantGoAway(3, ErrCodeCompression)
2851 }
2852
2853 func TestCompressionErrorOnClose(t *testing.T) { synctestTest(t, testCompressionErrorOnClose) }
2854 func testCompressionErrorOnClose(t testing.TB) {
2855 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2856
2857 })
2858 st.addLogFilter("connection error: COMPRESSION_ERROR")
2859 defer st.Close()
2860 st.greet()
2861
2862 hbf := st.encodeHeader("foo", "bar")
2863 hbf = hbf[:len(hbf)-1]
2864 st.writeHeaders(HeadersFrameParam{
2865 StreamID: 1,
2866 BlockFragment: hbf,
2867 EndStream: true,
2868 EndHeaders: true,
2869 })
2870 st.wantGoAway(1, ErrCodeCompression)
2871 }
2872
2873
2874 func TestServerReadsTrailers(t *testing.T) { synctestTest(t, testServerReadsTrailers) }
2875 func testServerReadsTrailers(t testing.TB) {
2876 const testBody = "some test body"
2877 writeReq := func(st *serverTester) {
2878 st.writeHeaders(HeadersFrameParam{
2879 StreamID: 1,
2880 BlockFragment: st.encodeHeader("trailer", "Foo, Bar", "trailer", "Baz"),
2881 EndStream: false,
2882 EndHeaders: true,
2883 })
2884 st.writeData(1, false, []byte(testBody))
2885 st.writeHeaders(HeadersFrameParam{
2886 StreamID: 1,
2887 BlockFragment: st.encodeHeaderRaw(
2888 "foo", "foov",
2889 "bar", "barv",
2890 "baz", "bazv",
2891 "surprise", "wasn't declared; shouldn't show up",
2892 ),
2893 EndStream: true,
2894 EndHeaders: true,
2895 })
2896 }
2897 checkReq := func(r *http.Request) {
2898 wantTrailer := http.Header{
2899 "Foo": nil,
2900 "Bar": nil,
2901 "Baz": nil,
2902 }
2903 if !reflect.DeepEqual(r.Trailer, wantTrailer) {
2904 t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer)
2905 }
2906 slurp, err := io.ReadAll(r.Body)
2907 if string(slurp) != testBody {
2908 t.Errorf("read body %q; want %q", slurp, testBody)
2909 }
2910 if err != nil {
2911 t.Fatalf("Body slurp: %v", err)
2912 }
2913 wantTrailerAfter := http.Header{
2914 "Foo": {"foov"},
2915 "Bar": {"barv"},
2916 "Baz": {"bazv"},
2917 }
2918 if !reflect.DeepEqual(r.Trailer, wantTrailerAfter) {
2919 t.Errorf("final Trailer = %v; want %v", r.Trailer, wantTrailerAfter)
2920 }
2921 }
2922 testServerRequest(t, writeReq, checkReq)
2923 }
2924
2925
2926 func TestServerWritesTrailers_WithFlush(t *testing.T) {
2927 synctestTest(t, func(t testing.TB) {
2928 testServerWritesTrailers(t, true)
2929 })
2930 }
2931 func TestServerWritesTrailers_WithoutFlush(t *testing.T) {
2932 synctestTest(t, func(t testing.TB) {
2933 testServerWritesTrailers(t, false)
2934 })
2935 }
2936
2937 func testServerWritesTrailers(t testing.TB, withFlush bool) {
2938
2939 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2940 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
2941 w.Header().Add("Trailer", "Server-Trailer-C")
2942 w.Header().Add("Trailer", "Transfer-Encoding, Content-Length, Trailer")
2943
2944
2945 w.Header().Set("Foo", "Bar")
2946 w.Header().Set("Content-Length", "5")
2947
2948 io.WriteString(w, "Hello")
2949 if withFlush {
2950 w.(http.Flusher).Flush()
2951 }
2952 w.Header().Set("Server-Trailer-A", "valuea")
2953 w.Header().Set("Server-Trailer-C", "valuec")
2954
2955 w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!")
2956
2957
2958
2959 w.Header().Set("Trailer:Post-Header-Trailer", "hi1")
2960 w.Header().Set("Trailer:post-header-trailer2", "hi2")
2961 w.Header().Set("Trailer:Range", "invalid")
2962 w.Header().Set("Trailer:Foo\x01Bogus", "invalid")
2963 w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 7230 4.1.2")
2964 w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 7230 4.1.2")
2965 w.Header().Set("Trailer", "should not be included; Forbidden by RFC 7230 4.1.2")
2966 return nil
2967 }, func(st *serverTester) {
2968
2969 st.h1server.ErrorLog = log.New(io.Discard, "", 0)
2970 getSlash(st)
2971 st.wantHeaders(wantHeader{
2972 streamID: 1,
2973 endStream: false,
2974 header: http.Header{
2975 ":status": []string{"200"},
2976 "foo": []string{"Bar"},
2977 "trailer": []string{
2978 "Server-Trailer-A, Server-Trailer-B",
2979 "Server-Trailer-C",
2980 "Transfer-Encoding, Content-Length, Trailer",
2981 },
2982 "content-type": []string{"text/plain; charset=utf-8"},
2983 "content-length": []string{"5"},
2984 },
2985 })
2986 st.wantData(wantData{
2987 streamID: 1,
2988 endStream: false,
2989 data: []byte("Hello"),
2990 })
2991 st.wantHeaders(wantHeader{
2992 streamID: 1,
2993 endStream: true,
2994 header: http.Header{
2995 "post-header-trailer": []string{"hi1"},
2996 "post-header-trailer2": []string{"hi2"},
2997 "server-trailer-a": []string{"valuea"},
2998 "server-trailer-c": []string{"valuec"},
2999 },
3000 })
3001 })
3002 }
3003
3004 func TestServerWritesUndeclaredTrailers(t *testing.T) {
3005 synctestTest(t, testServerWritesUndeclaredTrailers)
3006 }
3007 func testServerWritesUndeclaredTrailers(t testing.TB) {
3008 const trailer = "Trailer-Header"
3009 const value = "hi1"
3010 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3011 w.Header().Set(http.TrailerPrefix+trailer, value)
3012 })
3013
3014 tr := &http.Transport{
3015 TLSClientConfig: tlsConfigInsecure,
3016 Protocols: protocols("h2"),
3017 }
3018 defer tr.CloseIdleConnections()
3019
3020 cl := &http.Client{Transport: tr}
3021 resp, err := cl.Get(ts.URL)
3022 if err != nil {
3023 t.Fatal(err)
3024 }
3025 io.Copy(io.Discard, resp.Body)
3026 resp.Body.Close()
3027
3028 if got, want := resp.Trailer.Get(trailer), value; got != want {
3029 t.Errorf("trailer %v = %q, want %q", trailer, got, want)
3030 }
3031 }
3032
3033
3034
3035 func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
3036 synctestTest(t, testServerDoesntWriteInvalidHeaders)
3037 }
3038 func testServerDoesntWriteInvalidHeaders(t testing.TB) {
3039 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3040 w.Header().Add("OK1", "x")
3041 w.Header().Add("Bad:Colon", "x")
3042 w.Header().Add("Bad1\x00", "x")
3043 w.Header().Add("Bad2", "x\x00y")
3044 return nil
3045 }, func(st *serverTester) {
3046 getSlash(st)
3047 st.wantHeaders(wantHeader{
3048 streamID: 1,
3049 endStream: true,
3050 header: http.Header{
3051 ":status": []string{"200"},
3052 "ok1": []string{"x"},
3053 "content-length": []string{"0"},
3054 },
3055 })
3056 })
3057 }
3058
3059 func TestIssue53(t *testing.T) { synctestTest(t, testIssue53) }
3060 func testIssue53(t testing.TB) {
3061 const data = "PRI * HTTP/2.0\r\n\r\nSM" +
3062 "\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad"
3063 st := newServerTester(t, func(w http.ResponseWriter, req *http.Request) {
3064 w.Write([]byte("hello"))
3065 })
3066
3067 st.cc.Write([]byte(data))
3068 st.wantFrameType(FrameSettings)
3069 st.wantFrameType(FrameWindowUpdate)
3070 st.wantFrameType(FrameGoAway)
3071 time.Sleep(GoAwayTimeout)
3072 st.wantClosed()
3073 }
3074
3075 func TestServerServeNoBannedCiphers(t *testing.T) {
3076 tests := []struct {
3077 name string
3078 tlsConfig *tls.Config
3079 wantErr string
3080 }{
3081 {
3082 name: "empty CipherSuites",
3083 tlsConfig: &tls.Config{},
3084 },
3085 {
3086 name: "bad CipherSuites but MinVersion TLS 1.3",
3087 tlsConfig: &tls.Config{
3088 MinVersion: tls.VersionTLS13,
3089 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384},
3090 },
3091 },
3092 {
3093 name: "just the required cipher suite",
3094 tlsConfig: &tls.Config{
3095 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3096 },
3097 },
3098 {
3099 name: "just the alternative required cipher suite",
3100 tlsConfig: &tls.Config{
3101 CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
3102 },
3103 },
3104 {
3105 name: "missing required cipher suite",
3106 tlsConfig: &tls.Config{
3107 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384},
3108 },
3109 wantErr: "is missing an HTTP/2-required",
3110 },
3111 {
3112 name: "required after bad",
3113 tlsConfig: &tls.Config{
3114 CipherSuites: []uint16{tls.TLS_RSA_WITH_RC4_128_SHA, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3115 },
3116 },
3117 {
3118 name: "bad after required",
3119 tlsConfig: &tls.Config{
3120 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_RC4_128_SHA},
3121 },
3122 },
3123 }
3124 for _, tt := range tests {
3125 tt.tlsConfig.Certificates = testServerTLSConfig.Certificates
3126
3127 srv := &http.Server{
3128 TLSConfig: tt.tlsConfig,
3129 Protocols: protocols("h2"),
3130 }
3131
3132 err := srv.ServeTLS(errListener{}, "", "")
3133 if (err != net.ErrClosed) != (tt.wantErr != "") {
3134 if tt.wantErr != "" {
3135 t.Errorf("%s: success, but want error", tt.name)
3136 } else {
3137 t.Errorf("%s: unexpected error: %v", tt.name, err)
3138 }
3139 }
3140 if err != nil && tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) {
3141 t.Errorf("%s: err = %v; want substring %q", tt.name, err, tt.wantErr)
3142 }
3143 if err == nil && !srv.TLSConfig.PreferServerCipherSuites {
3144 t.Errorf("%s: PreferServerCipherSuite is false; want true", tt.name)
3145 }
3146 }
3147 }
3148
3149 type errListener struct{}
3150
3151 func (li errListener) Accept() (net.Conn, error) { return nil, net.ErrClosed }
3152 func (li errListener) Close() error { return nil }
3153 func (li errListener) Addr() net.Addr { return nil }
3154
3155 func TestServerNoAutoContentLengthOnHead(t *testing.T) {
3156 synctestTest(t, testServerNoAutoContentLengthOnHead)
3157 }
3158 func testServerNoAutoContentLengthOnHead(t testing.TB) {
3159 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3160
3161 })
3162 defer st.Close()
3163 st.greet()
3164 st.writeHeaders(HeadersFrameParam{
3165 StreamID: 1,
3166 BlockFragment: st.encodeHeader(":method", "HEAD"),
3167 EndStream: true,
3168 EndHeaders: true,
3169 })
3170 st.wantHeaders(wantHeader{
3171 streamID: 1,
3172 endStream: true,
3173 header: http.Header{
3174 ":status": []string{"200"},
3175 },
3176 })
3177 }
3178
3179
3180 func TestServerNoDuplicateContentType(t *testing.T) {
3181 synctestTest(t, testServerNoDuplicateContentType)
3182 }
3183 func testServerNoDuplicateContentType(t testing.TB) {
3184 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3185 w.Header()["Content-Type"] = []string{""}
3186 fmt.Fprintf(w, "<html><head></head><body>hi</body></html>")
3187 })
3188 defer st.Close()
3189 st.greet()
3190 st.writeHeaders(HeadersFrameParam{
3191 StreamID: 1,
3192 BlockFragment: st.encodeHeader(),
3193 EndStream: true,
3194 EndHeaders: true,
3195 })
3196 st.wantHeaders(wantHeader{
3197 streamID: 1,
3198 endStream: false,
3199 header: http.Header{
3200 ":status": []string{"200"},
3201 "content-type": []string{""},
3202 "content-length": []string{"41"},
3203 },
3204 })
3205 }
3206
3207 func TestServerContentLengthCanBeDisabled(t *testing.T) {
3208 synctestTest(t, testServerContentLengthCanBeDisabled)
3209 }
3210 func testServerContentLengthCanBeDisabled(t testing.TB) {
3211 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3212 w.Header()["Content-Length"] = nil
3213 fmt.Fprintf(w, "OK")
3214 })
3215 defer st.Close()
3216 st.greet()
3217 st.writeHeaders(HeadersFrameParam{
3218 StreamID: 1,
3219 BlockFragment: st.encodeHeader(),
3220 EndStream: true,
3221 EndHeaders: true,
3222 })
3223 st.wantHeaders(wantHeader{
3224 streamID: 1,
3225 endStream: false,
3226 header: http.Header{
3227 ":status": []string{"200"},
3228 "content-type": []string{"text/plain; charset=utf-8"},
3229 },
3230 })
3231 }
3232
3233
3234 func TestServer_Rejects_ConnHeaders(t *testing.T) { synctestTest(t, testServer_Rejects_ConnHeaders) }
3235 func testServer_Rejects_ConnHeaders(t testing.TB) {
3236 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3237 t.Error("should not get to Handler")
3238 })
3239 defer st.Close()
3240 st.greet()
3241 st.bodylessReq1("connection", "foo")
3242 st.wantHeaders(wantHeader{
3243 streamID: 1,
3244 endStream: false,
3245 header: http.Header{
3246 ":status": []string{"400"},
3247 "content-type": []string{"text/plain; charset=utf-8"},
3248 "x-content-type-options": []string{"nosniff"},
3249 "content-length": []string{"51"},
3250 },
3251 })
3252 }
3253
3254 type hpackEncoder struct {
3255 enc *hpack.Encoder
3256 buf bytes.Buffer
3257 }
3258
3259 func (he *hpackEncoder) encodeHeaderRaw(t testing.TB, headers ...string) []byte {
3260 if len(headers)%2 == 1 {
3261 panic("odd number of kv args")
3262 }
3263 he.buf.Reset()
3264 if he.enc == nil {
3265 he.enc = hpack.NewEncoder(&he.buf)
3266 }
3267 for len(headers) > 0 {
3268 k, v := headers[0], headers[1]
3269 err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v})
3270 if err != nil {
3271 t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
3272 }
3273 headers = headers[2:]
3274 }
3275 return he.buf.Bytes()
3276 }
3277
3278
3279 func TestExpect100ContinueAfterHandlerWrites(t *testing.T) {
3280 synctestTest(t, testExpect100ContinueAfterHandlerWrites)
3281 }
3282 func testExpect100ContinueAfterHandlerWrites(t testing.TB) {
3283 const msg = "Hello"
3284 const msg2 = "World"
3285
3286 doRead := make(chan bool, 1)
3287 defer close(doRead)
3288
3289 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3290 io.WriteString(w, msg)
3291 w.(http.Flusher).Flush()
3292
3293
3294 <-doRead
3295 r.Body.Read(make([]byte, 10))
3296
3297 io.WriteString(w, msg2)
3298 })
3299
3300 tr := &http.Transport{
3301 TLSClientConfig: tlsConfigInsecure,
3302 Protocols: protocols("h2"),
3303 }
3304 defer tr.CloseIdleConnections()
3305
3306 req, _ := http.NewRequest("POST", ts.URL, io.LimitReader(neverEnding('A'), 2<<20))
3307 req.Header.Set("Expect", "100-continue")
3308
3309 res, err := tr.RoundTrip(req)
3310 if err != nil {
3311 t.Fatal(err)
3312 }
3313 defer res.Body.Close()
3314
3315 buf := make([]byte, len(msg))
3316 if _, err := io.ReadFull(res.Body, buf); err != nil {
3317 t.Fatal(err)
3318 }
3319 if string(buf) != msg {
3320 t.Fatalf("msg = %q; want %q", buf, msg)
3321 }
3322
3323 doRead <- true
3324
3325 if _, err := io.ReadFull(res.Body, buf); err != nil {
3326 t.Fatal(err)
3327 }
3328 if string(buf) != msg2 {
3329 t.Fatalf("second msg = %q; want %q", buf, msg2)
3330 }
3331 }
3332
3333 type funcReader func([]byte) (n int, err error)
3334
3335 func (f funcReader) Read(p []byte) (n int, err error) { return f(p) }
3336
3337
3338
3339 func TestUnreadFlowControlReturned_Server(t *testing.T) {
3340 for _, tt := range []struct {
3341 name string
3342 reqFn func(r *http.Request)
3343 }{
3344 {
3345 "body-open",
3346 func(r *http.Request) {},
3347 },
3348 {
3349 "body-closed",
3350 func(r *http.Request) {
3351 r.Body.Close()
3352 },
3353 },
3354 {
3355 "read-1-byte-and-close",
3356 func(r *http.Request) {
3357 b := make([]byte, 1)
3358 r.Body.Read(b)
3359 r.Body.Close()
3360 },
3361 },
3362 } {
3363 synctestSubtest(t, tt.name, func(t testing.TB) {
3364 unblock := make(chan bool, 1)
3365 defer close(unblock)
3366
3367 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3368
3369
3370
3371 tt.reqFn(r)
3372 <-unblock
3373 })
3374
3375 tr := &http.Transport{
3376 TLSClientConfig: tlsConfigInsecure,
3377 Protocols: protocols("h2"),
3378 }
3379 defer tr.CloseIdleConnections()
3380
3381
3382 iters := 100
3383 if testing.Short() {
3384 iters = 20
3385 }
3386 for i := 0; i < iters; i++ {
3387 body := io.MultiReader(
3388 io.LimitReader(neverEnding('A'), 16<<10),
3389 funcReader(func([]byte) (n int, err error) {
3390 unblock <- true
3391 return 0, io.EOF
3392 }),
3393 )
3394 req, _ := http.NewRequest("POST", ts.URL, body)
3395 res, err := tr.RoundTrip(req)
3396 if err != nil {
3397 t.Fatal(tt.name, err)
3398 }
3399 res.Body.Close()
3400 }
3401 })
3402 }
3403 }
3404
3405 func TestServerReturnsStreamAndConnFlowControlOnBodyClose(t *testing.T) {
3406 synctestTest(t, testServerReturnsStreamAndConnFlowControlOnBodyClose)
3407 }
3408 func testServerReturnsStreamAndConnFlowControlOnBodyClose(t testing.TB) {
3409 unblockHandler := make(chan struct{})
3410 defer close(unblockHandler)
3411
3412 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3413 r.Body.Close()
3414 w.WriteHeader(200)
3415 w.(http.Flusher).Flush()
3416 <-unblockHandler
3417 })
3418 defer st.Close()
3419
3420 st.greet()
3421 st.writeHeaders(HeadersFrameParam{
3422 StreamID: 1,
3423 BlockFragment: st.encodeHeader(),
3424 EndHeaders: true,
3425 })
3426 st.wantHeaders(wantHeader{
3427 streamID: 1,
3428 endStream: false,
3429 })
3430 const size = InflowMinRefresh
3431 st.writeData(1, false, make([]byte, size))
3432 st.wantWindowUpdate(0, size)
3433 unblockHandler <- struct{}{}
3434 st.wantData(wantData{
3435 streamID: 1,
3436 endStream: true,
3437 })
3438 }
3439
3440 func TestServerIdleTimeout(t *testing.T) { synctestTest(t, testServerIdleTimeout) }
3441 func testServerIdleTimeout(t testing.TB) {
3442 if testing.Short() {
3443 t.Skip("skipping in short mode")
3444 }
3445
3446 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3447 }, func(s *http.Server) {
3448 s.IdleTimeout = 500 * time.Millisecond
3449 })
3450 defer st.Close()
3451
3452 st.greet()
3453 st.advance(500 * time.Millisecond)
3454 st.wantGoAway(0, ErrCodeNo)
3455 }
3456
3457 func TestServerIdleTimeout_AfterRequest(t *testing.T) {
3458 synctestTest(t, testServerIdleTimeout_AfterRequest)
3459 }
3460 func testServerIdleTimeout_AfterRequest(t testing.TB) {
3461 if testing.Short() {
3462 t.Skip("skipping in short mode")
3463 }
3464 const (
3465 requestTimeout = 2 * time.Second
3466 idleTimeout = 1 * time.Second
3467 )
3468
3469 var st *serverTester
3470 st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3471 time.Sleep(requestTimeout)
3472 }, func(s *http.Server) {
3473 s.IdleTimeout = idleTimeout
3474 })
3475 defer st.Close()
3476
3477 st.greet()
3478
3479
3480
3481 st.bodylessReq1()
3482 st.advance(requestTimeout)
3483 st.wantHeaders(wantHeader{
3484 streamID: 1,
3485 endStream: true,
3486 })
3487
3488
3489
3490 st.advance(idleTimeout)
3491 st.wantGoAway(1, ErrCodeNo)
3492 }
3493
3494
3495
3496
3497 func TestRequestBodyReadCloseRace(t *testing.T) { synctestTest(t, testRequestBodyReadCloseRace) }
3498 func testRequestBodyReadCloseRace(t testing.TB) {
3499 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3500 go r.Body.Close()
3501 io.Copy(io.Discard, r.Body)
3502 })
3503 st.greet()
3504
3505 data := make([]byte, 1024)
3506 for i := range 100 {
3507 streamID := uint32(1 + (i * 2))
3508 st.writeHeaders(HeadersFrameParam{
3509 StreamID: streamID,
3510 BlockFragment: st.encodeHeader(),
3511 EndHeaders: true,
3512 })
3513 st.writeData(1, false, data)
3514
3515 for {
3516
3517
3518 fr := st.readFrame()
3519 if fr == nil {
3520 t.Fatalf("got no RSTStreamFrame, want one")
3521 }
3522 rst, ok := fr.(*RSTStreamFrame)
3523 if !ok {
3524 continue
3525 }
3526
3527 if rst.ErrCode != ErrCodeNo && rst.ErrCode != ErrCodeStreamClosed {
3528 t.Fatalf("got RSTStreamFrame with error code %v, want ErrCodeNo or ErrCodeStreamClosed", rst.ErrCode)
3529 }
3530 break
3531 }
3532 }
3533 }
3534
3535 func TestIssue20704Race(t *testing.T) { synctestTest(t, testIssue20704Race) }
3536 func testIssue20704Race(t testing.TB) {
3537 if testing.Short() && os.Getenv("GO_BUILDER_NAME") == "" {
3538 t.Skip("skipping in short mode")
3539 }
3540 const (
3541 itemSize = 1 << 10
3542 itemCount = 100
3543 )
3544
3545 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3546 for range itemCount {
3547 _, err := w.Write(make([]byte, itemSize))
3548 if err != nil {
3549 return
3550 }
3551 }
3552 })
3553
3554 tr := &http.Transport{
3555 TLSClientConfig: tlsConfigInsecure,
3556 Protocols: protocols("h2"),
3557 }
3558 defer tr.CloseIdleConnections()
3559 cl := &http.Client{Transport: tr}
3560
3561 for range 1000 {
3562 resp, err := cl.Get(ts.URL)
3563 if err != nil {
3564 t.Fatal(err)
3565 }
3566
3567
3568 resp.Body.Close()
3569 }
3570 }
3571
3572 func TestServer_Rejects_TooSmall(t *testing.T) { synctestTest(t, testServer_Rejects_TooSmall) }
3573 func testServer_Rejects_TooSmall(t testing.TB) {
3574 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3575 io.ReadAll(r.Body)
3576 return nil
3577 }, func(st *serverTester) {
3578 st.writeHeaders(HeadersFrameParam{
3579 StreamID: 1,
3580 BlockFragment: st.encodeHeader(
3581 ":method", "POST",
3582 "content-length", "4",
3583 ),
3584 EndStream: false,
3585 EndHeaders: true,
3586 })
3587 st.writeData(1, true, []byte("12345"))
3588 st.wantRSTStream(1, ErrCodeProtocol)
3589 st.wantConnFlowControlConsumed(0)
3590 })
3591 }
3592
3593
3594
3595 func TestServerHandlerConnectionClose(t *testing.T) {
3596 synctestTest(t, testServerHandlerConnectionClose)
3597 }
3598 func testServerHandlerConnectionClose(t testing.TB) {
3599 unblockHandler := make(chan bool, 1)
3600 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3601 w.Header().Set("Connection", "close")
3602 w.Header().Set("Foo", "bar")
3603 w.(http.Flusher).Flush()
3604 <-unblockHandler
3605 return nil
3606 }, func(st *serverTester) {
3607 defer close(unblockHandler)
3608 st.writeHeaders(HeadersFrameParam{
3609 StreamID: 1,
3610 BlockFragment: st.encodeHeader(),
3611 EndStream: true,
3612 EndHeaders: true,
3613 })
3614 var sawGoAway bool
3615 var sawRes bool
3616 var sawWindowUpdate bool
3617 for {
3618 f := st.readFrame()
3619 if f == nil {
3620 break
3621 }
3622 switch f := f.(type) {
3623 case *GoAwayFrame:
3624 sawGoAway = true
3625 if f.LastStreamID != 1 || f.ErrCode != ErrCodeNo {
3626 t.Errorf("unexpected GOAWAY frame: %v", SummarizeFrame(f))
3627 }
3628
3629
3630 st.writeHeaders(HeadersFrameParam{
3631 StreamID: 3,
3632 BlockFragment: st.encodeHeader(),
3633 EndStream: false,
3634 EndHeaders: true,
3635 })
3636 st.fr.WriteRSTStream(3, ErrCodeCancel)
3637
3638
3639
3640 st.writeHeaders(HeadersFrameParam{
3641 StreamID: 5,
3642 BlockFragment: st.encodeHeader(),
3643 EndStream: false,
3644 EndHeaders: true,
3645 })
3646
3647 st.writeData(5, true, make([]byte, 1<<19))
3648 case *HeadersFrame:
3649 goth := st.decodeHeader(f.HeaderBlockFragment())
3650 wanth := [][2]string{
3651 {":status", "200"},
3652 {"foo", "bar"},
3653 }
3654 if !reflect.DeepEqual(goth, wanth) {
3655 t.Errorf("got headers %v; want %v", goth, wanth)
3656 }
3657 sawRes = true
3658 case *DataFrame:
3659 if f.StreamID != 1 || !f.StreamEnded() || len(f.Data()) != 0 {
3660 t.Errorf("unexpected DATA frame: %v", SummarizeFrame(f))
3661 }
3662 case *WindowUpdateFrame:
3663 if !sawGoAway {
3664 t.Errorf("unexpected WINDOW_UPDATE frame: %v", SummarizeFrame(f))
3665 return
3666 }
3667 if f.StreamID != 0 {
3668 st.t.Fatalf("WindowUpdate StreamID = %d; want 5", f.FrameHeader.StreamID)
3669 return
3670 }
3671 sawWindowUpdate = true
3672 unblockHandler <- true
3673 st.sync()
3674 st.advance(GoAwayTimeout)
3675 default:
3676 t.Logf("unexpected frame: %v", SummarizeFrame(f))
3677 }
3678 }
3679 if !sawGoAway {
3680 t.Errorf("didn't see GOAWAY")
3681 }
3682 if !sawRes {
3683 t.Errorf("didn't see response")
3684 }
3685 if !sawWindowUpdate {
3686 t.Errorf("didn't see WINDOW_UPDATE")
3687 }
3688 })
3689 }
3690
3691 func TestServer_Headers_HalfCloseRemote(t *testing.T) {
3692 synctestTest(t, testServer_Headers_HalfCloseRemote)
3693 }
3694 func testServer_Headers_HalfCloseRemote(t testing.TB) {
3695 var st *serverTester
3696 writeData := make(chan bool)
3697 writeHeaders := make(chan bool)
3698 leaveHandler := make(chan bool)
3699 st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3700 if !st.streamExists(1) {
3701 t.Errorf("stream 1 does not exist in handler")
3702 }
3703 if got, want := st.streamState(1), StateOpen; got != want {
3704 t.Errorf("in handler, state is %v; want %v", got, want)
3705 }
3706 writeData <- true
3707 if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
3708 t.Errorf("body read = %d, %v; want 0, EOF", n, err)
3709 }
3710 if got, want := st.streamState(1), StateHalfClosedRemote; got != want {
3711 t.Errorf("in handler, state is %v; want %v", got, want)
3712 }
3713 writeHeaders <- true
3714
3715 <-leaveHandler
3716 })
3717 st.greet()
3718
3719 st.writeHeaders(HeadersFrameParam{
3720 StreamID: 1,
3721 BlockFragment: st.encodeHeader(),
3722 EndStream: false,
3723 EndHeaders: true,
3724 })
3725 <-writeData
3726 st.writeData(1, true, nil)
3727
3728 <-writeHeaders
3729
3730 st.writeHeaders(HeadersFrameParam{
3731 StreamID: 1,
3732 BlockFragment: st.encodeHeader(),
3733 EndStream: false,
3734 EndHeaders: true,
3735 })
3736
3737 defer close(leaveHandler)
3738
3739 st.wantRSTStream(1, ErrCodeStreamClosed)
3740 }
3741
3742 func TestServerGracefulShutdown(t *testing.T) { synctestTest(t, testServerGracefulShutdown) }
3743 func testServerGracefulShutdown(t testing.TB) {
3744 handlerDone := make(chan struct{})
3745 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3746 <-handlerDone
3747 w.Header().Set("x-foo", "bar")
3748 })
3749 defer st.Close()
3750
3751 st.greet()
3752 st.bodylessReq1()
3753
3754 st.sync()
3755
3756 shutdownc := make(chan struct{})
3757 go func() {
3758 defer close(shutdownc)
3759 st.h1server.Shutdown(context.Background())
3760 }()
3761
3762 st.wantGoAway(1, ErrCodeNo)
3763
3764 close(handlerDone)
3765 st.sync()
3766
3767 st.wantHeaders(wantHeader{
3768 streamID: 1,
3769 endStream: true,
3770 header: http.Header{
3771 ":status": []string{"200"},
3772 "x-foo": []string{"bar"},
3773 "content-length": []string{"0"},
3774 },
3775 })
3776
3777 n, err := st.cc.Read([]byte{0})
3778 if n != 0 || err == nil {
3779 t.Errorf("Read = %v, %v; want 0, non-nil", n, err)
3780 }
3781
3782
3783 <-shutdownc
3784 }
3785
3786
3787 func TestContentEncodingNoSniffing(t *testing.T) {
3788 type resp struct {
3789 name string
3790 body []byte
3791
3792
3793
3794 contentEncoding any
3795 wantContentType string
3796 }
3797
3798 resps := []*resp{
3799 {
3800 name: "gzip content-encoding, gzipped",
3801 contentEncoding: "application/gzip",
3802 wantContentType: "",
3803 body: func() []byte {
3804 buf := new(bytes.Buffer)
3805 gzw := gzip.NewWriter(buf)
3806 gzw.Write([]byte("doctype html><p>Hello</p>"))
3807 gzw.Close()
3808 return buf.Bytes()
3809 }(),
3810 },
3811 {
3812 name: "zlib content-encoding, zlibbed",
3813 contentEncoding: "application/zlib",
3814 wantContentType: "",
3815 body: func() []byte {
3816 buf := new(bytes.Buffer)
3817 zw := zlib.NewWriter(buf)
3818 zw.Write([]byte("doctype html><p>Hello</p>"))
3819 zw.Close()
3820 return buf.Bytes()
3821 }(),
3822 },
3823 {
3824 name: "no content-encoding",
3825 wantContentType: "application/x-gzip",
3826 body: func() []byte {
3827 buf := new(bytes.Buffer)
3828 gzw := gzip.NewWriter(buf)
3829 gzw.Write([]byte("doctype html><p>Hello</p>"))
3830 gzw.Close()
3831 return buf.Bytes()
3832 }(),
3833 },
3834 {
3835 name: "phony content-encoding",
3836 contentEncoding: "foo/bar",
3837 body: []byte("doctype html><p>Hello</p>"),
3838 },
3839 {
3840 name: "empty but set content-encoding",
3841 contentEncoding: "",
3842 wantContentType: "audio/mpeg",
3843 body: []byte("ID3"),
3844 },
3845 }
3846
3847 for _, tt := range resps {
3848 synctestSubtest(t, tt.name, func(t testing.TB) {
3849 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3850 if tt.contentEncoding != nil {
3851 w.Header().Set("Content-Encoding", tt.contentEncoding.(string))
3852 }
3853 w.Write(tt.body)
3854 })
3855
3856 tr := &http.Transport{
3857 TLSClientConfig: tlsConfigInsecure,
3858 Protocols: protocols("h2"),
3859 }
3860 defer tr.CloseIdleConnections()
3861
3862 req, _ := http.NewRequest("GET", ts.URL, nil)
3863 res, err := tr.RoundTrip(req)
3864 if err != nil {
3865 t.Fatalf("GET %s: %v", ts.URL, err)
3866 }
3867 defer res.Body.Close()
3868
3869 g := res.Header.Get("Content-Encoding")
3870 t.Logf("%s: Content-Encoding: %s", ts.URL, g)
3871
3872 if w := tt.contentEncoding; g != w {
3873 if w != nil {
3874 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
3875 } else if g != "" {
3876 t.Errorf("Unexpected Content-Encoding %q", g)
3877 }
3878 }
3879
3880 g = res.Header.Get("Content-Type")
3881 if w := tt.wantContentType; g != w {
3882 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
3883 }
3884 t.Logf("%s: Content-Type: %s", ts.URL, g)
3885 })
3886 }
3887 }
3888
3889 func TestServerWindowUpdateOnBodyClose(t *testing.T) {
3890 synctestTest(t, testServerWindowUpdateOnBodyClose)
3891 }
3892 func testServerWindowUpdateOnBodyClose(t testing.TB) {
3893 const windowSize = 65535 * 2
3894 content := make([]byte, windowSize)
3895 errc := make(chan error)
3896 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3897 buf := make([]byte, 4)
3898 n, err := io.ReadFull(r.Body, buf)
3899 if err != nil {
3900 errc <- err
3901 return
3902 }
3903 if n != len(buf) {
3904 errc <- fmt.Errorf("too few bytes read: %d", n)
3905 return
3906 }
3907 r.Body.Close()
3908 errc <- nil
3909 }, func(h2 *http.HTTP2Config) {
3910 h2.MaxReceiveBufferPerConnection = windowSize
3911 h2.MaxReceiveBufferPerStream = windowSize
3912 })
3913 defer st.Close()
3914
3915 st.greet()
3916 st.writeHeaders(HeadersFrameParam{
3917 StreamID: 1,
3918 BlockFragment: st.encodeHeader(
3919 ":method", "POST",
3920 "content-length", strconv.Itoa(len(content)),
3921 ),
3922 EndStream: false,
3923 EndHeaders: true,
3924 })
3925 st.writeData(1, false, content[:windowSize/2])
3926 if err := <-errc; err != nil {
3927 t.Fatal(err)
3928 }
3929
3930
3931 increments := windowSize / 2
3932 for {
3933 f := st.readFrame()
3934 if f == nil {
3935 break
3936 }
3937 if wu, ok := f.(*WindowUpdateFrame); ok && wu.StreamID == 0 {
3938 increments -= int(wu.Increment)
3939 if increments == 0 {
3940 break
3941 }
3942 }
3943 }
3944
3945
3946 st.writeData(1, false, content[windowSize/2:])
3947 st.wantWindowUpdate(0, windowSize/2)
3948 }
3949
3950 func TestNoErrorLoggedOnPostAfterGOAWAY(t *testing.T) {
3951 synctestTest(t, testNoErrorLoggedOnPostAfterGOAWAY)
3952 }
3953 func testNoErrorLoggedOnPostAfterGOAWAY(t testing.TB) {
3954 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
3955 defer st.Close()
3956
3957 st.greet()
3958
3959 content := "some content"
3960 st.writeHeaders(HeadersFrameParam{
3961 StreamID: 1,
3962 BlockFragment: st.encodeHeader(
3963 ":method", "POST",
3964 "content-length", strconv.Itoa(len(content)),
3965 ),
3966 EndStream: false,
3967 EndHeaders: true,
3968 })
3969 st.wantHeaders(wantHeader{
3970 streamID: 1,
3971 endStream: true,
3972 })
3973
3974 st.sc.StartGracefulShutdown()
3975 st.wantRSTStream(1, ErrCodeNo)
3976 st.wantGoAway(1, ErrCodeNo)
3977
3978 st.writeData(1, true, []byte(content))
3979 st.Close()
3980
3981 if bytes.Contains(st.serverLogBuf.Bytes(), []byte("PROTOCOL_ERROR")) {
3982 t.Error("got protocol error")
3983 }
3984 }
3985
3986 func TestServerSendsProcessing(t *testing.T) { synctestTest(t, testServerSendsProcessing) }
3987 func testServerSendsProcessing(t testing.TB) {
3988 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3989 w.WriteHeader(http.StatusProcessing)
3990 w.Write([]byte("stuff"))
3991
3992 return nil
3993 }, func(st *serverTester) {
3994 getSlash(st)
3995 st.wantHeaders(wantHeader{
3996 streamID: 1,
3997 endStream: false,
3998 header: http.Header{
3999 ":status": []string{"102"},
4000 },
4001 })
4002 st.wantHeaders(wantHeader{
4003 streamID: 1,
4004 endStream: false,
4005 header: http.Header{
4006 ":status": []string{"200"},
4007 "content-type": []string{"text/plain; charset=utf-8"},
4008 "content-length": []string{"5"},
4009 },
4010 })
4011 })
4012 }
4013
4014 func TestServerSendsEarlyHints(t *testing.T) { synctestTest(t, testServerSendsEarlyHints) }
4015 func testServerSendsEarlyHints(t testing.TB) {
4016 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
4017 h := w.Header()
4018 h.Add("Content-Length", "123")
4019 h.Add("Link", "</style.css>; rel=preload; as=style")
4020 h.Add("Link", "</script.js>; rel=preload; as=script")
4021 w.WriteHeader(http.StatusEarlyHints)
4022
4023 h.Add("Link", "</foo.js>; rel=preload; as=script")
4024 w.WriteHeader(http.StatusEarlyHints)
4025
4026 w.Write([]byte("stuff"))
4027
4028 return nil
4029 }, func(st *serverTester) {
4030 getSlash(st)
4031 st.wantHeaders(wantHeader{
4032 streamID: 1,
4033 endStream: false,
4034 header: http.Header{
4035 ":status": []string{"103"},
4036 "link": []string{
4037 "</style.css>; rel=preload; as=style",
4038 "</script.js>; rel=preload; as=script",
4039 },
4040 },
4041 })
4042 st.wantHeaders(wantHeader{
4043 streamID: 1,
4044 endStream: false,
4045 header: http.Header{
4046 ":status": []string{"103"},
4047 "link": []string{
4048 "</style.css>; rel=preload; as=style",
4049 "</script.js>; rel=preload; as=script",
4050 "</foo.js>; rel=preload; as=script",
4051 },
4052 },
4053 })
4054 st.wantHeaders(wantHeader{
4055 streamID: 1,
4056 endStream: false,
4057 header: http.Header{
4058 ":status": []string{"200"},
4059 "link": []string{
4060 "</style.css>; rel=preload; as=style",
4061 "</script.js>; rel=preload; as=script",
4062 "</foo.js>; rel=preload; as=script",
4063 },
4064 "content-type": []string{"text/plain; charset=utf-8"},
4065 "content-length": []string{"123"},
4066 },
4067 })
4068 })
4069 }
4070
4071 func TestProtocolErrorAfterGoAway(t *testing.T) { synctestTest(t, testProtocolErrorAfterGoAway) }
4072 func testProtocolErrorAfterGoAway(t testing.TB) {
4073 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4074 io.Copy(io.Discard, r.Body)
4075 })
4076 defer st.Close()
4077
4078 st.greet()
4079 content := "some content"
4080 st.writeHeaders(HeadersFrameParam{
4081 StreamID: 1,
4082 BlockFragment: st.encodeHeader(
4083 ":method", "POST",
4084 "content-length", strconv.Itoa(len(content)),
4085 ),
4086 EndStream: false,
4087 EndHeaders: true,
4088 })
4089 st.writeData(1, false, []byte(content[:5]))
4090
4091
4092
4093 if err := st.fr.WriteGoAway(1, ErrCodeNo, nil); err != nil {
4094 t.Fatal(err)
4095 }
4096 if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
4097 t.Fatal(err)
4098 }
4099
4100 st.advance(GoAwayTimeout)
4101 st.wantGoAway(1, ErrCodeNo)
4102 st.wantClosed()
4103 }
4104
4105 func TestServerInitialFlowControlWindow(t *testing.T) {
4106 for _, want := range []int32{
4107 65535,
4108 1 << 19,
4109 1 << 21,
4110
4111
4112
4113
4114
4115 65535 * 2,
4116 } {
4117 synctestSubtest(t, fmt.Sprint(want), func(t testing.TB) {
4118
4119 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4120 }, func(h2 *http.HTTP2Config) {
4121 h2.MaxReceiveBufferPerConnection = int(want)
4122 })
4123 st.writePreface()
4124 st.writeSettings()
4125 _ = readFrame[*SettingsFrame](t, st)
4126 st.writeSettingsAck()
4127 st.writeHeaders(HeadersFrameParam{
4128 StreamID: 1,
4129 BlockFragment: st.encodeHeader(),
4130 EndStream: true,
4131 EndHeaders: true,
4132 })
4133 window := 65535
4134 Frames:
4135 for {
4136 f := st.readFrame()
4137 switch f := f.(type) {
4138 case *WindowUpdateFrame:
4139 if f.FrameHeader.StreamID != 0 {
4140 t.Errorf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
4141 return
4142 }
4143 window += int(f.Increment)
4144 case *HeadersFrame:
4145 break Frames
4146 case nil:
4147 break Frames
4148 default:
4149 }
4150 }
4151 if window != int(want) {
4152 t.Errorf("got initial flow control window = %v, want %v", window, want)
4153 }
4154 })
4155 }
4156 }
4157
4158
4159
4160
4161
4162
4163 func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) {
4164 synctestTest(t, testServerWriteDoesNotRetainBufferAfterReturn)
4165 }
4166 func testServerWriteDoesNotRetainBufferAfterReturn(t testing.TB) {
4167 donec := make(chan struct{})
4168 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4169 defer close(donec)
4170 buf := make([]byte, 1<<20)
4171 var i byte
4172 for {
4173 i++
4174 _, err := w.Write(buf)
4175 for j := range buf {
4176 buf[j] = byte(i)
4177 }
4178 if err != nil {
4179 return
4180 }
4181 }
4182 })
4183
4184 tr := &http.Transport{
4185 TLSClientConfig: tlsConfigInsecure,
4186 Protocols: protocols("h2"),
4187 }
4188 defer tr.CloseIdleConnections()
4189
4190 req, _ := http.NewRequest("GET", ts.URL, nil)
4191 res, err := tr.RoundTrip(req)
4192 if err != nil {
4193 t.Fatal(err)
4194 }
4195 res.Body.Close()
4196 <-donec
4197 }
4198
4199
4200
4201
4202
4203
4204 func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) {
4205 synctestTest(t, testServerWriteDoesNotRetainBufferAfterServerClose)
4206 }
4207 func testServerWriteDoesNotRetainBufferAfterServerClose(t testing.TB) {
4208 donec := make(chan struct{}, 1)
4209 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4210 donec <- struct{}{}
4211 defer close(donec)
4212 buf := make([]byte, 1<<20)
4213 var i byte
4214 for {
4215 i++
4216 _, err := w.Write(buf)
4217 for j := range buf {
4218 buf[j] = byte(i)
4219 }
4220 if err != nil {
4221 return
4222 }
4223 }
4224 })
4225
4226 tr := &http.Transport{
4227 TLSClientConfig: tlsConfigInsecure,
4228 Protocols: protocols("h2"),
4229 }
4230 defer tr.CloseIdleConnections()
4231
4232 req, _ := http.NewRequest("GET", ts.URL, nil)
4233 res, err := tr.RoundTrip(req)
4234 if err != nil {
4235 t.Fatal(err)
4236 }
4237 defer res.Body.Close()
4238 <-donec
4239 ts.Config.Close()
4240 <-donec
4241 }
4242
4243 func TestServerMaxHandlerGoroutines(t *testing.T) { synctestTest(t, testServerMaxHandlerGoroutines) }
4244 func testServerMaxHandlerGoroutines(t testing.TB) {
4245 const maxHandlers = 10
4246 handlerc := make(chan chan bool)
4247 donec := make(chan struct{})
4248 defer close(donec)
4249 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4250 stopc := make(chan bool, 1)
4251 select {
4252 case handlerc <- stopc:
4253 case <-donec:
4254 }
4255 select {
4256 case shouldPanic := <-stopc:
4257 if shouldPanic {
4258 panic(http.ErrAbortHandler)
4259 }
4260 case <-donec:
4261 }
4262 }, func(h2 *http.HTTP2Config) {
4263 h2.MaxConcurrentStreams = maxHandlers
4264 })
4265 defer st.Close()
4266
4267 st.greet()
4268
4269
4270
4271 var stops []chan bool
4272 streamID := uint32(1)
4273 for range maxHandlers {
4274 st.writeHeaders(HeadersFrameParam{
4275 StreamID: streamID,
4276 BlockFragment: st.encodeHeader(),
4277 EndStream: true,
4278 EndHeaders: true,
4279 })
4280 stops = append(stops, <-handlerc)
4281 st.fr.WriteRSTStream(streamID, ErrCodeCancel)
4282 streamID += 2
4283 }
4284
4285
4286 st.writeHeaders(HeadersFrameParam{
4287 StreamID: streamID,
4288 BlockFragment: st.encodeHeader(),
4289 EndStream: true,
4290 EndHeaders: true,
4291 })
4292 st.fr.WriteRSTStream(streamID, ErrCodeCancel)
4293 streamID += 2
4294
4295
4296 for range 2 {
4297 st.writeHeaders(HeadersFrameParam{
4298 StreamID: streamID,
4299 BlockFragment: st.encodeHeader(),
4300 EndStream: true,
4301 EndHeaders: true,
4302 })
4303 streamID += 2
4304 }
4305
4306
4307
4308 select {
4309 case <-handlerc:
4310 t.Errorf("handler unexpectedly started while maxHandlers are already running")
4311 case <-time.After(1 * time.Millisecond):
4312 }
4313
4314
4315
4316 stops[0] <- false
4317 stops[1] <- true
4318 stops = stops[2:]
4319 stops = append(stops, <-handlerc)
4320 stops = append(stops, <-handlerc)
4321
4322
4323
4324 for range 5 * maxHandlers {
4325 st.writeHeaders(HeadersFrameParam{
4326 StreamID: streamID,
4327 BlockFragment: st.encodeHeader(),
4328 EndStream: true,
4329 EndHeaders: true,
4330 })
4331 st.fr.WriteRSTStream(streamID, ErrCodeCancel)
4332 streamID += 2
4333 }
4334 fr := readFrame[*GoAwayFrame](t, st)
4335 if fr.ErrCode != ErrCodeEnhanceYourCalm {
4336 t.Errorf("err code = %v; want %v", fr.ErrCode, ErrCodeEnhanceYourCalm)
4337 }
4338
4339 for _, s := range stops {
4340 close(s)
4341 }
4342 }
4343
4344 func TestServerContinuationFlood(t *testing.T) { synctestTest(t, testServerContinuationFlood) }
4345 func testServerContinuationFlood(t testing.TB) {
4346 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4347 fmt.Println(r.Header)
4348 }, func(s *http.Server) {
4349 s.MaxHeaderBytes = 4096
4350 })
4351 defer st.Close()
4352
4353 st.greet()
4354
4355 st.writeHeaders(HeadersFrameParam{
4356 StreamID: 1,
4357 BlockFragment: st.encodeHeader(),
4358 EndStream: true,
4359 })
4360 for i := range 1000 {
4361 st.fr.WriteContinuation(1, false, st.encodeHeaderRaw(
4362 fmt.Sprintf("x-%v", i), "1234567890",
4363 ))
4364 }
4365 st.fr.WriteContinuation(1, true, st.encodeHeaderRaw(
4366 "x-last-header", "1",
4367 ))
4368
4369 for {
4370 f := st.readFrame()
4371 if f == nil {
4372 break
4373 }
4374 switch f := f.(type) {
4375 case *HeadersFrame:
4376 t.Fatalf("received HEADERS frame; want GOAWAY and a closed connection")
4377 case *GoAwayFrame:
4378
4379
4380
4381 if got, want := f.LastStreamID, uint32(1); got != want {
4382 t.Errorf("received GOAWAY with LastStreamId %v, want %v", got, want)
4383 }
4384
4385 }
4386 }
4387
4388
4389
4390
4391
4392
4393
4394
4395 }
4396
4397 func TestServerContinuationAfterInvalidHeader(t *testing.T) {
4398 synctestTest(t, testServerContinuationAfterInvalidHeader)
4399 }
4400 func testServerContinuationAfterInvalidHeader(t testing.TB) {
4401 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4402 fmt.Println(r.Header)
4403 })
4404 defer st.Close()
4405
4406 st.greet()
4407
4408 st.writeHeaders(HeadersFrameParam{
4409 StreamID: 1,
4410 BlockFragment: st.encodeHeader(),
4411 EndStream: true,
4412 })
4413 st.fr.WriteContinuation(1, false, st.encodeHeaderRaw(
4414 "x-invalid-header", "\x00",
4415 ))
4416 st.fr.WriteContinuation(1, true, st.encodeHeaderRaw(
4417 "x-valid-header", "1",
4418 ))
4419
4420 var sawGoAway bool
4421 for {
4422 f := st.readFrame()
4423 if f == nil {
4424 break
4425 }
4426 switch f.(type) {
4427 case *GoAwayFrame:
4428 sawGoAway = true
4429 case *HeadersFrame:
4430 t.Fatalf("received HEADERS frame; want GOAWAY")
4431 }
4432 }
4433 if !sawGoAway {
4434 t.Errorf("connection closed with no GOAWAY frame; want one")
4435 }
4436 }
4437
4438
4439 func TestServerRequestCancelOnError(t *testing.T) { synctestTest(t, testServerRequestCancelOnError) }
4440 func testServerRequestCancelOnError(t testing.TB) {
4441 recvc := make(chan struct{})
4442 donec := make(chan struct{})
4443 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4444 close(recvc)
4445 <-r.Context().Done()
4446 close(donec)
4447 })
4448 defer st.Close()
4449
4450 st.greet()
4451
4452
4453 st.writeHeaders(HeadersFrameParam{
4454 StreamID: 1,
4455 BlockFragment: st.encodeHeader(),
4456 EndStream: true,
4457 EndHeaders: true,
4458 })
4459 <-recvc
4460
4461
4462
4463
4464 st.writeHeaders(HeadersFrameParam{
4465 StreamID: 1,
4466 BlockFragment: st.encodeHeader(),
4467 EndStream: true,
4468 EndHeaders: true,
4469 })
4470 <-donec
4471 }
4472
4473 func TestServerSetReadWriteDeadlineRace(t *testing.T) {
4474 synctestTest(t, testServerSetReadWriteDeadlineRace)
4475 }
4476 func testServerSetReadWriteDeadlineRace(t testing.TB) {
4477 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4478 ctl := http.NewResponseController(w)
4479 ctl.SetReadDeadline(time.Now().Add(3600 * time.Second))
4480 ctl.SetWriteDeadline(time.Now().Add(3600 * time.Second))
4481 })
4482 resp, err := ts.Client().Get(ts.URL)
4483 if err != nil {
4484 t.Fatal(err)
4485 }
4486 resp.Body.Close()
4487 }
4488
4489 func TestServerWriteByteTimeout(t *testing.T) { synctestTest(t, testServerWriteByteTimeout) }
4490 func testServerWriteByteTimeout(t testing.TB) {
4491 const timeout = 1 * time.Second
4492 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4493 w.Write(make([]byte, 100))
4494 }, func(s *http.Server) {
4495
4496
4497
4498 s.Protocols = protocols("h2c")
4499 }, func(h2 *http.HTTP2Config) {
4500 h2.WriteByteTimeout = timeout
4501 })
4502 st.greet()
4503
4504 st.cc.(*synctestNetConn).SetReadBufferSize(1)
4505 st.writeHeaders(HeadersFrameParam{
4506 StreamID: 1,
4507 BlockFragment: st.encodeHeader(),
4508 EndStream: true,
4509 EndHeaders: true,
4510 })
4511
4512
4513 for i := range 10 {
4514 st.advance(timeout - 1)
4515 if n, err := st.cc.Read(make([]byte, 1)); n != 1 || err != nil {
4516 t.Fatalf("read %v: %v, %v; want 1, nil", i, n, err)
4517 }
4518 }
4519
4520
4521
4522 st.advance(1 * time.Second)
4523 st.advance(1 * time.Second)
4524 st.wantClosed()
4525 }
4526
4527 func TestServerPingSent(t *testing.T) { synctestTest(t, testServerPingSent) }
4528 func testServerPingSent(t testing.TB) {
4529 const sendPingTimeout = 15 * time.Second
4530 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4531 }, func(h2 *http.HTTP2Config) {
4532 h2.SendPingTimeout = sendPingTimeout
4533 })
4534 st.greet()
4535
4536 st.wantIdle()
4537
4538 st.advance(sendPingTimeout)
4539 _ = readFrame[*PingFrame](t, st)
4540 st.wantIdle()
4541
4542 st.advance(14 * time.Second)
4543 st.wantIdle()
4544 st.advance(1 * time.Second)
4545 st.wantClosed()
4546 }
4547
4548 func TestServerPingResponded(t *testing.T) { synctestTest(t, testServerPingResponded) }
4549 func testServerPingResponded(t testing.TB) {
4550 const sendPingTimeout = 15 * time.Second
4551 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4552 }, func(h2 *http.HTTP2Config) {
4553 h2.SendPingTimeout = sendPingTimeout
4554 })
4555 st.greet()
4556
4557 st.wantIdle()
4558
4559 st.advance(sendPingTimeout)
4560 pf := readFrame[*PingFrame](t, st)
4561 st.wantIdle()
4562
4563 st.advance(14 * time.Second)
4564 st.wantIdle()
4565
4566 st.writePing(true, pf.Data)
4567
4568 st.advance(2 * time.Second)
4569 st.wantIdle()
4570 }
4571
4572
4573
4574
4575
4576 func TestServerSendDataAfterRequestBodyClose(t *testing.T) {
4577 synctestTest(t, testServerSendDataAfterRequestBodyClose)
4578 }
4579 func testServerSendDataAfterRequestBodyClose(t testing.TB) {
4580 st := newServerTester(t, nil)
4581 st.greet()
4582
4583 st.writeHeaders(HeadersFrameParam{
4584 StreamID: 1,
4585 BlockFragment: st.encodeHeader(),
4586 EndStream: false,
4587 EndHeaders: true,
4588 })
4589
4590
4591 call := st.nextHandlerCall()
4592 call.do(func(w http.ResponseWriter, req *http.Request) {
4593 w.Write([]byte("one"))
4594 http.NewResponseController(w).Flush()
4595 })
4596 st.wantFrameType(FrameHeaders)
4597 st.wantData(wantData{
4598 streamID: 1,
4599 endStream: false,
4600 data: []byte("one"),
4601 })
4602 st.wantIdle()
4603
4604
4605
4606 call.do(func(w http.ResponseWriter, req *http.Request) {
4607 req.Body.Close()
4608 })
4609 st.wantIdle()
4610
4611
4612 st.writeData(1, false, []byte("client-sent data"))
4613 st.wantIdle()
4614
4615
4616
4617 call.do(func(w http.ResponseWriter, req *http.Request) {
4618 w.Write([]byte("two"))
4619 http.NewResponseController(w).Flush()
4620 })
4621 st.wantData(wantData{
4622 streamID: 1,
4623 endStream: false,
4624 data: []byte("two"),
4625 })
4626 st.wantIdle()
4627 }
4628
4629 func TestServerSettingNoRFC7540Priorities(t *testing.T) {
4630 synctestTest(t, testServerSettingNoRFC7540Priorities)
4631 }
4632 func testServerSettingNoRFC7540Priorities(t testing.TB) {
4633 const wantNoRFC7540Setting = true
4634 st := newServerTester(t, nil)
4635 defer st.Close()
4636
4637 var gotNoRFC7540Setting bool
4638 st.greetAndCheckSettings(func(s Setting) error {
4639 if s.ID != SettingNoRFC7540Priorities {
4640 return nil
4641 }
4642 gotNoRFC7540Setting = s.Val == 1
4643 return nil
4644 })
4645 if wantNoRFC7540Setting != gotNoRFC7540Setting {
4646 t.Errorf("want SETTINGS_NO_RFC7540_PRIORITIES to be %v, got %v", wantNoRFC7540Setting, gotNoRFC7540Setting)
4647 }
4648 }
4649
4650 func TestServerSettingNoRFC7540PrioritiesInvalid(t *testing.T) {
4651 synctestTest(t, testServerSettingNoRFC7540PrioritiesInvalid)
4652 }
4653 func testServerSettingNoRFC7540PrioritiesInvalid(t testing.TB) {
4654 st := newServerTester(t, nil)
4655 defer st.Close()
4656
4657 st.writePreface()
4658 st.writeSettings(Setting{ID: SettingNoRFC7540Priorities, Val: 2})
4659 synctest.Wait()
4660 st.readFrame()
4661 st.readFrame()
4662 st.wantGoAway(0, ErrCodeProtocol)
4663 }
4664
4665
4666
4667 func TestServerRFC9218PrioritySmallPayload(t *testing.T) {
4668 synctestTest(t, testServerRFC9218PrioritySmallPayload)
4669 }
4670 func testServerRFC9218PrioritySmallPayload(t testing.TB) {
4671 endTest := false
4672 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4673 for !endTest {
4674 w.Write([]byte("a"))
4675 if f, ok := w.(http.Flusher); ok {
4676 f.Flush()
4677 }
4678 }
4679 }, func(s *http.Server) {
4680 s.Protocols = protocols("h2c")
4681 })
4682 st.greet()
4683 if syncConn, ok := st.cc.(*synctestNetConn); ok {
4684 syncConn.SetReadBufferSize(1)
4685 } else {
4686 t.Fatal("Server connection is not synctestNetConn")
4687 }
4688 defer st.Close()
4689 defer func() { endTest = true }()
4690
4691
4692
4693
4694
4695
4696 for i := 1; i <= 19; i += 2 {
4697 urgency := uint8(0)
4698 if i > 10 {
4699 urgency = 7
4700 }
4701 st.writeHeaders(HeadersFrameParam{
4702 StreamID: uint32(i),
4703 BlockFragment: st.encodeHeader("priority", fmt.Sprintf("u=%d", urgency)),
4704 EndStream: true,
4705 EndHeaders: true,
4706 })
4707 synctest.Wait()
4708 }
4709
4710
4711
4712 streamWriteCount := make(map[uint32]int)
4713 totalWriteCount := 10000
4714 for range totalWriteCount {
4715 f := st.readFrame()
4716 if f == nil {
4717 break
4718 }
4719 streamWriteCount[f.Header().StreamID] += 1
4720 }
4721 for streamID, writeCount := range streamWriteCount {
4722 expectedWriteCount := totalWriteCount / len(streamWriteCount)
4723 errorMargin := expectedWriteCount / 100
4724 if writeCount >= expectedWriteCount+errorMargin || writeCount <= expectedWriteCount-errorMargin {
4725 t.Errorf("Expected stream %v to receive %v±%v writes, got %v", streamID, expectedWriteCount, errorMargin, writeCount)
4726 }
4727 }
4728 }
4729
4730 func TestServerRFC9218Priority(t *testing.T) {
4731 synctestTest(t, testServerRFC9218Priority)
4732 }
4733 func testServerRFC9218Priority(t testing.TB) {
4734 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4735 w.Write(slices.Repeat([]byte("a"), 16<<20))
4736 if f, ok := w.(http.Flusher); ok {
4737 f.Flush()
4738 }
4739 }, func(s *http.Server) {
4740 s.Protocols = protocols("h2c")
4741 })
4742 defer st.Close()
4743 st.greet()
4744 if syncConn, ok := st.cc.(*synctestNetConn); ok {
4745 syncConn.SetReadBufferSize(1)
4746 } else {
4747 t.Fatal("Server connection is not synctestNetConn")
4748 }
4749 st.writeWindowUpdate(0, 1<<30)
4750 synctest.Wait()
4751
4752
4753
4754 for i := range 8 {
4755 streamID := uint32(i*2 + 1)
4756 urgency := 7 - i
4757 st.writeHeaders(HeadersFrameParam{
4758 StreamID: streamID,
4759 BlockFragment: st.encodeHeader("priority", fmt.Sprintf("u=%d", urgency)),
4760 EndStream: true,
4761 EndHeaders: true,
4762 })
4763 }
4764 synctest.Wait()
4765
4766
4767
4768 lastFrame := make(map[uint32]int)
4769 for i := 0; ; i++ {
4770 f := st.readFrame()
4771 if f == nil {
4772 break
4773 }
4774 lastFrame[f.Header().StreamID] = i
4775 }
4776 for i := range 7 {
4777 streamID := uint32(i*2 + 1)
4778 nextStreamID := streamID + 2
4779 if lastFrame[streamID] < lastFrame[nextStreamID] {
4780 t.Errorf("stream %d finished before stream %d unexpectedly", streamID, nextStreamID)
4781 }
4782 }
4783 }
4784
4785 func TestServerRFC9218PriorityIgnoredWhenProxied(t *testing.T) {
4786 synctestTest(t, testServerRFC9218PriorityIgnoredWhenProxied)
4787 }
4788 func testServerRFC9218PriorityIgnoredWhenProxied(t testing.TB) {
4789 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4790 w.Write(slices.Repeat([]byte("a"), 16<<20))
4791 if f, ok := w.(http.Flusher); ok {
4792 f.Flush()
4793 }
4794 }, func(s *http.Server) {
4795 s.Protocols = protocols("h2c")
4796 })
4797 defer st.Close()
4798 st.greet()
4799 if syncConn, ok := st.cc.(*synctestNetConn); ok {
4800 syncConn.SetReadBufferSize(1)
4801 } else {
4802 t.Fatal("Server connection is not synctestNetConn")
4803 }
4804 st.writeWindowUpdate(0, 1<<30)
4805 synctest.Wait()
4806
4807
4808
4809
4810 for i := range 8 {
4811 streamID := uint32(i*2 + 1)
4812 urgency := 7 - i
4813 st.writeHeaders(HeadersFrameParam{
4814 StreamID: streamID,
4815 BlockFragment: st.encodeHeader("priority", fmt.Sprintf("u=%d", urgency), "via", "a proxy"),
4816 EndStream: true,
4817 EndHeaders: true,
4818 })
4819 }
4820 synctest.Wait()
4821 var streamFrameOrder []uint32
4822 for f := st.readFrame(); f != nil; f = st.readFrame() {
4823 streamFrameOrder = append(streamFrameOrder, f.Header().StreamID)
4824 }
4825
4826
4827
4828 half := streamFrameOrder[len(streamFrameOrder)/4 : len(streamFrameOrder)*3/4]
4829 if !slices.Equal(slices.Compact(half), half) {
4830 t.Errorf("want stream to be processed in round-robin manner when proxied, got: %v", streamFrameOrder)
4831 }
4832 }
4833
4834 func TestServerRFC9218PriorityAware(t *testing.T) {
4835 synctestTest(t, testServerRFC9218PriorityAware)
4836 }
4837 func testServerRFC9218PriorityAware(t testing.TB) {
4838 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4839 w.Write(slices.Repeat([]byte("a"), 16<<20))
4840 if f, ok := w.(http.Flusher); ok {
4841 f.Flush()
4842 }
4843 }, func(s *http.Server) {
4844 s.Protocols = protocols("h2c")
4845 })
4846 defer st.Close()
4847 st.greet()
4848 if syncConn, ok := st.cc.(*synctestNetConn); ok {
4849 syncConn.SetReadBufferSize(1)
4850 } else {
4851 t.Fatal("Server connection is not synctestNetConn")
4852 }
4853 st.writeWindowUpdate(0, 1<<30)
4854 synctest.Wait()
4855
4856
4857
4858 streamCount := 10
4859 for i := range streamCount {
4860 streamID := uint32(i*2 + 1)
4861 st.writeHeaders(HeadersFrameParam{
4862 StreamID: streamID,
4863 BlockFragment: st.encodeHeader(),
4864 EndStream: true,
4865 EndHeaders: true,
4866 })
4867 }
4868 synctest.Wait()
4869 var streamFrameOrder []uint32
4870 for f := st.readFrame(); f != nil; f = st.readFrame() {
4871 streamFrameOrder = append(streamFrameOrder, f.Header().StreamID)
4872 }
4873
4874
4875
4876 half := streamFrameOrder[len(streamFrameOrder)/4 : len(streamFrameOrder)*3/4]
4877 if !slices.Equal(slices.Compact(half), half) {
4878 t.Errorf("want stream to be processed in round-robin manner when unaware of priority, got: %v", streamFrameOrder)
4879 }
4880
4881
4882
4883
4884 st.writePriorityUpdate(1, "")
4885 synctest.Wait()
4886
4887
4888
4889
4890
4891 streamFrameOrder = []uint32{}
4892 for i := range streamCount {
4893 i += streamCount
4894 streamID := uint32(i*2 + 1)
4895 st.writeHeaders(HeadersFrameParam{
4896 StreamID: streamID,
4897 BlockFragment: st.encodeHeader(),
4898 EndStream: true,
4899 EndHeaders: true,
4900 })
4901 }
4902 for f := st.readFrame(); f != nil; f = st.readFrame() {
4903 streamFrameOrder = append(streamFrameOrder, f.Header().StreamID)
4904 }
4905 if !slices.Equal(slices.Compact(half), half) {
4906 t.Errorf("want stream to be processed one-by-one to completion when aware of priority, got: %v", streamFrameOrder)
4907 }
4908 }
4909
4910 func TestConsistentConstants(t *testing.T) {
4911 if h1, h2 := http.DefaultMaxHeaderBytes, http2.DefaultMaxHeaderBytes; h1 != h2 {
4912 t.Errorf("DefaultMaxHeaderBytes: http (%v) != http2 (%v)", h1, h2)
4913 }
4914 if h1, h2 := http.TimeFormat, http2.TimeFormat; h1 != h2 {
4915 t.Errorf("TimeFormat: http (%v) != http2 (%v)", h1, h2)
4916 }
4917 }
4918
4919 var (
4920 testServerTLSConfig *tls.Config
4921 testClientTLSConfig *tls.Config
4922 )
4923
4924 func init() {
4925 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4926 if err != nil {
4927 panic(err)
4928 }
4929 testServerTLSConfig = &tls.Config{
4930 Certificates: []tls.Certificate{cert},
4931 NextProtos: []string{"h2"},
4932 }
4933
4934 x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
4935 if err != nil {
4936 panic(err)
4937 }
4938 certpool := x509.NewCertPool()
4939 certpool.AddCert(x509Cert)
4940 testClientTLSConfig = &tls.Config{
4941 InsecureSkipVerify: true,
4942 RootCAs: certpool,
4943 NextProtos: []string{"h2"},
4944 }
4945 }
4946
4947 func protocols(protos ...string) *http.Protocols {
4948 p := new(http.Protocols)
4949 for _, s := range protos {
4950 switch s {
4951 case "h1":
4952 p.SetHTTP1(true)
4953 case "h2":
4954 p.SetHTTP2(true)
4955 case "h2c":
4956 p.SetUnencryptedHTTP2(true)
4957 default:
4958 panic("unknown protocol: " + s)
4959 }
4960 }
4961 return p
4962 }
4963
4964
4965 func transportFromH1Transport(tr *http.Transport) any
4966
View as plain text