Source file
src/net/http/clientserver_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bytes"
11 "compress/gzip"
12 "context"
13 "crypto/rand"
14 "crypto/sha1"
15 "crypto/tls"
16 "fmt"
17 "hash"
18 "io"
19 "log"
20 "maps"
21 "net"
22 . "net/http"
23 "net/http/httptest"
24 "net/http/httptrace"
25 "net/http/httputil"
26 "net/textproto"
27 "net/url"
28 "os"
29 "reflect"
30 "runtime"
31 "slices"
32 "strings"
33 "sync"
34 "sync/atomic"
35 "testing"
36 "time"
37 )
38
39 type testMode string
40
41 const (
42 http1Mode = testMode("h1")
43 https1Mode = testMode("https1")
44 http2Mode = testMode("h2")
45 )
46
47 type testNotParallelOpt struct{}
48
49 var (
50 testNotParallel = testNotParallelOpt{}
51 )
52
53 type TBRun[T any] interface {
54 testing.TB
55 Run(string, func(T)) bool
56 }
57
58
59
60
61
62
63
64
65 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
66 t.Helper()
67 modes := []testMode{http1Mode, http2Mode}
68 parallel := true
69 for _, opt := range opts {
70 switch opt := opt.(type) {
71 case []testMode:
72 modes = opt
73 case testNotParallelOpt:
74 parallel = false
75 default:
76 t.Fatalf("unknown option type %T", opt)
77 }
78 }
79 if t, ok := any(t).(*testing.T); ok && parallel {
80 setParallel(t)
81 }
82 for _, mode := range modes {
83 t.Run(string(mode), func(t T) {
84 t.Helper()
85 if t, ok := any(t).(*testing.T); ok && parallel {
86 setParallel(t)
87 }
88 t.Cleanup(func() {
89 afterTest(t)
90 })
91 f(t, mode)
92 })
93 }
94 }
95
96 type clientServerTest struct {
97 t testing.TB
98 h2 bool
99 h Handler
100 ts *httptest.Server
101 tr *Transport
102 c *Client
103 }
104
105 func (t *clientServerTest) close() {
106 t.tr.CloseIdleConnections()
107 t.ts.Close()
108 }
109
110 func (t *clientServerTest) getURL(u string) string {
111 res, err := t.c.Get(u)
112 if err != nil {
113 t.t.Fatal(err)
114 }
115 defer res.Body.Close()
116 slurp, err := io.ReadAll(res.Body)
117 if err != nil {
118 t.t.Fatal(err)
119 }
120 return string(slurp)
121 }
122
123 func (t *clientServerTest) scheme() string {
124 if t.h2 {
125 return "https"
126 }
127 return "http"
128 }
129
130 var optQuietLog = func(ts *httptest.Server) {
131 ts.Config.ErrorLog = quietLog
132 }
133
134 func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
135 return func(ts *httptest.Server) {
136 ts.Config.ErrorLog = lg
137 }
138 }
139
140
141
142
143
144
145
146
147
148
149
150
151 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
152 if mode == http2Mode {
153 CondSkipHTTP2(t)
154 }
155 cst := &clientServerTest{
156 t: t,
157 h2: mode == http2Mode,
158 h: h,
159 }
160 cst.ts = httptest.NewUnstartedServer(h)
161
162 var transportFuncs []func(*Transport)
163 for _, opt := range opts {
164 switch opt := opt.(type) {
165 case func(*Transport):
166 transportFuncs = append(transportFuncs, opt)
167 case func(*httptest.Server):
168 opt(cst.ts)
169 default:
170 t.Fatalf("unhandled option type %T", opt)
171 }
172 }
173
174 if cst.ts.Config.ErrorLog == nil {
175 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
176 }
177
178 switch mode {
179 case http1Mode:
180 cst.ts.Start()
181 case https1Mode:
182 cst.ts.StartTLS()
183 case http2Mode:
184 ExportHttp2ConfigureServer(cst.ts.Config, nil)
185 cst.ts.TLS = cst.ts.Config.TLSConfig
186 cst.ts.StartTLS()
187 default:
188 t.Fatalf("unknown test mode %v", mode)
189 }
190 cst.c = cst.ts.Client()
191 cst.tr = cst.c.Transport.(*Transport)
192 if mode == http2Mode {
193 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
194 t.Fatal(err)
195 }
196 }
197 for _, f := range transportFuncs {
198 f(cst.tr)
199 }
200 t.Cleanup(func() {
201 cst.close()
202 })
203 return cst
204 }
205
206 type testLogWriter struct {
207 t testing.TB
208 }
209
210 func (w testLogWriter) Write(b []byte) (int, error) {
211 w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
212 return len(b), nil
213 }
214
215
216 func TestNewClientServerTest(t *testing.T) {
217 run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode})
218 }
219 func testNewClientServerTest(t *testing.T, mode testMode) {
220 var got struct {
221 sync.Mutex
222 proto string
223 hasTLS bool
224 }
225 h := HandlerFunc(func(w ResponseWriter, r *Request) {
226 got.Lock()
227 defer got.Unlock()
228 got.proto = r.Proto
229 got.hasTLS = r.TLS != nil
230 })
231 cst := newClientServerTest(t, mode, h)
232 if _, err := cst.c.Head(cst.ts.URL); err != nil {
233 t.Fatal(err)
234 }
235 var wantProto string
236 var wantTLS bool
237 switch mode {
238 case http1Mode:
239 wantProto = "HTTP/1.1"
240 wantTLS = false
241 case https1Mode:
242 wantProto = "HTTP/1.1"
243 wantTLS = true
244 case http2Mode:
245 wantProto = "HTTP/2.0"
246 wantTLS = true
247 }
248 if got.proto != wantProto {
249 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
250 }
251 if got.hasTLS != wantTLS {
252 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
253 }
254 }
255
256 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
257 func testChunkedResponseHeaders(t *testing.T, mode testMode) {
258 log.SetOutput(io.Discard)
259 defer log.SetOutput(os.Stderr)
260 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
261 w.Header().Set("Content-Length", "intentional gibberish")
262 w.(Flusher).Flush()
263 fmt.Fprintf(w, "I am a chunked response.")
264 }))
265
266 res, err := cst.c.Get(cst.ts.URL)
267 if err != nil {
268 t.Fatalf("Get error: %v", err)
269 }
270 defer res.Body.Close()
271 if g, e := res.ContentLength, int64(-1); g != e {
272 t.Errorf("expected ContentLength of %d; got %d", e, g)
273 }
274 wantTE := []string{"chunked"}
275 if mode == http2Mode {
276 wantTE = nil
277 }
278 if !slices.Equal(res.TransferEncoding, wantTE) {
279 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
280 }
281 if got, haveCL := res.Header["Content-Length"]; haveCL {
282 t.Errorf("Unexpected Content-Length: %q", got)
283 }
284 }
285
286 type reqFunc func(c *Client, url string) (*Response, error)
287
288
289
290 type h12Compare struct {
291 Handler func(ResponseWriter, *Request)
292 ReqFunc reqFunc
293 CheckResponse func(proto string, res *Response)
294 EarlyCheckResponse func(proto string, res *Response)
295 Opts []any
296 }
297
298 func (tt h12Compare) reqFunc() reqFunc {
299 if tt.ReqFunc == nil {
300 return (*Client).Get
301 }
302 return tt.ReqFunc
303 }
304
305 func (tt h12Compare) run(t *testing.T) {
306 setParallel(t)
307 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
308 defer cst1.close()
309 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
310 defer cst2.close()
311
312 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
313 if err != nil {
314 t.Errorf("HTTP/1 request: %v", err)
315 return
316 }
317 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
318 if err != nil {
319 t.Errorf("HTTP/2 request: %v", err)
320 return
321 }
322
323 if fn := tt.EarlyCheckResponse; fn != nil {
324 fn("HTTP/1.1", res1)
325 fn("HTTP/2.0", res2)
326 }
327
328 tt.normalizeRes(t, res1, "HTTP/1.1")
329 tt.normalizeRes(t, res2, "HTTP/2.0")
330 res1body, res2body := res1.Body, res2.Body
331
332 eres1 := mostlyCopy(res1)
333 eres2 := mostlyCopy(res2)
334 if !reflect.DeepEqual(eres1, eres2) {
335 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
336 cst1.ts.URL, eres1, cst2.ts.URL, eres2)
337 }
338 if !reflect.DeepEqual(res1body, res2body) {
339 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
340 }
341 if fn := tt.CheckResponse; fn != nil {
342 res1.Body, res2.Body = res1body, res2body
343 fn("HTTP/1.1", res1)
344 fn("HTTP/2.0", res2)
345 }
346 }
347
348 func mostlyCopy(r *Response) *Response {
349 c := *r
350 c.Body = nil
351 c.TransferEncoding = nil
352 c.TLS = nil
353 c.Request = nil
354 return &c
355 }
356
357 type slurpResult struct {
358 io.ReadCloser
359 body []byte
360 err error
361 }
362
363 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
364
365 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
366 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
367 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
368 } else {
369 t.Errorf("got %q response; want %q", res.Proto, wantProto)
370 }
371 slurp, err := io.ReadAll(res.Body)
372
373 res.Body.Close()
374 res.Body = slurpResult{
375 ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
376 body: slurp,
377 err: err,
378 }
379 for i, v := range res.Header["Date"] {
380 res.Header["Date"][i] = strings.Repeat("x", len(v))
381 }
382 if res.Request == nil {
383 t.Errorf("for %s, no request", wantProto)
384 }
385 if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
386 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
387 }
388 }
389
390
391 func TestH12_HeadContentLengthNoBody(t *testing.T) {
392 h12Compare{
393 ReqFunc: (*Client).Head,
394 Handler: func(w ResponseWriter, r *Request) {
395 },
396 }.run(t)
397 }
398
399 func TestH12_HeadContentLengthSmallBody(t *testing.T) {
400 h12Compare{
401 ReqFunc: (*Client).Head,
402 Handler: func(w ResponseWriter, r *Request) {
403 io.WriteString(w, "small")
404 },
405 }.run(t)
406 }
407
408 func TestH12_HeadContentLengthLargeBody(t *testing.T) {
409 h12Compare{
410 ReqFunc: (*Client).Head,
411 Handler: func(w ResponseWriter, r *Request) {
412 chunk := strings.Repeat("x", 512<<10)
413 for i := 0; i < 10; i++ {
414 io.WriteString(w, chunk)
415 }
416 },
417 }.run(t)
418 }
419
420 func TestH12_200NoBody(t *testing.T) {
421 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
422 }
423
424 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
425 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
426 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
427
428 func testH12_noBody(t *testing.T, status int) {
429 h12Compare{Handler: func(w ResponseWriter, r *Request) {
430 w.WriteHeader(status)
431 }}.run(t)
432 }
433
434 func TestH12_SmallBody(t *testing.T) {
435 h12Compare{Handler: func(w ResponseWriter, r *Request) {
436 io.WriteString(w, "small body")
437 }}.run(t)
438 }
439
440 func TestH12_ExplicitContentLength(t *testing.T) {
441 h12Compare{Handler: func(w ResponseWriter, r *Request) {
442 w.Header().Set("Content-Length", "3")
443 io.WriteString(w, "foo")
444 }}.run(t)
445 }
446
447 func TestH12_FlushBeforeBody(t *testing.T) {
448 h12Compare{Handler: func(w ResponseWriter, r *Request) {
449 w.(Flusher).Flush()
450 io.WriteString(w, "foo")
451 }}.run(t)
452 }
453
454 func TestH12_FlushMidBody(t *testing.T) {
455 h12Compare{Handler: func(w ResponseWriter, r *Request) {
456 io.WriteString(w, "foo")
457 w.(Flusher).Flush()
458 io.WriteString(w, "bar")
459 }}.run(t)
460 }
461
462 func TestH12_Head_ExplicitLen(t *testing.T) {
463 h12Compare{
464 ReqFunc: (*Client).Head,
465 Handler: func(w ResponseWriter, r *Request) {
466 if r.Method != "HEAD" {
467 t.Errorf("unexpected method %q", r.Method)
468 }
469 w.Header().Set("Content-Length", "1235")
470 },
471 }.run(t)
472 }
473
474 func TestH12_Head_ImplicitLen(t *testing.T) {
475 h12Compare{
476 ReqFunc: (*Client).Head,
477 Handler: func(w ResponseWriter, r *Request) {
478 if r.Method != "HEAD" {
479 t.Errorf("unexpected method %q", r.Method)
480 }
481 io.WriteString(w, "foo")
482 },
483 }.run(t)
484 }
485
486 func TestH12_HandlerWritesTooLittle(t *testing.T) {
487 h12Compare{
488 Handler: func(w ResponseWriter, r *Request) {
489 w.Header().Set("Content-Length", "3")
490 io.WriteString(w, "12")
491 },
492 CheckResponse: func(proto string, res *Response) {
493 sr, ok := res.Body.(slurpResult)
494 if !ok {
495 t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
496 return
497 }
498 if sr.err != io.ErrUnexpectedEOF {
499 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
500 }
501 if string(sr.body) != "12" {
502 t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
503 }
504 },
505 }.run(t)
506 }
507
508
509
510
511
512
513
514 func TestHandlerWritesTooMuch(t *testing.T) { run(t, testHandlerWritesTooMuch) }
515 func testHandlerWritesTooMuch(t *testing.T, mode testMode) {
516 wantBody := []byte("123")
517 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
518 rc := NewResponseController(w)
519 w.Header().Set("Content-Length", fmt.Sprintf("%v", len(wantBody)))
520 rc.Flush()
521 w.Write(wantBody)
522 rc.Flush()
523 n, err := io.WriteString(w, "x")
524 if err == nil {
525 err = rc.Flush()
526 }
527
528 if err == nil {
529 t.Errorf("for proto %q, final write = %v, %v; want _, some error", r.Proto, n, err)
530 }
531 }))
532
533 res, err := cst.c.Get(cst.ts.URL)
534 if err != nil {
535 t.Fatal(err)
536 }
537 defer res.Body.Close()
538
539 gotBody, _ := io.ReadAll(res.Body)
540 if !bytes.Equal(gotBody, wantBody) {
541 t.Fatalf("got response body: %q; want %q", gotBody, wantBody)
542 }
543 }
544
545
546
547 func TestH12_AutoGzip(t *testing.T) {
548 h12Compare{
549 Handler: func(w ResponseWriter, r *Request) {
550 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
551 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
552 }
553 w.Header().Set("Content-Encoding", "gzip")
554 gz := gzip.NewWriter(w)
555 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
556 gz.Close()
557 },
558 }.run(t)
559 }
560
561 func TestH12_AutoGzip_Disabled(t *testing.T) {
562 h12Compare{
563 Opts: []any{
564 func(tr *Transport) { tr.DisableCompression = true },
565 },
566 Handler: func(w ResponseWriter, r *Request) {
567 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
568 if ae := r.Header.Get("Accept-Encoding"); ae != "" {
569 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
570 }
571 },
572 }.run(t)
573 }
574
575
576
577
578 func Test304Responses(t *testing.T) { run(t, test304Responses) }
579 func test304Responses(t *testing.T, mode testMode) {
580 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
581 w.WriteHeader(StatusNotModified)
582 _, err := w.Write([]byte("illegal body"))
583 if err != ErrBodyNotAllowed {
584 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
585 }
586 }))
587 defer cst.close()
588 res, err := cst.c.Get(cst.ts.URL)
589 if err != nil {
590 t.Fatal(err)
591 }
592 if len(res.TransferEncoding) > 0 {
593 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
594 }
595 body, err := io.ReadAll(res.Body)
596 if err != nil {
597 t.Error(err)
598 }
599 if len(body) > 0 {
600 t.Errorf("got unexpected body %q", string(body))
601 }
602 }
603
604 func TestH12_ServerEmptyContentLength(t *testing.T) {
605 h12Compare{
606 Handler: func(w ResponseWriter, r *Request) {
607 w.Header()["Content-Type"] = []string{""}
608 io.WriteString(w, "<html><body>hi</body></html>")
609 },
610 }.run(t)
611 }
612
613 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
614 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
615 }
616
617 func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
618 h12requestContentLength(t, func() io.Reader { return nil }, 0)
619 }
620
621 func TestH12_RequestContentLength_Unknown(t *testing.T) {
622 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
623 }
624
625 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
626 h12Compare{
627 Handler: func(w ResponseWriter, r *Request) {
628 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
629 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
630 },
631 ReqFunc: func(c *Client, url string) (*Response, error) {
632 return c.Post(url, "text/plain", bodyfn())
633 },
634 CheckResponse: func(proto string, res *Response) {
635 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
636 t.Errorf("Proto %q got length %q; want %q", proto, got, want)
637 }
638 },
639 }.run(t)
640 }
641
642
643
644 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
645 func testCancelRequestMidBody(t *testing.T, mode testMode) {
646 unblock := make(chan bool)
647 didFlush := make(chan bool, 1)
648 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
649 io.WriteString(w, "Hello")
650 w.(Flusher).Flush()
651 didFlush <- true
652 <-unblock
653 io.WriteString(w, ", world.")
654 }))
655 defer close(unblock)
656
657 req, _ := NewRequest("GET", cst.ts.URL, nil)
658 cancel := make(chan struct{})
659 req.Cancel = cancel
660
661 res, err := cst.c.Do(req)
662 if err != nil {
663 t.Fatal(err)
664 }
665 defer res.Body.Close()
666 <-didFlush
667
668
669
670 firstRead := make([]byte, 10)
671 n, err := res.Body.Read(firstRead)
672 if err != nil {
673 t.Fatal(err)
674 }
675 firstRead = firstRead[:n]
676
677 close(cancel)
678
679 rest, err := io.ReadAll(res.Body)
680 all := string(firstRead) + string(rest)
681 if all != "Hello" {
682 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
683 }
684 if err != ExportErrRequestCanceled {
685 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
686 }
687 }
688
689
690 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
691 func testTrailersClientToServer(t *testing.T, mode testMode) {
692 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
693 slurp, err := io.ReadAll(r.Body)
694 if err != nil {
695 t.Errorf("Server reading request body: %v", err)
696 }
697 if string(slurp) != "foo" {
698 t.Errorf("Server read request body %q; want foo", slurp)
699 }
700 if r.Trailer == nil {
701 io.WriteString(w, "nil Trailer")
702 } else {
703 decl := slices.Sorted(maps.Keys(r.Trailer))
704 fmt.Fprintf(w, "decl: %v, vals: %s, %s",
705 decl,
706 r.Trailer.Get("Client-Trailer-A"),
707 r.Trailer.Get("Client-Trailer-B"))
708 }
709 }))
710
711 var req *Request
712 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
713 eofReaderFunc(func() {
714 req.Trailer["Client-Trailer-A"] = []string{"valuea"}
715 }),
716 strings.NewReader("foo"),
717 eofReaderFunc(func() {
718 req.Trailer["Client-Trailer-B"] = []string{"valueb"}
719 }),
720 ))
721 req.Trailer = Header{
722 "Client-Trailer-A": nil,
723 "Client-Trailer-B": nil,
724 }
725 req.ContentLength = -1
726 res, err := cst.c.Do(req)
727 if err != nil {
728 t.Fatal(err)
729 }
730 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
731 t.Error(err)
732 }
733 }
734
735
736 func TestTrailersServerToClient(t *testing.T) {
737 run(t, func(t *testing.T, mode testMode) {
738 testTrailersServerToClient(t, mode, false)
739 })
740 }
741 func TestTrailersServerToClientFlush(t *testing.T) {
742 run(t, func(t *testing.T, mode testMode) {
743 testTrailersServerToClient(t, mode, true)
744 })
745 }
746
747 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
748 const body = "Some body"
749 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
750 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
751 w.Header().Add("Trailer", "Server-Trailer-C")
752
753 io.WriteString(w, body)
754 if flush {
755 w.(Flusher).Flush()
756 }
757
758
759
760
761
762 w.Header().Set("Server-Trailer-A", "valuea")
763 w.Header().Set("Server-Trailer-C", "valuec")
764 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
765 }))
766
767 res, err := cst.c.Get(cst.ts.URL)
768 if err != nil {
769 t.Fatal(err)
770 }
771
772 wantHeader := Header{
773 "Content-Type": {"text/plain; charset=utf-8"},
774 }
775 wantLen := -1
776 if mode == http2Mode && !flush {
777
778
779
780
781
782 wantLen = len(body)
783 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
784 }
785 if res.ContentLength != int64(wantLen) {
786 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
787 }
788
789 delete(res.Header, "Date")
790 if !reflect.DeepEqual(res.Header, wantHeader) {
791 t.Errorf("Header = %v; want %v", res.Header, wantHeader)
792 }
793
794 if got, want := res.Trailer, (Header{
795 "Server-Trailer-A": nil,
796 "Server-Trailer-B": nil,
797 "Server-Trailer-C": nil,
798 }); !reflect.DeepEqual(got, want) {
799 t.Errorf("Trailer before body read = %v; want %v", got, want)
800 }
801
802 if err := wantBody(res, nil, body); err != nil {
803 t.Fatal(err)
804 }
805
806 if got, want := res.Trailer, (Header{
807 "Server-Trailer-A": {"valuea"},
808 "Server-Trailer-B": nil,
809 "Server-Trailer-C": {"valuec"},
810 }); !reflect.DeepEqual(got, want) {
811 t.Errorf("Trailer after body read = %v; want %v", got, want)
812 }
813 }
814
815
816 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
817 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
818 const body = "Some body"
819 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
820 io.WriteString(w, body)
821 }))
822 res, err := cst.c.Get(cst.ts.URL)
823 if err != nil {
824 t.Fatal(err)
825 }
826 res.Body.Close()
827 data, err := io.ReadAll(res.Body)
828 if len(data) != 0 || err == nil {
829 t.Fatalf("ReadAll returned %q, %v; want error", data, err)
830 }
831 }
832
833 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
834 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
835 const reqBody = "some request body"
836 const resBody = "some response body"
837 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
838 var wg sync.WaitGroup
839 wg.Add(2)
840 didRead := make(chan bool, 1)
841
842 go func() {
843 defer wg.Done()
844 data, err := io.ReadAll(r.Body)
845 if string(data) != reqBody {
846 t.Errorf("Handler read %q; want %q", data, reqBody)
847 }
848 if err != nil {
849 t.Errorf("Handler Read: %v", err)
850 }
851 didRead <- true
852 }()
853
854 go func() {
855 defer wg.Done()
856 if mode != http2Mode {
857
858
859
860
861 <-didRead
862 }
863 io.WriteString(w, resBody)
864 }()
865 wg.Wait()
866 }))
867 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
868 req.Header.Add("Expect", "100-continue")
869 res, err := cst.c.Do(req)
870 if err != nil {
871 t.Fatal(err)
872 }
873 data, err := io.ReadAll(res.Body)
874 defer res.Body.Close()
875 if err != nil {
876 t.Fatal(err)
877 }
878 if string(data) != resBody {
879 t.Errorf("read %q; want %q", data, resBody)
880 }
881 }
882
883 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
884 func testConnectRequest(t *testing.T, mode testMode) {
885 gotc := make(chan *Request, 1)
886 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
887 gotc <- r
888 }))
889
890 u, err := url.Parse(cst.ts.URL)
891 if err != nil {
892 t.Fatal(err)
893 }
894
895 tests := []struct {
896 req *Request
897 want string
898 }{
899 {
900 req: &Request{
901 Method: "CONNECT",
902 Header: Header{},
903 URL: u,
904 },
905 want: u.Host,
906 },
907 {
908 req: &Request{
909 Method: "CONNECT",
910 Header: Header{},
911 URL: u,
912 Host: "example.com:123",
913 },
914 want: "example.com:123",
915 },
916 }
917
918 for i, tt := range tests {
919 res, err := cst.c.Do(tt.req)
920 if err != nil {
921 t.Errorf("%d. RoundTrip = %v", i, err)
922 continue
923 }
924 res.Body.Close()
925 req := <-gotc
926 if req.Method != "CONNECT" {
927 t.Errorf("method = %q; want CONNECT", req.Method)
928 }
929 if req.Host != tt.want {
930 t.Errorf("Host = %q; want %q", req.Host, tt.want)
931 }
932 if req.URL.Host != tt.want {
933 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
934 }
935 }
936 }
937
938 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
939 func testTransportUserAgent(t *testing.T, mode testMode) {
940 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
941 fmt.Fprintf(w, "%q", r.Header["User-Agent"])
942 }))
943
944 either := func(a, b string) string {
945 if mode == http2Mode {
946 return b
947 }
948 return a
949 }
950
951 tests := []struct {
952 setup func(*Request)
953 want string
954 }{
955 {
956 func(r *Request) {},
957 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
958 },
959 {
960 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
961 `["foo/1.2.3"]`,
962 },
963 {
964 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
965 `["single"]`,
966 },
967 {
968 func(r *Request) { r.Header.Set("User-Agent", "") },
969 `[]`,
970 },
971 {
972 func(r *Request) { r.Header["User-Agent"] = nil },
973 `[]`,
974 },
975 }
976 for i, tt := range tests {
977 req, _ := NewRequest("GET", cst.ts.URL, nil)
978 tt.setup(req)
979 res, err := cst.c.Do(req)
980 if err != nil {
981 t.Errorf("%d. RoundTrip = %v", i, err)
982 continue
983 }
984 slurp, err := io.ReadAll(res.Body)
985 res.Body.Close()
986 if err != nil {
987 t.Errorf("%d. read body = %v", i, err)
988 continue
989 }
990 if string(slurp) != tt.want {
991 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
992 }
993 }
994 }
995
996 func TestStarRequestMethod(t *testing.T) {
997 for _, method := range []string{"FOO", "OPTIONS"} {
998 t.Run(method, func(t *testing.T) {
999 run(t, func(t *testing.T, mode testMode) {
1000 testStarRequest(t, method, mode)
1001 })
1002 })
1003 }
1004 }
1005 func testStarRequest(t *testing.T, method string, mode testMode) {
1006 gotc := make(chan *Request, 1)
1007 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1008 w.Header().Set("foo", "bar")
1009 gotc <- r
1010 w.(Flusher).Flush()
1011 }))
1012
1013 u, err := url.Parse(cst.ts.URL)
1014 if err != nil {
1015 t.Fatal(err)
1016 }
1017 u.Path = "*"
1018
1019 req := &Request{
1020 Method: method,
1021 Header: Header{},
1022 URL: u,
1023 }
1024
1025 res, err := cst.c.Do(req)
1026 if err != nil {
1027 t.Fatalf("RoundTrip = %v", err)
1028 }
1029 res.Body.Close()
1030
1031 wantFoo := "bar"
1032 wantLen := int64(-1)
1033 if method == "OPTIONS" {
1034 wantFoo = ""
1035 wantLen = 0
1036 }
1037 if res.StatusCode != 200 {
1038 t.Errorf("status code = %v; want %d", res.Status, 200)
1039 }
1040 if res.ContentLength != wantLen {
1041 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
1042 }
1043 if got := res.Header.Get("foo"); got != wantFoo {
1044 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
1045 }
1046 select {
1047 case req = <-gotc:
1048 default:
1049 req = nil
1050 }
1051 if req == nil {
1052 if method != "OPTIONS" {
1053 t.Fatalf("handler never got request")
1054 }
1055 return
1056 }
1057 if req.Method != method {
1058 t.Errorf("method = %q; want %q", req.Method, method)
1059 }
1060 if req.URL.Path != "*" {
1061 t.Errorf("URL.Path = %q; want *", req.URL.Path)
1062 }
1063 if req.RequestURI != "*" {
1064 t.Errorf("RequestURI = %q; want *", req.RequestURI)
1065 }
1066 }
1067
1068
1069 func TestTransportDiscardsUnneededConns(t *testing.T) {
1070 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
1071 }
1072 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
1073 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1074 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
1075 }))
1076 defer cst.close()
1077
1078 var numOpen, numClose int32
1079
1080 tlsConfig := &tls.Config{InsecureSkipVerify: true}
1081 tr := &Transport{
1082 TLSClientConfig: tlsConfig,
1083 DialTLS: func(_, addr string) (net.Conn, error) {
1084 time.Sleep(10 * time.Millisecond)
1085 rc, err := net.Dial("tcp", addr)
1086 if err != nil {
1087 return nil, err
1088 }
1089 atomic.AddInt32(&numOpen, 1)
1090 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
1091 return tls.Client(c, tlsConfig), nil
1092 },
1093 }
1094 if err := ExportHttp2ConfigureTransport(tr); err != nil {
1095 t.Fatal(err)
1096 }
1097 defer tr.CloseIdleConnections()
1098
1099 c := &Client{Transport: tr}
1100
1101 const N = 10
1102 gotBody := make(chan string, N)
1103 var wg sync.WaitGroup
1104 for i := 0; i < N; i++ {
1105 wg.Add(1)
1106 go func() {
1107 defer wg.Done()
1108 resp, err := c.Get(cst.ts.URL)
1109 if err != nil {
1110
1111
1112 time.Sleep(10 * time.Millisecond)
1113 resp, err = c.Get(cst.ts.URL)
1114 if err != nil {
1115 t.Errorf("Get: %v", err)
1116 return
1117 }
1118 }
1119 defer resp.Body.Close()
1120 slurp, err := io.ReadAll(resp.Body)
1121 if err != nil {
1122 t.Error(err)
1123 }
1124 gotBody <- string(slurp)
1125 }()
1126 }
1127 wg.Wait()
1128 close(gotBody)
1129
1130 var last string
1131 for got := range gotBody {
1132 if last == "" {
1133 last = got
1134 continue
1135 }
1136 if got != last {
1137 t.Errorf("Response body changed: %q -> %q", last, got)
1138 }
1139 }
1140
1141 var open, close int32
1142 for i := 0; i < 150; i++ {
1143 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
1144 if open < 1 {
1145 t.Fatalf("open = %d; want at least", open)
1146 }
1147 if close == open-1 {
1148
1149 return
1150 }
1151 time.Sleep(10 * time.Millisecond)
1152 }
1153 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
1154 }
1155
1156
1157 func TestTransportGCRequest(t *testing.T) {
1158 run(t, func(t *testing.T, mode testMode) {
1159 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
1160 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
1161 })
1162 }
1163 func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
1164 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1165 io.ReadAll(r.Body)
1166 if body {
1167 io.WriteString(w, "Hello.")
1168 }
1169 }))
1170
1171 didGC := make(chan struct{})
1172 (func() {
1173 body := strings.NewReader("some body")
1174 req, _ := NewRequest("POST", cst.ts.URL, body)
1175 runtime.SetFinalizer(req, func(*Request) { close(didGC) })
1176 res, err := cst.c.Do(req)
1177 if err != nil {
1178 t.Fatal(err)
1179 }
1180 if _, err := io.ReadAll(res.Body); err != nil {
1181 t.Fatal(err)
1182 }
1183 if err := res.Body.Close(); err != nil {
1184 t.Fatal(err)
1185 }
1186 })()
1187 for {
1188 select {
1189 case <-didGC:
1190 return
1191 case <-time.After(1 * time.Millisecond):
1192 runtime.GC()
1193 }
1194 }
1195 }
1196
1197 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
1198 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
1199 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1200 fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
1201 }), optQuietLog)
1202 cst.tr.DisableKeepAlives = true
1203
1204 tests := []struct {
1205 key, val string
1206 ok bool
1207 }{
1208 {"Foo", "capital-key", true},
1209 {"Foo", "foo\x00bar", false},
1210 {"Foo", "two\nlines", false},
1211 {"bogus\nkey", "v", false},
1212 {"A space", "v", false},
1213 {"имя", "v", false},
1214 {"name", "валю", true},
1215 {"", "v", false},
1216 {"k", "", true},
1217 }
1218 for _, tt := range tests {
1219 dialedc := make(chan bool, 1)
1220 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
1221 dialedc <- true
1222 return net.Dial(netw, addr)
1223 }
1224 req, _ := NewRequest("GET", cst.ts.URL, nil)
1225 req.Header[tt.key] = []string{tt.val}
1226 res, err := cst.c.Do(req)
1227 var body []byte
1228 if err == nil {
1229 body, _ = io.ReadAll(res.Body)
1230 res.Body.Close()
1231 }
1232 var dialed bool
1233 select {
1234 case <-dialedc:
1235 dialed = true
1236 default:
1237 }
1238
1239 if !tt.ok && dialed {
1240 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
1241 } else if (err == nil) != tt.ok {
1242 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
1243 }
1244 }
1245 }
1246
1247 func TestInterruptWithPanic(t *testing.T) {
1248 run(t, func(t *testing.T, mode testMode) {
1249 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
1250 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
1251 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
1252 }, testNotParallel)
1253 }
1254 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
1255 const msg = "hello"
1256
1257 testDone := make(chan struct{})
1258 defer close(testDone)
1259
1260 var errorLog lockedBytesBuffer
1261 gotHeaders := make(chan bool, 1)
1262 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1263 io.WriteString(w, msg)
1264 w.(Flusher).Flush()
1265
1266 select {
1267 case <-gotHeaders:
1268 case <-testDone:
1269 }
1270 panic(panicValue)
1271 }), func(ts *httptest.Server) {
1272 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1273 })
1274 res, err := cst.c.Get(cst.ts.URL)
1275 if err != nil {
1276 t.Fatal(err)
1277 }
1278 gotHeaders <- true
1279 defer res.Body.Close()
1280 slurp, err := io.ReadAll(res.Body)
1281 if string(slurp) != msg {
1282 t.Errorf("client read %q; want %q", slurp, msg)
1283 }
1284 if err == nil {
1285 t.Errorf("client read all successfully; want some error")
1286 }
1287 logOutput := func() string {
1288 errorLog.Lock()
1289 defer errorLog.Unlock()
1290 return errorLog.String()
1291 }
1292 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
1293
1294 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
1295 gotLog := logOutput()
1296 if !wantStackLogged {
1297 if gotLog == "" {
1298 return true
1299 }
1300 t.Fatalf("want no log output; got: %s", gotLog)
1301 }
1302 if gotLog == "" {
1303 if d > 0 {
1304 t.Logf("wanted a stack trace logged; got nothing after %v", d)
1305 }
1306 return false
1307 }
1308 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
1309 if d > 0 {
1310 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
1311 }
1312 return false
1313 }
1314 return true
1315 })
1316 }
1317
1318 type lockedBytesBuffer struct {
1319 sync.Mutex
1320 bytes.Buffer
1321 }
1322
1323 func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
1324 b.Lock()
1325 defer b.Unlock()
1326 return b.Buffer.Write(p)
1327 }
1328
1329
1330 func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
1331 h12Compare{
1332 Handler: func(w ResponseWriter, r *Request) {
1333 h := w.Header()
1334 h.Set("Content-Encoding", "gzip")
1335 h.Set("Content-Length", "23")
1336 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
1337 },
1338 EarlyCheckResponse: func(proto string, res *Response) {
1339 if !res.Uncompressed {
1340 t.Errorf("%s: expected Uncompressed to be set", proto)
1341 }
1342 dump, err := httputil.DumpResponse(res, true)
1343 if err != nil {
1344 t.Errorf("%s: DumpResponse: %v", proto, err)
1345 return
1346 }
1347 if strings.Contains(string(dump), "Connection: close") {
1348 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
1349 }
1350 if !strings.Contains(string(dump), "FOO") {
1351 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
1352 }
1353 },
1354 }.run(t)
1355 }
1356
1357
1358 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
1359 func testCloseIdleConnections(t *testing.T, mode testMode) {
1360 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1361 w.Header().Set("X-Addr", r.RemoteAddr)
1362 }))
1363 get := func() string {
1364 res, err := cst.c.Get(cst.ts.URL)
1365 if err != nil {
1366 t.Fatal(err)
1367 }
1368 res.Body.Close()
1369 v := res.Header.Get("X-Addr")
1370 if v == "" {
1371 t.Fatal("didn't get X-Addr")
1372 }
1373 return v
1374 }
1375 a1 := get()
1376 cst.tr.CloseIdleConnections()
1377 a2 := get()
1378 if a1 == a2 {
1379 t.Errorf("didn't close connection")
1380 }
1381 }
1382
1383 type noteCloseConn struct {
1384 net.Conn
1385 closeFunc func()
1386 }
1387
1388 func (x noteCloseConn) Close() error {
1389 x.closeFunc()
1390 return x.Conn.Close()
1391 }
1392
1393 type testErrorReader struct{ t *testing.T }
1394
1395 func (r testErrorReader) Read(p []byte) (n int, err error) {
1396 r.t.Error("unexpected Read call")
1397 return 0, io.EOF
1398 }
1399
1400 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
1401 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
1402 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1403 w.WriteHeader(StatusUnauthorized)
1404 }))
1405
1406
1407 cst.tr.ExpectContinueTimeout = 10 * time.Second
1408
1409 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
1410 if err != nil {
1411 t.Fatal(err)
1412 }
1413 req.ContentLength = 0
1414 req.Header.Set("Expect", "100-continue")
1415 res, err := cst.tr.RoundTrip(req)
1416 if err != nil {
1417 t.Fatal(err)
1418 }
1419 defer res.Body.Close()
1420 if res.StatusCode != StatusUnauthorized {
1421 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
1422 }
1423 }
1424
1425 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
1426 func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
1427 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1428 w.Header().Set("Foo", "Bar")
1429 w.Header().Set("Trailer:Foo", "Baz")
1430 w.(Flusher).Flush()
1431 w.Header().Add("Trailer:Foo", "Baz2")
1432 w.Header().Set("Trailer:Bar", "Quux")
1433 }))
1434 res, err := cst.c.Get(cst.ts.URL)
1435 if err != nil {
1436 t.Fatal(err)
1437 }
1438 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1439 t.Fatal(err)
1440 }
1441 res.Body.Close()
1442 delete(res.Header, "Date")
1443 delete(res.Header, "Content-Type")
1444
1445 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
1446 t.Errorf("Header = %#v; want %#v", res.Header, want)
1447 }
1448 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
1449 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
1450 }
1451 }
1452
1453 func TestBadResponseAfterReadingBody(t *testing.T) {
1454 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
1455 }
1456 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
1457 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1458 _, err := io.Copy(io.Discard, r.Body)
1459 if err != nil {
1460 t.Fatal(err)
1461 }
1462 c, _, err := w.(Hijacker).Hijack()
1463 if err != nil {
1464 t.Fatal(err)
1465 }
1466 defer c.Close()
1467 fmt.Fprintln(c, "some bogus crap")
1468 }))
1469
1470 closes := 0
1471 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
1472 if err == nil {
1473 res.Body.Close()
1474 t.Fatal("expected an error to be returned from Post")
1475 }
1476 if closes != 1 {
1477 t.Errorf("closes = %d; want 1", closes)
1478 }
1479 }
1480
1481 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
1482 func testWriteHeader0(t *testing.T, mode testMode) {
1483 gotpanic := make(chan bool, 1)
1484 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1485 defer close(gotpanic)
1486 defer func() {
1487 if e := recover(); e != nil {
1488 got := fmt.Sprintf("%T, %v", e, e)
1489 want := "string, invalid WriteHeader code 0"
1490 if got != want {
1491 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
1492 }
1493 gotpanic <- true
1494
1495
1496
1497
1498 w.WriteHeader(503)
1499 }
1500 }()
1501 w.WriteHeader(0)
1502 }))
1503 res, err := cst.c.Get(cst.ts.URL)
1504 if err != nil {
1505 t.Fatal(err)
1506 }
1507 if res.StatusCode != 503 {
1508 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
1509 }
1510 if !<-gotpanic {
1511 t.Error("expected panic in handler")
1512 }
1513 }
1514
1515
1516
1517 func TestWriteHeaderNoCodeCheck(t *testing.T) {
1518 run(t, func(t *testing.T, mode testMode) {
1519 testWriteHeaderAfterWrite(t, mode, false)
1520 })
1521 }
1522 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
1523 testWriteHeaderAfterWrite(t, http1Mode, true)
1524 }
1525 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
1526 var errorLog lockedBytesBuffer
1527 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1528 if hijack {
1529 conn, _, _ := w.(Hijacker).Hijack()
1530 defer conn.Close()
1531 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
1532 w.WriteHeader(0)
1533 conn.Write([]byte("bar"))
1534 return
1535 }
1536 io.WriteString(w, "foo")
1537 w.(Flusher).Flush()
1538 w.WriteHeader(0)
1539 io.WriteString(w, "bar")
1540 }), func(ts *httptest.Server) {
1541 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1542 })
1543 res, err := cst.c.Get(cst.ts.URL)
1544 if err != nil {
1545 t.Fatal(err)
1546 }
1547 defer res.Body.Close()
1548 body, err := io.ReadAll(res.Body)
1549 if err != nil {
1550 t.Fatal(err)
1551 }
1552 if got, want := string(body), "foobar"; got != want {
1553 t.Errorf("got = %q; want %q", got, want)
1554 }
1555
1556
1557 if mode == http2Mode {
1558
1559
1560 return
1561 }
1562 gotLog := strings.TrimSpace(errorLog.String())
1563 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1564 if hijack {
1565 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1566 }
1567 if !strings.HasPrefix(gotLog, wantLog) {
1568 t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
1569 }
1570 }
1571
1572 func TestBidiStreamReverseProxy(t *testing.T) {
1573 run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
1574 }
1575 func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
1576 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1577 if _, err := io.Copy(w, r.Body); err != nil {
1578 log.Printf("bidi backend copy: %v", err)
1579 }
1580 }))
1581
1582 backURL, err := url.Parse(backend.ts.URL)
1583 if err != nil {
1584 t.Fatal(err)
1585 }
1586 rp := httputil.NewSingleHostReverseProxy(backURL)
1587 rp.Transport = backend.tr
1588 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1589 rp.ServeHTTP(w, r)
1590 }))
1591
1592 bodyRes := make(chan any, 1)
1593 pr, pw := io.Pipe()
1594 req, _ := NewRequest("PUT", proxy.ts.URL, pr)
1595 const size = 4 << 20
1596 go func() {
1597 h := sha1.New()
1598 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
1599 go pw.Close()
1600 if err != nil {
1601 t.Errorf("body copy: %v", err)
1602 bodyRes <- err
1603 } else {
1604 bodyRes <- h
1605 }
1606 }()
1607 res, err := backend.c.Do(req)
1608 if err != nil {
1609 t.Fatal(err)
1610 }
1611 defer res.Body.Close()
1612 hgot := sha1.New()
1613 n, err := io.Copy(hgot, res.Body)
1614 if err != nil {
1615 t.Fatal(err)
1616 }
1617 if n != size {
1618 t.Fatalf("got %d bytes; want %d", n, size)
1619 }
1620 select {
1621 case v := <-bodyRes:
1622 switch v := v.(type) {
1623 default:
1624 t.Fatalf("body copy: %v", err)
1625 case hash.Hash:
1626 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
1627 t.Errorf("written bytes didn't match received bytes")
1628 }
1629 }
1630 case <-time.After(10 * time.Second):
1631 t.Fatal("timeout")
1632 }
1633
1634 }
1635
1636
1637 func TestH12_WebSocketUpgrade(t *testing.T) {
1638 h12Compare{
1639 Handler: func(w ResponseWriter, r *Request) {
1640 h := w.Header()
1641 h.Set("Foo", "bar")
1642 },
1643 ReqFunc: func(c *Client, url string) (*Response, error) {
1644 req, _ := NewRequest("GET", url, nil)
1645 req.Header.Set("Connection", "Upgrade")
1646 req.Header.Set("Upgrade", "WebSocket")
1647 return c.Do(req)
1648 },
1649 EarlyCheckResponse: func(proto string, res *Response) {
1650 if res.Proto != "HTTP/1.1" {
1651 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
1652 }
1653 res.Proto = "HTTP/IGNORE"
1654 },
1655 }.run(t)
1656 }
1657
1658 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
1659 func testIdentityTransferEncoding(t *testing.T, mode testMode) {
1660 const body = "body"
1661 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1662 gotBody, _ := io.ReadAll(r.Body)
1663 if got, want := string(gotBody), body; got != want {
1664 t.Errorf("got request body = %q; want %q", got, want)
1665 }
1666 w.Header().Set("Transfer-Encoding", "identity")
1667 w.WriteHeader(StatusOK)
1668 w.(Flusher).Flush()
1669 io.WriteString(w, body)
1670 }))
1671 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
1672 res, err := cst.c.Do(req)
1673 if err != nil {
1674 t.Fatal(err)
1675 }
1676 defer res.Body.Close()
1677 gotBody, err := io.ReadAll(res.Body)
1678 if err != nil {
1679 t.Fatal(err)
1680 }
1681 if got, want := string(gotBody), body; got != want {
1682 t.Errorf("got response body = %q; want %q", got, want)
1683 }
1684 }
1685
1686 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
1687 func testEarlyHintsRequest(t *testing.T, mode testMode) {
1688 var wg sync.WaitGroup
1689 wg.Add(1)
1690 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1691 h := w.Header()
1692
1693 h.Add("Content-Length", "123")
1694 h.Add("Link", "</style.css>; rel=preload; as=style")
1695 h.Add("Link", "</script.js>; rel=preload; as=script")
1696 w.WriteHeader(StatusEarlyHints)
1697
1698 wg.Wait()
1699
1700 h.Add("Link", "</foo.js>; rel=preload; as=script")
1701 w.WriteHeader(StatusEarlyHints)
1702
1703 w.Write([]byte("Hello"))
1704 }))
1705
1706 checkLinkHeaders := func(t *testing.T, expected, got []string) {
1707 t.Helper()
1708
1709 if len(expected) != len(got) {
1710 t.Errorf("got %d expected %d", len(got), len(expected))
1711 }
1712
1713 for i := range expected {
1714 if expected[i] != got[i] {
1715 t.Errorf("got %q expected %q", got[i], expected[i])
1716 }
1717 }
1718 }
1719
1720 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
1721 t.Helper()
1722
1723 for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
1724 if v, ok := header[h]; ok {
1725 t.Errorf("%s is %q; must not be sent", h, v)
1726 }
1727 }
1728 }
1729
1730 var respCounter uint8
1731 trace := &httptrace.ClientTrace{
1732 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1733 switch respCounter {
1734 case 0:
1735 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1736 checkExcludedHeaders(t, header)
1737
1738 wg.Done()
1739 case 1:
1740 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1741 checkExcludedHeaders(t, header)
1742
1743 default:
1744 t.Error("Unexpected 1xx response")
1745 }
1746
1747 respCounter++
1748
1749 return nil
1750 },
1751 }
1752 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
1753
1754 res, err := cst.c.Do(req)
1755 if err != nil {
1756 t.Fatal(err)
1757 }
1758 defer res.Body.Close()
1759
1760 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1761 if cl := res.Header.Get("Content-Length"); cl != "123" {
1762 t.Errorf("Content-Length is %q; want 123", cl)
1763 }
1764
1765 body, _ := io.ReadAll(res.Body)
1766 if string(body) != "Hello" {
1767 t.Errorf("Read body %q; want Hello", body)
1768 }
1769 }
1770
View as plain text