1
2
3
4
5 package http2_test
6
7 import (
8 "bufio"
9 "bytes"
10 "compress/gzip"
11 "context"
12 crand "crypto/rand"
13 "crypto/tls"
14 "encoding/hex"
15 "errors"
16 "flag"
17 "fmt"
18 "io"
19 "log"
20 "math/rand"
21 "net"
22 "net/http"
23 "net/http/httptest"
24 "net/http/httptrace"
25 "net/textproto"
26 "net/url"
27 "os"
28 "reflect"
29 "sort"
30 "strconv"
31 "strings"
32 "sync"
33 "sync/atomic"
34 "testing"
35 "testing/synctest"
36 "time"
37
38 . "net/http/internal/http2"
39 "net/http/internal/httpcommon"
40
41 "golang.org/x/net/http2/hpack"
42 )
43
44 var (
45 extNet = flag.Bool("extnet", false, "do external network tests")
46 transportHost = flag.String("transporthost", "go.dev", "hostname to use for TestTransport")
47 )
48
49 var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
50
51 var canceledCtx context.Context
52
53 func init() {
54 ctx, cancel := context.WithCancel(context.Background())
55 cancel()
56 canceledCtx = ctx
57 }
58
59
60 func newTransport(t testing.TB, opts ...any) *http.Transport {
61 tr1 := &http.Transport{
62 TLSClientConfig: tlsConfigInsecure,
63 Protocols: protocols("h2"),
64 HTTP2: &http.HTTP2Config{},
65 }
66 for _, o := range opts {
67 switch o := o.(type) {
68 case func(*http.Transport):
69 o(tr1)
70 case func(*http.HTTP2Config):
71 o(tr1.HTTP2)
72 default:
73 t.Fatalf("unknown newTransport option %T", o)
74 }
75 }
76 t.Cleanup(tr1.CloseIdleConnections)
77 return tr1
78 }
79
80 func TestTransportExternal(t *testing.T) {
81 if !*extNet {
82 t.Skip("skipping external network test")
83 }
84 req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
85 rt := newTransport(t)
86 res, err := rt.RoundTrip(req)
87 if err != nil {
88 t.Fatalf("%v", err)
89 }
90 res.Write(os.Stdout)
91 }
92
93 func TestIdleConnTimeout(t *testing.T) {
94 for _, test := range []struct {
95 name string
96 idleConnTimeout time.Duration
97 wait time.Duration
98 baseTransport *http.Transport
99 wantNewConn bool
100 }{{
101 name: "NoExpiry",
102 idleConnTimeout: 2 * time.Second,
103 wait: 1 * time.Second,
104 baseTransport: nil,
105 wantNewConn: false,
106 }, {
107 name: "H2TransportTimeoutExpires",
108 idleConnTimeout: 1 * time.Second,
109 wait: 2 * time.Second,
110 baseTransport: nil,
111 wantNewConn: true,
112 }, {
113 name: "H1TransportTimeoutExpires",
114 idleConnTimeout: 0 * time.Second,
115 wait: 1 * time.Second,
116 baseTransport: newTransport(t, func(tr1 *http.Transport) {
117 tr1.IdleConnTimeout = 2 * time.Second
118 }),
119 wantNewConn: false,
120 }} {
121 synctestSubtest(t, test.name, func(t testing.TB) {
122 tt := newTestTransport(t, func(tr *http.Transport) {
123 tr.IdleConnTimeout = test.idleConnTimeout
124 })
125 var tc *testClientConn
126 for i := range 3 {
127 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
128 rt := tt.roundTrip(req)
129
130
131
132
133 wantConn := i == 0 || test.wantNewConn
134 if has := tt.hasConn(); has != wantConn {
135 t.Fatalf("request %v: hasConn=%v, want %v", i, has, wantConn)
136 }
137 if wantConn {
138 tc = tt.getConn()
139
140
141 tc.wantFrameType(FrameSettings)
142 tc.wantFrameType(FrameWindowUpdate)
143 tc.writeSettings()
144 }
145 if tt.hasConn() {
146 t.Fatalf("request %v: Transport has more than one conn", i)
147 }
148
149
150 hf := readFrame[*HeadersFrame](t, tc)
151 tc.writeHeaders(HeadersFrameParam{
152 StreamID: hf.StreamID,
153 EndHeaders: true,
154 EndStream: true,
155 BlockFragment: tc.makeHeaderBlockFragment(
156 ":status", "200",
157 ),
158 })
159 rt.wantStatus(200)
160
161
162 if wantConn {
163 tc.wantFrameType(FrameSettings)
164 }
165
166 time.Sleep(test.wait)
167 if got, want := tc.isClosed(), test.wantNewConn; got != want {
168 t.Fatalf("after waiting %v, conn closed=%v; want %v", test.wait, got, want)
169 }
170 }
171 })
172 }
173 }
174
175 func TestTransportH2c(t *testing.T) {
176 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
177 fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil)
178 }, func(s *http.Server) {
179 s.Protocols = protocols("h2c")
180 })
181 req, err := http.NewRequest("GET", ts.URL+"/foobar", nil)
182 if err != nil {
183 t.Fatal(err)
184 }
185 var gotConnCnt int32
186 trace := &httptrace.ClientTrace{
187 GotConn: func(connInfo httptrace.GotConnInfo) {
188 if !connInfo.Reused {
189 atomic.AddInt32(&gotConnCnt, 1)
190 }
191 },
192 }
193 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
194 tr := newTransport(t)
195 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
196 return net.Dial(network, addr)
197 }
198 tr.Protocols = protocols("h2c")
199 res, err := tr.RoundTrip(req)
200 if err != nil {
201 t.Fatal(err)
202 }
203 if res.ProtoMajor != 2 {
204 t.Fatal("proto not h2c")
205 }
206 body, err := io.ReadAll(res.Body)
207 if err != nil {
208 t.Fatal(err)
209 }
210 if got, want := string(body), "Hello, /foobar, http: true"; got != want {
211 t.Fatalf("response got %v, want %v", got, want)
212 }
213 if got, want := gotConnCnt, int32(1); got != want {
214 t.Errorf("Too many got connections: %d", gotConnCnt)
215 }
216 }
217
218 func TestTransport(t *testing.T) {
219 const body = "sup"
220 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
221 io.WriteString(w, body)
222 })
223
224 tr := ts.Client().Transport.(*http.Transport)
225 defer tr.CloseIdleConnections()
226
227 u, err := url.Parse(ts.URL)
228 if err != nil {
229 t.Fatal(err)
230 }
231 for i, m := range []string{"GET", ""} {
232 req := &http.Request{
233 Method: m,
234 URL: u,
235 Header: http.Header{},
236 }
237 res, err := tr.RoundTrip(req)
238 if err != nil {
239 t.Fatalf("%d: %s", i, err)
240 }
241
242 t.Logf("%d: Got res: %+v", i, res)
243 if g, w := res.StatusCode, 200; g != w {
244 t.Errorf("%d: StatusCode = %v; want %v", i, g, w)
245 }
246 if g, w := res.Status, "200 OK"; g != w {
247 t.Errorf("%d: Status = %q; want %q", i, g, w)
248 }
249 wantHeader := http.Header{
250 "Content-Length": []string{"3"},
251 "Content-Type": []string{"text/plain; charset=utf-8"},
252 "Date": []string{"XXX"},
253 }
254
255 if d := res.Header["Date"]; len(d) == 1 {
256 d[0] = "XXX"
257 }
258 if !reflect.DeepEqual(res.Header, wantHeader) {
259 t.Errorf("%d: res Header = %v; want %v", i, res.Header, wantHeader)
260 }
261 if res.Request != req {
262 t.Errorf("%d: Response.Request = %p; want %p", i, res.Request, req)
263 }
264 if res.TLS == nil {
265 t.Errorf("%d: Response.TLS = nil; want non-nil", i)
266 }
267 slurp, err := io.ReadAll(res.Body)
268 if err != nil {
269 t.Errorf("%d: Body read: %v", i, err)
270 } else if string(slurp) != body {
271 t.Errorf("%d: Body = %q; want %q", i, slurp, body)
272 }
273 res.Body.Close()
274 }
275 }
276
277 func TestTransportFailureErrorForHTTP1Response(t *testing.T) {
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299 t.Skip("test is racy")
300
301 const expectedHTTP1PayloadHint = "frame header looked like an HTTP/1.1 header"
302
303 ts := httptest.NewServer(http.NewServeMux())
304 t.Cleanup(ts.Close)
305
306 for _, tc := range []struct {
307 name string
308 maxFrameSize uint32
309 expectedErrorIs error
310 }{
311 {
312 name: "with default max frame size",
313 maxFrameSize: 0,
314 },
315 {
316 name: "with enough frame size to start reading",
317 maxFrameSize: InvalidHTTP1LookingFrameHeader().Length + 1,
318 },
319 } {
320 t.Run(tc.name, func(t *testing.T) {
321 tr := newTransport(t)
322 tr.HTTP2.MaxReadFrameSize = int(tc.maxFrameSize)
323 tr.Protocols = protocols("h2c")
324
325 req, err := http.NewRequest("GET", ts.URL, nil)
326 if err != nil {
327 t.Fatal(err)
328 }
329
330 _, err = tr.RoundTrip(req)
331 if err == nil || !strings.Contains(err.Error(), expectedHTTP1PayloadHint) {
332 t.Errorf("expected error to contain %q, got %v", expectedHTTP1PayloadHint, err)
333 }
334 })
335 }
336 }
337
338 func testTransportReusesConns(t *testing.T, wantSame bool, modReq func(*http.Request)) {
339 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
340 io.WriteString(w, r.RemoteAddr)
341 }, func(ts *httptest.Server) {
342 ts.Config.ConnState = func(c net.Conn, st http.ConnState) {
343 t.Logf("conn %v is now state %v", c.RemoteAddr(), st)
344 }
345 })
346 tr := newTransport(t)
347 get := func() string {
348 req, err := http.NewRequest("GET", ts.URL, nil)
349 if err != nil {
350 t.Fatal(err)
351 }
352 modReq(req)
353 res, err := tr.RoundTrip(req)
354 if err != nil {
355 t.Fatal(err)
356 }
357 defer res.Body.Close()
358 slurp, err := io.ReadAll(res.Body)
359 if err != nil {
360 t.Fatalf("Body read: %v", err)
361 }
362 addr := strings.TrimSpace(string(slurp))
363 if addr == "" {
364 t.Fatalf("didn't get an addr in response")
365 }
366 return addr
367 }
368 first := get()
369 second := get()
370 if got := first == second; got != wantSame {
371 t.Errorf("first and second responses on same connection: %v; want %v", got, wantSame)
372 }
373 }
374
375 func TestTransportReusesConns(t *testing.T) {
376 for _, test := range []struct {
377 name string
378 modReq func(*http.Request)
379 wantSame bool
380 }{{
381 name: "ReuseConn",
382 modReq: func(*http.Request) {},
383 wantSame: true,
384 }, {
385 name: "RequestClose",
386 modReq: func(r *http.Request) { r.Close = true },
387 wantSame: false,
388 }, {
389 name: "ConnClose",
390 modReq: func(r *http.Request) { r.Header.Set("Connection", "close") },
391 wantSame: false,
392 }} {
393 t.Run(test.name, func(t *testing.T) {
394 testTransportReusesConns(t, test.wantSame, test.modReq)
395 })
396 }
397 }
398
399 func TestTransportGetGotConnHooks_HTTP2Transport(t *testing.T) {
400 testTransportGetGotConnHooks(t, false)
401 }
402 func TestTransportGetGotConnHooks_Client(t *testing.T) { testTransportGetGotConnHooks(t, true) }
403
404 func testTransportGetGotConnHooks(t *testing.T, useClient bool) {
405 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
406 io.WriteString(w, r.RemoteAddr)
407 })
408
409 tr := newTransport(t)
410 client := ts.Client()
411
412 var (
413 getConns int32
414 gotConns int32
415 )
416 for i := range 2 {
417 trace := &httptrace.ClientTrace{
418 GetConn: func(hostport string) {
419 atomic.AddInt32(&getConns, 1)
420 },
421 GotConn: func(connInfo httptrace.GotConnInfo) {
422 got := atomic.AddInt32(&gotConns, 1)
423 wantReused, wantWasIdle := false, false
424 if got > 1 {
425 wantReused, wantWasIdle = true, true
426 }
427 if connInfo.Reused != wantReused || connInfo.WasIdle != wantWasIdle {
428 t.Errorf("GotConn %v: Reused=%v (want %v), WasIdle=%v (want %v)", i, connInfo.Reused, wantReused, connInfo.WasIdle, wantWasIdle)
429 }
430 },
431 }
432 req, err := http.NewRequest("GET", ts.URL, nil)
433 if err != nil {
434 t.Fatal(err)
435 }
436 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
437
438 var res *http.Response
439 if useClient {
440 res, err = client.Do(req)
441 } else {
442 res, err = tr.RoundTrip(req)
443 }
444 if err != nil {
445 t.Fatal(err)
446 }
447 res.Body.Close()
448 if get := atomic.LoadInt32(&getConns); get != int32(i+1) {
449 t.Errorf("after request %v, %v calls to GetConns: want %v", i, get, i+1)
450 }
451 if got := atomic.LoadInt32(&gotConns); got != int32(i+1) {
452 t.Errorf("after request %v, %v calls to GotConns: want %v", i, got, i+1)
453 }
454 }
455 }
456
457 func TestTransportAbortClosesPipes(t *testing.T) {
458 shutdown := make(chan struct{})
459 ts := newTestServer(t,
460 func(w http.ResponseWriter, r *http.Request) {
461 w.(http.Flusher).Flush()
462 <-shutdown
463 },
464 )
465 defer close(shutdown)
466
467 errCh := make(chan error)
468 go func() {
469 defer close(errCh)
470 tr := newTransport(t)
471 req, err := http.NewRequest("GET", ts.URL, nil)
472 if err != nil {
473 errCh <- err
474 return
475 }
476 res, err := tr.RoundTrip(req)
477 if err != nil {
478 errCh <- err
479 return
480 }
481 defer res.Body.Close()
482 ts.CloseClientConnections()
483 _, err = io.ReadAll(res.Body)
484 if err == nil {
485 errCh <- errors.New("expected error from res.Body.Read")
486 return
487 }
488 }()
489
490 select {
491 case err := <-errCh:
492 if err != nil {
493 t.Fatal(err)
494 }
495
496 case <-time.After(3 * time.Second):
497 t.Fatal("timeout")
498 }
499 }
500
501
502
503 func TestTransportPath(t *testing.T) {
504 gotc := make(chan *url.URL, 1)
505 ts := newTestServer(t,
506 func(w http.ResponseWriter, r *http.Request) {
507 gotc <- r.URL
508 },
509 )
510
511 tr := newTransport(t)
512 const (
513 path = "/testpath"
514 query = "q=1"
515 )
516 surl := ts.URL + path + "?" + query
517 req, err := http.NewRequest("POST", surl, nil)
518 if err != nil {
519 t.Fatal(err)
520 }
521 c := &http.Client{Transport: tr}
522 res, err := c.Do(req)
523 if err != nil {
524 t.Fatal(err)
525 }
526 defer res.Body.Close()
527 got := <-gotc
528 if got.Path != path {
529 t.Errorf("Read Path = %q; want %q", got.Path, path)
530 }
531 if got.RawQuery != query {
532 t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query)
533 }
534 }
535
536 func randString(n int) string {
537 rnd := rand.New(rand.NewSource(int64(n)))
538 b := make([]byte, n)
539 for i := range b {
540 b[i] = byte(rnd.Intn(256))
541 }
542 return string(b)
543 }
544
545 func TestTransportBody(t *testing.T) {
546 bodyTests := []struct {
547 body string
548 noContentLen bool
549 }{
550 {body: "some message"},
551 {body: "some message", noContentLen: true},
552 {body: strings.Repeat("a", 1<<20), noContentLen: true},
553 {body: strings.Repeat("a", 1<<20)},
554 {body: randString(16<<10 - 1)},
555 {body: randString(16 << 10)},
556 {body: randString(16<<10 + 1)},
557 {body: randString(512<<10 - 1)},
558 {body: randString(512 << 10)},
559 {body: randString(512<<10 + 1)},
560 {body: randString(1<<20 - 1)},
561 {body: randString(1 << 20)},
562 {body: randString(1<<20 + 2)},
563 }
564
565 type reqInfo struct {
566 req *http.Request
567 slurp []byte
568 err error
569 }
570 gotc := make(chan reqInfo, 1)
571 ts := newTestServer(t,
572 func(w http.ResponseWriter, r *http.Request) {
573 slurp, err := io.ReadAll(r.Body)
574 if err != nil {
575 gotc <- reqInfo{err: err}
576 } else {
577 gotc <- reqInfo{req: r, slurp: slurp}
578 }
579 },
580 )
581
582 for i, tt := range bodyTests {
583 tr := newTransport(t)
584
585 var body io.Reader = strings.NewReader(tt.body)
586 if tt.noContentLen {
587 body = struct{ io.Reader }{body}
588 }
589 req, err := http.NewRequest("POST", ts.URL, body)
590 if err != nil {
591 t.Fatalf("#%d: %v", i, err)
592 }
593 c := &http.Client{Transport: tr}
594 res, err := c.Do(req)
595 if err != nil {
596 t.Fatalf("#%d: %v", i, err)
597 }
598 defer res.Body.Close()
599 ri := <-gotc
600 if ri.err != nil {
601 t.Errorf("#%d: read error: %v", i, ri.err)
602 continue
603 }
604 if got := string(ri.slurp); got != tt.body {
605 t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
606 }
607 wantLen := int64(len(tt.body))
608 if tt.noContentLen && tt.body != "" {
609 wantLen = -1
610 }
611 if ri.req.ContentLength != wantLen {
612 t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen)
613 }
614 }
615 }
616
617 func shortString(v string) string {
618 const maxLen = 100
619 if len(v) <= maxLen {
620 return v
621 }
622 return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:])
623 }
624
625 type capitalizeReader struct {
626 r io.Reader
627 }
628
629 func (cr capitalizeReader) Read(p []byte) (n int, err error) {
630 n, err = cr.r.Read(p)
631 for i, b := range p[:n] {
632 if b >= 'a' && b <= 'z' {
633 p[i] = b - ('a' - 'A')
634 }
635 }
636 return
637 }
638
639 type flushWriter struct {
640 w io.Writer
641 }
642
643 func (fw flushWriter) Write(p []byte) (n int, err error) {
644 n, err = fw.w.Write(p)
645 if f, ok := fw.w.(http.Flusher); ok {
646 f.Flush()
647 }
648 return
649 }
650
651 func newLocalListener(t *testing.T) net.Listener {
652 ln, err := net.Listen("tcp4", "127.0.0.1:0")
653 if err == nil {
654 return ln
655 }
656 ln, err = net.Listen("tcp6", "[::1]:0")
657 if err != nil {
658 t.Fatal(err)
659 }
660 return ln
661 }
662
663 func TestTransportReqBodyAfterResponse_200(t *testing.T) {
664 synctestTest(t, func(t testing.TB) {
665 testTransportReqBodyAfterResponse(t, 200)
666 })
667 }
668 func TestTransportReqBodyAfterResponse_403(t *testing.T) {
669 synctestTest(t, func(t testing.TB) {
670 testTransportReqBodyAfterResponse(t, 403)
671 })
672 }
673
674 func testTransportReqBodyAfterResponse(t testing.TB, status int) {
675 const bodySize = 1 << 10
676
677 tc := newTestClientConn(t)
678 tc.greet()
679
680 body := tc.newRequestBody()
681 body.writeBytes(bodySize / 2)
682 req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
683 rt := tc.roundTrip(req)
684
685 tc.wantHeaders(wantHeader{
686 streamID: rt.streamID(),
687 endStream: false,
688 header: http.Header{
689 ":authority": []string{"dummy.tld"},
690 ":method": []string{"PUT"},
691 ":path": []string{"/"},
692 },
693 })
694
695
696 tc.writeWindowUpdate(0, bodySize)
697 tc.writeWindowUpdate(rt.streamID(), bodySize)
698
699 tc.wantData(wantData{
700 streamID: rt.streamID(),
701 endStream: false,
702 size: bodySize / 2,
703 })
704
705 tc.writeHeaders(HeadersFrameParam{
706 StreamID: rt.streamID(),
707 EndHeaders: true,
708 EndStream: true,
709 BlockFragment: tc.makeHeaderBlockFragment(
710 ":status", strconv.Itoa(status),
711 ),
712 })
713
714 res := rt.response()
715 if res.StatusCode != status {
716 t.Fatalf("status code = %v; want %v", res.StatusCode, status)
717 }
718
719 body.writeBytes(bodySize / 2)
720 body.closeWithError(io.EOF)
721
722 if status == 200 {
723
724 tc.wantData(wantData{
725 streamID: rt.streamID(),
726 endStream: true,
727 size: bodySize / 2,
728 multiple: true,
729 })
730 } else {
731
732 tc.wantFrameType(FrameRSTStream)
733 }
734
735 rt.wantBody(nil)
736 }
737
738
739 func TestTransportFullDuplex(t *testing.T) {
740 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
741 w.WriteHeader(200)
742 w.(http.Flusher).Flush()
743 io.Copy(flushWriter{w}, capitalizeReader{r.Body})
744 fmt.Fprintf(w, "bye.\n")
745 })
746
747 tr := newTransport(t)
748 c := &http.Client{Transport: tr}
749
750 pr, pw := io.Pipe()
751 req, err := http.NewRequest("PUT", ts.URL, io.NopCloser(pr))
752 if err != nil {
753 t.Fatal(err)
754 }
755 req.ContentLength = -1
756 res, err := c.Do(req)
757 if err != nil {
758 t.Fatal(err)
759 }
760 defer res.Body.Close()
761 if res.StatusCode != 200 {
762 t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
763 }
764 bs := bufio.NewScanner(res.Body)
765 want := func(v string) {
766 if !bs.Scan() {
767 t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
768 }
769 }
770 write := func(v string) {
771 _, err := io.WriteString(pw, v)
772 if err != nil {
773 t.Fatalf("pipe write: %v", err)
774 }
775 }
776 write("foo\n")
777 want("FOO")
778 write("bar\n")
779 want("BAR")
780 pw.Close()
781 want("bye.")
782 if err := bs.Err(); err != nil {
783 t.Fatal(err)
784 }
785 }
786
787 func TestTransportConnectRequest(t *testing.T) {
788 gotc := make(chan *http.Request, 1)
789 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
790 gotc <- r
791 })
792
793 u, err := url.Parse(ts.URL)
794 if err != nil {
795 t.Fatal(err)
796 }
797
798 tr := newTransport(t)
799 c := &http.Client{Transport: tr}
800
801 tests := []struct {
802 req *http.Request
803 want string
804 }{
805 {
806 req: &http.Request{
807 Method: "CONNECT",
808 Header: http.Header{},
809 URL: u,
810 },
811 want: u.Host,
812 },
813 {
814 req: &http.Request{
815 Method: "CONNECT",
816 Header: http.Header{},
817 URL: u,
818 Host: "example.com:123",
819 },
820 want: "example.com:123",
821 },
822 }
823
824 for i, tt := range tests {
825 res, err := c.Do(tt.req)
826 if err != nil {
827 t.Errorf("%d. RoundTrip = %v", i, err)
828 continue
829 }
830 res.Body.Close()
831 req := <-gotc
832 if req.Method != "CONNECT" {
833 t.Errorf("method = %q; want CONNECT", req.Method)
834 }
835 if req.Host != tt.want {
836 t.Errorf("Host = %q; want %q", req.Host, tt.want)
837 }
838 if req.URL.Host != tt.want {
839 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
840 }
841 }
842 }
843
844 type headerType int
845
846 const (
847 noHeader headerType = iota
848 oneHeader
849 splitHeader
850 )
851
852 const (
853 f0 = noHeader
854 f1 = oneHeader
855 f2 = splitHeader
856 d0 = false
857 d1 = true
858 )
859
860
861
862
863
864
865 func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) }
866 func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) }
867 func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) }
868 func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) }
869 func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) }
870 func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) }
871 func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) }
872 func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) }
873 func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) }
874 func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) }
875 func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) }
876 func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) }
877 func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) }
878 func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) }
879 func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) }
880 func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) }
881 func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) }
882 func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) }
883 func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) }
884 func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) }
885 func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) }
886 func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) }
887 func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) }
888 func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) }
889 func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) }
890 func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) }
891 func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) }
892 func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) }
893 func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) }
894 func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) }
895 func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) }
896 func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) }
897 func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) }
898 func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) }
899 func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) }
900 func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) }
901
902 func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
903 synctestTest(t, func(t testing.TB) {
904 testTransportResPatternBubble(t, expect100Continue, resHeader, withData, trailers)
905 })
906 }
907 func testTransportResPatternBubble(t testing.TB, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
908 const reqBody = "some request body"
909 const resBody = "some response body"
910
911 if resHeader == noHeader {
912
913
914 panic("invalid combination")
915 }
916
917 tc := newTestClientConn(t)
918 tc.greet()
919
920 req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
921 if expect100Continue != noHeader {
922 req.Header.Set("Expect", "100-continue")
923 }
924 rt := tc.roundTrip(req)
925
926 tc.wantFrameType(FrameHeaders)
927
928
929 tc.writeHeadersMode(expect100Continue, HeadersFrameParam{
930 StreamID: rt.streamID(),
931 EndHeaders: true,
932 EndStream: false,
933 BlockFragment: tc.makeHeaderBlockFragment(
934 ":status", "100",
935 ),
936 })
937
938
939 tc.wantData(wantData{
940 streamID: rt.streamID(),
941 endStream: true,
942 size: len(reqBody),
943 })
944
945 hdr := []string{
946 ":status", "200",
947 "x-foo", "blah",
948 "x-bar", "more",
949 }
950 if trailers != noHeader {
951 hdr = append(hdr, "trailer", "some-trailer")
952 }
953 tc.writeHeadersMode(resHeader, HeadersFrameParam{
954 StreamID: rt.streamID(),
955 EndHeaders: true,
956 EndStream: withData == false && trailers == noHeader,
957 BlockFragment: tc.makeHeaderBlockFragment(hdr...),
958 })
959 if withData {
960 endStream := trailers == noHeader
961 tc.writeData(rt.streamID(), endStream, []byte(resBody))
962 }
963 tc.writeHeadersMode(trailers, HeadersFrameParam{
964 StreamID: rt.streamID(),
965 EndHeaders: true,
966 EndStream: true,
967 BlockFragment: tc.makeHeaderBlockFragment(
968 "some-trailer", "some-value",
969 ),
970 })
971
972 rt.wantStatus(200)
973 if !withData {
974 rt.wantBody(nil)
975 } else {
976 rt.wantBody([]byte(resBody))
977 }
978 if trailers == noHeader {
979 rt.wantTrailers(nil)
980 } else {
981 rt.wantTrailers(http.Header{
982 "Some-Trailer": {"some-value"},
983 })
984 }
985 }
986
987
988 func TestTransportUnknown1xx(t *testing.T) { synctestTest(t, testTransportUnknown1xx) }
989 func testTransportUnknown1xx(t testing.TB) {
990 var buf bytes.Buffer
991 SetTestHookGot1xx(t, func(code int, header textproto.MIMEHeader) error {
992 fmt.Fprintf(&buf, "code=%d header=%v\n", code, header)
993 return nil
994 })
995
996 tc := newTestClientConn(t)
997 tc.greet()
998
999 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1000 rt := tc.roundTrip(req)
1001
1002 for i := 110; i <= 114; i++ {
1003 tc.writeHeaders(HeadersFrameParam{
1004 StreamID: rt.streamID(),
1005 EndHeaders: true,
1006 EndStream: false,
1007 BlockFragment: tc.makeHeaderBlockFragment(
1008 ":status", fmt.Sprint(i),
1009 "foo-bar", fmt.Sprint(i),
1010 ),
1011 })
1012 }
1013 tc.writeHeaders(HeadersFrameParam{
1014 StreamID: rt.streamID(),
1015 EndHeaders: true,
1016 EndStream: true,
1017 BlockFragment: tc.makeHeaderBlockFragment(
1018 ":status", "204",
1019 ),
1020 })
1021
1022 res := rt.response()
1023 if res.StatusCode != 204 {
1024 t.Fatalf("status code = %v; want 204", res.StatusCode)
1025 }
1026 want := `code=110 header=map[Foo-Bar:[110]]
1027 code=111 header=map[Foo-Bar:[111]]
1028 code=112 header=map[Foo-Bar:[112]]
1029 code=113 header=map[Foo-Bar:[113]]
1030 code=114 header=map[Foo-Bar:[114]]
1031 `
1032 if got := buf.String(); got != want {
1033 t.Errorf("Got trace:\n%s\nWant:\n%s", got, want)
1034 }
1035 }
1036
1037 func TestTransportReceiveUndeclaredTrailer(t *testing.T) {
1038 synctestTest(t, testTransportReceiveUndeclaredTrailer)
1039 }
1040 func testTransportReceiveUndeclaredTrailer(t testing.TB) {
1041 tc := newTestClientConn(t)
1042 tc.greet()
1043
1044 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1045 rt := tc.roundTrip(req)
1046
1047 tc.writeHeaders(HeadersFrameParam{
1048 StreamID: rt.streamID(),
1049 EndHeaders: true,
1050 EndStream: false,
1051 BlockFragment: tc.makeHeaderBlockFragment(
1052 ":status", "200",
1053 ),
1054 })
1055 tc.writeHeaders(HeadersFrameParam{
1056 StreamID: rt.streamID(),
1057 EndHeaders: true,
1058 EndStream: true,
1059 BlockFragment: tc.makeHeaderBlockFragment(
1060 "some-trailer", "I'm an undeclared Trailer!",
1061 ),
1062 })
1063
1064 rt.wantStatus(200)
1065 rt.wantBody(nil)
1066 rt.wantTrailers(http.Header{
1067 "Some-Trailer": []string{"I'm an undeclared Trailer!"},
1068 })
1069 }
1070
1071 func TestTransportInvalidTrailer_Pseudo1(t *testing.T) {
1072 testTransportInvalidTrailer_Pseudo(t, oneHeader)
1073 }
1074 func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
1075 testTransportInvalidTrailer_Pseudo(t, splitHeader)
1076 }
1077 func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
1078 testInvalidTrailer(t, trailers, PseudoHeaderError(":colon"),
1079 ":colon", "foo",
1080 "foo", "bar",
1081 )
1082 }
1083
1084 func TestTransportInvalidTrailer_Capital1(t *testing.T) {
1085 testTransportInvalidTrailer_Capital(t, oneHeader)
1086 }
1087 func TestTransportInvalidTrailer_Capital2(t *testing.T) {
1088 testTransportInvalidTrailer_Capital(t, splitHeader)
1089 }
1090 func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
1091 testInvalidTrailer(t, trailers, HeaderFieldNameError("Capital"),
1092 "foo", "bar",
1093 "Capital", "bad",
1094 )
1095 }
1096 func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
1097 testInvalidTrailer(t, oneHeader, HeaderFieldNameError(""),
1098 "", "bad",
1099 )
1100 }
1101 func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
1102 testInvalidTrailer(t, oneHeader, HeaderFieldValueError("x"),
1103 "x", "has\nnewline",
1104 )
1105 }
1106
1107 func testInvalidTrailer(t *testing.T, mode headerType, wantErr error, trailers ...string) {
1108 synctestTest(t, func(t testing.TB) {
1109 testInvalidTrailerBubble(t, mode, wantErr, trailers...)
1110 })
1111 }
1112 func testInvalidTrailerBubble(t testing.TB, mode headerType, wantErr error, trailers ...string) {
1113 tc := newTestClientConn(t)
1114 tc.greet()
1115
1116 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1117 rt := tc.roundTrip(req)
1118
1119 tc.writeHeaders(HeadersFrameParam{
1120 StreamID: rt.streamID(),
1121 EndHeaders: true,
1122 EndStream: false,
1123 BlockFragment: tc.makeHeaderBlockFragment(
1124 ":status", "200",
1125 "trailer", "declared",
1126 ),
1127 })
1128 tc.writeHeadersMode(mode, HeadersFrameParam{
1129 StreamID: rt.streamID(),
1130 EndHeaders: true,
1131 EndStream: true,
1132 BlockFragment: tc.makeHeaderBlockFragment(trailers...),
1133 })
1134
1135 rt.wantStatus(200)
1136 body, err := rt.readBody()
1137 se, ok := err.(StreamError)
1138 if !ok || se.Cause != wantErr {
1139 t.Fatalf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", body, err, wantErr, wantErr)
1140 }
1141 if len(body) > 0 {
1142 t.Fatalf("body = %q; want nothing", body)
1143 }
1144 }
1145
1146
1147
1148
1149
1150 func headerListSize(h http.Header) (size uint32) {
1151 for k, vv := range h {
1152 for _, v := range vv {
1153 hf := hpack.HeaderField{Name: k, Value: v}
1154 size += hf.Size()
1155 }
1156 }
1157 return size
1158 }
1159
1160
1161
1162
1163
1164
1165
1166
1167 func padHeaders(t testing.TB, h http.Header, limit uint64, filler string) {
1168 if limit > 0xffffffff {
1169 t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit)
1170 }
1171 hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
1172 minPadding := uint64(hf.Size())
1173 size := uint64(headerListSize(h))
1174
1175 minlimit := size + minPadding
1176 if limit < minlimit {
1177 t.Fatalf("padHeaders: limit %v < %v", limit, minlimit)
1178 }
1179
1180
1181
1182 nameFmt := "Pad-Headers-%06d"
1183 hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler}
1184 fieldSize := uint64(hf.Size())
1185
1186
1187
1188 limit = limit - minPadding
1189 for i := 0; size+fieldSize < limit; i++ {
1190 name := fmt.Sprintf(nameFmt, i)
1191 h.Add(name, filler)
1192 size += fieldSize
1193 }
1194
1195
1196 remain := limit - size
1197 lastValue := strings.Repeat("*", int(remain))
1198 h.Add("Pad-Headers", lastValue)
1199 }
1200
1201 func TestPadHeaders(t *testing.T) {
1202 check := func(h http.Header, limit uint32, fillerLen int) {
1203 if h == nil {
1204 h = make(http.Header)
1205 }
1206 filler := strings.Repeat("f", fillerLen)
1207 padHeaders(t, h, uint64(limit), filler)
1208 gotSize := headerListSize(h)
1209 if gotSize != limit {
1210 t.Errorf("Got size = %v; want %v", gotSize, limit)
1211 }
1212 }
1213
1214 hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
1215 minLimit := hf.Size()
1216 for limit := minLimit; limit <= 128; limit++ {
1217 for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ {
1218 check(nil, limit, fillerLen)
1219 }
1220 }
1221
1222
1223
1224
1225
1226
1227 tests := []struct {
1228 fillerLen int
1229 limit uint32
1230 }{
1231 {
1232 fillerLen: 64,
1233 limit: 1024,
1234 },
1235 {
1236 fillerLen: 1024,
1237 limit: 1286,
1238 },
1239 {
1240 fillerLen: 256,
1241 limit: 2048,
1242 },
1243 {
1244 fillerLen: 1024,
1245 limit: 10 * 1024,
1246 },
1247 {
1248 fillerLen: 1023,
1249 limit: 11 * 1024,
1250 },
1251 }
1252 h := make(http.Header)
1253 for _, tc := range tests {
1254 check(nil, tc.limit, tc.fillerLen)
1255 check(h, tc.limit, tc.fillerLen)
1256 }
1257 }
1258
1259 func TestTransportChecksRequestHeaderListSize(t *testing.T) {
1260 synctestTest(t, testTransportChecksRequestHeaderListSize)
1261 }
1262 func testTransportChecksRequestHeaderListSize(t testing.TB) {
1263 const peerSize = 16 << 10
1264
1265 tc := newTestClientConn(t)
1266 tc.greet(Setting{SettingMaxHeaderListSize, peerSize})
1267
1268 checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
1269 t.Helper()
1270 rt := tc.roundTrip(req)
1271 if wantErr != nil {
1272 if err := rt.err(); !errors.Is(err, wantErr) {
1273 t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr)
1274 }
1275 return
1276 }
1277
1278 tc.wantFrameType(FrameHeaders)
1279 tc.writeHeaders(HeadersFrameParam{
1280 StreamID: rt.streamID(),
1281 EndHeaders: true,
1282 EndStream: true,
1283 BlockFragment: tc.makeHeaderBlockFragment(
1284 ":status", "200",
1285 ),
1286 })
1287
1288 rt.wantStatus(http.StatusOK)
1289 }
1290 headerListSizeForRequest := func(req *http.Request) (size uint64) {
1291 _, err := httpcommon.EncodeHeaders(context.Background(), httpcommon.EncodeHeadersParam{
1292 Request: httpcommon.Request{
1293 Header: req.Header,
1294 Trailer: req.Trailer,
1295 URL: req.URL,
1296 Host: req.Host,
1297 Method: req.Method,
1298 ActualContentLength: req.ContentLength,
1299 },
1300 AddGzipHeader: true,
1301 PeerMaxHeaderListSize: 0xffffffffffffffff,
1302 }, func(name, value string) {
1303 hf := hpack.HeaderField{Name: name, Value: value}
1304 size += uint64(hf.Size())
1305 })
1306 if err != nil {
1307 t.Fatal(err)
1308 }
1309 return size
1310 }
1311
1312
1313
1314 newRequest := func() *http.Request {
1315
1316 const bodytext = "hello"
1317 body := strings.NewReader(bodytext)
1318 req, err := http.NewRequest("POST", "https://example.tld/", body)
1319 if err != nil {
1320 t.Fatalf("newRequest: NewRequest: %v", err)
1321 }
1322 req.ContentLength = int64(len(bodytext))
1323 req.Header = http.Header{"User-Agent": nil}
1324 return req
1325 }
1326
1327
1328 req := newRequest()
1329 req.Trailer = make(http.Header)
1330 filler := strings.Repeat("*", 1024)
1331 padHeaders(t, req.Trailer, peerSize, filler)
1332
1333
1334 defaultBytes := headerListSizeForRequest(req)
1335 padHeaders(t, req.Header, peerSize-defaultBytes, filler)
1336 checkRoundTrip(req, nil, "Headers & Trailers under limit")
1337
1338
1339 req = newRequest()
1340 padHeaders(t, req.Header, peerSize, filler)
1341 checkRoundTrip(req, ErrRequestHeaderListSize, "Headers over limit")
1342
1343
1344 req = newRequest()
1345 req.Trailer = make(http.Header)
1346 padHeaders(t, req.Trailer, peerSize+1, filler)
1347 checkRoundTrip(req, ErrRequestHeaderListSize, "Trailers over limit")
1348
1349
1350 req = newRequest()
1351 filler = strings.Repeat("*", int(peerSize))
1352 req.Header.Set("Big", filler)
1353 checkRoundTrip(req, ErrRequestHeaderListSize, "Single large header")
1354
1355
1356 req = newRequest()
1357 req.Trailer = make(http.Header)
1358 req.Trailer.Set("Big", filler)
1359 checkRoundTrip(req, ErrRequestHeaderListSize, "Single large trailer")
1360 }
1361
1362 func TestTransportChecksResponseHeaderListSize(t *testing.T) {
1363 synctestTest(t, testTransportChecksResponseHeaderListSize)
1364 }
1365 func testTransportChecksResponseHeaderListSize(t testing.TB) {
1366 tc := newTestClientConn(t)
1367 tc.greet()
1368
1369 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1370 rt := tc.roundTrip(req)
1371
1372 tc.wantFrameType(FrameHeaders)
1373
1374 hdr := []string{":status", "200"}
1375 large := strings.Repeat("a", 1<<10)
1376 for range 5042 {
1377 hdr = append(hdr, large, large)
1378 }
1379 hbf := tc.makeHeaderBlockFragment(hdr...)
1380
1381
1382
1383 if size, want := len(hbf), 6329; size != want {
1384 t.Fatalf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
1385 }
1386 tc.writeHeaders(HeadersFrameParam{
1387 StreamID: rt.streamID(),
1388 EndHeaders: true,
1389 EndStream: true,
1390 BlockFragment: hbf,
1391 })
1392
1393 res, err := rt.result()
1394 if e, ok := err.(StreamError); ok {
1395 err = e.Cause
1396 }
1397 if err != ErrResponseHeaderListSize {
1398 size := int64(0)
1399 if res != nil {
1400 res.Body.Close()
1401 for k, vv := range res.Header {
1402 for _, v := range vv {
1403 size += int64(len(k)) + int64(len(v)) + 32
1404 }
1405 }
1406 }
1407 t.Fatalf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
1408 }
1409 }
1410
1411 func TestTransportCookieHeaderSplit(t *testing.T) { synctestTest(t, testTransportCookieHeaderSplit) }
1412 func testTransportCookieHeaderSplit(t testing.TB) {
1413 tc := newTestClientConn(t)
1414 tc.greet()
1415
1416 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1417 req.Header.Add("Cookie", "a=b;c=d; e=f;")
1418 req.Header.Add("Cookie", "e=f;g=h; ")
1419 req.Header.Add("Cookie", "i=j")
1420 rt := tc.roundTrip(req)
1421
1422 tc.wantHeaders(wantHeader{
1423 streamID: rt.streamID(),
1424 endStream: true,
1425 header: http.Header{
1426 "cookie": []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"},
1427 },
1428 })
1429 tc.writeHeaders(HeadersFrameParam{
1430 StreamID: rt.streamID(),
1431 EndHeaders: true,
1432 EndStream: true,
1433 BlockFragment: tc.makeHeaderBlockFragment(
1434 ":status", "204",
1435 ),
1436 })
1437
1438 if err := rt.err(); err != nil {
1439 t.Fatalf("RoundTrip = %v, want success", err)
1440 }
1441 }
1442
1443
1444
1445
1446 func TestTransportBodyReadErrorType(t *testing.T) {
1447 doPanic := make(chan bool, 1)
1448 ts := newTestServer(t,
1449 func(w http.ResponseWriter, r *http.Request) {
1450 w.(http.Flusher).Flush()
1451 <-doPanic
1452 panic("boom")
1453 },
1454 optQuiet,
1455 )
1456
1457 tr := newTransport(t)
1458 c := &http.Client{Transport: tr}
1459
1460 res, err := c.Get(ts.URL)
1461 if err != nil {
1462 t.Fatal(err)
1463 }
1464 defer res.Body.Close()
1465 doPanic <- true
1466 buf := make([]byte, 100)
1467 n, err := res.Body.Read(buf)
1468 got, ok := err.(StreamError)
1469 want := StreamError{StreamID: 0x1, Code: 0x2}
1470 if !ok || got.StreamID != want.StreamID || got.Code != want.Code {
1471 t.Errorf("Read = %v, %#v; want error %#v", n, err, want)
1472 }
1473 }
1474
1475
1476
1477
1478 func TestTransportDoubleCloseOnWriteError(t *testing.T) {
1479 var (
1480 mu sync.Mutex
1481 conn net.Conn
1482 )
1483
1484 ts := newTestServer(t,
1485 func(w http.ResponseWriter, r *http.Request) {
1486 mu.Lock()
1487 defer mu.Unlock()
1488 if conn != nil {
1489 conn.Close()
1490 }
1491 },
1492 )
1493
1494 tr := newTransport(t)
1495 tr.DialTLS = func(network, addr string) (net.Conn, error) {
1496 tc, err := tls.Dial(network, addr, tlsConfigInsecure)
1497 if err != nil {
1498 return nil, err
1499 }
1500 mu.Lock()
1501 defer mu.Unlock()
1502 conn = tc
1503 return tc, nil
1504 }
1505 c := &http.Client{Transport: tr}
1506 c.Get(ts.URL)
1507 }
1508
1509
1510
1511
1512 func TestTransportDisableKeepAlives(t *testing.T) {
1513 ts := newTestServer(t,
1514 func(w http.ResponseWriter, r *http.Request) {
1515 io.WriteString(w, "hi")
1516 },
1517 )
1518
1519 connClosed := make(chan struct{})
1520 tr := newTransport(t)
1521 tr.Dial = func(network, addr string) (net.Conn, error) {
1522 tc, err := net.Dial(network, addr)
1523 if err != nil {
1524 return nil, err
1525 }
1526 return ¬eCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
1527 }
1528 tr.DisableKeepAlives = true
1529 c := &http.Client{Transport: tr}
1530 res, err := c.Get(ts.URL)
1531 if err != nil {
1532 t.Fatal(err)
1533 }
1534 if _, err := io.ReadAll(res.Body); err != nil {
1535 t.Fatal(err)
1536 }
1537 defer res.Body.Close()
1538
1539 select {
1540 case <-connClosed:
1541 case <-time.After(1 * time.Second):
1542 t.Errorf("timeout")
1543 }
1544
1545 }
1546
1547
1548
1549 func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
1550 const D = 25 * time.Millisecond
1551 ts := newTestServer(t,
1552 func(w http.ResponseWriter, r *http.Request) {
1553 time.Sleep(D)
1554 io.WriteString(w, "hi")
1555 },
1556 )
1557
1558 var dials int32
1559 var conns sync.WaitGroup
1560 tr := newTransport(t)
1561 tr.Dial = func(network, addr string) (net.Conn, error) {
1562 tc, err := net.Dial(network, addr)
1563 if err != nil {
1564 return nil, err
1565 }
1566 atomic.AddInt32(&dials, 1)
1567 conns.Add(1)
1568 return ¬eCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
1569 }
1570 tr.DisableKeepAlives = true
1571 c := &http.Client{Transport: tr}
1572 var reqs sync.WaitGroup
1573 const N = 20
1574 for i := range N {
1575 reqs.Add(1)
1576 if i == N-1 {
1577
1578
1579
1580
1581
1582
1583 time.Sleep(D * 2)
1584 }
1585 go func() {
1586 defer reqs.Done()
1587 res, err := c.Get(ts.URL)
1588 if err != nil {
1589 t.Error(err)
1590 return
1591 }
1592 if _, err := io.ReadAll(res.Body); err != nil {
1593 t.Error(err)
1594 return
1595 }
1596 res.Body.Close()
1597 }()
1598 }
1599 reqs.Wait()
1600 conns.Wait()
1601 t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
1602 }
1603
1604 type noteCloseConn struct {
1605 net.Conn
1606 onceClose sync.Once
1607 closefn func()
1608 }
1609
1610 func (c *noteCloseConn) Close() error {
1611 c.onceClose.Do(c.closefn)
1612 return c.Conn.Close()
1613 }
1614
1615 func isTimeout(err error) bool {
1616 switch err := err.(type) {
1617 case nil:
1618 return false
1619 case *url.Error:
1620 return isTimeout(err.Err)
1621 case net.Error:
1622 return err.Timeout()
1623 }
1624 return false
1625 }
1626
1627
1628 func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
1629 synctestTest(t, func(t testing.TB) {
1630 testTransportResponseHeaderTimeout(t, false)
1631 })
1632 }
1633 func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
1634 synctestTest(t, func(t testing.TB) {
1635 testTransportResponseHeaderTimeout(t, true)
1636 })
1637 }
1638
1639 func testTransportResponseHeaderTimeout(t testing.TB, body bool) {
1640 const bodySize = 4 << 20
1641 tc := newTestClientConn(t, func(t1 *http.Transport) {
1642 t1.ResponseHeaderTimeout = 5 * time.Millisecond
1643 })
1644 tc.greet()
1645
1646 var req *http.Request
1647 var reqBody *testRequestBody
1648 if body {
1649 reqBody = tc.newRequestBody()
1650 reqBody.writeBytes(bodySize)
1651 reqBody.closeWithError(io.EOF)
1652 req, _ = http.NewRequest("POST", "https://dummy.tld/", reqBody)
1653 req.Header.Set("Content-Type", "text/foo")
1654 } else {
1655 req, _ = http.NewRequest("GET", "https://dummy.tld/", nil)
1656 }
1657
1658 rt := tc.roundTrip(req)
1659
1660 tc.wantFrameType(FrameHeaders)
1661
1662 tc.writeWindowUpdate(0, bodySize)
1663 tc.writeWindowUpdate(rt.streamID(), bodySize)
1664
1665 if body {
1666 tc.wantData(wantData{
1667 endStream: true,
1668 size: bodySize,
1669 multiple: true,
1670 })
1671 }
1672
1673 time.Sleep(4 * time.Millisecond)
1674 if rt.done() {
1675 t.Fatalf("RoundTrip is done after 4ms; want still waiting")
1676 }
1677 time.Sleep(1 * time.Millisecond)
1678
1679 if err := rt.err(); !isTimeout(err) {
1680 t.Fatalf("RoundTrip error: %v; want timeout error", err)
1681 }
1682 }
1683
1684
1685 func TestTransportWindowUpdateBeyondLimit(t *testing.T) {
1686 synctestTest(t, testTransportWindowUpdateBeyondLimit)
1687 }
1688 func testTransportWindowUpdateBeyondLimit(t testing.TB) {
1689 const windowIncrease uint32 = (1 << 31) - 1
1690 tc := newTestClientConn(t)
1691 tc.greet()
1692
1693 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1694 rt := tc.roundTrip(req)
1695 tc.wantHeaders(wantHeader{
1696 streamID: rt.streamID(),
1697 endStream: true,
1698 })
1699
1700 tc.writeWindowUpdate(rt.streamID(), windowIncrease)
1701 tc.wantRSTStream(rt.streamID(), ErrCodeFlowControl)
1702
1703 tc.writeWindowUpdate(0, windowIncrease)
1704 tc.wantClosed()
1705 }
1706
1707 func TestTransportDisableCompression(t *testing.T) {
1708 const body = "sup"
1709 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
1710 want := http.Header{
1711 "User-Agent": []string{"Go-http-client/2.0"},
1712 }
1713 if !reflect.DeepEqual(r.Header, want) {
1714 t.Errorf("request headers = %v; want %v", r.Header, want)
1715 }
1716 })
1717
1718 tr := newTransport(t)
1719 tr.DisableCompression = true
1720
1721 req, err := http.NewRequest("GET", ts.URL, nil)
1722 if err != nil {
1723 t.Fatal(err)
1724 }
1725 res, err := tr.RoundTrip(req)
1726 if err != nil {
1727 t.Fatal(err)
1728 }
1729 defer res.Body.Close()
1730 }
1731
1732
1733 func TestTransportRejectsConnHeaders(t *testing.T) {
1734 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
1735 var got []string
1736 for k := range r.Header {
1737 got = append(got, k)
1738 }
1739 sort.Strings(got)
1740 w.Header().Set("Got-Header", strings.Join(got, ","))
1741 })
1742
1743 tr := newTransport(t)
1744
1745 tests := []struct {
1746 key string
1747 value []string
1748 want string
1749 }{
1750 {
1751 key: "Upgrade",
1752 value: []string{"anything"},
1753 want: "ERROR: http2: invalid Upgrade request header: [\"anything\"]",
1754 },
1755 {
1756 key: "Connection",
1757 value: []string{"foo"},
1758 want: "ERROR: http2: invalid Connection request header: [\"foo\"]",
1759 },
1760 {
1761 key: "Connection",
1762 value: []string{"close"},
1763 want: "Accept-Encoding,User-Agent",
1764 },
1765 {
1766 key: "Connection",
1767 value: []string{"CLoSe"},
1768 want: "Accept-Encoding,User-Agent",
1769 },
1770 {
1771 key: "Connection",
1772 value: []string{"close", "something-else"},
1773 want: "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]",
1774 },
1775 {
1776 key: "Connection",
1777 value: []string{"keep-alive"},
1778 want: "Accept-Encoding,User-Agent",
1779 },
1780 {
1781 key: "Connection",
1782 value: []string{"Keep-ALIVE"},
1783 want: "Accept-Encoding,User-Agent",
1784 },
1785 {
1786 key: "Proxy-Connection",
1787 value: []string{"keep-alive"},
1788 want: "Accept-Encoding,User-Agent",
1789 },
1790 {
1791 key: "Transfer-Encoding",
1792 value: []string{""},
1793 want: "Accept-Encoding,User-Agent",
1794 },
1795 {
1796 key: "Transfer-Encoding",
1797 value: []string{"foo"},
1798 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]",
1799 },
1800 {
1801 key: "Transfer-Encoding",
1802 value: []string{"chunked"},
1803 want: "Accept-Encoding,User-Agent",
1804 },
1805 {
1806 key: "Transfer-Encoding",
1807 value: []string{"chunKed"},
1808 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunKed\"]",
1809 },
1810 {
1811 key: "Transfer-Encoding",
1812 value: []string{"chunked", "other"},
1813 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]",
1814 },
1815 {
1816 key: "Content-Length",
1817 value: []string{"123"},
1818 want: "Accept-Encoding,User-Agent",
1819 },
1820 {
1821 key: "Keep-Alive",
1822 value: []string{"doop"},
1823 want: "Accept-Encoding,User-Agent",
1824 },
1825 }
1826
1827 for _, tt := range tests {
1828 req, _ := http.NewRequest("GET", ts.URL, nil)
1829 req.Header[tt.key] = tt.value
1830 res, err := tr.RoundTrip(req)
1831 var got string
1832 if err != nil {
1833 got = fmt.Sprintf("ERROR: %v", err)
1834 } else {
1835 got = res.Header.Get("Got-Header")
1836 res.Body.Close()
1837 }
1838 if got != tt.want {
1839 t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want)
1840 }
1841 }
1842 }
1843
1844
1845
1846 func TestTransportRejectsContentLengthWithSign(t *testing.T) {
1847 tests := []struct {
1848 name string
1849 cl []string
1850 wantCL string
1851 }{
1852 {
1853 name: "proper content-length",
1854 cl: []string{"3"},
1855 wantCL: "3",
1856 },
1857 {
1858 name: "ignore cl with plus sign",
1859 cl: []string{"+3"},
1860 wantCL: "",
1861 },
1862 {
1863 name: "ignore cl with minus sign",
1864 cl: []string{"-3"},
1865 wantCL: "",
1866 },
1867 {
1868 name: "max int64, for safe uint64->int64 conversion",
1869 cl: []string{"9223372036854775807"},
1870 wantCL: "9223372036854775807",
1871 },
1872 {
1873 name: "overflows int64, so ignored",
1874 cl: []string{"9223372036854775808"},
1875 wantCL: "",
1876 },
1877 }
1878
1879 for _, tt := range tests {
1880 t.Run(tt.name, func(t *testing.T) {
1881 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
1882 w.Header().Set("Content-Length", tt.cl[0])
1883 })
1884 tr := newTransport(t)
1885
1886 req, _ := http.NewRequest("HEAD", ts.URL, nil)
1887 res, err := tr.RoundTrip(req)
1888
1889 var got string
1890 if err != nil {
1891 got = fmt.Sprintf("ERROR: %v", err)
1892 } else {
1893 got = res.Header.Get("Content-Length")
1894 res.Body.Close()
1895 }
1896
1897 if got != tt.wantCL {
1898 t.Fatalf("Got: %q\nWant: %q", got, tt.wantCL)
1899 }
1900 })
1901 }
1902 }
1903
1904
1905
1906 func TestTransportFailsOnInvalidHeadersAndTrailers(t *testing.T) {
1907 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
1908 var got []string
1909 for k := range r.Header {
1910 got = append(got, k)
1911 }
1912 sort.Strings(got)
1913 w.Header().Set("Got-Header", strings.Join(got, ","))
1914 })
1915
1916 tests := [...]struct {
1917 h http.Header
1918 t http.Header
1919 wantErr string
1920 }{
1921 0: {
1922 h: http.Header{"with space": {"foo"}},
1923 wantErr: `net/http: invalid header field name "with space"`,
1924 },
1925 1: {
1926 h: http.Header{"name": {"Брэд"}},
1927 wantErr: "",
1928 },
1929 2: {
1930 h: http.Header{"имя": {"Brad"}},
1931 wantErr: `net/http: invalid header field name "имя"`,
1932 },
1933 3: {
1934 h: http.Header{"foo": {"foo\x01bar"}},
1935 wantErr: `net/http: invalid header field value for "foo"`,
1936 },
1937 4: {
1938 t: http.Header{"foo": {"foo\x01bar"}},
1939 wantErr: `net/http: invalid trailer field value for "foo"`,
1940 },
1941 5: {
1942 t: http.Header{"x-\r\nda": {"foo\x01bar"}},
1943 wantErr: `net/http: invalid trailer field name "x-\r\nda"`,
1944 },
1945 }
1946
1947 tr := newTransport(t)
1948
1949 for i, tt := range tests {
1950 req, _ := http.NewRequest("GET", ts.URL, nil)
1951 req.Header = tt.h
1952 if req.Header == nil {
1953 req.Header = http.Header{}
1954 }
1955 req.Trailer = tt.t
1956 res, err := tr.RoundTrip(req)
1957 var bad bool
1958 if tt.wantErr == "" {
1959 if err != nil {
1960 bad = true
1961 t.Errorf("case %d: error = %v; want no error", i, err)
1962 }
1963 } else {
1964 if !strings.Contains(fmt.Sprint(err), tt.wantErr) {
1965 bad = true
1966 t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr)
1967 }
1968 }
1969 if err == nil {
1970 if bad {
1971 t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header"))
1972 }
1973 res.Body.Close()
1974 }
1975 }
1976 }
1977
1978
1979
1980
1981 func TestTransportReadHeadResponse(t *testing.T) { synctestTest(t, testTransportReadHeadResponse) }
1982 func testTransportReadHeadResponse(t testing.TB) {
1983 tc := newTestClientConn(t)
1984 tc.greet()
1985
1986 req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
1987 rt := tc.roundTrip(req)
1988
1989 tc.wantFrameType(FrameHeaders)
1990 tc.writeHeaders(HeadersFrameParam{
1991 StreamID: rt.streamID(),
1992 EndHeaders: true,
1993 EndStream: false,
1994 BlockFragment: tc.makeHeaderBlockFragment(
1995 ":status", "200",
1996 "content-length", "123",
1997 ),
1998 })
1999 tc.writeData(rt.streamID(), true, nil)
2000
2001 res := rt.response()
2002 if res.ContentLength != 123 {
2003 t.Fatalf("Content-Length = %d; want 123", res.ContentLength)
2004 }
2005 rt.wantBody(nil)
2006 }
2007
2008 func TestTransportReadHeadResponseWithBody(t *testing.T) {
2009 synctestTest(t, testTransportReadHeadResponseWithBody)
2010 }
2011 func testTransportReadHeadResponseWithBody(t testing.TB) {
2012
2013
2014 log.SetOutput(io.Discard)
2015 defer log.SetOutput(os.Stderr)
2016
2017 response := "redirecting to /elsewhere"
2018 tc := newTestClientConn(t)
2019 tc.greet()
2020
2021 req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
2022 rt := tc.roundTrip(req)
2023
2024 tc.wantFrameType(FrameHeaders)
2025 tc.writeHeaders(HeadersFrameParam{
2026 StreamID: rt.streamID(),
2027 EndHeaders: true,
2028 EndStream: false,
2029 BlockFragment: tc.makeHeaderBlockFragment(
2030 ":status", "200",
2031 "content-length", strconv.Itoa(len(response)),
2032 ),
2033 })
2034 tc.writeData(rt.streamID(), true, []byte(response))
2035
2036 res := rt.response()
2037 if res.ContentLength != int64(len(response)) {
2038 t.Fatalf("Content-Length = %d; want %d", res.ContentLength, len(response))
2039 }
2040 rt.wantBody(nil)
2041 }
2042
2043 type neverEnding byte
2044
2045 func (b neverEnding) Read(p []byte) (int, error) {
2046 for i := range p {
2047 p[i] = byte(b)
2048 }
2049 return len(p), nil
2050 }
2051
2052
2053
2054 func TestTransportStreamEndsWhileBodyIsBeingWritten(t *testing.T) {
2055 synctestTest(t, testTransportStreamEndsWhileBodyIsBeingWritten)
2056 }
2057 func testTransportStreamEndsWhileBodyIsBeingWritten(t testing.TB) {
2058 body := "this is the client request body"
2059 const windowSize = 10
2060
2061 tc := newTestClientConn(t)
2062 tc.greet(Setting{SettingInitialWindowSize, windowSize})
2063
2064
2065 req, _ := http.NewRequest("PUT", "https://dummy.tld/", strings.NewReader(body))
2066 rt := tc.roundTrip(req)
2067 tc.wantFrameType(FrameHeaders)
2068 tc.wantData(wantData{
2069 streamID: rt.streamID(),
2070 endStream: false,
2071 size: windowSize,
2072 })
2073
2074
2075 tc.writeHeaders(HeadersFrameParam{
2076 StreamID: rt.streamID(),
2077 EndHeaders: true,
2078 EndStream: true,
2079 BlockFragment: tc.makeHeaderBlockFragment(
2080 ":status", "413",
2081 ),
2082 })
2083 rt.wantStatus(413)
2084 }
2085
2086 func TestTransportFlowControl(t *testing.T) { synctestTest(t, testTransportFlowControl) }
2087 func testTransportFlowControl(t testing.TB) {
2088 const maxBuffer = 64 << 10
2089 tc := newTestClientConn(t, func(tr *http.Transport) {
2090 tr.HTTP2 = &http.HTTP2Config{
2091 MaxReceiveBufferPerConnection: maxBuffer,
2092 MaxReceiveBufferPerStream: maxBuffer,
2093 MaxReadFrameSize: 16 << 20,
2094 }
2095 })
2096 tc.greet()
2097
2098 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2099 rt := tc.roundTrip(req)
2100 tc.wantFrameType(FrameHeaders)
2101
2102 tc.writeHeaders(HeadersFrameParam{
2103 StreamID: rt.streamID(),
2104 EndHeaders: true,
2105 EndStream: false,
2106 BlockFragment: tc.makeHeaderBlockFragment(
2107 ":status", "200",
2108 ),
2109 })
2110 rt.wantStatus(200)
2111
2112
2113
2114
2115 tc.writeData(rt.streamID(), false, make([]byte, maxBuffer))
2116 tc.wantIdle()
2117
2118
2119
2120 resp := rt.response()
2121 if _, err := io.ReadFull(resp.Body, make([]byte, maxBuffer)); err != nil {
2122 t.Fatalf("io.Body.Read: %v", err)
2123 }
2124 var connTokens, streamTokens uint32
2125 for {
2126 f := tc.readFrame()
2127 if f == nil {
2128 break
2129 }
2130 wu, ok := f.(*WindowUpdateFrame)
2131 if !ok {
2132 t.Fatalf("received unexpected frame %T (want WINDOW_UPDATE)", f)
2133 }
2134 switch wu.StreamID {
2135 case 0:
2136 connTokens += wu.Increment
2137 case wu.StreamID:
2138 streamTokens += wu.Increment
2139 default:
2140 t.Fatalf("received unexpected WINDOW_UPDATE for stream %v", wu.StreamID)
2141 }
2142 }
2143 if got, want := connTokens, uint32(maxBuffer); got != want {
2144 t.Errorf("transport provided %v bytes of connection WINDOW_UPDATE, want %v", got, want)
2145 }
2146 if got, want := streamTokens, uint32(maxBuffer); got != want {
2147 t.Errorf("transport provided %v bytes of stream WINDOW_UPDATE, want %v", got, want)
2148 }
2149 }
2150
2151
2152
2153
2154
2155
2156 func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) {
2157 synctestTest(t, func(t testing.TB) {
2158 testTransportUsesGoAwayDebugError(t, false)
2159 })
2160 }
2161
2162 func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
2163 synctestTest(t, func(t testing.TB) {
2164 testTransportUsesGoAwayDebugError(t, true)
2165 })
2166 }
2167
2168 func testTransportUsesGoAwayDebugError(t testing.TB, failMidBody bool) {
2169 tc := newTestClientConn(t)
2170 tc.greet()
2171
2172 const goAwayErrCode = ErrCodeHTTP11Required
2173 const goAwayDebugData = "some debug data"
2174
2175 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2176 rt := tc.roundTrip(req)
2177
2178 tc.wantFrameType(FrameHeaders)
2179
2180 if failMidBody {
2181 tc.writeHeaders(HeadersFrameParam{
2182 StreamID: rt.streamID(),
2183 EndHeaders: true,
2184 EndStream: false,
2185 BlockFragment: tc.makeHeaderBlockFragment(
2186 ":status", "200",
2187 "content-length", "123",
2188 ),
2189 })
2190 }
2191
2192
2193
2194 tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
2195 tc.writeGoAway(5, goAwayErrCode, nil)
2196 tc.closeWrite()
2197
2198 res, err := rt.result()
2199 whence := "RoundTrip"
2200 if failMidBody {
2201 whence = "Body.Read"
2202 if err != nil {
2203 t.Fatalf("RoundTrip error = %v, want success", err)
2204 }
2205 _, err = res.Body.Read(make([]byte, 1))
2206 }
2207
2208 want := GoAwayError{
2209 LastStreamID: 5,
2210 ErrCode: goAwayErrCode,
2211 DebugData: goAwayDebugData,
2212 }
2213 if !reflect.DeepEqual(err, want) {
2214 t.Errorf("%v error = %T: %#v, want %T (%#v)", whence, err, err, want, want)
2215 }
2216 }
2217
2218 func testTransportReturnsUnusedFlowControl(t testing.TB, oneDataFrame bool) {
2219 tc := newTestClientConn(t)
2220 tc.greet()
2221
2222 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2223 rt := tc.roundTrip(req)
2224
2225 tc.wantFrameType(FrameHeaders)
2226 tc.writeHeaders(HeadersFrameParam{
2227 StreamID: rt.streamID(),
2228 EndHeaders: true,
2229 EndStream: false,
2230 BlockFragment: tc.makeHeaderBlockFragment(
2231 ":status", "200",
2232 "content-length", "5000",
2233 ),
2234 })
2235 initialInflow := tc.inflowWindow(0)
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247 const streamNotEnded = false
2248 if oneDataFrame {
2249 tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 5000))
2250 } else {
2251 tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 1))
2252 }
2253
2254 res := rt.response()
2255 if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
2256 t.Fatalf("body read = %v, %v; want 1, nil", n, err)
2257 }
2258 res.Body.Close()
2259 synctest.Wait()
2260
2261 sentAdditionalData := false
2262 tc.wantUnorderedFrames(
2263 func(f *RSTStreamFrame) bool {
2264 if f.ErrCode != ErrCodeCancel {
2265 t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", SummarizeFrame(f))
2266 }
2267 if !oneDataFrame {
2268
2269 tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 4999))
2270 sentAdditionalData = true
2271 }
2272 return true
2273 },
2274 func(f *WindowUpdateFrame) bool {
2275 if !oneDataFrame && !sentAdditionalData {
2276 t.Fatalf("Got WindowUpdateFrame, don't expect one yet")
2277 }
2278 if f.Increment != 5000 {
2279 t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", SummarizeFrame(f))
2280 }
2281 return true
2282 },
2283 )
2284
2285 if got, want := tc.inflowWindow(0), initialInflow; got != want {
2286 t.Fatalf("connection flow tokens = %v, want %v", got, want)
2287 }
2288 }
2289
2290
2291 func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) {
2292 synctestTest(t, func(t testing.TB) {
2293 testTransportReturnsUnusedFlowControl(t, true)
2294 })
2295 }
2296
2297
2298 func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) {
2299 synctestTest(t, func(t testing.TB) {
2300 testTransportReturnsUnusedFlowControl(t, false)
2301 })
2302 }
2303
2304
2305
2306 func TestTransportAdjustsFlowControl(t *testing.T) { synctestTest(t, testTransportAdjustsFlowControl) }
2307 func testTransportAdjustsFlowControl(t testing.TB) {
2308 const bodySize = 1 << 20
2309
2310 tc := newTestClientConn(t)
2311 tc.wantFrameType(FrameSettings)
2312 tc.wantFrameType(FrameWindowUpdate)
2313
2314
2315 body := tc.newRequestBody()
2316 body.writeBytes(bodySize)
2317 body.closeWithError(io.EOF)
2318
2319 req, _ := http.NewRequest("POST", "https://dummy.tld/", body)
2320 rt := tc.roundTrip(req)
2321
2322 tc.wantFrameType(FrameHeaders)
2323
2324 gotBytes := int64(0)
2325 for {
2326 f := readFrame[*DataFrame](t, tc)
2327 gotBytes += int64(len(f.Data()))
2328
2329
2330 if gotBytes >= InitialWindowSize/2 {
2331 break
2332 }
2333 }
2334
2335 tc.writeSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
2336 tc.writeWindowUpdate(0, bodySize)
2337 tc.writeSettingsAck()
2338
2339 tc.wantUnorderedFrames(
2340 func(f *SettingsFrame) bool { return true },
2341 func(f *DataFrame) bool {
2342 gotBytes += int64(len(f.Data()))
2343 return f.StreamEnded()
2344 },
2345 )
2346
2347 if gotBytes != bodySize {
2348 t.Fatalf("server received %v bytes of body, want %v", gotBytes, bodySize)
2349 }
2350
2351 tc.writeHeaders(HeadersFrameParam{
2352 StreamID: rt.streamID(),
2353 EndHeaders: true,
2354 EndStream: true,
2355 BlockFragment: tc.makeHeaderBlockFragment(
2356 ":status", "200",
2357 ),
2358 })
2359 rt.wantStatus(200)
2360 }
2361
2362
2363 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
2364 synctestTest(t, testTransportReturnsDataPaddingFlowControl)
2365 }
2366 func testTransportReturnsDataPaddingFlowControl(t testing.TB) {
2367 tc := newTestClientConn(t)
2368 tc.greet()
2369
2370 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2371 rt := tc.roundTrip(req)
2372
2373 tc.wantFrameType(FrameHeaders)
2374 tc.writeHeaders(HeadersFrameParam{
2375 StreamID: rt.streamID(),
2376 EndHeaders: true,
2377 EndStream: false,
2378 BlockFragment: tc.makeHeaderBlockFragment(
2379 ":status", "200",
2380 "content-length", "5000",
2381 ),
2382 })
2383
2384 initialConnWindow := tc.inflowWindow(0)
2385 initialStreamWindow := tc.inflowWindow(rt.streamID())
2386
2387 pad := make([]byte, 5)
2388 tc.writeDataPadded(rt.streamID(), false, make([]byte, 5000), pad)
2389
2390
2391 synctest.Wait()
2392 if got, want := tc.inflowWindow(0), initialConnWindow-5000; got != want {
2393 t.Errorf("conn inflow window = %v, want %v", got, want)
2394 }
2395 if got, want := tc.inflowWindow(rt.streamID()), initialStreamWindow-5000; got != want {
2396 t.Errorf("stream inflow window = %v, want %v", got, want)
2397 }
2398 }
2399
2400
2401
2402 func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
2403 synctestTest(t, testTransportReturnsErrorOnBadResponseHeaders)
2404 }
2405 func testTransportReturnsErrorOnBadResponseHeaders(t testing.TB) {
2406 tc := newTestClientConn(t)
2407 tc.greet()
2408
2409 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2410 rt := tc.roundTrip(req)
2411
2412 tc.wantFrameType(FrameHeaders)
2413 tc.writeHeaders(HeadersFrameParam{
2414 StreamID: rt.streamID(),
2415 EndHeaders: true,
2416 EndStream: false,
2417 BlockFragment: tc.makeHeaderBlockFragment(
2418 ":status", "200",
2419 " content-type", "bogus",
2420 ),
2421 })
2422
2423 err := rt.err()
2424 want := StreamError{1, ErrCodeProtocol, HeaderFieldNameError(" content-type")}
2425 if !reflect.DeepEqual(err, want) {
2426 t.Fatalf("RoundTrip error = %#v; want %#v", err, want)
2427 }
2428
2429 fr := readFrame[*RSTStreamFrame](t, tc)
2430 if fr.StreamID != 1 || fr.ErrCode != ErrCodeProtocol {
2431 t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", SummarizeFrame(fr))
2432 }
2433 }
2434
2435
2436
2437 type byteAndEOFReader byte
2438
2439 func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
2440 if len(p) == 0 {
2441 panic("unexpected useless call")
2442 }
2443 p[0] = byte(b)
2444 return 1, io.EOF
2445 }
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456 func TestTransportBodyDoubleEndStream(t *testing.T) {
2457 synctestTest(t, testTransportBodyDoubleEndStream)
2458 }
2459 func testTransportBodyDoubleEndStream(t testing.TB) {
2460 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2461
2462 })
2463
2464 tr := newTransport(t)
2465
2466 for i := range 2 {
2467 req, _ := http.NewRequest("POST", ts.URL, byteAndEOFReader('a'))
2468 req.ContentLength = 1
2469 res, err := tr.RoundTrip(req)
2470 if err != nil {
2471 t.Fatalf("failure on req %d: %v", i+1, err)
2472 }
2473 defer res.Body.Close()
2474 }
2475 }
2476
2477
2478 func TestTransportRequestPathPseudo(t *testing.T) {
2479 type result struct {
2480 path string
2481 err string
2482 }
2483 tests := []struct {
2484 req *http.Request
2485 want result
2486 }{
2487 0: {
2488 req: &http.Request{
2489 Method: "GET",
2490 URL: &url.URL{
2491 Host: "foo.com",
2492 Path: "/foo",
2493 },
2494 },
2495 want: result{path: "/foo"},
2496 },
2497
2498
2499
2500 1: {
2501 req: &http.Request{
2502 Method: "GET",
2503 URL: &url.URL{
2504 Host: "foo.com",
2505 Path: "//foo",
2506 },
2507 },
2508 want: result{path: "//foo"},
2509 },
2510
2511
2512 2: {
2513 req: &http.Request{
2514 Method: "GET",
2515 URL: &url.URL{
2516 Scheme: "https",
2517 Opaque: "//foo.com/path",
2518 Host: "foo.com",
2519 Path: "/ignored",
2520 },
2521 },
2522 want: result{path: "/path"},
2523 },
2524
2525
2526 3: {
2527 req: &http.Request{
2528 Method: "GET",
2529 Host: "bar.com",
2530 URL: &url.URL{
2531 Scheme: "https",
2532 Opaque: "//bar.com/path",
2533 Host: "foo.com",
2534 Path: "/ignored",
2535 },
2536 },
2537 want: result{path: "/path"},
2538 },
2539
2540
2541 4: {
2542 req: &http.Request{
2543 Method: "GET",
2544 URL: &url.URL{
2545 Opaque: "/path",
2546 Host: "foo.com",
2547 Path: "/ignored",
2548 },
2549 },
2550 want: result{path: "/path"},
2551 },
2552
2553
2554 5: {
2555 req: &http.Request{
2556 Method: "GET",
2557 URL: &url.URL{
2558 Scheme: "https",
2559 Opaque: "//unknown_host/path",
2560 Host: "foo.com",
2561 Path: "/ignored",
2562 },
2563 },
2564 want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`},
2565 },
2566
2567
2568 6: {
2569 req: &http.Request{
2570 Method: "CONNECT",
2571 URL: &url.URL{
2572 Host: "foo.com",
2573 },
2574 },
2575 want: result{},
2576 },
2577 }
2578 for i, tt := range tests {
2579 hbuf := &bytes.Buffer{}
2580 henc := hpack.NewEncoder(hbuf)
2581 _, err := httpcommon.EncodeHeaders(context.Background(), httpcommon.EncodeHeadersParam{
2582 Request: httpcommon.Request{
2583 Header: tt.req.Header,
2584 Trailer: tt.req.Trailer,
2585 URL: tt.req.URL,
2586 Host: tt.req.Host,
2587 Method: tt.req.Method,
2588 ActualContentLength: tt.req.ContentLength,
2589 },
2590 AddGzipHeader: false,
2591 PeerMaxHeaderListSize: 0xffffffffffffffff,
2592 }, func(name, value string) {
2593 henc.WriteField(hpack.HeaderField{Name: name, Value: value})
2594 })
2595 hdrs := hbuf.Bytes()
2596 var got result
2597 hpackDec := hpack.NewDecoder(InitialHeaderTableSize, func(f hpack.HeaderField) {
2598 if f.Name == ":path" {
2599 got.path = f.Value
2600 }
2601 })
2602 if err != nil {
2603 got.err = err.Error()
2604 } else if len(hdrs) > 0 {
2605 if _, err := hpackDec.Write(hdrs); err != nil {
2606 t.Errorf("%d. bogus hpack: %v", i, err)
2607 continue
2608 }
2609 }
2610 if got != tt.want {
2611 t.Errorf("%d. got %+v; want %+v", i, got, tt.want)
2612 }
2613
2614 }
2615
2616 }
2617
2618
2619
2620 func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) {
2621 synctestTest(t, testRoundTripDoesntConsumeRequestBodyEarly)
2622 }
2623 func testRoundTripDoesntConsumeRequestBodyEarly(t testing.TB) {
2624 tc := newTestClientConn(t)
2625 tc.greet()
2626 tc.closeWrite()
2627 synctest.Wait()
2628
2629 const body = "foo"
2630 req, _ := http.NewRequest("POST", "http://foo.com/", io.NopCloser(strings.NewReader(body)))
2631 rt := tc.roundTrip(req)
2632 if err := rt.err(); err != ErrClientConnNotEstablished {
2633 t.Fatalf("RoundTrip = %v; want errClientConnNotEstablished", err)
2634 }
2635
2636 slurp, err := io.ReadAll(req.Body)
2637 if err != nil {
2638 t.Errorf("ReadAll = %v", err)
2639 }
2640 if string(slurp) != body {
2641 t.Errorf("Body = %q; want %q", slurp, body)
2642 }
2643 }
2644
2645
2646
2647
2648
2649 func TestTransportCancelDataResponseRace(t *testing.T) {
2650 cancel := make(chan struct{})
2651 clientGotResponse := make(chan bool, 1)
2652
2653 const msg = "Hello."
2654 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2655 if strings.Contains(r.URL.Path, "/hello") {
2656 time.Sleep(50 * time.Millisecond)
2657 io.WriteString(w, msg)
2658 return
2659 }
2660 for i := range 50 {
2661 io.WriteString(w, "Some data.")
2662 w.(http.Flusher).Flush()
2663 if i == 2 {
2664 <-clientGotResponse
2665 close(cancel)
2666 }
2667 time.Sleep(10 * time.Millisecond)
2668 }
2669 })
2670
2671 tr := newTransport(t)
2672
2673 c := &http.Client{Transport: tr}
2674 req, _ := http.NewRequest("GET", ts.URL, nil)
2675 req.Cancel = cancel
2676 res, err := c.Do(req)
2677 clientGotResponse <- true
2678 if err != nil {
2679 t.Fatal(err)
2680 }
2681 if _, err = io.Copy(io.Discard, res.Body); err == nil {
2682 t.Fatal("unexpected success")
2683 }
2684
2685 res, err = c.Get(ts.URL + "/hello")
2686 if err != nil {
2687 t.Fatal(err)
2688 }
2689 slurp, err := io.ReadAll(res.Body)
2690 if err != nil {
2691 t.Fatal(err)
2692 }
2693 if string(slurp) != msg {
2694 t.Errorf("Got = %q; want %q", slurp, msg)
2695 }
2696 }
2697
2698
2699
2700 func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
2701 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2702 w.WriteHeader(200)
2703 io.WriteString(w, "body")
2704 })
2705
2706 tr := newTransport(t)
2707
2708 req, _ := http.NewRequest("GET", ts.URL, nil)
2709 resp, err := tr.RoundTrip(req)
2710 if err != nil {
2711 t.Fatal(err)
2712 }
2713 if _, err = io.Copy(io.Discard, resp.Body); err != nil {
2714 t.Fatalf("error reading response body: %v", err)
2715 }
2716 if err := resp.Body.Close(); err != nil {
2717 t.Fatalf("error closing response body: %v", err)
2718 }
2719
2720
2721 req.Header = http.Header{}
2722 }
2723
2724 func TestTransportCloseAfterLostPing(t *testing.T) { synctestTest(t, testTransportCloseAfterLostPing) }
2725 func testTransportCloseAfterLostPing(t testing.TB) {
2726 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
2727 h2.PingTimeout = 1 * time.Second
2728 h2.SendPingTimeout = 1 * time.Second
2729 })
2730 tc.greet()
2731
2732 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2733 rt := tc.roundTrip(req)
2734 tc.wantFrameType(FrameHeaders)
2735
2736 time.Sleep(1 * time.Second)
2737 tc.wantFrameType(FramePing)
2738
2739 time.Sleep(1 * time.Second)
2740 err := rt.err()
2741 if err == nil || !strings.Contains(err.Error(), "client connection lost") {
2742 t.Fatalf("expected to get error about \"connection lost\", got %v", err)
2743 }
2744 }
2745
2746 func TestTransportPingWriteBlocks(t *testing.T) {
2747 ts := newTestServer(t,
2748 func(w http.ResponseWriter, r *http.Request) {},
2749 )
2750 tr := newTransport(t)
2751 tr.Dial = func(network, addr string) (net.Conn, error) {
2752 s, c := net.Pipe()
2753 go func() {
2754 srv := tls.Server(s, tlsConfigInsecure)
2755 srv.Handshake()
2756
2757
2758
2759
2760 var buf [1024]byte
2761 s.Read(buf[:])
2762 }()
2763 return c, nil
2764 }
2765 tr.HTTP2.PingTimeout = 1 * time.Millisecond
2766 tr.HTTP2.SendPingTimeout = 1 * time.Millisecond
2767 c := &http.Client{Transport: tr}
2768 _, err := c.Get(ts.URL)
2769 if err == nil {
2770 t.Fatalf("Get = nil, want error")
2771 }
2772 }
2773
2774 func TestTransportPingWhenReadingMultiplePings(t *testing.T) {
2775 synctestTest(t, testTransportPingWhenReadingMultiplePings)
2776 }
2777 func testTransportPingWhenReadingMultiplePings(t testing.TB) {
2778 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
2779 h2.SendPingTimeout = 1000 * time.Millisecond
2780 })
2781 tc.greet()
2782
2783 ctx, cancel := context.WithCancel(context.Background())
2784 req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
2785 rt := tc.roundTrip(req)
2786
2787 tc.wantFrameType(FrameHeaders)
2788 tc.writeHeaders(HeadersFrameParam{
2789 StreamID: rt.streamID(),
2790 EndHeaders: true,
2791 EndStream: false,
2792 BlockFragment: tc.makeHeaderBlockFragment(
2793 ":status", "200",
2794 ),
2795 })
2796
2797 for range 5 {
2798
2799 time.Sleep(999 * time.Millisecond)
2800 if f := tc.readFrame(); f != nil {
2801 t.Fatalf("unexpected frame: %v", f)
2802 }
2803
2804
2805 time.Sleep(1 * time.Millisecond)
2806 f := readFrame[*PingFrame](t, tc)
2807 tc.writePing(true, f.Data)
2808 }
2809
2810
2811 cancel()
2812 synctest.Wait()
2813
2814 tc.wantFrameType(FrameRSTStream)
2815 _, err := rt.readBody()
2816 if err == nil {
2817 t.Fatalf("Response.Body.Read() = %v, want error", err)
2818 }
2819 }
2820
2821 func TestTransportPingWhenReadingPingDisabled(t *testing.T) {
2822 synctestTest(t, testTransportPingWhenReadingPingDisabled)
2823 }
2824 func testTransportPingWhenReadingPingDisabled(t testing.TB) {
2825 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
2826 h2.SendPingTimeout = 0
2827 })
2828 tc.greet()
2829
2830 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2831 rt := tc.roundTrip(req)
2832
2833 tc.wantFrameType(FrameHeaders)
2834 tc.writeHeaders(HeadersFrameParam{
2835 StreamID: rt.streamID(),
2836 EndHeaders: true,
2837 EndStream: false,
2838 BlockFragment: tc.makeHeaderBlockFragment(
2839 ":status", "200",
2840 ),
2841 })
2842
2843
2844 time.Sleep(1 * time.Minute)
2845 if f := tc.readFrame(); f != nil {
2846 t.Fatalf("unexpected frame: %v", f)
2847 }
2848 }
2849
2850 func TestTransportRetryAfterGOAWAYNoRetry(t *testing.T) {
2851 synctestTest(t, testTransportRetryAfterGOAWAYNoRetry)
2852 }
2853 func testTransportRetryAfterGOAWAYNoRetry(t testing.TB) {
2854 tt := newTestTransport(t)
2855
2856 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2857 rt := tt.roundTrip(req)
2858
2859
2860
2861
2862
2863 tc := tt.getConn()
2864 tc.wantFrameType(FrameSettings)
2865 tc.wantFrameType(FrameWindowUpdate)
2866 tc.wantHeaders(wantHeader{
2867 streamID: 1,
2868 endStream: true,
2869 })
2870 tc.writeSettings()
2871 tc.writeGoAway(0 , ErrCodeInternal, nil)
2872 if rt.err() == nil {
2873 t.Fatalf("after GOAWAY, RoundTrip is not done, want error")
2874 }
2875 }
2876
2877 func TestTransportRetryAfterGOAWAYRetry(t *testing.T) {
2878 synctestTest(t, testTransportRetryAfterGOAWAYRetry)
2879 }
2880 func testTransportRetryAfterGOAWAYRetry(t testing.TB) {
2881 tt := newTestTransport(t)
2882
2883 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2884 rt := tt.roundTrip(req)
2885
2886
2887
2888
2889
2890 tc := tt.getConn()
2891 tc.wantFrameType(FrameSettings)
2892 tc.wantFrameType(FrameWindowUpdate)
2893 tc.wantHeaders(wantHeader{
2894 streamID: 1,
2895 endStream: true,
2896 })
2897 tc.writeSettings()
2898 tc.writeGoAway(0 , ErrCodeNo, nil)
2899 if rt.done() {
2900 t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying")
2901 }
2902
2903
2904 tc = tt.getConn()
2905 tc.wantFrameType(FrameSettings)
2906 tc.wantFrameType(FrameWindowUpdate)
2907 tc.wantHeaders(wantHeader{
2908 streamID: 1,
2909 endStream: true,
2910 })
2911 tc.writeSettings()
2912 tc.writeHeaders(HeadersFrameParam{
2913 StreamID: 1,
2914 EndHeaders: true,
2915 EndStream: true,
2916 BlockFragment: tc.makeHeaderBlockFragment(
2917 ":status", "200",
2918 ),
2919 })
2920
2921 rt.wantStatus(200)
2922 }
2923
2924 func TestTransportRetryAfterGOAWAYSecondRequest(t *testing.T) {
2925 synctestTest(t, testTransportRetryAfterGOAWAYSecondRequest)
2926 }
2927 func testTransportRetryAfterGOAWAYSecondRequest(t testing.TB) {
2928 tt := newTestTransport(t)
2929
2930
2931 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2932 rt1 := tt.roundTrip(req)
2933 tc := tt.getConn()
2934 tc.wantFrameType(FrameSettings)
2935 tc.wantFrameType(FrameWindowUpdate)
2936 tc.wantHeaders(wantHeader{
2937 streamID: 1,
2938 endStream: true,
2939 })
2940 tc.writeSettings()
2941 tc.wantFrameType(FrameSettings)
2942 tc.writeHeaders(HeadersFrameParam{
2943 StreamID: 1,
2944 EndHeaders: true,
2945 EndStream: true,
2946 BlockFragment: tc.makeHeaderBlockFragment(
2947 ":status", "200",
2948 ),
2949 })
2950 rt1.wantStatus(200)
2951
2952
2953
2954
2955
2956 req, _ = http.NewRequest("GET", "https://dummy.tld/", nil)
2957 rt2 := tt.roundTrip(req)
2958
2959
2960 tc.wantHeaders(wantHeader{
2961 streamID: 3,
2962 endStream: true,
2963 })
2964 tc.writeSettings()
2965 tc.writeGoAway(1 , ErrCodeProtocol, nil)
2966 if rt2.done() {
2967 t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying")
2968 }
2969
2970
2971 tc = tt.getConn()
2972 tc.wantFrameType(FrameSettings)
2973 tc.wantFrameType(FrameWindowUpdate)
2974 tc.wantHeaders(wantHeader{
2975 streamID: 1,
2976 endStream: true,
2977 })
2978 tc.writeSettings()
2979 tc.writeHeaders(HeadersFrameParam{
2980 StreamID: 1,
2981 EndHeaders: true,
2982 EndStream: true,
2983 BlockFragment: tc.makeHeaderBlockFragment(
2984 ":status", "200",
2985 ),
2986 })
2987 rt2.wantStatus(200)
2988 }
2989
2990 func TestTransportRetryAfterRefusedStream(t *testing.T) {
2991 synctestTest(t, testTransportRetryAfterRefusedStream)
2992 }
2993 func testTransportRetryAfterRefusedStream(t testing.TB) {
2994 tt := newTestTransport(t)
2995
2996 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2997 rt := tt.roundTrip(req)
2998
2999
3000 tc := tt.getConn()
3001 tc.wantFrameType(FrameSettings)
3002 tc.wantFrameType(FrameWindowUpdate)
3003 tc.wantHeaders(wantHeader{
3004 streamID: 1,
3005 endStream: true,
3006 })
3007 tc.writeSettings()
3008 tc.wantFrameType(FrameSettings)
3009 tc.writeRSTStream(1, ErrCodeRefusedStream)
3010 if rt.done() {
3011 t.Fatalf("after RST_STREAM, RoundTrip is done; want it to be retrying")
3012 }
3013
3014
3015 tc.wantHeaders(wantHeader{
3016 streamID: 3,
3017 endStream: true,
3018 })
3019 tc.writeSettings()
3020 tc.writeHeaders(HeadersFrameParam{
3021 StreamID: 3,
3022 EndHeaders: true,
3023 EndStream: true,
3024 BlockFragment: tc.makeHeaderBlockFragment(
3025 ":status", "204",
3026 ),
3027 })
3028
3029 rt.wantStatus(204)
3030 }
3031
3032 func TestTransportRetryHasLimit(t *testing.T) { synctestTest(t, testTransportRetryHasLimit) }
3033 func testTransportRetryHasLimit(t testing.TB) {
3034 tt := newTestTransport(t)
3035
3036 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3037 rt := tt.roundTrip(req)
3038
3039 tc := tt.getConn()
3040 tc.netconn.SetReadDeadline(time.Time{})
3041 tc.wantFrameType(FrameSettings)
3042 tc.wantFrameType(FrameWindowUpdate)
3043
3044 count := 0
3045 start := time.Now()
3046 for streamID := uint32(1); !rt.done(); streamID += 2 {
3047 count++
3048 tc.wantHeaders(wantHeader{
3049 streamID: streamID,
3050 endStream: true,
3051 })
3052 if streamID == 1 {
3053 tc.writeSettings()
3054 tc.wantFrameType(FrameSettings)
3055 }
3056 tc.writeRSTStream(streamID, ErrCodeRefusedStream)
3057
3058 if totalDelay := time.Since(start); totalDelay > 5*time.Minute {
3059 t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay)
3060 }
3061 synctest.Wait()
3062 }
3063 if got, want := count, 5; got < count {
3064 t.Errorf("RoundTrip made %v attempts, want at least %v", got, want)
3065 }
3066 if rt.err() == nil {
3067 t.Errorf("RoundTrip succeeded, want error")
3068 }
3069 }
3070
3071 func TestTransportResponseDataBeforeHeaders(t *testing.T) {
3072 synctestTest(t, testTransportResponseDataBeforeHeaders)
3073 }
3074 func testTransportResponseDataBeforeHeaders(t testing.TB) {
3075
3076 log.SetOutput(io.Discard)
3077 t.Cleanup(func() { log.SetOutput(os.Stderr) })
3078
3079 tc := newTestClientConn(t)
3080 tc.greet()
3081
3082
3083 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3084 rt1 := tc.roundTrip(req)
3085 tc.wantFrameType(FrameHeaders)
3086 tc.writeHeaders(HeadersFrameParam{
3087 StreamID: rt1.streamID(),
3088 EndHeaders: true,
3089 EndStream: true,
3090 BlockFragment: tc.makeHeaderBlockFragment(
3091 ":status", "200",
3092 ),
3093 })
3094 rt1.wantStatus(200)
3095
3096
3097 rt2 := tc.roundTrip(req)
3098 tc.wantFrameType(FrameHeaders)
3099 tc.writeData(rt2.streamID(), true, []byte("payload"))
3100 if err, ok := rt2.err().(StreamError); !ok || err.Code != ErrCodeProtocol {
3101 t.Fatalf("expected stream PROTOCOL_ERROR, got: %v", err)
3102 }
3103 }
3104
3105 func TestTransportMaxFrameReadSize(t *testing.T) {
3106 for _, test := range []struct {
3107 maxReadFrameSize uint32
3108 want uint32
3109 }{{
3110 maxReadFrameSize: 64000,
3111 want: 64000,
3112 }, {
3113 maxReadFrameSize: 1024,
3114
3115
3116
3117
3118
3119
3120
3121
3122 want: DefaultMaxReadFrameSize,
3123 }} {
3124 synctestSubtest(t, fmt.Sprint(test.maxReadFrameSize), func(t testing.TB) {
3125 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
3126 h2.MaxReadFrameSize = int(test.maxReadFrameSize)
3127 })
3128
3129 fr := readFrame[*SettingsFrame](t, tc)
3130 got, ok := fr.Value(SettingMaxFrameSize)
3131 if !ok {
3132 t.Errorf("Transport.MaxReadFrameSize = %v; server got no setting, want %v", test.maxReadFrameSize, test.want)
3133 } else if got != test.want {
3134 t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want)
3135 }
3136 })
3137 }
3138 }
3139
3140 func TestTransportRequestsLowServerLimit(t *testing.T) {
3141 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3142 }, func(h2 *http.HTTP2Config) {
3143 h2.MaxConcurrentStreams = 1
3144 })
3145
3146 var (
3147 connCountMu sync.Mutex
3148 connCount int
3149 )
3150 tr := newTransport(t)
3151 tr.DialTLS = func(network, addr string) (net.Conn, error) {
3152 connCountMu.Lock()
3153 defer connCountMu.Unlock()
3154 connCount++
3155 return tls.Dial(network, addr, tlsConfigInsecure)
3156 }
3157
3158 const reqCount = 3
3159 for range reqCount {
3160 req, err := http.NewRequest("GET", ts.URL, nil)
3161 if err != nil {
3162 t.Fatal(err)
3163 }
3164 res, err := tr.RoundTrip(req)
3165 if err != nil {
3166 t.Fatal(err)
3167 }
3168 if got, want := res.StatusCode, 200; got != want {
3169 t.Errorf("StatusCode = %v; want %v", got, want)
3170 }
3171 if res != nil && res.Body != nil {
3172 res.Body.Close()
3173 }
3174 }
3175
3176 if connCount != 1 {
3177 t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount)
3178 }
3179 }
3180
3181
3182 func TestTransportRequestsStallAtServerLimit(t *testing.T) {
3183 synctest.Test(t, testTransportRequestsStallAtServerLimit)
3184 }
3185 func testTransportRequestsStallAtServerLimit(t *testing.T) {
3186 const maxConcurrent = 2
3187
3188 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
3189 h2.StrictMaxConcurrentRequests = true
3190 })
3191 tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
3192
3193 cancelClientRequest := make(chan struct{})
3194
3195
3196
3197 var rts []*testRoundTrip
3198 for k := range maxConcurrent + 2 {
3199 req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
3200 if k == maxConcurrent {
3201 req.Cancel = cancelClientRequest
3202 }
3203 rt := tc.roundTrip(req)
3204 rts = append(rts, rt)
3205
3206 if k < maxConcurrent {
3207
3208 tc.wantHeaders(wantHeader{
3209 streamID: rt.streamID(),
3210 endStream: true,
3211 header: http.Header{
3212 ":authority": []string{"dummy.tld"},
3213 ":method": []string{"GET"},
3214 ":path": []string{fmt.Sprintf("/%d", k)},
3215 },
3216 })
3217 } else {
3218
3219
3220 if fr := tc.readFrame(); fr != nil {
3221 t.Fatalf("after making new request while at stream limit, got unexpected frame: %v", fr)
3222 }
3223 }
3224
3225 if rt.done() {
3226 t.Fatalf("rt %v done", k)
3227 }
3228 }
3229
3230
3231
3232 close(cancelClientRequest)
3233 synctest.Wait()
3234 if err := rts[maxConcurrent].err(); err == nil {
3235 t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent)
3236 }
3237
3238
3239 for i, rt := range rts {
3240 if i != maxConcurrent && rt.done() {
3241 t.Fatalf("RoundTrip(%d) is done, but should not be", i)
3242 }
3243 }
3244
3245
3246 tc.writeHeaders(HeadersFrameParam{
3247 StreamID: rts[0].streamID(),
3248 EndHeaders: true,
3249 EndStream: true,
3250 BlockFragment: tc.makeHeaderBlockFragment(
3251 ":status", "200",
3252 ),
3253 })
3254 synctest.Wait()
3255 tc.wantHeaders(wantHeader{
3256 streamID: rts[maxConcurrent+1].streamID(),
3257 endStream: true,
3258 header: http.Header{
3259 ":authority": []string{"dummy.tld"},
3260 ":method": []string{"GET"},
3261 ":path": []string{fmt.Sprintf("/%d", maxConcurrent+1)},
3262 },
3263 })
3264 rts[0].wantStatus(200)
3265 }
3266
3267 func TestTransportMaxDecoderHeaderTableSize(t *testing.T) {
3268 synctestTest(t, testTransportMaxDecoderHeaderTableSize)
3269 }
3270 func testTransportMaxDecoderHeaderTableSize(t testing.TB) {
3271 var reqSize, resSize uint32 = 8192, 16384
3272 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
3273 h2.MaxDecoderHeaderTableSize = int(reqSize)
3274 })
3275
3276 fr := readFrame[*SettingsFrame](t, tc)
3277 if v, ok := fr.Value(SettingHeaderTableSize); !ok {
3278 t.Fatalf("missing SETTINGS_HEADER_TABLE_SIZE setting")
3279 } else if v != reqSize {
3280 t.Fatalf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", v, reqSize)
3281 }
3282
3283 tc.writeSettings(Setting{SettingHeaderTableSize, resSize})
3284 synctest.Wait()
3285 if got, want := tc.cc.TestPeerMaxHeaderTableSize(), resSize; got != want {
3286 t.Fatalf("peerHeaderTableSize = %d, want %d", got, want)
3287 }
3288 }
3289
3290 func TestTransportMaxEncoderHeaderTableSize(t *testing.T) {
3291 synctestTest(t, testTransportMaxEncoderHeaderTableSize)
3292 }
3293 func testTransportMaxEncoderHeaderTableSize(t testing.TB) {
3294 var peerAdvertisedMaxHeaderTableSize uint32 = 16384
3295 const wantMaxEncoderHeaderTableSize = 8192
3296 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
3297 h2.MaxEncoderHeaderTableSize = wantMaxEncoderHeaderTableSize
3298 })
3299 tc.greet(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize})
3300
3301 if got, want := tc.cc.TestHPACKEncoder().MaxDynamicTableSize(), uint32(wantMaxEncoderHeaderTableSize); got != want {
3302 t.Fatalf("henc.MaxDynamicTableSize() = %d, want %d", got, want)
3303 }
3304 }
3305
3306
3307
3308 func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
3309 synctestTest(t, testTransportAllocationsAfterResponseBodyClose)
3310 }
3311 func testTransportAllocationsAfterResponseBodyClose(t testing.TB) {
3312 tc := newTestClientConn(t)
3313 tc.greet()
3314
3315
3316 req, _ := http.NewRequest("PUT", "https://dummy.tld/", nil)
3317 rt := tc.roundTrip(req)
3318 tc.wantFrameType(FrameHeaders)
3319
3320
3321 tc.writeHeaders(HeadersFrameParam{
3322 StreamID: rt.streamID(),
3323 EndHeaders: true,
3324 EndStream: false,
3325 BlockFragment: tc.makeHeaderBlockFragment(
3326 ":status", "200",
3327 ),
3328 })
3329 tc.writeData(rt.streamID(), false, make([]byte, 64))
3330 tc.wantIdle()
3331
3332
3333 respBody := rt.response().Body
3334 var buf [1]byte
3335 if _, err := respBody.Read(buf[:]); err != nil {
3336 t.Error(err)
3337 }
3338 if err := respBody.Close(); err != nil {
3339 t.Error(err)
3340 }
3341 tc.wantFrameType(FrameRSTStream)
3342
3343
3344 tc.writeData(rt.streamID(), false, make([]byte, 64))
3345
3346 if _, err := respBody.Read(buf[:]); err == nil {
3347 t.Error("read from closed body unexpectedly succeeded")
3348 }
3349 }
3350
3351
3352
3353 func TestTransportNoBodyMeansNoDATA(t *testing.T) { synctestTest(t, testTransportNoBodyMeansNoDATA) }
3354 func testTransportNoBodyMeansNoDATA(t testing.TB) {
3355 tc := newTestClientConn(t)
3356 tc.greet()
3357
3358 req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
3359 rt := tc.roundTrip(req)
3360
3361 tc.wantHeaders(wantHeader{
3362 streamID: rt.streamID(),
3363 endStream: true,
3364 header: http.Header{
3365 ":authority": []string{"dummy.tld"},
3366 ":method": []string{"GET"},
3367 ":path": []string{"/"},
3368 },
3369 })
3370 if fr := tc.readFrame(); fr != nil {
3371 t.Fatalf("unexpected frame after headers: %v", fr)
3372 }
3373 }
3374
3375 func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) {
3376 DisableGoroutineTracking(b)
3377 b.ReportAllocs()
3378 ts := newTestServer(b,
3379 func(w http.ResponseWriter, r *http.Request) {
3380 for i := range nResHeader {
3381 name := fmt.Sprint("A-", i)
3382 w.Header().Set(name, "*")
3383 }
3384 },
3385 optQuiet,
3386 )
3387
3388 tr := newTransport(b)
3389
3390 req, err := http.NewRequest("GET", ts.URL, nil)
3391 if err != nil {
3392 b.Fatal(err)
3393 }
3394
3395 for i := range nReqHeaders {
3396 name := fmt.Sprint("A-", i)
3397 req.Header.Set(name, "*")
3398 }
3399
3400 b.ResetTimer()
3401
3402 for i := 0; i < b.N; i++ {
3403 res, err := tr.RoundTrip(req)
3404 if err != nil {
3405 if res != nil {
3406 res.Body.Close()
3407 }
3408 b.Fatalf("RoundTrip err = %v; want nil", err)
3409 }
3410 res.Body.Close()
3411 if res.StatusCode != http.StatusOK {
3412 b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3413 }
3414 }
3415 }
3416
3417 type infiniteReader struct{}
3418
3419 func (r infiniteReader) Read(b []byte) (int, error) {
3420 return len(b), nil
3421 }
3422
3423
3424
3425 func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) {
3426 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3427 w.WriteHeader(http.StatusOK)
3428 })
3429
3430 tr := newTransport(t)
3431
3432
3433 req, _ := http.NewRequest("PUT", ts.URL, infiniteReader{})
3434 res, err := tr.RoundTrip(req)
3435 if err != nil {
3436 t.Fatal(err)
3437 }
3438 if res.StatusCode != http.StatusOK {
3439 t.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3440 }
3441 }
3442
3443
3444
3445 func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) {
3446 synctestTest(t, testTransportHandlesInvalidStatuslessResponse)
3447 }
3448 func testTransportHandlesInvalidStatuslessResponse(t testing.TB) {
3449 tc := newTestClientConn(t)
3450 tc.greet()
3451
3452 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3453 rt := tc.roundTrip(req)
3454
3455 tc.wantFrameType(FrameHeaders)
3456 tc.writeHeaders(HeadersFrameParam{
3457 StreamID: rt.streamID(),
3458 EndHeaders: true,
3459 EndStream: false,
3460 BlockFragment: tc.makeHeaderBlockFragment(
3461 "content-type", "text/html",
3462 ),
3463 })
3464 tc.writeData(rt.streamID(), true, []byte("payload"))
3465 }
3466
3467 func BenchmarkClientRequestHeaders(b *testing.B) {
3468 b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
3469 b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10, 0) })
3470 b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100, 0) })
3471 b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000, 0) })
3472 }
3473
3474 func BenchmarkClientResponseHeaders(b *testing.B) {
3475 b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
3476 b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 10) })
3477 b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 100) })
3478 b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) })
3479 }
3480
3481 func BenchmarkDownloadFrameSize(b *testing.B) {
3482 b.Run(" 16k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 16*1024) })
3483 b.Run(" 64k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 64*1024) })
3484 b.Run("128k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 128*1024) })
3485 b.Run("256k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 256*1024) })
3486 b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) })
3487 }
3488 func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) {
3489 DisableGoroutineTracking(b)
3490 const transferSize = 1024 * 1024 * 1024
3491 b.ReportAllocs()
3492 ts := newTestServer(b,
3493 func(w http.ResponseWriter, r *http.Request) {
3494
3495 w.Header().Set("Content-Length", strconv.Itoa(transferSize))
3496 w.Header().Set("Content-Transfer-Encoding", "binary")
3497 var data [1024 * 1024]byte
3498 for range transferSize / (1024 * 1024) {
3499 w.Write(data[:])
3500 }
3501 }, optQuiet,
3502 )
3503
3504 tr := newTransport(b)
3505 tr.HTTP2.MaxReadFrameSize = int(frameSize)
3506
3507 req, err := http.NewRequest("GET", ts.URL, nil)
3508 if err != nil {
3509 b.Fatal(err)
3510 }
3511
3512 b.N = 3
3513 b.SetBytes(transferSize)
3514 b.ResetTimer()
3515
3516 for i := 0; i < b.N; i++ {
3517 res, err := tr.RoundTrip(req)
3518 if err != nil {
3519 if res != nil {
3520 res.Body.Close()
3521 }
3522 b.Fatalf("RoundTrip err = %v; want nil", err)
3523 }
3524 data, _ := io.ReadAll(res.Body)
3525 if len(data) != transferSize {
3526 b.Fatalf("Response length invalid")
3527 }
3528 res.Body.Close()
3529 if res.StatusCode != http.StatusOK {
3530 b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3531 }
3532 }
3533 }
3534
3535 func BenchmarkClientGzip(b *testing.B) {
3536 DisableGoroutineTracking(b)
3537 b.ReportAllocs()
3538
3539 const responseSize = 1024 * 1024
3540
3541 var buf bytes.Buffer
3542 gz := gzip.NewWriter(&buf)
3543 if _, err := io.CopyN(gz, crand.Reader, responseSize); err != nil {
3544 b.Fatal(err)
3545 }
3546 gz.Close()
3547
3548 data := buf.Bytes()
3549 ts := newTestServer(b,
3550 func(w http.ResponseWriter, r *http.Request) {
3551 w.Header().Set("Content-Encoding", "gzip")
3552 w.Write(data)
3553 },
3554 optQuiet,
3555 )
3556
3557 tr := newTransport(b)
3558
3559 req, err := http.NewRequest("GET", ts.URL, nil)
3560 if err != nil {
3561 b.Fatal(err)
3562 }
3563
3564 b.ResetTimer()
3565
3566 for i := 0; i < b.N; i++ {
3567 res, err := tr.RoundTrip(req)
3568 if err != nil {
3569 b.Fatalf("RoundTrip err = %v; want nil", err)
3570 }
3571 if res.StatusCode != http.StatusOK {
3572 b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3573 }
3574 n, err := io.Copy(io.Discard, res.Body)
3575 res.Body.Close()
3576 if err != nil {
3577 b.Fatalf("RoundTrip err = %v; want nil", err)
3578 }
3579 if n != responseSize {
3580 b.Fatalf("RoundTrip expected %d bytes, got %d", responseSize, n)
3581 }
3582 }
3583 }
3584
3585
3586
3587
3588 func TestClientConnCloseAtHeaders(t *testing.T) { synctestTest(t, testClientConnCloseAtHeaders) }
3589 func testClientConnCloseAtHeaders(t testing.TB) {
3590 tc := newTestClientConn(t)
3591 tc.greet()
3592
3593 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3594 rt := tc.roundTrip(req)
3595 tc.wantFrameType(FrameHeaders)
3596
3597 tc.cc.Close()
3598 synctest.Wait()
3599 if err := rt.err(); err != ErrClientConnForceClosed {
3600 t.Fatalf("RoundTrip error = %v, want errClientConnForceClosed", err)
3601 }
3602 }
3603
3604
3605
3606 func TestClientConnCloseAtBody(t *testing.T) { synctestTest(t, testClientConnCloseAtBody) }
3607 func testClientConnCloseAtBody(t testing.TB) {
3608 tc := newTestClientConn(t)
3609 tc.greet()
3610
3611 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3612 rt := tc.roundTrip(req)
3613 tc.wantFrameType(FrameHeaders)
3614
3615 tc.writeHeaders(HeadersFrameParam{
3616 StreamID: rt.streamID(),
3617 EndHeaders: true,
3618 EndStream: false,
3619 BlockFragment: tc.makeHeaderBlockFragment(
3620 ":status", "200",
3621 ),
3622 })
3623 tc.writeData(rt.streamID(), false, make([]byte, 64))
3624 resp := rt.response()
3625 tc.cc.Close()
3626 synctest.Wait()
3627
3628 if _, err := io.Copy(io.Discard, resp.Body); err == nil {
3629 t.Error("expected a Copy error, got nil")
3630 }
3631 }
3632
3633
3634
3635 func TestClientConnShutdown(t *testing.T) { synctestTest(t, testClientConnShutdown) }
3636 func testClientConnShutdown(t testing.TB) {
3637 tc := newTestClientConn(t)
3638 tc.greet()
3639
3640 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3641 rt := tc.roundTrip(req)
3642 tc.wantFrameType(FrameHeaders)
3643
3644 go tc.cc.Shutdown(context.Background())
3645 synctest.Wait()
3646
3647 tc.wantFrameType(FrameGoAway)
3648 tc.wantIdle()
3649 body := []byte("body")
3650 tc.writeHeaders(HeadersFrameParam{
3651 StreamID: rt.streamID(),
3652 EndHeaders: true,
3653 EndStream: false,
3654 BlockFragment: tc.makeHeaderBlockFragment(
3655 ":status", "200",
3656 ),
3657 })
3658 tc.writeData(rt.streamID(), true, body)
3659
3660 rt.wantStatus(200)
3661 rt.wantBody(body)
3662
3663
3664 tc.wantClosed()
3665 }
3666
3667
3668
3669
3670 func TestClientConnShutdownCancel(t *testing.T) { synctestTest(t, testClientConnShutdownCancel) }
3671 func testClientConnShutdownCancel(t testing.TB) {
3672 tc := newTestClientConn(t)
3673 tc.greet()
3674
3675 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3676 rt := tc.roundTrip(req)
3677 tc.wantFrameType(FrameHeaders)
3678
3679 ctx, cancel := context.WithCancel(t.Context())
3680 var shutdownErr error
3681 go func() {
3682 shutdownErr = tc.cc.Shutdown(ctx)
3683 }()
3684 synctest.Wait()
3685
3686 tc.wantFrameType(FrameGoAway)
3687 tc.wantIdle()
3688
3689 cancel()
3690 synctest.Wait()
3691
3692 if shutdownErr != context.Canceled {
3693 t.Fatalf("ClientConn.Shutdown(ctx) did not return context.Canceled after cancelling context")
3694 }
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705 if rt.done() {
3706 t.Fatal("RoundTrip unexpectedly returned during shutdown")
3707 }
3708 }
3709
3710 type errReader struct {
3711 body []byte
3712 err error
3713 }
3714
3715 func (r *errReader) Read(p []byte) (int, error) {
3716 if len(r.body) > 0 {
3717 n := copy(p, r.body)
3718 r.body = r.body[n:]
3719 return n, nil
3720 }
3721 return 0, r.err
3722 }
3723
3724 func testTransportBodyReadError(t *testing.T, body []byte) {
3725 synctestTest(t, func(t testing.TB) {
3726 testTransportBodyReadErrorBubble(t, body)
3727 })
3728 }
3729 func testTransportBodyReadErrorBubble(t testing.TB, body []byte) {
3730 tc := newTestClientConn(t)
3731 tc.greet()
3732
3733 bodyReadError := errors.New("body read error")
3734 b := tc.newRequestBody()
3735 b.Write(body)
3736 b.closeWithError(bodyReadError)
3737 req, _ := http.NewRequest("PUT", "https://dummy.tld/", b)
3738 rt := tc.roundTrip(req)
3739
3740 tc.wantFrameType(FrameHeaders)
3741 var receivedBody []byte
3742 readFrames:
3743 for {
3744 switch f := tc.readFrame().(type) {
3745 case *DataFrame:
3746 receivedBody = append(receivedBody, f.Data()...)
3747 case *RSTStreamFrame:
3748 break readFrames
3749 default:
3750 t.Fatalf("unexpected frame: %v", f)
3751 case nil:
3752 t.Fatalf("transport is idle, want RST_STREAM")
3753 }
3754 }
3755 if !bytes.Equal(receivedBody, body) {
3756 t.Fatalf("body: %q; expected %q", receivedBody, body)
3757 }
3758
3759 if err := rt.err(); err != bodyReadError {
3760 t.Fatalf("err = %v; want %v", err, bodyReadError)
3761 }
3762 }
3763
3764 func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
3765 func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) }
3766
3767
3768
3769
3770 func TestTransportBodyEagerEndStream(t *testing.T) { synctestTest(t, testTransportBodyEagerEndStream) }
3771 func testTransportBodyEagerEndStream(t testing.TB) {
3772 const reqBody = "some request body"
3773 const resBody = "some response body"
3774
3775 tc := newTestClientConn(t)
3776 tc.greet()
3777
3778 body := strings.NewReader(reqBody)
3779 req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
3780 tc.roundTrip(req)
3781
3782 tc.wantFrameType(FrameHeaders)
3783 f := readFrame[*DataFrame](t, tc)
3784 if !f.StreamEnded() {
3785 t.Fatalf("data frame without END_STREAM %v", f)
3786 }
3787 }
3788
3789 type chunkReader struct {
3790 chunks [][]byte
3791 }
3792
3793 func (r *chunkReader) Read(p []byte) (int, error) {
3794 if len(r.chunks) > 0 {
3795 n := copy(p, r.chunks[0])
3796 r.chunks = r.chunks[1:]
3797 return n, nil
3798 }
3799 panic("shouldn't read this many times")
3800 }
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810 func TestTransportBodyLargerThanSpecifiedContentLength_len3(t *testing.T) {
3811 body := &chunkReader{[][]byte{
3812 []byte("123"),
3813 []byte("456"),
3814 }}
3815 synctestTest(t, func(t testing.TB) {
3816 testTransportBodyLargerThanSpecifiedContentLength(t, body, 3)
3817 })
3818 }
3819
3820 func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) {
3821 body := &chunkReader{[][]byte{
3822 []byte("123"),
3823 }}
3824 synctestTest(t, func(t testing.TB) {
3825 testTransportBodyLargerThanSpecifiedContentLength(t, body, 2)
3826 })
3827 }
3828
3829 func testTransportBodyLargerThanSpecifiedContentLength(t testing.TB, body *chunkReader, contentLen int64) {
3830 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3831 r.Body.Read(make([]byte, 6))
3832 })
3833
3834 tr := newTransport(t)
3835
3836 req, _ := http.NewRequest("POST", ts.URL, body)
3837 req.ContentLength = contentLen
3838 _, err := tr.RoundTrip(req)
3839 if err != ErrReqBodyTooLong {
3840 t.Fatalf("expected %v, got %v", ErrReqBodyTooLong, err)
3841 }
3842 }
3843
3844
3845 func TestTransportNewClientConnCloseOnWriteError(t *testing.T) {
3846 synctestTest(t, testTransportNewClientConnCloseOnWriteError)
3847 }
3848 func testTransportNewClientConnCloseOnWriteError(t testing.TB) {
3849
3850
3851
3852
3853
3854
3855
3856 t.Skip("TODO: test fails because write errors don't cause the conn to close")
3857
3858 tc := newTestClientConn(t)
3859
3860 synctest.Wait()
3861 writeErr := errors.New("write error")
3862 tc.netconn.loc.setWriteError(writeErr)
3863
3864 tc.writeSettings()
3865 tc.wantIdle()
3866
3867
3868 tc.wantFrameType(FrameSettings)
3869 tc.wantFrameType(FrameWindowUpdate)
3870 tc.wantIdle()
3871
3872 synctest.Wait()
3873 if !tc.netconn.IsClosedByPeer() {
3874 t.Error("expected closed conn")
3875 }
3876 }
3877
3878 func TestTransportRoundtripCloseOnWriteError(t *testing.T) {
3879 synctestTest(t, testTransportRoundtripCloseOnWriteError)
3880 }
3881 func testTransportRoundtripCloseOnWriteError(t testing.TB) {
3882 tc := newTestClientConn(t)
3883 tc.greet()
3884
3885 body := tc.newRequestBody()
3886 body.writeBytes(1)
3887 req, _ := http.NewRequest("GET", "https://dummy.tld/", body)
3888 rt := tc.roundTrip(req)
3889
3890 writeErr := errors.New("write error")
3891 tc.closeWriteWithError(writeErr)
3892
3893 body.writeBytes(1)
3894 if err := rt.err(); err != writeErr {
3895 t.Fatalf("RoundTrip error %v, want %v", err, writeErr)
3896 }
3897
3898 rt2 := tc.roundTrip(req)
3899 if err := rt2.err(); err != ErrClientConnUnusable {
3900 t.Fatalf("RoundTrip error %v, want errClientConnUnusable", err)
3901 }
3902 }
3903
3904
3905
3906
3907 func TestTransportBodyRewindRace(t *testing.T) {
3908 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3909 w.Header().Set("Connection", "close")
3910 w.WriteHeader(http.StatusOK)
3911 return
3912 })
3913
3914 tr := newTransport(t)
3915 tr.MaxConnsPerHost = 1
3916 client := &http.Client{
3917 Transport: tr,
3918 }
3919
3920 const clients = 50
3921
3922 var wg sync.WaitGroup
3923 wg.Add(clients)
3924 for range clients {
3925 req, err := http.NewRequest("POST", ts.URL, bytes.NewBufferString("abcdef"))
3926 if err != nil {
3927 t.Fatalf("unexpected new request error: %v", err)
3928 }
3929
3930 go func() {
3931 defer wg.Done()
3932 res, err := client.Do(req)
3933 if err == nil {
3934 res.Body.Close()
3935 }
3936 }()
3937 }
3938
3939 wg.Wait()
3940 }
3941
3942 type errorReader struct{ err error }
3943
3944 func (r errorReader) Read(p []byte) (int, error) { return 0, r.err }
3945
3946
3947
3948 func TestTransportServerResetStreamAtHeaders(t *testing.T) {
3949 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3950 w.WriteHeader(http.StatusUnauthorized)
3951 return
3952 })
3953
3954 tr := newTransport(t)
3955 tr.MaxConnsPerHost = 1
3956 tr.ExpectContinueTimeout = 10 * time.Second
3957
3958 client := &http.Client{
3959 Transport: tr,
3960 }
3961
3962 req, err := http.NewRequest("POST", ts.URL, errorReader{io.EOF})
3963 if err != nil {
3964 t.Fatalf("unexpected new request error: %v", err)
3965 }
3966 req.ContentLength = 0
3967 req.Header.Set("Expect", "100-continue")
3968 res, err := client.Do(req)
3969 if err != nil {
3970 t.Fatal(err)
3971 }
3972 res.Body.Close()
3973 }
3974
3975 type trackingReader struct {
3976 rdr io.Reader
3977 wasRead uint32
3978 }
3979
3980 func (tr *trackingReader) Read(p []byte) (int, error) {
3981 atomic.StoreUint32(&tr.wasRead, 1)
3982 return tr.rdr.Read(p)
3983 }
3984
3985 func (tr *trackingReader) WasRead() bool {
3986 return atomic.LoadUint32(&tr.wasRead) != 0
3987 }
3988
3989 func TestTransportExpectContinue(t *testing.T) {
3990 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3991 switch r.URL.Path {
3992 case "/reject":
3993 w.WriteHeader(403)
3994 default:
3995 io.Copy(io.Discard, r.Body)
3996 }
3997 })
3998
3999 tr := newTransport(t)
4000 tr.MaxConnsPerHost = 1
4001 tr.ExpectContinueTimeout = 10 * time.Second
4002
4003 client := &http.Client{
4004 Transport: tr,
4005 }
4006
4007 testCases := []struct {
4008 Name string
4009 Path string
4010 Body *trackingReader
4011 ExpectedCode int
4012 ShouldRead bool
4013 }{
4014 {
4015 Name: "read-all",
4016 Path: "/",
4017 Body: &trackingReader{rdr: strings.NewReader("hello")},
4018 ExpectedCode: 200,
4019 ShouldRead: true,
4020 },
4021 {
4022 Name: "reject",
4023 Path: "/reject",
4024 Body: &trackingReader{rdr: strings.NewReader("hello")},
4025 ExpectedCode: 403,
4026 ShouldRead: false,
4027 },
4028 }
4029
4030 for _, tc := range testCases {
4031 t.Run(tc.Name, func(t *testing.T) {
4032 startTime := time.Now()
4033
4034 req, err := http.NewRequest("POST", ts.URL+tc.Path, tc.Body)
4035 if err != nil {
4036 t.Fatal(err)
4037 }
4038 req.Header.Set("Expect", "100-continue")
4039 res, err := client.Do(req)
4040 if err != nil {
4041 t.Fatal(err)
4042 }
4043 res.Body.Close()
4044
4045 if delta := time.Since(startTime); delta >= tr.ExpectContinueTimeout {
4046 t.Error("Request didn't finish before expect continue timeout")
4047 }
4048 if res.StatusCode != tc.ExpectedCode {
4049 t.Errorf("Unexpected status code, got %d, expected %d", res.StatusCode, tc.ExpectedCode)
4050 }
4051 if tc.Body.WasRead() != tc.ShouldRead {
4052 t.Errorf("Unexpected read status, got %v, expected %v", tc.Body.WasRead(), tc.ShouldRead)
4053 }
4054 })
4055 }
4056 }
4057
4058 type closeChecker struct {
4059 io.ReadCloser
4060 closed chan struct{}
4061 }
4062
4063 func newCloseChecker(r io.ReadCloser) *closeChecker {
4064 return &closeChecker{r, make(chan struct{})}
4065 }
4066
4067 func newStaticCloseChecker(body string) *closeChecker {
4068 return newCloseChecker(io.NopCloser(strings.NewReader("body")))
4069 }
4070
4071 func (rc *closeChecker) Read(b []byte) (n int, err error) {
4072 select {
4073 default:
4074 case <-rc.closed:
4075
4076
4077
4078 return 0, errors.New("read after Body.Close")
4079 }
4080 return rc.ReadCloser.Read(b)
4081 }
4082
4083 func (rc *closeChecker) Close() error {
4084 close(rc.closed)
4085 return rc.ReadCloser.Close()
4086 }
4087
4088 func (rc *closeChecker) isClosed() error {
4089
4090
4091
4092 timeout := time.Duration(10 * time.Second)
4093 select {
4094 case <-rc.closed:
4095 case <-time.After(timeout):
4096 return fmt.Errorf("body not closed after %v", timeout)
4097 }
4098 return nil
4099 }
4100
4101
4102 type blockingWriteConn struct {
4103 net.Conn
4104 writeOnce sync.Once
4105 writec chan struct{}
4106 unblockc chan struct{}
4107 count, limit int
4108 }
4109
4110 func newBlockingWriteConn(conn net.Conn, limit int) *blockingWriteConn {
4111 return &blockingWriteConn{
4112 Conn: conn,
4113 limit: limit,
4114 writec: make(chan struct{}),
4115 unblockc: make(chan struct{}),
4116 }
4117 }
4118
4119
4120 func (c *blockingWriteConn) wait() {
4121 <-c.writec
4122 }
4123
4124
4125 func (c *blockingWriteConn) unblock() {
4126 close(c.unblockc)
4127 }
4128
4129 func (c *blockingWriteConn) Write(b []byte) (n int, err error) {
4130 if c.count+len(b) > c.limit {
4131 c.writeOnce.Do(func() {
4132 close(c.writec)
4133 })
4134 <-c.unblockc
4135 }
4136 n, err = c.Conn.Write(b)
4137 c.count += n
4138 return n, err
4139 }
4140
4141
4142
4143 func TestTransportFrameBufferReuse(t *testing.T) {
4144 filler := hex.EncodeToString([]byte(randString(2048)))
4145
4146 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4147 if got, want := r.Header.Get("Big"), filler; got != want {
4148 t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want)
4149 }
4150 b, err := io.ReadAll(r.Body)
4151 if err != nil {
4152 t.Errorf("error reading request body: %v", err)
4153 }
4154 if got, want := string(b), filler; got != want {
4155 t.Errorf("request body = %q, want %q", got, want)
4156 }
4157 if got, want := r.Trailer.Get("Big"), filler; got != want {
4158 t.Errorf(`r.Trailer.Get("Big") = %q, want %q`, got, want)
4159 }
4160 })
4161
4162 tr := newTransport(t)
4163
4164 var wg sync.WaitGroup
4165 defer wg.Wait()
4166 for range 10 {
4167 wg.Go(func() {
4168 req, err := http.NewRequest("POST", ts.URL, strings.NewReader(filler))
4169 if err != nil {
4170 t.Error(err)
4171 return
4172 }
4173 req.Header.Set("Big", filler)
4174 req.Trailer = make(http.Header)
4175 req.Trailer.Set("Big", filler)
4176 res, err := tr.RoundTrip(req)
4177 if err != nil {
4178 t.Error(err)
4179 return
4180 }
4181 if got, want := res.StatusCode, 200; got != want {
4182 t.Errorf("StatusCode = %v; want %v", got, want)
4183 }
4184 if res != nil && res.Body != nil {
4185 res.Body.Close()
4186 }
4187 })
4188 }
4189
4190 }
4191
4192
4193
4194
4195
4196 func TestTransportBlockingRequestWrite(t *testing.T) {
4197 filler := hex.EncodeToString([]byte(randString(2048)))
4198 for _, test := range []struct {
4199 name string
4200 req *http.Request
4201 }{{
4202 name: "headers",
4203 req: func() *http.Request {
4204 req, _ := http.NewRequest("POST", "https://dummy.tld/", nil)
4205 req.Header.Set("Big", filler)
4206 return req
4207 }(),
4208 }, {
4209 name: "body",
4210 req: func() *http.Request {
4211 req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(filler))
4212 return req
4213 }(),
4214 }, {
4215 name: "trailer",
4216 req: func() *http.Request {
4217 req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader("body"))
4218 req.Trailer = make(http.Header)
4219 req.Trailer.Set("Big", filler)
4220 return req
4221 }(),
4222 }} {
4223 t.Run(test.name, func(t *testing.T) {
4224 synctestTest(t, func(t testing.TB) {
4225 testTransportBlockingRequestWrite(t, test.req)
4226 })
4227 })
4228 }
4229 }
4230 func testTransportBlockingRequestWrite(t testing.TB, req2 *http.Request) {
4231 tt := newTestTransport(t)
4232
4233 smallReq := func() *http.Request {
4234 req, _ := http.NewRequest("GET", req2.URL.String(), nil)
4235 return req
4236 }
4237
4238
4239 rt1 := tt.roundTrip(smallReq())
4240 tc1 := tt.getConn()
4241 tc1.wantFrameType(FrameSettings)
4242 tc1.wantFrameType(FrameWindowUpdate)
4243 tc1.wantHeaders(wantHeader{
4244 streamID: 1,
4245 endStream: true,
4246 })
4247 tc1.writeSettings(Setting{SettingMaxConcurrentStreams, 1})
4248 tc1.writeHeaders(HeadersFrameParam{
4249 StreamID: 1,
4250 EndHeaders: true,
4251 EndStream: true,
4252 BlockFragment: tc1.makeHeaderBlockFragment(
4253 ":status", "200",
4254 ),
4255 })
4256 rt1.wantStatus(200)
4257 tc1.wantFrameType(FrameSettings)
4258
4259
4260 tc1.netconn.SetReadBufferSize(1024)
4261 rt2 := tt.roundTrip(req2)
4262
4263
4264
4265 rt3 := tt.roundTrip(smallReq())
4266 tc2 := tt.getConn()
4267 tc2.wantFrameType(FrameSettings)
4268 tc2.wantFrameType(FrameWindowUpdate)
4269 tc2.wantHeaders(wantHeader{
4270 streamID: 1,
4271 endStream: true,
4272 })
4273 tc2.writeSettings()
4274 tc2.writeHeaders(HeadersFrameParam{
4275 StreamID: 1,
4276 EndHeaders: true,
4277 EndStream: true,
4278 BlockFragment: tc1.makeHeaderBlockFragment(
4279 ":status", "200",
4280 ),
4281 })
4282 rt3.wantStatus(200)
4283 tc2.wantFrameType(FrameSettings)
4284
4285 if rt2.done() {
4286 t.Errorf("RoundTrip 2 is done, expect it to be still pending")
4287 }
4288 }
4289
4290 func TestTransportCloseRequestBody(t *testing.T) {
4291 var statusCode int
4292 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4293 w.WriteHeader(statusCode)
4294 })
4295
4296 tr := newTransport(t)
4297 ctx := context.Background()
4298 cc, err := tr.NewClientConn(ctx, "https", ts.Listener.Addr().String())
4299 if err != nil {
4300 t.Fatal(err)
4301 }
4302 defer cc.Close()
4303
4304 for _, status := range []int{200, 401} {
4305 t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) {
4306 statusCode = status
4307 pr, pw := io.Pipe()
4308 body := newCloseChecker(pr)
4309 req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
4310 if err != nil {
4311 t.Fatal(err)
4312 }
4313 res, err := cc.RoundTrip(req)
4314 if err != nil {
4315 t.Fatal(err)
4316 }
4317 res.Body.Close()
4318 pw.Close()
4319 if err := body.isClosed(); err != nil {
4320 t.Fatal(err)
4321 }
4322 })
4323 }
4324 }
4325
4326 func TestTransportNoRetryOnStreamProtocolError(t *testing.T) {
4327 synctestTest(t, testTransportNoRetryOnStreamProtocolError)
4328 }
4329 func testTransportNoRetryOnStreamProtocolError(t testing.TB) {
4330
4331
4332
4333
4334
4335 tt := newTestTransport(t)
4336
4337
4338
4339
4340
4341
4342 req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4343 rt1 := tt.roundTrip(req1)
4344 tc1 := tt.getConn()
4345 tc1.wantFrameType(FrameSettings)
4346 tc1.wantFrameType(FrameWindowUpdate)
4347 tc1.wantHeaders(wantHeader{
4348 streamID: 1,
4349 endStream: true,
4350 })
4351 tc1.writeSettings()
4352 tc1.wantFrameType(FrameSettings)
4353
4354
4355 req2, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4356 rt2 := tt.roundTrip(req2)
4357 tc1.wantHeaders(wantHeader{
4358 streamID: 3,
4359 endStream: true,
4360 })
4361
4362
4363 tc1.writeRSTStream(3, ErrCodeProtocol)
4364 if rt1.done() {
4365 t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #1 is done; want still in progress")
4366 }
4367 if !rt2.done() {
4368 t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is in progress; want done")
4369 }
4370
4371 if tt.hasConn() {
4372 t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is unexpectedly retried")
4373 }
4374
4375
4376 tc1.writeHeaders(HeadersFrameParam{
4377 StreamID: 1,
4378 EndHeaders: true,
4379 EndStream: true,
4380 BlockFragment: tc1.makeHeaderBlockFragment(
4381 ":status", "200",
4382 ),
4383 })
4384 rt1.wantStatus(200)
4385 }
4386
4387 func TestClientConnReservations(t *testing.T) { synctestTest(t, testClientConnReservations) }
4388 func testClientConnReservations(t testing.TB) {
4389 tc := newTestClientConn(t)
4390 tc.greet(
4391 Setting{ID: SettingMaxConcurrentStreams, Val: InitialMaxConcurrentStreams},
4392 )
4393
4394 doRoundTrip := func() {
4395 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4396 rt := tc.roundTrip(req)
4397 tc.wantFrameType(FrameHeaders)
4398 tc.writeHeaders(HeadersFrameParam{
4399 StreamID: rt.streamID(),
4400 EndHeaders: true,
4401 EndStream: true,
4402 BlockFragment: tc.makeHeaderBlockFragment(
4403 ":status", "200",
4404 ),
4405 })
4406 rt.wantStatus(200)
4407 }
4408
4409 n := 0
4410 for n <= InitialMaxConcurrentStreams && tc.cc.ReserveNewRequest() {
4411 n++
4412 }
4413 if n != InitialMaxConcurrentStreams {
4414 t.Errorf("did %v reservations; want %v", n, InitialMaxConcurrentStreams)
4415 }
4416 doRoundTrip()
4417 n2 := 0
4418 for n2 <= 5 && tc.cc.ReserveNewRequest() {
4419 n2++
4420 }
4421 if n2 != 1 {
4422 t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2)
4423 }
4424
4425
4426 for i := 0; i < n; i++ {
4427 doRoundTrip()
4428 }
4429
4430 n2 = 0
4431 for n2 <= InitialMaxConcurrentStreams && tc.cc.ReserveNewRequest() {
4432 n2++
4433 }
4434 if n2 != n {
4435 t.Errorf("after reset, reservations = %v; want %v", n2, n)
4436 }
4437 }
4438
4439 func TestTransportTimeoutServerHangs(t *testing.T) { synctestTest(t, testTransportTimeoutServerHangs) }
4440 func testTransportTimeoutServerHangs(t testing.TB) {
4441 tc := newTestClientConn(t)
4442 tc.greet()
4443
4444 ctx, cancel := context.WithCancel(context.Background())
4445 req, _ := http.NewRequestWithContext(ctx, "PUT", "https://dummy.tld/", nil)
4446 rt := tc.roundTrip(req)
4447
4448 tc.wantFrameType(FrameHeaders)
4449 time.Sleep(5 * time.Second)
4450 if f := tc.readFrame(); f != nil {
4451 t.Fatalf("unexpected frame: %v", f)
4452 }
4453 if rt.done() {
4454 t.Fatalf("after 5 seconds with no response, RoundTrip unexpectedly returned")
4455 }
4456
4457 cancel()
4458 synctest.Wait()
4459 if rt.err() != context.Canceled {
4460 t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err())
4461 }
4462 }
4463
4464 func TestTransportContentLengthWithoutBody(t *testing.T) {
4465 for _, test := range []struct {
4466 name string
4467 contentLength string
4468 wantBody string
4469 wantErr error
4470 wantContentLength int64
4471 }{
4472 {
4473 name: "non-zero content length",
4474 contentLength: "42",
4475 wantErr: io.ErrUnexpectedEOF,
4476 wantContentLength: 42,
4477 },
4478 {
4479 name: "zero content length",
4480 contentLength: "0",
4481 wantErr: nil,
4482 wantContentLength: 0,
4483 },
4484 } {
4485 synctestSubtest(t, test.name, func(t testing.TB) {
4486 contentLength := ""
4487 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4488 w.Header().Set("Content-Length", contentLength)
4489 })
4490 tr := newTransport(t)
4491
4492 contentLength = test.contentLength
4493
4494 req, _ := http.NewRequest("GET", ts.URL, nil)
4495 res, err := tr.RoundTrip(req)
4496 if err != nil {
4497 t.Fatal(err)
4498 }
4499 defer res.Body.Close()
4500 body, err := io.ReadAll(res.Body)
4501
4502 if err != test.wantErr {
4503 t.Errorf("Expected error %v, got: %v", test.wantErr, err)
4504 }
4505 if len(body) > 0 {
4506 t.Errorf("Expected empty body, got: %v", body)
4507 }
4508 if res.ContentLength != test.wantContentLength {
4509 t.Errorf("Expected content length %d, got: %d", test.wantContentLength, res.ContentLength)
4510 }
4511 })
4512 }
4513 }
4514
4515 func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) {
4516 synctestTest(t, testTransportCloseResponseBodyWhileRequestBodyHangs)
4517 }
4518 func testTransportCloseResponseBodyWhileRequestBodyHangs(t testing.TB) {
4519 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4520 w.WriteHeader(200)
4521 w.(http.Flusher).Flush()
4522 io.Copy(io.Discard, r.Body)
4523 })
4524
4525 tr := newTransport(t)
4526
4527 pr, pw := net.Pipe()
4528 req, err := http.NewRequest("GET", ts.URL, pr)
4529 if err != nil {
4530 t.Fatal(err)
4531 }
4532 res, err := tr.RoundTrip(req)
4533 if err != nil {
4534 t.Fatal(err)
4535 }
4536
4537 res.Body.Close()
4538 pw.Close()
4539 }
4540
4541 func TestTransport300ResponseBody(t *testing.T) { synctestTest(t, testTransport300ResponseBody) }
4542 func testTransport300ResponseBody(t testing.TB) {
4543 reqc := make(chan struct{})
4544 body := []byte("response body")
4545 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4546 w.WriteHeader(300)
4547 w.(http.Flusher).Flush()
4548 <-reqc
4549 w.Write(body)
4550 })
4551
4552 tr := newTransport(t)
4553
4554 pr, pw := net.Pipe()
4555 req, err := http.NewRequest("GET", ts.URL, pr)
4556 if err != nil {
4557 t.Fatal(err)
4558 }
4559 res, err := tr.RoundTrip(req)
4560 if err != nil {
4561 t.Fatal(err)
4562 }
4563 close(reqc)
4564 got, err := io.ReadAll(res.Body)
4565 if err != nil {
4566 t.Fatalf("error reading response body: %v", err)
4567 }
4568 if !bytes.Equal(got, body) {
4569 t.Errorf("got response body %q, want %q", string(got), string(body))
4570 }
4571 res.Body.Close()
4572 pw.Close()
4573 }
4574
4575 func TestTransportWriteByteTimeout(t *testing.T) {
4576 ts := newTestServer(t, nil, func(s *http.Server) {
4577 s.Protocols = protocols("h2c")
4578 })
4579 tr := newTransport(t)
4580 tr.Protocols = protocols("h2c")
4581 tr.Dial = func(network, addr string) (net.Conn, error) {
4582 _, c := net.Pipe()
4583 return c, nil
4584 }
4585 tr.HTTP2.WriteByteTimeout = 1 * time.Millisecond
4586 defer tr.CloseIdleConnections()
4587 c := &http.Client{Transport: tr}
4588
4589 _, err := c.Get(ts.URL)
4590 if !errors.Is(err, os.ErrDeadlineExceeded) {
4591 t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err)
4592 }
4593 }
4594
4595 type slowWriteConn struct {
4596 net.Conn
4597 hasWriteDeadline bool
4598 }
4599
4600 func (c *slowWriteConn) SetWriteDeadline(t time.Time) error {
4601 c.hasWriteDeadline = !t.IsZero()
4602 return nil
4603 }
4604
4605 func (c *slowWriteConn) Write(b []byte) (n int, err error) {
4606 if c.hasWriteDeadline && len(b) > 1 {
4607 n, err = c.Conn.Write(b[:1])
4608 if err != nil {
4609 return n, err
4610 }
4611 return n, fmt.Errorf("slow write: %w", os.ErrDeadlineExceeded)
4612 }
4613 return c.Conn.Write(b)
4614 }
4615
4616 func TestTransportSlowWrites(t *testing.T) { synctestTest(t, testTransportSlowWrites) }
4617 func testTransportSlowWrites(t testing.TB) {
4618 ts := newTestServer(t, nil, func(s *http.Server) {
4619 s.Protocols = protocols("h2c")
4620 })
4621 tr := newTransport(t)
4622 tr.Protocols = protocols("h2c")
4623 tr.Dial = func(network, addr string) (net.Conn, error) {
4624 c, err := net.Dial(network, addr)
4625 return &slowWriteConn{Conn: c}, err
4626 }
4627 tr.HTTP2.WriteByteTimeout = 1 * time.Millisecond
4628 c := &http.Client{Transport: tr}
4629
4630 const bodySize = 1 << 20
4631 resp, err := c.Post(ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize))
4632 if err != nil {
4633 t.Fatal(err)
4634 }
4635 resp.Body.Close()
4636 }
4637
4638 func TestTransportClosesConnAfterGoAwayNoStreams(t *testing.T) {
4639 synctestTest(t, func(t testing.TB) {
4640 testTransportClosesConnAfterGoAway(t, 0)
4641 })
4642 }
4643 func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) {
4644 synctestTest(t, func(t testing.TB) {
4645 testTransportClosesConnAfterGoAway(t, 1)
4646 })
4647 }
4648
4649
4650
4651
4652
4653
4654
4655 func testTransportClosesConnAfterGoAway(t testing.TB, lastStream uint32) {
4656 tc := newTestClientConn(t)
4657 tc.greet()
4658
4659 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4660 rt := tc.roundTrip(req)
4661
4662 tc.wantFrameType(FrameHeaders)
4663 tc.writeGoAway(lastStream, ErrCodeNo, nil)
4664
4665 if lastStream > 0 {
4666
4667 tc.writeHeaders(HeadersFrameParam{
4668 StreamID: rt.streamID(),
4669 EndHeaders: true,
4670 EndStream: true,
4671 BlockFragment: tc.makeHeaderBlockFragment(
4672 ":status", "200",
4673 ),
4674 })
4675 }
4676
4677 tc.closeWrite()
4678 err := rt.err()
4679 if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
4680 t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)
4681 }
4682 if !tc.isClosed() {
4683 t.Errorf("ClientConn did not close its net.Conn, expected it to")
4684 }
4685 }
4686
4687 type slowCloser struct {
4688 closing chan struct{}
4689 closed chan struct{}
4690 }
4691
4692 func (r *slowCloser) Read([]byte) (int, error) {
4693 return 0, io.EOF
4694 }
4695
4696 func (r *slowCloser) Close() error {
4697 close(r.closing)
4698 <-r.closed
4699 return nil
4700 }
4701
4702 func TestTransportSlowClose(t *testing.T) {
4703 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4704 })
4705
4706 client := ts.Client()
4707 body := &slowCloser{
4708 closing: make(chan struct{}),
4709 closed: make(chan struct{}),
4710 }
4711
4712 reqc := make(chan struct{})
4713 go func() {
4714 defer close(reqc)
4715 res, err := client.Post(ts.URL, "text/plain", body)
4716 if err != nil {
4717 t.Error(err)
4718 }
4719 res.Body.Close()
4720 }()
4721 defer func() {
4722 close(body.closed)
4723 <-reqc
4724 }()
4725
4726 <-body.closing
4727
4728 res, err := client.Get(ts.URL)
4729 if err != nil {
4730 t.Fatal(err)
4731 }
4732 res.Body.Close()
4733 }
4734
4735 func TestTransportDialTLSContext(t *testing.T) {
4736 blockCh := make(chan struct{})
4737 serverTLSConfigFunc := func(ts *httptest.Server) {
4738 ts.Config.TLSConfig = &tls.Config{
4739
4740
4741 ClientAuth: tls.RequestClientCert,
4742 }
4743 }
4744 ts := newTestServer(t,
4745 func(w http.ResponseWriter, r *http.Request) {},
4746 serverTLSConfigFunc,
4747 )
4748 tr := newTransport(t)
4749 tr.TLSClientConfig = &tls.Config{
4750 GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
4751
4752
4753 close(blockCh)
4754 <-cri.Context().Done()
4755 return nil, cri.Context().Err()
4756 },
4757 InsecureSkipVerify: true,
4758 }
4759 req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
4760 if err != nil {
4761 t.Fatal(err)
4762 }
4763 ctx, cancel := context.WithCancel(context.Background())
4764 defer cancel()
4765 req = req.WithContext(ctx)
4766 errCh := make(chan error)
4767 go func() {
4768 defer close(errCh)
4769 res, err := tr.RoundTrip(req)
4770 if err != nil {
4771 errCh <- err
4772 return
4773 }
4774 res.Body.Close()
4775 }()
4776
4777 <-blockCh
4778
4779 cancel()
4780
4781 err = <-errCh
4782 if err == nil {
4783 t.Fatal("cancelling context during client certificate fetch did not error as expected")
4784 return
4785 }
4786 if !errors.Is(err, context.Canceled) {
4787 t.Fatalf("unexpected error returned after cancellation: %v", err)
4788 }
4789 }
4790
4791
4792
4793
4794
4795 func TestDialRaceResumesDial(t *testing.T) {
4796 t.Skip("https://go.dev/issue/77908: test fails when using an http.Transport")
4797 blockCh := make(chan struct{})
4798 serverTLSConfigFunc := func(ts *httptest.Server) {
4799 ts.Config.TLSConfig = &tls.Config{
4800
4801
4802 ClientAuth: tls.RequestClientCert,
4803 }
4804 }
4805 ts := newTestServer(t,
4806 func(w http.ResponseWriter, r *http.Request) {},
4807 serverTLSConfigFunc,
4808 )
4809 tr := newTransport(t)
4810 tr.TLSClientConfig = &tls.Config{
4811 GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
4812 select {
4813 case <-blockCh:
4814
4815 return &tls.Certificate{}, nil
4816 default:
4817 }
4818 close(blockCh)
4819 <-cri.Context().Done()
4820 return nil, cri.Context().Err()
4821 },
4822 InsecureSkipVerify: true,
4823 }
4824 req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
4825 if err != nil {
4826 t.Fatal(err)
4827 }
4828
4829 ctx1, cancel1 := context.WithCancel(context.Background())
4830 defer cancel1()
4831 req1 := req.WithContext(ctx1)
4832 ctx2 := t.Context()
4833 req2 := req.WithContext(ctx2)
4834 errCh := make(chan error)
4835 go func() {
4836 res, err := tr.RoundTrip(req1)
4837 if err != nil {
4838 errCh <- err
4839 return
4840 }
4841 res.Body.Close()
4842 }()
4843 successCh := make(chan struct{})
4844 go func() {
4845
4846
4847 <-blockCh
4848 res, err := tr.RoundTrip(req2)
4849 if err != nil {
4850 errCh <- err
4851 return
4852 }
4853 res.Body.Close()
4854
4855
4856 close(successCh)
4857 }()
4858
4859 <-blockCh
4860
4861 cancel1()
4862
4863 err = <-errCh
4864 if err == nil {
4865 t.Fatal("cancelling context during client certificate fetch did not error as expected")
4866 return
4867 }
4868 if !errors.Is(err, context.Canceled) {
4869 t.Fatalf("unexpected error returned after cancellation: %v", err)
4870 }
4871 select {
4872 case err := <-errCh:
4873 t.Fatalf("unexpected second error: %v", err)
4874 case <-successCh:
4875 }
4876 }
4877
4878 func TestTransportDataAfter1xxHeader(t *testing.T) { synctestTest(t, testTransportDataAfter1xxHeader) }
4879 func testTransportDataAfter1xxHeader(t testing.TB) {
4880
4881 log.SetOutput(io.Discard)
4882 defer log.SetOutput(os.Stderr)
4883
4884
4885 tc := newTestClientConn(t)
4886 tc.greet()
4887
4888 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4889 rt := tc.roundTrip(req)
4890
4891 tc.wantFrameType(FrameHeaders)
4892 tc.writeHeaders(HeadersFrameParam{
4893 StreamID: rt.streamID(),
4894 EndHeaders: true,
4895 EndStream: false,
4896 BlockFragment: tc.makeHeaderBlockFragment(
4897 ":status", "100",
4898 ),
4899 })
4900 tc.writeData(rt.streamID(), true, []byte{0})
4901 err := rt.err()
4902 if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
4903 t.Errorf("RoundTrip error: %v; want ErrCodeProtocol", err)
4904 }
4905 tc.wantFrameType(FrameRSTStream)
4906 }
4907
4908 func TestIssue66763Race(t *testing.T) {
4909 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {},
4910 func(s *http.Server) {
4911 s.Protocols = protocols("h2c")
4912 })
4913 tr := newTransport(t)
4914 tr.IdleConnTimeout = 1 * time.Nanosecond
4915 tr.Protocols = protocols("h2c")
4916
4917 donec := make(chan struct{})
4918 go func() {
4919
4920
4921
4922 conn, err := tr.NewClientConn(t.Context(), "http", ts.URL)
4923 close(donec)
4924 if err == nil {
4925 conn.Close()
4926 }
4927 }()
4928
4929
4930
4931 <-donec
4932 }
4933
4934
4935
4936 func TestIssue67671(t *testing.T) {
4937 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {},
4938 func(s *http.Server) {
4939 s.Protocols = protocols("h2c")
4940 })
4941 tr := newTransport(t)
4942 tr.Protocols = protocols("h2c")
4943 req, _ := http.NewRequest("GET", ts.URL, nil)
4944 req.Close = true
4945 for range 2 {
4946 res, err := tr.RoundTrip(req)
4947 if err != nil {
4948 t.Fatal(err)
4949 }
4950 res.Body.Close()
4951 }
4952 }
4953
4954 func TestTransport1xxLimits(t *testing.T) {
4955 for _, test := range []struct {
4956 name string
4957 opt any
4958 ctxfn func(context.Context) context.Context
4959 hcount int
4960 limited bool
4961 }{{
4962 name: "default",
4963 hcount: 10,
4964 limited: false,
4965 }, {
4966 name: "MaxResponseHeaderBytes",
4967 opt: func(tr *http.Transport) {
4968 tr.MaxResponseHeaderBytes = 10000
4969 },
4970 hcount: 10,
4971 limited: true,
4972 }, {
4973 name: "limit by client trace",
4974 ctxfn: func(ctx context.Context) context.Context {
4975 count := 0
4976 return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
4977 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
4978 count++
4979 if count >= 10 {
4980 return errors.New("too many 1xx")
4981 }
4982 return nil
4983 },
4984 })
4985 },
4986 hcount: 10,
4987 limited: true,
4988 }, {
4989 name: "limit disabled by client trace",
4990 opt: func(tr *http.Transport) {
4991 tr.MaxResponseHeaderBytes = 10000
4992 },
4993 ctxfn: func(ctx context.Context) context.Context {
4994 return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
4995 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
4996 return nil
4997 },
4998 })
4999 },
5000 hcount: 20,
5001 limited: false,
5002 }} {
5003 synctestSubtest(t, test.name, func(t testing.TB) {
5004 tc := newTestClientConn(t, test.opt)
5005 tc.greet()
5006
5007 ctx := context.Background()
5008 if test.ctxfn != nil {
5009 ctx = test.ctxfn(ctx)
5010 }
5011 req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
5012 rt := tc.roundTrip(req)
5013 tc.wantFrameType(FrameHeaders)
5014
5015 for i := 0; i < test.hcount; i++ {
5016 if fr, err := tc.fr.ReadFrame(); err != os.ErrDeadlineExceeded {
5017 t.Fatalf("after writing %v 1xx headers: read %v, %v; want idle", i, fr, err)
5018 }
5019 tc.writeHeaders(HeadersFrameParam{
5020 StreamID: rt.streamID(),
5021 EndHeaders: true,
5022 EndStream: false,
5023 BlockFragment: tc.makeHeaderBlockFragment(
5024 ":status", "103",
5025 "x-field", strings.Repeat("a", 1000),
5026 ),
5027 })
5028 }
5029 if test.limited {
5030 tc.wantFrameType(FrameRSTStream)
5031 } else {
5032 tc.wantIdle()
5033 }
5034 })
5035 }
5036 }
5037
5038
5039
5040 func TestTransportSendPingWithReset(t *testing.T) { synctestTest(t, testTransportSendPingWithReset) }
5041 func testTransportSendPingWithReset(t testing.TB) {
5042 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
5043 h2.StrictMaxConcurrentRequests = true
5044 })
5045
5046 const maxConcurrent = 3
5047 tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
5048
5049
5050 var rts []*testRoundTrip
5051 for i := range maxConcurrent + 1 {
5052 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5053 rt := tc.roundTrip(req)
5054 if i >= maxConcurrent {
5055 tc.wantIdle()
5056 continue
5057 }
5058 tc.wantFrameType(FrameHeaders)
5059 rts = append(rts, rt)
5060 }
5061
5062
5063 rts[0].cancel()
5064 tc.wantRSTStream(rts[0].streamID(), ErrCodeCancel)
5065 pf := readFrame[*PingFrame](t, tc)
5066 tc.wantIdle()
5067
5068
5069 rts[1].cancel()
5070 tc.wantRSTStream(rts[1].streamID(), ErrCodeCancel)
5071 tc.wantIdle()
5072
5073
5074
5075 tc.writePing(true, pf.Data)
5076 tc.wantFrameType(FrameHeaders)
5077 tc.wantIdle()
5078 }
5079
5080
5081
5082
5083
5084 func TestTransportNoPingAfterResetWithFrames(t *testing.T) {
5085 synctestTest(t, testTransportNoPingAfterResetWithFrames)
5086 }
5087 func testTransportNoPingAfterResetWithFrames(t testing.TB) {
5088 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
5089 h2.StrictMaxConcurrentRequests = true
5090 })
5091
5092 const maxConcurrent = 1
5093 tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
5094
5095
5096
5097 req1 := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5098 rt1 := tc.roundTrip(req1)
5099 tc.wantFrameType(FrameHeaders)
5100 tc.writeHeaders(HeadersFrameParam{
5101 StreamID: rt1.streamID(),
5102 EndHeaders: true,
5103 BlockFragment: tc.makeHeaderBlockFragment(
5104 ":status", "200",
5105 ),
5106 })
5107 rt1.wantStatus(200)
5108
5109
5110
5111 req2 := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5112 rt2 := tc.roundTrip(req2)
5113 tc.wantIdle()
5114
5115
5116
5117 rt1.cancel()
5118 tc.wantRSTStream(rt1.streamID(), ErrCodeCancel)
5119 tc.wantFrameType(FrameHeaders)
5120
5121
5122
5123
5124 rt2.cancel()
5125 tc.wantRSTStream(rt2.streamID(), ErrCodeCancel)
5126 tc.wantFrameType(FramePing)
5127 }
5128
5129
5130
5131 func TestTransportSendNoMoreThanOnePingWithReset(t *testing.T) {
5132 synctestTest(t, testTransportSendNoMoreThanOnePingWithReset)
5133 }
5134 func testTransportSendNoMoreThanOnePingWithReset(t testing.TB) {
5135 tc := newTestClientConn(t)
5136 tc.greet()
5137
5138 makeAndResetRequest := func() {
5139 t.Helper()
5140 ctx, cancel := context.WithCancel(context.Background())
5141 req := Must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil))
5142 rt := tc.roundTrip(req)
5143 tc.wantFrameType(FrameHeaders)
5144 cancel()
5145 tc.wantRSTStream(rt.streamID(), ErrCodeCancel)
5146 }
5147
5148
5149
5150 makeAndResetRequest()
5151 pf1 := readFrame[*PingFrame](t, tc)
5152 tc.wantIdle()
5153
5154
5155
5156
5157
5158 makeAndResetRequest()
5159 tc.wantIdle()
5160
5161
5162
5163 tc.writeHeaders(HeadersFrameParam{
5164 StreamID: 1,
5165 EndHeaders: true,
5166 EndStream: true,
5167 BlockFragment: tc.makeHeaderBlockFragment(
5168 ":status", "200",
5169 ),
5170 })
5171 tc.wantIdle()
5172
5173
5174
5175
5176 makeAndResetRequest()
5177 tc.wantIdle()
5178
5179
5180 tc.writePing(true, pf1.Data)
5181 tc.wantIdle()
5182
5183
5184
5185
5186 makeAndResetRequest()
5187 tc.wantIdle()
5188
5189
5190 tc.writeHeaders(HeadersFrameParam{
5191 StreamID: 3,
5192 EndHeaders: true,
5193 EndStream: true,
5194 BlockFragment: tc.makeHeaderBlockFragment(
5195 ":status", "200",
5196 ),
5197 })
5198 tc.wantIdle()
5199
5200
5201
5202 makeAndResetRequest()
5203 tc.wantFrameType(FramePing)
5204 }
5205
5206 func TestTransportConnBecomesUnresponsive(t *testing.T) {
5207 synctestTest(t, testTransportConnBecomesUnresponsive)
5208 }
5209 func testTransportConnBecomesUnresponsive(t testing.TB) {
5210
5211
5212
5213 tt := newTestTransport(t)
5214
5215 const maxConcurrent = 3
5216
5217 t.Logf("first request opens a new connection and succeeds")
5218 req1 := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5219 rt1 := tt.roundTrip(req1)
5220 tc1 := tt.getConn()
5221 tc1.wantFrameType(FrameSettings)
5222 tc1.wantFrameType(FrameWindowUpdate)
5223 hf1 := readFrame[*HeadersFrame](t, tc1)
5224 tc1.writeSettings(Setting{SettingMaxConcurrentStreams, maxConcurrent})
5225 tc1.wantFrameType(FrameSettings)
5226 tc1.writeHeaders(HeadersFrameParam{
5227 StreamID: hf1.StreamID,
5228 EndHeaders: true,
5229 EndStream: true,
5230 BlockFragment: tc1.makeHeaderBlockFragment(
5231 ":status", "200",
5232 ),
5233 })
5234 rt1.wantStatus(200)
5235 rt1.response().Body.Close()
5236
5237
5238
5239
5240 for i := range maxConcurrent {
5241 t.Logf("request %v receives no response and is canceled", i)
5242 ctx, cancel := context.WithCancel(context.Background())
5243 req := Must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil))
5244 tt.roundTrip(req)
5245 if tt.hasConn() {
5246 t.Fatalf("new connection created; expect existing conn to be reused")
5247 }
5248 tc1.wantFrameType(FrameHeaders)
5249 cancel()
5250 tc1.wantFrameType(FrameRSTStream)
5251 if i == 0 {
5252 tc1.wantFrameType(FramePing)
5253 }
5254 tc1.wantIdle()
5255 }
5256
5257
5258
5259 req2 := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5260 rt2 := tt.roundTrip(req2)
5261 tc2 := tt.getConn()
5262 tc2.wantFrameType(FrameSettings)
5263 tc2.wantFrameType(FrameWindowUpdate)
5264 hf := readFrame[*HeadersFrame](t, tc2)
5265 tc2.writeSettings(Setting{SettingMaxConcurrentStreams, maxConcurrent})
5266 tc2.wantFrameType(FrameSettings)
5267 tc2.writeHeaders(HeadersFrameParam{
5268 StreamID: hf.StreamID,
5269 EndHeaders: true,
5270 EndStream: true,
5271 BlockFragment: tc2.makeHeaderBlockFragment(
5272 ":status", "200",
5273 ),
5274 })
5275 rt2.wantStatus(200)
5276 rt2.response().Body.Close()
5277 }
5278
5279
5280
5281
5282
5283
5284 func newTestTransportWithUnusedConn(t testing.TB, opts ...any) *testTransport {
5285 tt := newTestTransport(t, opts...)
5286
5287 waitc := make(chan struct{})
5288 dialContext := tt.tr1.DialContext
5289 tt.tr1.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
5290 <-waitc
5291 return dialContext(ctx, network, address)
5292 }
5293
5294 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5295 rt := tt.roundTrip(req)
5296 rt.cancel()
5297 if rt.err() == nil {
5298 t.Fatalf("RoundTrip still running after request is canceled")
5299 }
5300
5301 close(waitc)
5302 synctest.Wait()
5303 return tt
5304 }
5305
5306
5307 func TestTransportUnusedConnOK(t *testing.T) { synctestTest(t, testTransportUnusedConnOK) }
5308 func testTransportUnusedConnOK(t testing.TB) {
5309 tt := newTestTransportWithUnusedConn(t)
5310
5311 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5312 tc := tt.getConn()
5313 tc.wantFrameType(FrameSettings)
5314 tc.wantFrameType(FrameWindowUpdate)
5315
5316
5317
5318 rt := tt.roundTrip(req)
5319 tc.wantHeaders(wantHeader{
5320 streamID: 1,
5321 endStream: true,
5322 header: http.Header{
5323 ":authority": []string{"dummy.tld"},
5324 ":method": []string{"GET"},
5325 ":path": []string{"/"},
5326 },
5327 })
5328
5329 tc.writeSettings()
5330 tc.writeSettingsAck()
5331 tc.wantFrameType(FrameSettings)
5332
5333 tc.writeHeaders(HeadersFrameParam{
5334 StreamID: 1,
5335 EndHeaders: true,
5336 EndStream: true,
5337 BlockFragment: tc.makeHeaderBlockFragment(
5338 ":status", "200",
5339 ),
5340 })
5341 rt.wantStatus(200)
5342 rt.wantBody(nil)
5343 }
5344
5345
5346 func TestTransportUnusedConnImmediateFailureUsed(t *testing.T) {
5347 synctestTest(t, testTransportUnusedConnImmediateFailureUsed)
5348 }
5349 func testTransportUnusedConnImmediateFailureUsed(t testing.TB) {
5350 tt := newTestTransportWithUnusedConn(t)
5351
5352
5353 tc1 := tt.getConn()
5354 tc1.closeWrite()
5355
5356
5357
5358
5359 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5360 rt := tt.roundTrip(req)
5361 if err := rt.err(); err == nil || errors.Is(err, ErrNoCachedConn) {
5362 t.Fatalf("RoundTrip with broken conn: got %v, want an error other than ErrNoCachedConn", err)
5363 }
5364
5365
5366
5367
5368 _ = tt.roundTrip(req)
5369 tc2 := tt.getConn()
5370 tc2.wantFrameType(FrameSettings)
5371 tc2.wantFrameType(FrameWindowUpdate)
5372 tc2.wantFrameType(FrameHeaders)
5373 }
5374
5375
5376 func TestTransportUnusedConnIdleTimoutBeforeUse(t *testing.T) {
5377 synctestTest(t, testTransportUnusedConnIdleTimoutBeforeUse)
5378 }
5379 func testTransportUnusedConnIdleTimoutBeforeUse(t testing.TB) {
5380 tt := newTestTransportWithUnusedConn(t, func(t1 *http.Transport) {
5381 t1.IdleConnTimeout = 1 * time.Second
5382 })
5383
5384 _ = tt.getConn()
5385
5386
5387 time.Sleep(2 * time.Second)
5388 synctest.Wait()
5389
5390
5391
5392
5393
5394 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5395 _ = tt.roundTrip(req)
5396 tc2 := tt.getConn()
5397 tc2.wantFrameType(FrameSettings)
5398 tc2.wantFrameType(FrameWindowUpdate)
5399 tc2.wantFrameType(FrameHeaders)
5400 }
5401
5402
5403
5404 func TestTransportTLSNextProtoConnImmediateFailureUnused(t *testing.T) {
5405 synctestTest(t, testTransportTLSNextProtoConnImmediateFailureUnused)
5406 }
5407 func testTransportTLSNextProtoConnImmediateFailureUnused(t testing.TB) {
5408 tt := newTestTransportWithUnusedConn(t, func(t1 *http.Transport) {
5409 t1.IdleConnTimeout = 1 * time.Second
5410 })
5411
5412
5413 tc1 := tt.getConn()
5414 tc1.closeWrite()
5415
5416
5417
5418 time.Sleep(10 * time.Second)
5419
5420
5421
5422
5423 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5424 _ = tt.roundTrip(req)
5425 tc2 := tt.getConn()
5426 tc2.wantFrameType(FrameSettings)
5427 tc2.wantFrameType(FrameWindowUpdate)
5428 tc2.wantFrameType(FrameHeaders)
5429 }
5430
5431 func TestExtendedConnectClientWithServerSupport(t *testing.T) {
5432 t.Skip("https://go.dev/issue/53208 -- net/http needs to support the :protocol header")
5433 SetDisableExtendedConnectProtocol(t, false)
5434 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
5435 if r.Header.Get(":protocol") != "extended-connect" {
5436 t.Fatalf("unexpected :protocol header received")
5437 }
5438 t.Log(io.Copy(w, r.Body))
5439 })
5440 tr := newTransport(t)
5441 pr, pw := io.Pipe()
5442 pwDone := make(chan struct{})
5443 req, _ := http.NewRequest("CONNECT", ts.URL, pr)
5444 req.Header.Set(":protocol", "extended-connect")
5445 req.Header.Set("X-A", "A")
5446 req.Header.Set("X-B", "B")
5447 req.Header.Set("X-C", "C")
5448 go func() {
5449 pw.Write([]byte("hello, extended connect"))
5450 pw.Close()
5451 close(pwDone)
5452 }()
5453
5454 res, err := tr.RoundTrip(req)
5455 if err != nil {
5456 t.Fatal(err)
5457 }
5458 body, err := io.ReadAll(res.Body)
5459 if err != nil {
5460 t.Fatal(err)
5461 }
5462 if !bytes.Equal(body, []byte("hello, extended connect")) {
5463 t.Fatal("unexpected body received")
5464 }
5465 }
5466
5467 func TestExtendedConnectClientWithoutServerSupport(t *testing.T) {
5468 t.Skip("https://go.dev/issue/53208 -- net/http needs to support the :protocol header")
5469 SetDisableExtendedConnectProtocol(t, true)
5470 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
5471 io.Copy(w, r.Body)
5472 })
5473 tr := newTransport(t)
5474 pr, pw := io.Pipe()
5475 pwDone := make(chan struct{})
5476 req, _ := http.NewRequest("CONNECT", ts.URL, pr)
5477 req.Header.Set(":protocol", "extended-connect")
5478 req.Header.Set("X-A", "A")
5479 req.Header.Set("X-B", "B")
5480 req.Header.Set("X-C", "C")
5481 go func() {
5482 pw.Write([]byte("hello, extended connect"))
5483 pw.Close()
5484 close(pwDone)
5485 }()
5486
5487 _, err := tr.RoundTrip(req)
5488 if !errors.Is(err, ErrExtendedConnectNotSupported) {
5489 t.Fatalf("expected error errExtendedConnectNotSupported, got: %v", err)
5490 }
5491 }
5492
5493
5494
5495 func TestExtendedConnectReadFrameError(t *testing.T) {
5496 synctestTest(t, testExtendedConnectReadFrameError)
5497 }
5498 func testExtendedConnectReadFrameError(t testing.TB) {
5499 t.Skip("https://go.dev/issue/53208 -- net/http needs to support the :protocol header")
5500 tc := newTestClientConn(t)
5501 tc.wantFrameType(FrameSettings)
5502 tc.wantFrameType(FrameWindowUpdate)
5503
5504 req, _ := http.NewRequest("CONNECT", "https://dummy.tld/", nil)
5505 req.Header.Set(":protocol", "extended-connect")
5506 rt := tc.roundTrip(req)
5507 tc.wantIdle()
5508
5509 tc.closeWrite()
5510 if !rt.done() {
5511 t.Fatalf("after connection closed: RoundTrip still running; want done")
5512 }
5513 if rt.err() == nil {
5514 t.Fatalf("after connection closed: RoundTrip succeeded; want error")
5515 }
5516 }
5517
View as plain text