Source file
src/net/http/clientserver_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bytes"
11 "compress/gzip"
12 "context"
13 "crypto/rand"
14 "crypto/sha1"
15 "crypto/tls"
16 "fmt"
17 "hash"
18 "io"
19 "log"
20 "maps"
21 "net"
22 . "net/http"
23 "net/http/httptest"
24 "net/http/httptrace"
25 "net/http/httputil"
26 "net/textproto"
27 "net/url"
28 "os"
29 "reflect"
30 "runtime"
31 "slices"
32 "strings"
33 "sync"
34 "sync/atomic"
35 "testing"
36 "testing/synctest"
37 "time"
38 )
39
40 type testMode string
41
42 const (
43 http1Mode = testMode("h1")
44 https1Mode = testMode("https1")
45 http2Mode = testMode("h2")
46 http2UnencryptedMode = testMode("h2unencrypted")
47 )
48
49 func (m testMode) Scheme() string {
50 switch m {
51 case http1Mode, http2UnencryptedMode:
52 return "http"
53 case https1Mode, http2Mode:
54 return "https"
55 }
56 panic("unknown testMode")
57 }
58
59 type testNotParallelOpt struct{}
60
61 var (
62 testNotParallel = testNotParallelOpt{}
63 )
64
65 type TBRun[T any] interface {
66 testing.TB
67 Run(string, func(T)) bool
68 }
69
70
71
72
73
74
75
76
77 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
78 t.Helper()
79 modes := []testMode{http1Mode, http2Mode}
80 parallel := true
81 for _, opt := range opts {
82 switch opt := opt.(type) {
83 case []testMode:
84 modes = opt
85 case testNotParallelOpt:
86 parallel = false
87 default:
88 t.Fatalf("unknown option type %T", opt)
89 }
90 }
91 if t, ok := any(t).(*testing.T); ok && parallel {
92 setParallel(t)
93 }
94 for _, mode := range modes {
95 t.Run(string(mode), func(t T) {
96 t.Helper()
97 if t, ok := any(t).(*testing.T); ok && parallel {
98 setParallel(t)
99 }
100 t.Cleanup(func() {
101 afterTest(t)
102 })
103 f(t, mode)
104 })
105 }
106 }
107
108
109
110
111 func runSynctest(t *testing.T, f func(t *testing.T, mode testMode), opts ...any) {
112 run(t, func(t *testing.T, mode testMode) {
113 synctest.Test(t, func(t *testing.T) {
114 f(t, mode)
115 })
116 }, opts...)
117 }
118
119 type clientServerTest struct {
120 t testing.TB
121 h2 bool
122 h Handler
123 ts *httptest.Server
124 tr *Transport
125 c *Client
126 li *fakeNetListener
127 }
128
129 func (t *clientServerTest) close() {
130 t.tr.CloseIdleConnections()
131 t.ts.Close()
132 }
133
134 func (t *clientServerTest) getURL(u string) string {
135 res, err := t.c.Get(u)
136 if err != nil {
137 t.t.Fatal(err)
138 }
139 defer res.Body.Close()
140 slurp, err := io.ReadAll(res.Body)
141 if err != nil {
142 t.t.Fatal(err)
143 }
144 return string(slurp)
145 }
146
147 func (t *clientServerTest) scheme() string {
148 if t.h2 {
149 return "https"
150 }
151 return "http"
152 }
153
154 var optQuietLog = func(ts *httptest.Server) {
155 ts.Config.ErrorLog = quietLog
156 }
157
158 func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
159 return func(ts *httptest.Server) {
160 ts.Config.ErrorLog = lg
161 }
162 }
163
164 var optFakeNet = new(struct{})
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
181 if mode == http2Mode {
182 CondSkipHTTP2(t)
183 }
184 cst := &clientServerTest{
185 t: t,
186 h2: mode == http2Mode,
187 h: h,
188 }
189
190 var transportFuncs []func(*Transport)
191
192 if idx := slices.Index(opts, any(optFakeNet)); idx >= 0 {
193 opts = slices.Delete(opts, idx, idx+1)
194 cst.li = fakeNetListen()
195 cst.ts = &httptest.Server{
196 Config: &Server{Handler: h},
197 Listener: cst.li,
198 }
199 transportFuncs = append(transportFuncs, func(tr *Transport) {
200 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
201 return cst.li.connect(), nil
202 }
203 })
204 } else {
205 cst.ts = httptest.NewUnstartedServer(h)
206 }
207
208 if mode == http2UnencryptedMode {
209 p := &Protocols{}
210 p.SetUnencryptedHTTP2(true)
211 cst.ts.Config.Protocols = p
212 }
213
214 for _, opt := range opts {
215 switch opt := opt.(type) {
216 case func(*Transport):
217 transportFuncs = append(transportFuncs, opt)
218 case func(*httptest.Server):
219 opt(cst.ts)
220 case func(*Server):
221 opt(cst.ts.Config)
222 default:
223 t.Fatalf("unhandled option type %T", opt)
224 }
225 }
226
227 if cst.ts.Config.ErrorLog == nil {
228 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
229 }
230
231 switch mode {
232 case http1Mode:
233 cst.ts.Start()
234 case https1Mode:
235 cst.ts.StartTLS()
236 case http2UnencryptedMode:
237 ExportHttp2ConfigureServer(cst.ts.Config, nil)
238 cst.ts.Start()
239 case http2Mode:
240 ExportHttp2ConfigureServer(cst.ts.Config, nil)
241 cst.ts.TLS = cst.ts.Config.TLSConfig
242 cst.ts.StartTLS()
243 default:
244 t.Fatalf("unknown test mode %v", mode)
245 }
246 cst.c = cst.ts.Client()
247 cst.tr = cst.c.Transport.(*Transport)
248 if mode == http2Mode || mode == http2UnencryptedMode {
249 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
250 t.Fatal(err)
251 }
252 }
253 for _, f := range transportFuncs {
254 f(cst.tr)
255 }
256
257 if mode == http2UnencryptedMode {
258 p := &Protocols{}
259 p.SetUnencryptedHTTP2(true)
260 cst.tr.Protocols = p
261 }
262
263 t.Cleanup(func() {
264 cst.close()
265 })
266 return cst
267 }
268
269 type testLogWriter struct {
270 t testing.TB
271 }
272
273 func (w testLogWriter) Write(b []byte) (int, error) {
274 w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
275 return len(b), nil
276 }
277
278
279 func TestNewClientServerTest(t *testing.T) {
280 modes := []testMode{http1Mode, https1Mode, http2Mode}
281 t.Run("realnet", func(t *testing.T) {
282 run(t, func(t *testing.T, mode testMode) {
283 testNewClientServerTest(t, mode)
284 }, modes)
285 })
286 t.Run("synctest", func(t *testing.T) {
287 runSynctest(t, func(t *testing.T, mode testMode) {
288 testNewClientServerTest(t, mode, optFakeNet)
289 }, modes)
290 })
291 }
292 func testNewClientServerTest(t *testing.T, mode testMode, opts ...any) {
293 var got struct {
294 sync.Mutex
295 proto string
296 hasTLS bool
297 }
298 h := HandlerFunc(func(w ResponseWriter, r *Request) {
299 got.Lock()
300 defer got.Unlock()
301 got.proto = r.Proto
302 got.hasTLS = r.TLS != nil
303 })
304 cst := newClientServerTest(t, mode, h, opts...)
305 if _, err := cst.c.Head(cst.ts.URL); err != nil {
306 t.Fatal(err)
307 }
308 var wantProto string
309 var wantTLS bool
310 switch mode {
311 case http1Mode:
312 wantProto = "HTTP/1.1"
313 wantTLS = false
314 case https1Mode:
315 wantProto = "HTTP/1.1"
316 wantTLS = true
317 case http2Mode:
318 wantProto = "HTTP/2.0"
319 wantTLS = true
320 }
321 if got.proto != wantProto {
322 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
323 }
324 if got.hasTLS != wantTLS {
325 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
326 }
327 }
328
329 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
330 func testChunkedResponseHeaders(t *testing.T, mode testMode) {
331 log.SetOutput(io.Discard)
332 defer log.SetOutput(os.Stderr)
333 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
334 w.Header().Set("Content-Length", "intentional gibberish")
335 w.(Flusher).Flush()
336 fmt.Fprintf(w, "I am a chunked response.")
337 }))
338
339 res, err := cst.c.Get(cst.ts.URL)
340 if err != nil {
341 t.Fatalf("Get error: %v", err)
342 }
343 defer res.Body.Close()
344 if g, e := res.ContentLength, int64(-1); g != e {
345 t.Errorf("expected ContentLength of %d; got %d", e, g)
346 }
347 wantTE := []string{"chunked"}
348 if mode == http2Mode {
349 wantTE = nil
350 }
351 if !slices.Equal(res.TransferEncoding, wantTE) {
352 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
353 }
354 if got, haveCL := res.Header["Content-Length"]; haveCL {
355 t.Errorf("Unexpected Content-Length: %q", got)
356 }
357 }
358
359 type reqFunc func(c *Client, url string) (*Response, error)
360
361
362
363 type h12Compare struct {
364 Handler func(ResponseWriter, *Request)
365 ReqFunc reqFunc
366 CheckResponse func(proto string, res *Response)
367 EarlyCheckResponse func(proto string, res *Response)
368 Opts []any
369 }
370
371 func (tt h12Compare) reqFunc() reqFunc {
372 if tt.ReqFunc == nil {
373 return (*Client).Get
374 }
375 return tt.ReqFunc
376 }
377
378 func (tt h12Compare) run(t *testing.T) {
379 setParallel(t)
380 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
381 defer cst1.close()
382 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
383 defer cst2.close()
384
385 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
386 if err != nil {
387 t.Errorf("HTTP/1 request: %v", err)
388 return
389 }
390 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
391 if err != nil {
392 t.Errorf("HTTP/2 request: %v", err)
393 return
394 }
395
396 if fn := tt.EarlyCheckResponse; fn != nil {
397 fn("HTTP/1.1", res1)
398 fn("HTTP/2.0", res2)
399 }
400
401 tt.normalizeRes(t, res1, "HTTP/1.1")
402 tt.normalizeRes(t, res2, "HTTP/2.0")
403 res1body, res2body := res1.Body, res2.Body
404
405 eres1 := mostlyCopy(res1)
406 eres2 := mostlyCopy(res2)
407 if !reflect.DeepEqual(eres1, eres2) {
408 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
409 cst1.ts.URL, eres1, cst2.ts.URL, eres2)
410 }
411 if !reflect.DeepEqual(res1body, res2body) {
412 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
413 }
414 if fn := tt.CheckResponse; fn != nil {
415 res1.Body, res2.Body = res1body, res2body
416 fn("HTTP/1.1", res1)
417 fn("HTTP/2.0", res2)
418 }
419 }
420
421 func mostlyCopy(r *Response) *Response {
422 c := *r
423 c.Body = nil
424 c.TransferEncoding = nil
425 c.TLS = nil
426 c.Request = nil
427 return &c
428 }
429
430 type slurpResult struct {
431 io.ReadCloser
432 body []byte
433 err error
434 }
435
436 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
437
438 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
439 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
440 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
441 } else {
442 t.Errorf("got %q response; want %q", res.Proto, wantProto)
443 }
444 slurp, err := io.ReadAll(res.Body)
445
446 res.Body.Close()
447 res.Body = slurpResult{
448 ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
449 body: slurp,
450 err: err,
451 }
452 for i, v := range res.Header["Date"] {
453 res.Header["Date"][i] = strings.Repeat("x", len(v))
454 }
455 if res.Request == nil {
456 t.Errorf("for %s, no request", wantProto)
457 }
458 if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
459 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
460 }
461 }
462
463
464 func TestH12_HeadContentLengthNoBody(t *testing.T) {
465 h12Compare{
466 ReqFunc: (*Client).Head,
467 Handler: func(w ResponseWriter, r *Request) {
468 },
469 }.run(t)
470 }
471
472 func TestH12_HeadContentLengthSmallBody(t *testing.T) {
473 h12Compare{
474 ReqFunc: (*Client).Head,
475 Handler: func(w ResponseWriter, r *Request) {
476 io.WriteString(w, "small")
477 },
478 }.run(t)
479 }
480
481 func TestH12_HeadContentLengthLargeBody(t *testing.T) {
482 h12Compare{
483 ReqFunc: (*Client).Head,
484 Handler: func(w ResponseWriter, r *Request) {
485 chunk := strings.Repeat("x", 512<<10)
486 for i := 0; i < 10; i++ {
487 io.WriteString(w, chunk)
488 }
489 },
490 }.run(t)
491 }
492
493 func TestH12_200NoBody(t *testing.T) {
494 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
495 }
496
497 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
498 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
499 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
500
501 func testH12_noBody(t *testing.T, status int) {
502 h12Compare{Handler: func(w ResponseWriter, r *Request) {
503 w.WriteHeader(status)
504 }}.run(t)
505 }
506
507 func TestH12_SmallBody(t *testing.T) {
508 h12Compare{Handler: func(w ResponseWriter, r *Request) {
509 io.WriteString(w, "small body")
510 }}.run(t)
511 }
512
513 func TestH12_ExplicitContentLength(t *testing.T) {
514 h12Compare{Handler: func(w ResponseWriter, r *Request) {
515 w.Header().Set("Content-Length", "3")
516 io.WriteString(w, "foo")
517 }}.run(t)
518 }
519
520 func TestH12_FlushBeforeBody(t *testing.T) {
521 h12Compare{Handler: func(w ResponseWriter, r *Request) {
522 w.(Flusher).Flush()
523 io.WriteString(w, "foo")
524 }}.run(t)
525 }
526
527 func TestH12_FlushMidBody(t *testing.T) {
528 h12Compare{Handler: func(w ResponseWriter, r *Request) {
529 io.WriteString(w, "foo")
530 w.(Flusher).Flush()
531 io.WriteString(w, "bar")
532 }}.run(t)
533 }
534
535 func TestH12_Head_ExplicitLen(t *testing.T) {
536 h12Compare{
537 ReqFunc: (*Client).Head,
538 Handler: func(w ResponseWriter, r *Request) {
539 if r.Method != "HEAD" {
540 t.Errorf("unexpected method %q", r.Method)
541 }
542 w.Header().Set("Content-Length", "1235")
543 },
544 }.run(t)
545 }
546
547 func TestH12_Head_ImplicitLen(t *testing.T) {
548 h12Compare{
549 ReqFunc: (*Client).Head,
550 Handler: func(w ResponseWriter, r *Request) {
551 if r.Method != "HEAD" {
552 t.Errorf("unexpected method %q", r.Method)
553 }
554 io.WriteString(w, "foo")
555 },
556 }.run(t)
557 }
558
559 func TestH12_HandlerWritesTooLittle(t *testing.T) {
560 h12Compare{
561 Handler: func(w ResponseWriter, r *Request) {
562 w.Header().Set("Content-Length", "3")
563 io.WriteString(w, "12")
564 },
565 CheckResponse: func(proto string, res *Response) {
566 sr, ok := res.Body.(slurpResult)
567 if !ok {
568 t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
569 return
570 }
571 if sr.err != io.ErrUnexpectedEOF {
572 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
573 }
574 if string(sr.body) != "12" {
575 t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
576 }
577 },
578 }.run(t)
579 }
580
581
582
583
584
585
586
587 func TestHandlerWritesTooMuch(t *testing.T) { run(t, testHandlerWritesTooMuch) }
588 func testHandlerWritesTooMuch(t *testing.T, mode testMode) {
589 wantBody := []byte("123")
590 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
591 rc := NewResponseController(w)
592 w.Header().Set("Content-Length", fmt.Sprintf("%v", len(wantBody)))
593 rc.Flush()
594 w.Write(wantBody)
595 rc.Flush()
596 n, err := io.WriteString(w, "x")
597 if err == nil {
598 err = rc.Flush()
599 }
600
601 if err == nil {
602 t.Errorf("for proto %q, final write = %v, %v; want _, some error", r.Proto, n, err)
603 }
604 }))
605
606 res, err := cst.c.Get(cst.ts.URL)
607 if err != nil {
608 t.Fatal(err)
609 }
610 defer res.Body.Close()
611
612 gotBody, _ := io.ReadAll(res.Body)
613 if !bytes.Equal(gotBody, wantBody) {
614 t.Fatalf("got response body: %q; want %q", gotBody, wantBody)
615 }
616 }
617
618
619
620 func TestH12_AutoGzip(t *testing.T) {
621 h12Compare{
622 Handler: func(w ResponseWriter, r *Request) {
623 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
624 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
625 }
626 w.Header().Set("Content-Encoding", "gzip")
627 gz := gzip.NewWriter(w)
628 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
629 gz.Close()
630 },
631 }.run(t)
632 }
633
634 func TestH12_AutoGzip_Disabled(t *testing.T) {
635 h12Compare{
636 Opts: []any{
637 func(tr *Transport) { tr.DisableCompression = true },
638 },
639 Handler: func(w ResponseWriter, r *Request) {
640 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
641 if ae := r.Header.Get("Accept-Encoding"); ae != "" {
642 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
643 }
644 },
645 }.run(t)
646 }
647
648
649
650
651 func Test304Responses(t *testing.T) { run(t, test304Responses) }
652 func test304Responses(t *testing.T, mode testMode) {
653 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
654 w.WriteHeader(StatusNotModified)
655 _, err := w.Write([]byte("illegal body"))
656 if err != ErrBodyNotAllowed {
657 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
658 }
659 }))
660 defer cst.close()
661 res, err := cst.c.Get(cst.ts.URL)
662 if err != nil {
663 t.Fatal(err)
664 }
665 if len(res.TransferEncoding) > 0 {
666 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
667 }
668 body, err := io.ReadAll(res.Body)
669 if err != nil {
670 t.Error(err)
671 }
672 if len(body) > 0 {
673 t.Errorf("got unexpected body %q", string(body))
674 }
675 }
676
677 func TestH12_ServerEmptyContentLength(t *testing.T) {
678 h12Compare{
679 Handler: func(w ResponseWriter, r *Request) {
680 w.Header()["Content-Type"] = []string{""}
681 io.WriteString(w, "<html><body>hi</body></html>")
682 },
683 }.run(t)
684 }
685
686 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
687 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
688 }
689
690 func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
691 h12requestContentLength(t, func() io.Reader { return nil }, 0)
692 }
693
694 func TestH12_RequestContentLength_Unknown(t *testing.T) {
695 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
696 }
697
698 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
699 h12Compare{
700 Handler: func(w ResponseWriter, r *Request) {
701 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
702 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
703 },
704 ReqFunc: func(c *Client, url string) (*Response, error) {
705 return c.Post(url, "text/plain", bodyfn())
706 },
707 CheckResponse: func(proto string, res *Response) {
708 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
709 t.Errorf("Proto %q got length %q; want %q", proto, got, want)
710 }
711 },
712 }.run(t)
713 }
714
715
716
717 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
718 func testCancelRequestMidBody(t *testing.T, mode testMode) {
719 unblock := make(chan bool)
720 didFlush := make(chan bool, 1)
721 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
722 io.WriteString(w, "Hello")
723 w.(Flusher).Flush()
724 didFlush <- true
725 <-unblock
726 io.WriteString(w, ", world.")
727 }))
728 defer close(unblock)
729
730 req, _ := NewRequest("GET", cst.ts.URL, nil)
731 cancel := make(chan struct{})
732 req.Cancel = cancel
733
734 res, err := cst.c.Do(req)
735 if err != nil {
736 t.Fatal(err)
737 }
738 defer res.Body.Close()
739 <-didFlush
740
741
742
743 firstRead := make([]byte, 10)
744 n, err := res.Body.Read(firstRead)
745 if err != nil {
746 t.Fatal(err)
747 }
748 firstRead = firstRead[:n]
749
750 close(cancel)
751
752 rest, err := io.ReadAll(res.Body)
753 all := string(firstRead) + string(rest)
754 if all != "Hello" {
755 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
756 }
757 if err != ExportErrRequestCanceled {
758 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
759 }
760 }
761
762
763 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
764 func testTrailersClientToServer(t *testing.T, mode testMode) {
765 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
766 slurp, err := io.ReadAll(r.Body)
767 if err != nil {
768 t.Errorf("Server reading request body: %v", err)
769 }
770 if string(slurp) != "foo" {
771 t.Errorf("Server read request body %q; want foo", slurp)
772 }
773 if r.Trailer == nil {
774 io.WriteString(w, "nil Trailer")
775 } else {
776 decl := slices.Sorted(maps.Keys(r.Trailer))
777 fmt.Fprintf(w, "decl: %v, vals: %s, %s",
778 decl,
779 r.Trailer.Get("Client-Trailer-A"),
780 r.Trailer.Get("Client-Trailer-B"))
781 }
782 }))
783
784 var req *Request
785 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
786 eofReaderFunc(func() {
787 req.Trailer["Client-Trailer-A"] = []string{"valuea"}
788 }),
789 strings.NewReader("foo"),
790 eofReaderFunc(func() {
791 req.Trailer["Client-Trailer-B"] = []string{"valueb"}
792 }),
793 ))
794 req.Trailer = Header{
795 "Client-Trailer-A": nil,
796 "Client-Trailer-B": nil,
797 }
798 req.ContentLength = -1
799 res, err := cst.c.Do(req)
800 if err != nil {
801 t.Fatal(err)
802 }
803 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
804 t.Error(err)
805 }
806 }
807
808
809 func TestTrailersServerToClient(t *testing.T) {
810 run(t, func(t *testing.T, mode testMode) {
811 testTrailersServerToClient(t, mode, false)
812 })
813 }
814 func TestTrailersServerToClientFlush(t *testing.T) {
815 run(t, func(t *testing.T, mode testMode) {
816 testTrailersServerToClient(t, mode, true)
817 })
818 }
819
820 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
821 const body = "Some body"
822 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
823 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
824 w.Header().Add("Trailer", "Server-Trailer-C")
825
826 io.WriteString(w, body)
827 if flush {
828 w.(Flusher).Flush()
829 }
830
831
832
833
834
835 w.Header().Set("Server-Trailer-A", "valuea")
836 w.Header().Set("Server-Trailer-C", "valuec")
837 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
838 }))
839
840 res, err := cst.c.Get(cst.ts.URL)
841 if err != nil {
842 t.Fatal(err)
843 }
844
845 wantHeader := Header{
846 "Content-Type": {"text/plain; charset=utf-8"},
847 }
848 wantLen := -1
849 if mode == http2Mode && !flush {
850
851
852
853
854
855 wantLen = len(body)
856 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
857 }
858 if res.ContentLength != int64(wantLen) {
859 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
860 }
861
862 delete(res.Header, "Date")
863 if !reflect.DeepEqual(res.Header, wantHeader) {
864 t.Errorf("Header = %v; want %v", res.Header, wantHeader)
865 }
866
867 if got, want := res.Trailer, (Header{
868 "Server-Trailer-A": nil,
869 "Server-Trailer-B": nil,
870 "Server-Trailer-C": nil,
871 }); !reflect.DeepEqual(got, want) {
872 t.Errorf("Trailer before body read = %v; want %v", got, want)
873 }
874
875 if err := wantBody(res, nil, body); err != nil {
876 t.Fatal(err)
877 }
878
879 if got, want := res.Trailer, (Header{
880 "Server-Trailer-A": {"valuea"},
881 "Server-Trailer-B": nil,
882 "Server-Trailer-C": {"valuec"},
883 }); !reflect.DeepEqual(got, want) {
884 t.Errorf("Trailer after body read = %v; want %v", got, want)
885 }
886 }
887
888
889 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
890 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
891 const body = "Some body"
892 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
893 io.WriteString(w, body)
894 }))
895 res, err := cst.c.Get(cst.ts.URL)
896 if err != nil {
897 t.Fatal(err)
898 }
899 res.Body.Close()
900 data, err := io.ReadAll(res.Body)
901 if len(data) != 0 || err == nil {
902 t.Fatalf("ReadAll returned %q, %v; want error", data, err)
903 }
904 }
905
906 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
907 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
908 const reqBody = "some request body"
909 const resBody = "some response body"
910 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
911 var wg sync.WaitGroup
912 wg.Add(2)
913 didRead := make(chan bool, 1)
914
915 go func() {
916 defer wg.Done()
917 data, err := io.ReadAll(r.Body)
918 if string(data) != reqBody {
919 t.Errorf("Handler read %q; want %q", data, reqBody)
920 }
921 if err != nil {
922 t.Errorf("Handler Read: %v", err)
923 }
924 didRead <- true
925 }()
926
927 go func() {
928 defer wg.Done()
929 if mode != http2Mode {
930
931
932
933
934 <-didRead
935 }
936 io.WriteString(w, resBody)
937 }()
938 wg.Wait()
939 }))
940 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
941 req.Header.Add("Expect", "100-continue")
942 res, err := cst.c.Do(req)
943 if err != nil {
944 t.Fatal(err)
945 }
946 data, err := io.ReadAll(res.Body)
947 defer res.Body.Close()
948 if err != nil {
949 t.Fatal(err)
950 }
951 if string(data) != resBody {
952 t.Errorf("read %q; want %q", data, resBody)
953 }
954 }
955
956 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
957 func testConnectRequest(t *testing.T, mode testMode) {
958 gotc := make(chan *Request, 1)
959 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
960 gotc <- r
961 }))
962
963 u, err := url.Parse(cst.ts.URL)
964 if err != nil {
965 t.Fatal(err)
966 }
967
968 tests := []struct {
969 req *Request
970 want string
971 }{
972 {
973 req: &Request{
974 Method: "CONNECT",
975 Header: Header{},
976 URL: u,
977 },
978 want: u.Host,
979 },
980 {
981 req: &Request{
982 Method: "CONNECT",
983 Header: Header{},
984 URL: u,
985 Host: "example.com:123",
986 },
987 want: "example.com:123",
988 },
989 }
990
991 for i, tt := range tests {
992 res, err := cst.c.Do(tt.req)
993 if err != nil {
994 t.Errorf("%d. RoundTrip = %v", i, err)
995 continue
996 }
997 res.Body.Close()
998 req := <-gotc
999 if req.Method != "CONNECT" {
1000 t.Errorf("method = %q; want CONNECT", req.Method)
1001 }
1002 if req.Host != tt.want {
1003 t.Errorf("Host = %q; want %q", req.Host, tt.want)
1004 }
1005 if req.URL.Host != tt.want {
1006 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
1007 }
1008 }
1009 }
1010
1011 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
1012 func testTransportUserAgent(t *testing.T, mode testMode) {
1013 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1014 fmt.Fprintf(w, "%q", r.Header["User-Agent"])
1015 }))
1016
1017 either := func(a, b string) string {
1018 if mode == http2Mode {
1019 return b
1020 }
1021 return a
1022 }
1023
1024 tests := []struct {
1025 setup func(*Request)
1026 want string
1027 }{
1028 {
1029 func(r *Request) {},
1030 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
1031 },
1032 {
1033 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
1034 `["foo/1.2.3"]`,
1035 },
1036 {
1037 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
1038 `["single"]`,
1039 },
1040 {
1041 func(r *Request) { r.Header.Set("User-Agent", "") },
1042 `[]`,
1043 },
1044 {
1045 func(r *Request) { r.Header["User-Agent"] = nil },
1046 `[]`,
1047 },
1048 }
1049 for i, tt := range tests {
1050 req, _ := NewRequest("GET", cst.ts.URL, nil)
1051 tt.setup(req)
1052 res, err := cst.c.Do(req)
1053 if err != nil {
1054 t.Errorf("%d. RoundTrip = %v", i, err)
1055 continue
1056 }
1057 slurp, err := io.ReadAll(res.Body)
1058 res.Body.Close()
1059 if err != nil {
1060 t.Errorf("%d. read body = %v", i, err)
1061 continue
1062 }
1063 if string(slurp) != tt.want {
1064 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
1065 }
1066 }
1067 }
1068
1069 func TestStarRequestMethod(t *testing.T) {
1070 for _, method := range []string{"FOO", "OPTIONS"} {
1071 t.Run(method, func(t *testing.T) {
1072 run(t, func(t *testing.T, mode testMode) {
1073 testStarRequest(t, method, mode)
1074 })
1075 })
1076 }
1077 }
1078 func testStarRequest(t *testing.T, method string, mode testMode) {
1079 gotc := make(chan *Request, 1)
1080 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1081 w.Header().Set("foo", "bar")
1082 gotc <- r
1083 w.(Flusher).Flush()
1084 }))
1085
1086 u, err := url.Parse(cst.ts.URL)
1087 if err != nil {
1088 t.Fatal(err)
1089 }
1090 u.Path = "*"
1091
1092 req := &Request{
1093 Method: method,
1094 Header: Header{},
1095 URL: u,
1096 }
1097
1098 res, err := cst.c.Do(req)
1099 if err != nil {
1100 t.Fatalf("RoundTrip = %v", err)
1101 }
1102 res.Body.Close()
1103
1104 wantFoo := "bar"
1105 wantLen := int64(-1)
1106 if method == "OPTIONS" {
1107 wantFoo = ""
1108 wantLen = 0
1109 }
1110 if res.StatusCode != 200 {
1111 t.Errorf("status code = %v; want %d", res.Status, 200)
1112 }
1113 if res.ContentLength != wantLen {
1114 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
1115 }
1116 if got := res.Header.Get("foo"); got != wantFoo {
1117 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
1118 }
1119 select {
1120 case req = <-gotc:
1121 default:
1122 req = nil
1123 }
1124 if req == nil {
1125 if method != "OPTIONS" {
1126 t.Fatalf("handler never got request")
1127 }
1128 return
1129 }
1130 if req.Method != method {
1131 t.Errorf("method = %q; want %q", req.Method, method)
1132 }
1133 if req.URL.Path != "*" {
1134 t.Errorf("URL.Path = %q; want *", req.URL.Path)
1135 }
1136 if req.RequestURI != "*" {
1137 t.Errorf("RequestURI = %q; want *", req.RequestURI)
1138 }
1139 }
1140
1141
1142 func TestTransportDiscardsUnneededConns(t *testing.T) {
1143 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
1144 }
1145 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
1146 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1147 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
1148 }))
1149 defer cst.close()
1150
1151 var numOpen, numClose int32
1152
1153 tlsConfig := &tls.Config{InsecureSkipVerify: true}
1154 tr := &Transport{
1155 TLSClientConfig: tlsConfig,
1156 DialTLS: func(_, addr string) (net.Conn, error) {
1157 time.Sleep(10 * time.Millisecond)
1158 rc, err := net.Dial("tcp", addr)
1159 if err != nil {
1160 return nil, err
1161 }
1162 atomic.AddInt32(&numOpen, 1)
1163 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
1164 return tls.Client(c, tlsConfig), nil
1165 },
1166 }
1167 if err := ExportHttp2ConfigureTransport(tr); err != nil {
1168 t.Fatal(err)
1169 }
1170 defer tr.CloseIdleConnections()
1171
1172 c := &Client{Transport: tr}
1173
1174 const N = 10
1175 gotBody := make(chan string, N)
1176 var wg sync.WaitGroup
1177 for i := 0; i < N; i++ {
1178 wg.Add(1)
1179 go func() {
1180 defer wg.Done()
1181 resp, err := c.Get(cst.ts.URL)
1182 if err != nil {
1183
1184
1185 time.Sleep(10 * time.Millisecond)
1186 resp, err = c.Get(cst.ts.URL)
1187 if err != nil {
1188 t.Errorf("Get: %v", err)
1189 return
1190 }
1191 }
1192 defer resp.Body.Close()
1193 slurp, err := io.ReadAll(resp.Body)
1194 if err != nil {
1195 t.Error(err)
1196 }
1197 gotBody <- string(slurp)
1198 }()
1199 }
1200 wg.Wait()
1201 close(gotBody)
1202
1203 var last string
1204 for got := range gotBody {
1205 if last == "" {
1206 last = got
1207 continue
1208 }
1209 if got != last {
1210 t.Errorf("Response body changed: %q -> %q", last, got)
1211 }
1212 }
1213
1214 var open, close int32
1215 for i := 0; i < 150; i++ {
1216 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
1217 if open < 1 {
1218 t.Fatalf("open = %d; want at least", open)
1219 }
1220 if close == open-1 {
1221
1222 return
1223 }
1224 time.Sleep(10 * time.Millisecond)
1225 }
1226 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
1227 }
1228
1229
1230 func TestTransportGCRequest(t *testing.T) {
1231 run(t, func(t *testing.T, mode testMode) {
1232 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
1233 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
1234 })
1235 }
1236 func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
1237 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1238 io.ReadAll(r.Body)
1239 if body {
1240 io.WriteString(w, "Hello.")
1241 }
1242 }))
1243
1244 didGC := make(chan struct{})
1245 (func() {
1246 body := strings.NewReader("some body")
1247 req, _ := NewRequest("POST", cst.ts.URL, body)
1248 runtime.AddCleanup(req, func(ch chan struct{}) { close(ch) }, didGC)
1249 res, err := cst.c.Do(req)
1250 if err != nil {
1251 t.Fatal(err)
1252 }
1253 if _, err := io.ReadAll(res.Body); err != nil {
1254 t.Fatal(err)
1255 }
1256 if err := res.Body.Close(); err != nil {
1257 t.Fatal(err)
1258 }
1259 })()
1260 for {
1261 select {
1262 case <-didGC:
1263 return
1264 case <-time.After(1 * time.Millisecond):
1265 runtime.GC()
1266 }
1267 }
1268 }
1269
1270 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
1271 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
1272 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1273 fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
1274 }), optQuietLog)
1275 cst.tr.DisableKeepAlives = true
1276
1277 tests := []struct {
1278 key, val string
1279 ok bool
1280 }{
1281 {"Foo", "capital-key", true},
1282 {"Foo", "foo\x00bar", false},
1283 {"Foo", "two\nlines", false},
1284 {"bogus\nkey", "v", false},
1285 {"A space", "v", false},
1286 {"имя", "v", false},
1287 {"name", "валю", true},
1288 {"", "v", false},
1289 {"k", "", true},
1290 }
1291 for _, tt := range tests {
1292 dialedc := make(chan bool, 1)
1293 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
1294 dialedc <- true
1295 return net.Dial(netw, addr)
1296 }
1297 req, _ := NewRequest("GET", cst.ts.URL, nil)
1298 req.Header[tt.key] = []string{tt.val}
1299 res, err := cst.c.Do(req)
1300 var body []byte
1301 if err == nil {
1302 body, _ = io.ReadAll(res.Body)
1303 res.Body.Close()
1304 }
1305 var dialed bool
1306 select {
1307 case <-dialedc:
1308 dialed = true
1309 default:
1310 }
1311
1312 if !tt.ok && dialed {
1313 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
1314 } else if (err == nil) != tt.ok {
1315 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
1316 }
1317 }
1318 }
1319
1320 func TestInterruptWithPanic(t *testing.T) {
1321 run(t, func(t *testing.T, mode testMode) {
1322 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
1323 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
1324 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
1325 }, testNotParallel)
1326 }
1327 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
1328 const msg = "hello"
1329
1330 testDone := make(chan struct{})
1331 defer close(testDone)
1332
1333 var errorLog lockedBytesBuffer
1334 gotHeaders := make(chan bool, 1)
1335 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1336 io.WriteString(w, msg)
1337 w.(Flusher).Flush()
1338
1339 select {
1340 case <-gotHeaders:
1341 case <-testDone:
1342 }
1343 panic(panicValue)
1344 }), func(ts *httptest.Server) {
1345 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1346 })
1347 res, err := cst.c.Get(cst.ts.URL)
1348 if err != nil {
1349 t.Fatal(err)
1350 }
1351 gotHeaders <- true
1352 defer res.Body.Close()
1353 slurp, err := io.ReadAll(res.Body)
1354 if string(slurp) != msg {
1355 t.Errorf("client read %q; want %q", slurp, msg)
1356 }
1357 if err == nil {
1358 t.Errorf("client read all successfully; want some error")
1359 }
1360 logOutput := func() string {
1361 errorLog.Lock()
1362 defer errorLog.Unlock()
1363 return errorLog.String()
1364 }
1365 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
1366
1367 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
1368 gotLog := logOutput()
1369 if !wantStackLogged {
1370 if gotLog == "" {
1371 return true
1372 }
1373 t.Fatalf("want no log output; got: %s", gotLog)
1374 }
1375 if gotLog == "" {
1376 if d > 0 {
1377 t.Logf("wanted a stack trace logged; got nothing after %v", d)
1378 }
1379 return false
1380 }
1381 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
1382 if d > 0 {
1383 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
1384 }
1385 return false
1386 }
1387 return true
1388 })
1389 }
1390
1391 type lockedBytesBuffer struct {
1392 sync.Mutex
1393 bytes.Buffer
1394 }
1395
1396 func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
1397 b.Lock()
1398 defer b.Unlock()
1399 return b.Buffer.Write(p)
1400 }
1401
1402
1403 func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
1404 h12Compare{
1405 Handler: func(w ResponseWriter, r *Request) {
1406 h := w.Header()
1407 h.Set("Content-Encoding", "gzip")
1408 h.Set("Content-Length", "23")
1409 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
1410 },
1411 EarlyCheckResponse: func(proto string, res *Response) {
1412 if !res.Uncompressed {
1413 t.Errorf("%s: expected Uncompressed to be set", proto)
1414 }
1415 dump, err := httputil.DumpResponse(res, true)
1416 if err != nil {
1417 t.Errorf("%s: DumpResponse: %v", proto, err)
1418 return
1419 }
1420 if strings.Contains(string(dump), "Connection: close") {
1421 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
1422 }
1423 if !strings.Contains(string(dump), "FOO") {
1424 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
1425 }
1426 },
1427 }.run(t)
1428 }
1429
1430
1431 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
1432 func testCloseIdleConnections(t *testing.T, mode testMode) {
1433 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1434 w.Header().Set("X-Addr", r.RemoteAddr)
1435 }))
1436 get := func() string {
1437 res, err := cst.c.Get(cst.ts.URL)
1438 if err != nil {
1439 t.Fatal(err)
1440 }
1441 res.Body.Close()
1442 v := res.Header.Get("X-Addr")
1443 if v == "" {
1444 t.Fatal("didn't get X-Addr")
1445 }
1446 return v
1447 }
1448 a1 := get()
1449 cst.tr.CloseIdleConnections()
1450 a2 := get()
1451 if a1 == a2 {
1452 t.Errorf("didn't close connection")
1453 }
1454 }
1455
1456 type noteCloseConn struct {
1457 net.Conn
1458 closeFunc func()
1459 }
1460
1461 func (x noteCloseConn) Close() error {
1462 x.closeFunc()
1463 return x.Conn.Close()
1464 }
1465
1466 type testErrorReader struct{ t *testing.T }
1467
1468 func (r testErrorReader) Read(p []byte) (n int, err error) {
1469 r.t.Error("unexpected Read call")
1470 return 0, io.EOF
1471 }
1472
1473 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
1474 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
1475 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1476 w.WriteHeader(StatusUnauthorized)
1477 }))
1478
1479
1480 cst.tr.ExpectContinueTimeout = 10 * time.Second
1481
1482 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
1483 if err != nil {
1484 t.Fatal(err)
1485 }
1486 req.ContentLength = 0
1487 req.Header.Set("Expect", "100-continue")
1488 res, err := cst.tr.RoundTrip(req)
1489 if err != nil {
1490 t.Fatal(err)
1491 }
1492 defer res.Body.Close()
1493 if res.StatusCode != StatusUnauthorized {
1494 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
1495 }
1496 }
1497
1498 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
1499 func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
1500 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1501 w.Header().Set("Foo", "Bar")
1502 w.Header().Set("Trailer:Foo", "Baz")
1503 w.(Flusher).Flush()
1504 w.Header().Add("Trailer:Foo", "Baz2")
1505 w.Header().Set("Trailer:Bar", "Quux")
1506 }))
1507 res, err := cst.c.Get(cst.ts.URL)
1508 if err != nil {
1509 t.Fatal(err)
1510 }
1511 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1512 t.Fatal(err)
1513 }
1514 res.Body.Close()
1515 delete(res.Header, "Date")
1516 delete(res.Header, "Content-Type")
1517
1518 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
1519 t.Errorf("Header = %#v; want %#v", res.Header, want)
1520 }
1521 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
1522 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
1523 }
1524 }
1525
1526 func TestBadResponseAfterReadingBody(t *testing.T) {
1527 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
1528 }
1529 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
1530 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1531 _, err := io.Copy(io.Discard, r.Body)
1532 if err != nil {
1533 t.Fatal(err)
1534 }
1535 c, _, err := w.(Hijacker).Hijack()
1536 if err != nil {
1537 t.Fatal(err)
1538 }
1539 defer c.Close()
1540 fmt.Fprintln(c, "some bogus crap")
1541 }))
1542
1543 closes := 0
1544 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
1545 if err == nil {
1546 res.Body.Close()
1547 t.Fatal("expected an error to be returned from Post")
1548 }
1549 if closes != 1 {
1550 t.Errorf("closes = %d; want 1", closes)
1551 }
1552 }
1553
1554 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
1555 func testWriteHeader0(t *testing.T, mode testMode) {
1556 gotpanic := make(chan bool, 1)
1557 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1558 defer close(gotpanic)
1559 defer func() {
1560 if e := recover(); e != nil {
1561 got := fmt.Sprintf("%T, %v", e, e)
1562 want := "string, invalid WriteHeader code 0"
1563 if got != want {
1564 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
1565 }
1566 gotpanic <- true
1567
1568
1569
1570
1571 w.WriteHeader(503)
1572 }
1573 }()
1574 w.WriteHeader(0)
1575 }))
1576 res, err := cst.c.Get(cst.ts.URL)
1577 if err != nil {
1578 t.Fatal(err)
1579 }
1580 if res.StatusCode != 503 {
1581 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
1582 }
1583 if !<-gotpanic {
1584 t.Error("expected panic in handler")
1585 }
1586 }
1587
1588
1589
1590 func TestWriteHeaderNoCodeCheck(t *testing.T) {
1591 run(t, func(t *testing.T, mode testMode) {
1592 testWriteHeaderAfterWrite(t, mode, false)
1593 })
1594 }
1595 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
1596 testWriteHeaderAfterWrite(t, http1Mode, true)
1597 }
1598 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
1599 var errorLog lockedBytesBuffer
1600 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1601 if hijack {
1602 conn, _, _ := w.(Hijacker).Hijack()
1603 defer conn.Close()
1604 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
1605 w.WriteHeader(0)
1606 conn.Write([]byte("bar"))
1607 return
1608 }
1609 io.WriteString(w, "foo")
1610 w.(Flusher).Flush()
1611 w.WriteHeader(0)
1612 io.WriteString(w, "bar")
1613 }), func(ts *httptest.Server) {
1614 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1615 })
1616 res, err := cst.c.Get(cst.ts.URL)
1617 if err != nil {
1618 t.Fatal(err)
1619 }
1620 defer res.Body.Close()
1621 body, err := io.ReadAll(res.Body)
1622 if err != nil {
1623 t.Fatal(err)
1624 }
1625 if got, want := string(body), "foobar"; got != want {
1626 t.Errorf("got = %q; want %q", got, want)
1627 }
1628
1629
1630 if mode == http2Mode {
1631
1632
1633 return
1634 }
1635 gotLog := strings.TrimSpace(errorLog.String())
1636 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1637 if hijack {
1638 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1639 }
1640 if !strings.HasPrefix(gotLog, wantLog) {
1641 t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
1642 }
1643 }
1644
1645 func TestBidiStreamReverseProxy(t *testing.T) {
1646 run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
1647 }
1648 func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
1649 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1650 if _, err := io.Copy(w, r.Body); err != nil {
1651 log.Printf("bidi backend copy: %v", err)
1652 }
1653 }))
1654
1655 backURL, err := url.Parse(backend.ts.URL)
1656 if err != nil {
1657 t.Fatal(err)
1658 }
1659 rp := httputil.NewSingleHostReverseProxy(backURL)
1660 rp.Transport = backend.tr
1661 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1662 rp.ServeHTTP(w, r)
1663 }))
1664
1665 bodyRes := make(chan any, 1)
1666 pr, pw := io.Pipe()
1667 req, _ := NewRequest("PUT", proxy.ts.URL, pr)
1668 const size = 4 << 20
1669 go func() {
1670 h := sha1.New()
1671 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
1672 go pw.Close()
1673 if err != nil {
1674 t.Errorf("body copy: %v", err)
1675 bodyRes <- err
1676 } else {
1677 bodyRes <- h
1678 }
1679 }()
1680 res, err := backend.c.Do(req)
1681 if err != nil {
1682 t.Fatal(err)
1683 }
1684 defer res.Body.Close()
1685 hgot := sha1.New()
1686 n, err := io.Copy(hgot, res.Body)
1687 if err != nil {
1688 t.Fatal(err)
1689 }
1690 if n != size {
1691 t.Fatalf("got %d bytes; want %d", n, size)
1692 }
1693 select {
1694 case v := <-bodyRes:
1695 switch v := v.(type) {
1696 default:
1697 t.Fatalf("body copy: %v", err)
1698 case hash.Hash:
1699 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
1700 t.Errorf("written bytes didn't match received bytes")
1701 }
1702 }
1703 case <-time.After(10 * time.Second):
1704 t.Fatal("timeout")
1705 }
1706
1707 }
1708
1709
1710 func TestH12_WebSocketUpgrade(t *testing.T) {
1711 h12Compare{
1712 Handler: func(w ResponseWriter, r *Request) {
1713 h := w.Header()
1714 h.Set("Foo", "bar")
1715 },
1716 ReqFunc: func(c *Client, url string) (*Response, error) {
1717 req, _ := NewRequest("GET", url, nil)
1718 req.Header.Set("Connection", "Upgrade")
1719 req.Header.Set("Upgrade", "WebSocket")
1720 return c.Do(req)
1721 },
1722 EarlyCheckResponse: func(proto string, res *Response) {
1723 if res.Proto != "HTTP/1.1" {
1724 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
1725 }
1726 res.Proto = "HTTP/IGNORE"
1727 },
1728 }.run(t)
1729 }
1730
1731 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
1732 func testIdentityTransferEncoding(t *testing.T, mode testMode) {
1733 const body = "body"
1734 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1735 gotBody, _ := io.ReadAll(r.Body)
1736 if got, want := string(gotBody), body; got != want {
1737 t.Errorf("got request body = %q; want %q", got, want)
1738 }
1739 w.Header().Set("Transfer-Encoding", "identity")
1740 w.WriteHeader(StatusOK)
1741 w.(Flusher).Flush()
1742 io.WriteString(w, body)
1743 }))
1744 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
1745 res, err := cst.c.Do(req)
1746 if err != nil {
1747 t.Fatal(err)
1748 }
1749 defer res.Body.Close()
1750 gotBody, err := io.ReadAll(res.Body)
1751 if err != nil {
1752 t.Fatal(err)
1753 }
1754 if got, want := string(gotBody), body; got != want {
1755 t.Errorf("got response body = %q; want %q", got, want)
1756 }
1757 }
1758
1759 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
1760 func testEarlyHintsRequest(t *testing.T, mode testMode) {
1761 var wg sync.WaitGroup
1762 wg.Add(1)
1763 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1764 h := w.Header()
1765
1766 h.Add("Content-Length", "123")
1767 h.Add("Link", "</style.css>; rel=preload; as=style")
1768 h.Add("Link", "</script.js>; rel=preload; as=script")
1769 w.WriteHeader(StatusEarlyHints)
1770
1771 wg.Wait()
1772
1773 h.Add("Link", "</foo.js>; rel=preload; as=script")
1774 w.WriteHeader(StatusEarlyHints)
1775
1776 w.Write([]byte("Hello"))
1777 }))
1778
1779 checkLinkHeaders := func(t *testing.T, expected, got []string) {
1780 t.Helper()
1781
1782 if len(expected) != len(got) {
1783 t.Errorf("got %d expected %d", len(got), len(expected))
1784 }
1785
1786 for i := range expected {
1787 if expected[i] != got[i] {
1788 t.Errorf("got %q expected %q", got[i], expected[i])
1789 }
1790 }
1791 }
1792
1793 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
1794 t.Helper()
1795
1796 for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
1797 if v, ok := header[h]; ok {
1798 t.Errorf("%s is %q; must not be sent", h, v)
1799 }
1800 }
1801 }
1802
1803 var respCounter uint8
1804 trace := &httptrace.ClientTrace{
1805 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1806 switch respCounter {
1807 case 0:
1808 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1809 checkExcludedHeaders(t, header)
1810
1811 wg.Done()
1812 case 1:
1813 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1814 checkExcludedHeaders(t, header)
1815
1816 default:
1817 t.Error("Unexpected 1xx response")
1818 }
1819
1820 respCounter++
1821
1822 return nil
1823 },
1824 }
1825 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
1826
1827 res, err := cst.c.Do(req)
1828 if err != nil {
1829 t.Fatal(err)
1830 }
1831 defer res.Body.Close()
1832
1833 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1834 if cl := res.Header.Get("Content-Length"); cl != "123" {
1835 t.Errorf("Content-Length is %q; want 123", cl)
1836 }
1837
1838 body, _ := io.ReadAll(res.Body)
1839 if string(body) != "Hello" {
1840 t.Errorf("Read body %q; want Hello", body)
1841 }
1842 }
1843
View as plain text