Source file
src/net/http/clientserver_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bytes"
11 "compress/gzip"
12 "context"
13 "crypto/rand"
14 "crypto/sha1"
15 "crypto/tls"
16 "fmt"
17 "hash"
18 "io"
19 "log"
20 "net"
21 . "net/http"
22 "net/http/httptest"
23 "net/http/httptrace"
24 "net/http/httputil"
25 "net/textproto"
26 "net/url"
27 "os"
28 "reflect"
29 "runtime"
30 "slices"
31 "strings"
32 "sync"
33 "sync/atomic"
34 "testing"
35 "time"
36 )
37
38 type testMode string
39
40 const (
41 http1Mode = testMode("h1")
42 https1Mode = testMode("https1")
43 http2Mode = testMode("h2")
44 )
45
46 type testNotParallelOpt struct{}
47
48 var (
49 testNotParallel = testNotParallelOpt{}
50 )
51
52 type TBRun[T any] interface {
53 testing.TB
54 Run(string, func(T)) bool
55 }
56
57
58
59
60
61
62
63
64 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
65 t.Helper()
66 modes := []testMode{http1Mode, http2Mode}
67 parallel := true
68 for _, opt := range opts {
69 switch opt := opt.(type) {
70 case []testMode:
71 modes = opt
72 case testNotParallelOpt:
73 parallel = false
74 default:
75 t.Fatalf("unknown option type %T", opt)
76 }
77 }
78 if t, ok := any(t).(*testing.T); ok && parallel {
79 setParallel(t)
80 }
81 for _, mode := range modes {
82 t.Run(string(mode), func(t T) {
83 t.Helper()
84 if t, ok := any(t).(*testing.T); ok && parallel {
85 setParallel(t)
86 }
87 t.Cleanup(func() {
88 afterTest(t)
89 })
90 f(t, mode)
91 })
92 }
93 }
94
95 type clientServerTest struct {
96 t testing.TB
97 h2 bool
98 h Handler
99 ts *httptest.Server
100 tr *Transport
101 c *Client
102 }
103
104 func (t *clientServerTest) close() {
105 t.tr.CloseIdleConnections()
106 t.ts.Close()
107 }
108
109 func (t *clientServerTest) getURL(u string) string {
110 res, err := t.c.Get(u)
111 if err != nil {
112 t.t.Fatal(err)
113 }
114 defer res.Body.Close()
115 slurp, err := io.ReadAll(res.Body)
116 if err != nil {
117 t.t.Fatal(err)
118 }
119 return string(slurp)
120 }
121
122 func (t *clientServerTest) scheme() string {
123 if t.h2 {
124 return "https"
125 }
126 return "http"
127 }
128
129 var optQuietLog = func(ts *httptest.Server) {
130 ts.Config.ErrorLog = quietLog
131 }
132
133 func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
134 return func(ts *httptest.Server) {
135 ts.Config.ErrorLog = lg
136 }
137 }
138
139
140
141
142
143
144
145
146
147
148
149
150 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
151 if mode == http2Mode {
152 CondSkipHTTP2(t)
153 }
154 cst := &clientServerTest{
155 t: t,
156 h2: mode == http2Mode,
157 h: h,
158 }
159 cst.ts = httptest.NewUnstartedServer(h)
160
161 var transportFuncs []func(*Transport)
162 for _, opt := range opts {
163 switch opt := opt.(type) {
164 case func(*Transport):
165 transportFuncs = append(transportFuncs, opt)
166 case func(*httptest.Server):
167 opt(cst.ts)
168 default:
169 t.Fatalf("unhandled option type %T", opt)
170 }
171 }
172
173 if cst.ts.Config.ErrorLog == nil {
174 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
175 }
176
177 switch mode {
178 case http1Mode:
179 cst.ts.Start()
180 case https1Mode:
181 cst.ts.StartTLS()
182 case http2Mode:
183 ExportHttp2ConfigureServer(cst.ts.Config, nil)
184 cst.ts.TLS = cst.ts.Config.TLSConfig
185 cst.ts.StartTLS()
186 default:
187 t.Fatalf("unknown test mode %v", mode)
188 }
189 cst.c = cst.ts.Client()
190 cst.tr = cst.c.Transport.(*Transport)
191 if mode == http2Mode {
192 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
193 t.Fatal(err)
194 }
195 }
196 for _, f := range transportFuncs {
197 f(cst.tr)
198 }
199 t.Cleanup(func() {
200 cst.close()
201 })
202 return cst
203 }
204
205 type testLogWriter struct {
206 t testing.TB
207 }
208
209 func (w testLogWriter) Write(b []byte) (int, error) {
210 w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
211 return len(b), nil
212 }
213
214
215 func TestNewClientServerTest(t *testing.T) {
216 run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode})
217 }
218 func testNewClientServerTest(t *testing.T, mode testMode) {
219 var got struct {
220 sync.Mutex
221 proto string
222 hasTLS bool
223 }
224 h := HandlerFunc(func(w ResponseWriter, r *Request) {
225 got.Lock()
226 defer got.Unlock()
227 got.proto = r.Proto
228 got.hasTLS = r.TLS != nil
229 })
230 cst := newClientServerTest(t, mode, h)
231 if _, err := cst.c.Head(cst.ts.URL); err != nil {
232 t.Fatal(err)
233 }
234 var wantProto string
235 var wantTLS bool
236 switch mode {
237 case http1Mode:
238 wantProto = "HTTP/1.1"
239 wantTLS = false
240 case https1Mode:
241 wantProto = "HTTP/1.1"
242 wantTLS = true
243 case http2Mode:
244 wantProto = "HTTP/2.0"
245 wantTLS = true
246 }
247 if got.proto != wantProto {
248 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
249 }
250 if got.hasTLS != wantTLS {
251 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
252 }
253 }
254
255 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
256 func testChunkedResponseHeaders(t *testing.T, mode testMode) {
257 log.SetOutput(io.Discard)
258 defer log.SetOutput(os.Stderr)
259 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
260 w.Header().Set("Content-Length", "intentional gibberish")
261 w.(Flusher).Flush()
262 fmt.Fprintf(w, "I am a chunked response.")
263 }))
264
265 res, err := cst.c.Get(cst.ts.URL)
266 if err != nil {
267 t.Fatalf("Get error: %v", err)
268 }
269 defer res.Body.Close()
270 if g, e := res.ContentLength, int64(-1); g != e {
271 t.Errorf("expected ContentLength of %d; got %d", e, g)
272 }
273 wantTE := []string{"chunked"}
274 if mode == http2Mode {
275 wantTE = nil
276 }
277 if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
278 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
279 }
280 if got, haveCL := res.Header["Content-Length"]; haveCL {
281 t.Errorf("Unexpected Content-Length: %q", got)
282 }
283 }
284
285 type reqFunc func(c *Client, url string) (*Response, error)
286
287
288
289 type h12Compare struct {
290 Handler func(ResponseWriter, *Request)
291 ReqFunc reqFunc
292 CheckResponse func(proto string, res *Response)
293 EarlyCheckResponse func(proto string, res *Response)
294 Opts []any
295 }
296
297 func (tt h12Compare) reqFunc() reqFunc {
298 if tt.ReqFunc == nil {
299 return (*Client).Get
300 }
301 return tt.ReqFunc
302 }
303
304 func (tt h12Compare) run(t *testing.T) {
305 setParallel(t)
306 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
307 defer cst1.close()
308 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
309 defer cst2.close()
310
311 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
312 if err != nil {
313 t.Errorf("HTTP/1 request: %v", err)
314 return
315 }
316 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
317 if err != nil {
318 t.Errorf("HTTP/2 request: %v", err)
319 return
320 }
321
322 if fn := tt.EarlyCheckResponse; fn != nil {
323 fn("HTTP/1.1", res1)
324 fn("HTTP/2.0", res2)
325 }
326
327 tt.normalizeRes(t, res1, "HTTP/1.1")
328 tt.normalizeRes(t, res2, "HTTP/2.0")
329 res1body, res2body := res1.Body, res2.Body
330
331 eres1 := mostlyCopy(res1)
332 eres2 := mostlyCopy(res2)
333 if !reflect.DeepEqual(eres1, eres2) {
334 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
335 cst1.ts.URL, eres1, cst2.ts.URL, eres2)
336 }
337 if !reflect.DeepEqual(res1body, res2body) {
338 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
339 }
340 if fn := tt.CheckResponse; fn != nil {
341 res1.Body, res2.Body = res1body, res2body
342 fn("HTTP/1.1", res1)
343 fn("HTTP/2.0", res2)
344 }
345 }
346
347 func mostlyCopy(r *Response) *Response {
348 c := *r
349 c.Body = nil
350 c.TransferEncoding = nil
351 c.TLS = nil
352 c.Request = nil
353 return &c
354 }
355
356 type slurpResult struct {
357 io.ReadCloser
358 body []byte
359 err error
360 }
361
362 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
363
364 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
365 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
366 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
367 } else {
368 t.Errorf("got %q response; want %q", res.Proto, wantProto)
369 }
370 slurp, err := io.ReadAll(res.Body)
371
372 res.Body.Close()
373 res.Body = slurpResult{
374 ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
375 body: slurp,
376 err: err,
377 }
378 for i, v := range res.Header["Date"] {
379 res.Header["Date"][i] = strings.Repeat("x", len(v))
380 }
381 if res.Request == nil {
382 t.Errorf("for %s, no request", wantProto)
383 }
384 if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
385 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
386 }
387 }
388
389
390 func TestH12_HeadContentLengthNoBody(t *testing.T) {
391 h12Compare{
392 ReqFunc: (*Client).Head,
393 Handler: func(w ResponseWriter, r *Request) {
394 },
395 }.run(t)
396 }
397
398 func TestH12_HeadContentLengthSmallBody(t *testing.T) {
399 h12Compare{
400 ReqFunc: (*Client).Head,
401 Handler: func(w ResponseWriter, r *Request) {
402 io.WriteString(w, "small")
403 },
404 }.run(t)
405 }
406
407 func TestH12_HeadContentLengthLargeBody(t *testing.T) {
408 h12Compare{
409 ReqFunc: (*Client).Head,
410 Handler: func(w ResponseWriter, r *Request) {
411 chunk := strings.Repeat("x", 512<<10)
412 for i := 0; i < 10; i++ {
413 io.WriteString(w, chunk)
414 }
415 },
416 }.run(t)
417 }
418
419 func TestH12_200NoBody(t *testing.T) {
420 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
421 }
422
423 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
424 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
425 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
426
427 func testH12_noBody(t *testing.T, status int) {
428 h12Compare{Handler: func(w ResponseWriter, r *Request) {
429 w.WriteHeader(status)
430 }}.run(t)
431 }
432
433 func TestH12_SmallBody(t *testing.T) {
434 h12Compare{Handler: func(w ResponseWriter, r *Request) {
435 io.WriteString(w, "small body")
436 }}.run(t)
437 }
438
439 func TestH12_ExplicitContentLength(t *testing.T) {
440 h12Compare{Handler: func(w ResponseWriter, r *Request) {
441 w.Header().Set("Content-Length", "3")
442 io.WriteString(w, "foo")
443 }}.run(t)
444 }
445
446 func TestH12_FlushBeforeBody(t *testing.T) {
447 h12Compare{Handler: func(w ResponseWriter, r *Request) {
448 w.(Flusher).Flush()
449 io.WriteString(w, "foo")
450 }}.run(t)
451 }
452
453 func TestH12_FlushMidBody(t *testing.T) {
454 h12Compare{Handler: func(w ResponseWriter, r *Request) {
455 io.WriteString(w, "foo")
456 w.(Flusher).Flush()
457 io.WriteString(w, "bar")
458 }}.run(t)
459 }
460
461 func TestH12_Head_ExplicitLen(t *testing.T) {
462 h12Compare{
463 ReqFunc: (*Client).Head,
464 Handler: func(w ResponseWriter, r *Request) {
465 if r.Method != "HEAD" {
466 t.Errorf("unexpected method %q", r.Method)
467 }
468 w.Header().Set("Content-Length", "1235")
469 },
470 }.run(t)
471 }
472
473 func TestH12_Head_ImplicitLen(t *testing.T) {
474 h12Compare{
475 ReqFunc: (*Client).Head,
476 Handler: func(w ResponseWriter, r *Request) {
477 if r.Method != "HEAD" {
478 t.Errorf("unexpected method %q", r.Method)
479 }
480 io.WriteString(w, "foo")
481 },
482 }.run(t)
483 }
484
485 func TestH12_HandlerWritesTooLittle(t *testing.T) {
486 h12Compare{
487 Handler: func(w ResponseWriter, r *Request) {
488 w.Header().Set("Content-Length", "3")
489 io.WriteString(w, "12")
490 },
491 CheckResponse: func(proto string, res *Response) {
492 sr, ok := res.Body.(slurpResult)
493 if !ok {
494 t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
495 return
496 }
497 if sr.err != io.ErrUnexpectedEOF {
498 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
499 }
500 if string(sr.body) != "12" {
501 t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
502 }
503 },
504 }.run(t)
505 }
506
507
508
509
510
511
512
513 func TestHandlerWritesTooMuch(t *testing.T) { run(t, testHandlerWritesTooMuch) }
514 func testHandlerWritesTooMuch(t *testing.T, mode testMode) {
515 wantBody := []byte("123")
516 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
517 rc := NewResponseController(w)
518 w.Header().Set("Content-Length", fmt.Sprintf("%v", len(wantBody)))
519 rc.Flush()
520 w.Write(wantBody)
521 rc.Flush()
522 n, err := io.WriteString(w, "x")
523 if err == nil {
524 err = rc.Flush()
525 }
526
527 if err == nil {
528 t.Errorf("for proto %q, final write = %v, %v; want _, some error", r.Proto, n, err)
529 }
530 }))
531
532 res, err := cst.c.Get(cst.ts.URL)
533 if err != nil {
534 t.Fatal(err)
535 }
536 defer res.Body.Close()
537
538 gotBody, _ := io.ReadAll(res.Body)
539 if !bytes.Equal(gotBody, wantBody) {
540 t.Fatalf("got response body: %q; want %q", gotBody, wantBody)
541 }
542 }
543
544
545
546 func TestH12_AutoGzip(t *testing.T) {
547 h12Compare{
548 Handler: func(w ResponseWriter, r *Request) {
549 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
550 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
551 }
552 w.Header().Set("Content-Encoding", "gzip")
553 gz := gzip.NewWriter(w)
554 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
555 gz.Close()
556 },
557 }.run(t)
558 }
559
560 func TestH12_AutoGzip_Disabled(t *testing.T) {
561 h12Compare{
562 Opts: []any{
563 func(tr *Transport) { tr.DisableCompression = true },
564 },
565 Handler: func(w ResponseWriter, r *Request) {
566 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
567 if ae := r.Header.Get("Accept-Encoding"); ae != "" {
568 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
569 }
570 },
571 }.run(t)
572 }
573
574
575
576
577 func Test304Responses(t *testing.T) { run(t, test304Responses) }
578 func test304Responses(t *testing.T, mode testMode) {
579 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
580 w.WriteHeader(StatusNotModified)
581 _, err := w.Write([]byte("illegal body"))
582 if err != ErrBodyNotAllowed {
583 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
584 }
585 }))
586 defer cst.close()
587 res, err := cst.c.Get(cst.ts.URL)
588 if err != nil {
589 t.Fatal(err)
590 }
591 if len(res.TransferEncoding) > 0 {
592 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
593 }
594 body, err := io.ReadAll(res.Body)
595 if err != nil {
596 t.Error(err)
597 }
598 if len(body) > 0 {
599 t.Errorf("got unexpected body %q", string(body))
600 }
601 }
602
603 func TestH12_ServerEmptyContentLength(t *testing.T) {
604 h12Compare{
605 Handler: func(w ResponseWriter, r *Request) {
606 w.Header()["Content-Type"] = []string{""}
607 io.WriteString(w, "<html><body>hi</body></html>")
608 },
609 }.run(t)
610 }
611
612 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
613 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
614 }
615
616 func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
617 h12requestContentLength(t, func() io.Reader { return nil }, 0)
618 }
619
620 func TestH12_RequestContentLength_Unknown(t *testing.T) {
621 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
622 }
623
624 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
625 h12Compare{
626 Handler: func(w ResponseWriter, r *Request) {
627 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
628 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
629 },
630 ReqFunc: func(c *Client, url string) (*Response, error) {
631 return c.Post(url, "text/plain", bodyfn())
632 },
633 CheckResponse: func(proto string, res *Response) {
634 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
635 t.Errorf("Proto %q got length %q; want %q", proto, got, want)
636 }
637 },
638 }.run(t)
639 }
640
641
642
643 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
644 func testCancelRequestMidBody(t *testing.T, mode testMode) {
645 unblock := make(chan bool)
646 didFlush := make(chan bool, 1)
647 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
648 io.WriteString(w, "Hello")
649 w.(Flusher).Flush()
650 didFlush <- true
651 <-unblock
652 io.WriteString(w, ", world.")
653 }))
654 defer close(unblock)
655
656 req, _ := NewRequest("GET", cst.ts.URL, nil)
657 cancel := make(chan struct{})
658 req.Cancel = cancel
659
660 res, err := cst.c.Do(req)
661 if err != nil {
662 t.Fatal(err)
663 }
664 defer res.Body.Close()
665 <-didFlush
666
667
668
669 firstRead := make([]byte, 10)
670 n, err := res.Body.Read(firstRead)
671 if err != nil {
672 t.Fatal(err)
673 }
674 firstRead = firstRead[:n]
675
676 close(cancel)
677
678 rest, err := io.ReadAll(res.Body)
679 all := string(firstRead) + string(rest)
680 if all != "Hello" {
681 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
682 }
683 if err != ExportErrRequestCanceled {
684 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
685 }
686 }
687
688
689 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
690 func testTrailersClientToServer(t *testing.T, mode testMode) {
691 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
692 var decl []string
693 for k := range r.Trailer {
694 decl = append(decl, k)
695 }
696 slices.Sort(decl)
697
698 slurp, err := io.ReadAll(r.Body)
699 if err != nil {
700 t.Errorf("Server reading request body: %v", err)
701 }
702 if string(slurp) != "foo" {
703 t.Errorf("Server read request body %q; want foo", slurp)
704 }
705 if r.Trailer == nil {
706 io.WriteString(w, "nil Trailer")
707 } else {
708 fmt.Fprintf(w, "decl: %v, vals: %s, %s",
709 decl,
710 r.Trailer.Get("Client-Trailer-A"),
711 r.Trailer.Get("Client-Trailer-B"))
712 }
713 }))
714
715 var req *Request
716 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
717 eofReaderFunc(func() {
718 req.Trailer["Client-Trailer-A"] = []string{"valuea"}
719 }),
720 strings.NewReader("foo"),
721 eofReaderFunc(func() {
722 req.Trailer["Client-Trailer-B"] = []string{"valueb"}
723 }),
724 ))
725 req.Trailer = Header{
726 "Client-Trailer-A": nil,
727 "Client-Trailer-B": nil,
728 }
729 req.ContentLength = -1
730 res, err := cst.c.Do(req)
731 if err != nil {
732 t.Fatal(err)
733 }
734 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
735 t.Error(err)
736 }
737 }
738
739
740 func TestTrailersServerToClient(t *testing.T) {
741 run(t, func(t *testing.T, mode testMode) {
742 testTrailersServerToClient(t, mode, false)
743 })
744 }
745 func TestTrailersServerToClientFlush(t *testing.T) {
746 run(t, func(t *testing.T, mode testMode) {
747 testTrailersServerToClient(t, mode, true)
748 })
749 }
750
751 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
752 const body = "Some body"
753 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
754 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
755 w.Header().Add("Trailer", "Server-Trailer-C")
756
757 io.WriteString(w, body)
758 if flush {
759 w.(Flusher).Flush()
760 }
761
762
763
764
765
766 w.Header().Set("Server-Trailer-A", "valuea")
767 w.Header().Set("Server-Trailer-C", "valuec")
768 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
769 }))
770
771 res, err := cst.c.Get(cst.ts.URL)
772 if err != nil {
773 t.Fatal(err)
774 }
775
776 wantHeader := Header{
777 "Content-Type": {"text/plain; charset=utf-8"},
778 }
779 wantLen := -1
780 if mode == http2Mode && !flush {
781
782
783
784
785
786 wantLen = len(body)
787 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
788 }
789 if res.ContentLength != int64(wantLen) {
790 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
791 }
792
793 delete(res.Header, "Date")
794 if !reflect.DeepEqual(res.Header, wantHeader) {
795 t.Errorf("Header = %v; want %v", res.Header, wantHeader)
796 }
797
798 if got, want := res.Trailer, (Header{
799 "Server-Trailer-A": nil,
800 "Server-Trailer-B": nil,
801 "Server-Trailer-C": nil,
802 }); !reflect.DeepEqual(got, want) {
803 t.Errorf("Trailer before body read = %v; want %v", got, want)
804 }
805
806 if err := wantBody(res, nil, body); err != nil {
807 t.Fatal(err)
808 }
809
810 if got, want := res.Trailer, (Header{
811 "Server-Trailer-A": {"valuea"},
812 "Server-Trailer-B": nil,
813 "Server-Trailer-C": {"valuec"},
814 }); !reflect.DeepEqual(got, want) {
815 t.Errorf("Trailer after body read = %v; want %v", got, want)
816 }
817 }
818
819
820 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
821 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
822 const body = "Some body"
823 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
824 io.WriteString(w, body)
825 }))
826 res, err := cst.c.Get(cst.ts.URL)
827 if err != nil {
828 t.Fatal(err)
829 }
830 res.Body.Close()
831 data, err := io.ReadAll(res.Body)
832 if len(data) != 0 || err == nil {
833 t.Fatalf("ReadAll returned %q, %v; want error", data, err)
834 }
835 }
836
837 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
838 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
839 const reqBody = "some request body"
840 const resBody = "some response body"
841 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
842 var wg sync.WaitGroup
843 wg.Add(2)
844 didRead := make(chan bool, 1)
845
846 go func() {
847 defer wg.Done()
848 data, err := io.ReadAll(r.Body)
849 if string(data) != reqBody {
850 t.Errorf("Handler read %q; want %q", data, reqBody)
851 }
852 if err != nil {
853 t.Errorf("Handler Read: %v", err)
854 }
855 didRead <- true
856 }()
857
858 go func() {
859 defer wg.Done()
860 if mode != http2Mode {
861
862
863
864
865 <-didRead
866 }
867 io.WriteString(w, resBody)
868 }()
869 wg.Wait()
870 }))
871 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
872 req.Header.Add("Expect", "100-continue")
873 res, err := cst.c.Do(req)
874 if err != nil {
875 t.Fatal(err)
876 }
877 data, err := io.ReadAll(res.Body)
878 defer res.Body.Close()
879 if err != nil {
880 t.Fatal(err)
881 }
882 if string(data) != resBody {
883 t.Errorf("read %q; want %q", data, resBody)
884 }
885 }
886
887 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
888 func testConnectRequest(t *testing.T, mode testMode) {
889 gotc := make(chan *Request, 1)
890 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
891 gotc <- r
892 }))
893
894 u, err := url.Parse(cst.ts.URL)
895 if err != nil {
896 t.Fatal(err)
897 }
898
899 tests := []struct {
900 req *Request
901 want string
902 }{
903 {
904 req: &Request{
905 Method: "CONNECT",
906 Header: Header{},
907 URL: u,
908 },
909 want: u.Host,
910 },
911 {
912 req: &Request{
913 Method: "CONNECT",
914 Header: Header{},
915 URL: u,
916 Host: "example.com:123",
917 },
918 want: "example.com:123",
919 },
920 }
921
922 for i, tt := range tests {
923 res, err := cst.c.Do(tt.req)
924 if err != nil {
925 t.Errorf("%d. RoundTrip = %v", i, err)
926 continue
927 }
928 res.Body.Close()
929 req := <-gotc
930 if req.Method != "CONNECT" {
931 t.Errorf("method = %q; want CONNECT", req.Method)
932 }
933 if req.Host != tt.want {
934 t.Errorf("Host = %q; want %q", req.Host, tt.want)
935 }
936 if req.URL.Host != tt.want {
937 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
938 }
939 }
940 }
941
942 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
943 func testTransportUserAgent(t *testing.T, mode testMode) {
944 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
945 fmt.Fprintf(w, "%q", r.Header["User-Agent"])
946 }))
947
948 either := func(a, b string) string {
949 if mode == http2Mode {
950 return b
951 }
952 return a
953 }
954
955 tests := []struct {
956 setup func(*Request)
957 want string
958 }{
959 {
960 func(r *Request) {},
961 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
962 },
963 {
964 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
965 `["foo/1.2.3"]`,
966 },
967 {
968 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
969 `["single"]`,
970 },
971 {
972 func(r *Request) { r.Header.Set("User-Agent", "") },
973 `[]`,
974 },
975 {
976 func(r *Request) { r.Header["User-Agent"] = nil },
977 `[]`,
978 },
979 }
980 for i, tt := range tests {
981 req, _ := NewRequest("GET", cst.ts.URL, nil)
982 tt.setup(req)
983 res, err := cst.c.Do(req)
984 if err != nil {
985 t.Errorf("%d. RoundTrip = %v", i, err)
986 continue
987 }
988 slurp, err := io.ReadAll(res.Body)
989 res.Body.Close()
990 if err != nil {
991 t.Errorf("%d. read body = %v", i, err)
992 continue
993 }
994 if string(slurp) != tt.want {
995 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
996 }
997 }
998 }
999
1000 func TestStarRequestMethod(t *testing.T) {
1001 for _, method := range []string{"FOO", "OPTIONS"} {
1002 t.Run(method, func(t *testing.T) {
1003 run(t, func(t *testing.T, mode testMode) {
1004 testStarRequest(t, method, mode)
1005 })
1006 })
1007 }
1008 }
1009 func testStarRequest(t *testing.T, method string, mode testMode) {
1010 gotc := make(chan *Request, 1)
1011 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1012 w.Header().Set("foo", "bar")
1013 gotc <- r
1014 w.(Flusher).Flush()
1015 }))
1016
1017 u, err := url.Parse(cst.ts.URL)
1018 if err != nil {
1019 t.Fatal(err)
1020 }
1021 u.Path = "*"
1022
1023 req := &Request{
1024 Method: method,
1025 Header: Header{},
1026 URL: u,
1027 }
1028
1029 res, err := cst.c.Do(req)
1030 if err != nil {
1031 t.Fatalf("RoundTrip = %v", err)
1032 }
1033 res.Body.Close()
1034
1035 wantFoo := "bar"
1036 wantLen := int64(-1)
1037 if method == "OPTIONS" {
1038 wantFoo = ""
1039 wantLen = 0
1040 }
1041 if res.StatusCode != 200 {
1042 t.Errorf("status code = %v; want %d", res.Status, 200)
1043 }
1044 if res.ContentLength != wantLen {
1045 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
1046 }
1047 if got := res.Header.Get("foo"); got != wantFoo {
1048 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
1049 }
1050 select {
1051 case req = <-gotc:
1052 default:
1053 req = nil
1054 }
1055 if req == nil {
1056 if method != "OPTIONS" {
1057 t.Fatalf("handler never got request")
1058 }
1059 return
1060 }
1061 if req.Method != method {
1062 t.Errorf("method = %q; want %q", req.Method, method)
1063 }
1064 if req.URL.Path != "*" {
1065 t.Errorf("URL.Path = %q; want *", req.URL.Path)
1066 }
1067 if req.RequestURI != "*" {
1068 t.Errorf("RequestURI = %q; want *", req.RequestURI)
1069 }
1070 }
1071
1072
1073 func TestTransportDiscardsUnneededConns(t *testing.T) {
1074 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
1075 }
1076 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
1077 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1078 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
1079 }))
1080 defer cst.close()
1081
1082 var numOpen, numClose int32
1083
1084 tlsConfig := &tls.Config{InsecureSkipVerify: true}
1085 tr := &Transport{
1086 TLSClientConfig: tlsConfig,
1087 DialTLS: func(_, addr string) (net.Conn, error) {
1088 time.Sleep(10 * time.Millisecond)
1089 rc, err := net.Dial("tcp", addr)
1090 if err != nil {
1091 return nil, err
1092 }
1093 atomic.AddInt32(&numOpen, 1)
1094 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
1095 return tls.Client(c, tlsConfig), nil
1096 },
1097 }
1098 if err := ExportHttp2ConfigureTransport(tr); err != nil {
1099 t.Fatal(err)
1100 }
1101 defer tr.CloseIdleConnections()
1102
1103 c := &Client{Transport: tr}
1104
1105 const N = 10
1106 gotBody := make(chan string, N)
1107 var wg sync.WaitGroup
1108 for i := 0; i < N; i++ {
1109 wg.Add(1)
1110 go func() {
1111 defer wg.Done()
1112 resp, err := c.Get(cst.ts.URL)
1113 if err != nil {
1114
1115
1116 time.Sleep(10 * time.Millisecond)
1117 resp, err = c.Get(cst.ts.URL)
1118 if err != nil {
1119 t.Errorf("Get: %v", err)
1120 return
1121 }
1122 }
1123 defer resp.Body.Close()
1124 slurp, err := io.ReadAll(resp.Body)
1125 if err != nil {
1126 t.Error(err)
1127 }
1128 gotBody <- string(slurp)
1129 }()
1130 }
1131 wg.Wait()
1132 close(gotBody)
1133
1134 var last string
1135 for got := range gotBody {
1136 if last == "" {
1137 last = got
1138 continue
1139 }
1140 if got != last {
1141 t.Errorf("Response body changed: %q -> %q", last, got)
1142 }
1143 }
1144
1145 var open, close int32
1146 for i := 0; i < 150; i++ {
1147 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
1148 if open < 1 {
1149 t.Fatalf("open = %d; want at least", open)
1150 }
1151 if close == open-1 {
1152
1153 return
1154 }
1155 time.Sleep(10 * time.Millisecond)
1156 }
1157 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
1158 }
1159
1160
1161 func TestTransportGCRequest(t *testing.T) {
1162 run(t, func(t *testing.T, mode testMode) {
1163 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
1164 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
1165 })
1166 }
1167 func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
1168 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1169 io.ReadAll(r.Body)
1170 if body {
1171 io.WriteString(w, "Hello.")
1172 }
1173 }))
1174
1175 didGC := make(chan struct{})
1176 (func() {
1177 body := strings.NewReader("some body")
1178 req, _ := NewRequest("POST", cst.ts.URL, body)
1179 runtime.SetFinalizer(req, func(*Request) { close(didGC) })
1180 res, err := cst.c.Do(req)
1181 if err != nil {
1182 t.Fatal(err)
1183 }
1184 if _, err := io.ReadAll(res.Body); err != nil {
1185 t.Fatal(err)
1186 }
1187 if err := res.Body.Close(); err != nil {
1188 t.Fatal(err)
1189 }
1190 })()
1191 for {
1192 select {
1193 case <-didGC:
1194 return
1195 case <-time.After(1 * time.Millisecond):
1196 runtime.GC()
1197 }
1198 }
1199 }
1200
1201 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
1202 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
1203 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1204 fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
1205 }), optQuietLog)
1206 cst.tr.DisableKeepAlives = true
1207
1208 tests := []struct {
1209 key, val string
1210 ok bool
1211 }{
1212 {"Foo", "capital-key", true},
1213 {"Foo", "foo\x00bar", false},
1214 {"Foo", "two\nlines", false},
1215 {"bogus\nkey", "v", false},
1216 {"A space", "v", false},
1217 {"имя", "v", false},
1218 {"name", "валю", true},
1219 {"", "v", false},
1220 {"k", "", true},
1221 }
1222 for _, tt := range tests {
1223 dialedc := make(chan bool, 1)
1224 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
1225 dialedc <- true
1226 return net.Dial(netw, addr)
1227 }
1228 req, _ := NewRequest("GET", cst.ts.URL, nil)
1229 req.Header[tt.key] = []string{tt.val}
1230 res, err := cst.c.Do(req)
1231 var body []byte
1232 if err == nil {
1233 body, _ = io.ReadAll(res.Body)
1234 res.Body.Close()
1235 }
1236 var dialed bool
1237 select {
1238 case <-dialedc:
1239 dialed = true
1240 default:
1241 }
1242
1243 if !tt.ok && dialed {
1244 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
1245 } else if (err == nil) != tt.ok {
1246 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
1247 }
1248 }
1249 }
1250
1251 func TestInterruptWithPanic(t *testing.T) {
1252 run(t, func(t *testing.T, mode testMode) {
1253 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
1254 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
1255 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
1256 }, testNotParallel)
1257 }
1258 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
1259 const msg = "hello"
1260
1261 testDone := make(chan struct{})
1262 defer close(testDone)
1263
1264 var errorLog lockedBytesBuffer
1265 gotHeaders := make(chan bool, 1)
1266 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1267 io.WriteString(w, msg)
1268 w.(Flusher).Flush()
1269
1270 select {
1271 case <-gotHeaders:
1272 case <-testDone:
1273 }
1274 panic(panicValue)
1275 }), func(ts *httptest.Server) {
1276 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1277 })
1278 res, err := cst.c.Get(cst.ts.URL)
1279 if err != nil {
1280 t.Fatal(err)
1281 }
1282 gotHeaders <- true
1283 defer res.Body.Close()
1284 slurp, err := io.ReadAll(res.Body)
1285 if string(slurp) != msg {
1286 t.Errorf("client read %q; want %q", slurp, msg)
1287 }
1288 if err == nil {
1289 t.Errorf("client read all successfully; want some error")
1290 }
1291 logOutput := func() string {
1292 errorLog.Lock()
1293 defer errorLog.Unlock()
1294 return errorLog.String()
1295 }
1296 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
1297
1298 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
1299 gotLog := logOutput()
1300 if !wantStackLogged {
1301 if gotLog == "" {
1302 return true
1303 }
1304 t.Fatalf("want no log output; got: %s", gotLog)
1305 }
1306 if gotLog == "" {
1307 if d > 0 {
1308 t.Logf("wanted a stack trace logged; got nothing after %v", d)
1309 }
1310 return false
1311 }
1312 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
1313 if d > 0 {
1314 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
1315 }
1316 return false
1317 }
1318 return true
1319 })
1320 }
1321
1322 type lockedBytesBuffer struct {
1323 sync.Mutex
1324 bytes.Buffer
1325 }
1326
1327 func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
1328 b.Lock()
1329 defer b.Unlock()
1330 return b.Buffer.Write(p)
1331 }
1332
1333
1334 func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
1335 h12Compare{
1336 Handler: func(w ResponseWriter, r *Request) {
1337 h := w.Header()
1338 h.Set("Content-Encoding", "gzip")
1339 h.Set("Content-Length", "23")
1340 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
1341 },
1342 EarlyCheckResponse: func(proto string, res *Response) {
1343 if !res.Uncompressed {
1344 t.Errorf("%s: expected Uncompressed to be set", proto)
1345 }
1346 dump, err := httputil.DumpResponse(res, true)
1347 if err != nil {
1348 t.Errorf("%s: DumpResponse: %v", proto, err)
1349 return
1350 }
1351 if strings.Contains(string(dump), "Connection: close") {
1352 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
1353 }
1354 if !strings.Contains(string(dump), "FOO") {
1355 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
1356 }
1357 },
1358 }.run(t)
1359 }
1360
1361
1362 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
1363 func testCloseIdleConnections(t *testing.T, mode testMode) {
1364 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1365 w.Header().Set("X-Addr", r.RemoteAddr)
1366 }))
1367 get := func() string {
1368 res, err := cst.c.Get(cst.ts.URL)
1369 if err != nil {
1370 t.Fatal(err)
1371 }
1372 res.Body.Close()
1373 v := res.Header.Get("X-Addr")
1374 if v == "" {
1375 t.Fatal("didn't get X-Addr")
1376 }
1377 return v
1378 }
1379 a1 := get()
1380 cst.tr.CloseIdleConnections()
1381 a2 := get()
1382 if a1 == a2 {
1383 t.Errorf("didn't close connection")
1384 }
1385 }
1386
1387 type noteCloseConn struct {
1388 net.Conn
1389 closeFunc func()
1390 }
1391
1392 func (x noteCloseConn) Close() error {
1393 x.closeFunc()
1394 return x.Conn.Close()
1395 }
1396
1397 type testErrorReader struct{ t *testing.T }
1398
1399 func (r testErrorReader) Read(p []byte) (n int, err error) {
1400 r.t.Error("unexpected Read call")
1401 return 0, io.EOF
1402 }
1403
1404 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
1405 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
1406 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1407 w.WriteHeader(StatusUnauthorized)
1408 }))
1409
1410
1411 cst.tr.ExpectContinueTimeout = 10 * time.Second
1412
1413 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
1414 if err != nil {
1415 t.Fatal(err)
1416 }
1417 req.ContentLength = 0
1418 req.Header.Set("Expect", "100-continue")
1419 res, err := cst.tr.RoundTrip(req)
1420 if err != nil {
1421 t.Fatal(err)
1422 }
1423 defer res.Body.Close()
1424 if res.StatusCode != StatusUnauthorized {
1425 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
1426 }
1427 }
1428
1429 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
1430 func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
1431 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1432 w.Header().Set("Foo", "Bar")
1433 w.Header().Set("Trailer:Foo", "Baz")
1434 w.(Flusher).Flush()
1435 w.Header().Add("Trailer:Foo", "Baz2")
1436 w.Header().Set("Trailer:Bar", "Quux")
1437 }))
1438 res, err := cst.c.Get(cst.ts.URL)
1439 if err != nil {
1440 t.Fatal(err)
1441 }
1442 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1443 t.Fatal(err)
1444 }
1445 res.Body.Close()
1446 delete(res.Header, "Date")
1447 delete(res.Header, "Content-Type")
1448
1449 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
1450 t.Errorf("Header = %#v; want %#v", res.Header, want)
1451 }
1452 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
1453 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
1454 }
1455 }
1456
1457 func TestBadResponseAfterReadingBody(t *testing.T) {
1458 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
1459 }
1460 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
1461 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1462 _, err := io.Copy(io.Discard, r.Body)
1463 if err != nil {
1464 t.Fatal(err)
1465 }
1466 c, _, err := w.(Hijacker).Hijack()
1467 if err != nil {
1468 t.Fatal(err)
1469 }
1470 defer c.Close()
1471 fmt.Fprintln(c, "some bogus crap")
1472 }))
1473
1474 closes := 0
1475 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
1476 if err == nil {
1477 res.Body.Close()
1478 t.Fatal("expected an error to be returned from Post")
1479 }
1480 if closes != 1 {
1481 t.Errorf("closes = %d; want 1", closes)
1482 }
1483 }
1484
1485 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
1486 func testWriteHeader0(t *testing.T, mode testMode) {
1487 gotpanic := make(chan bool, 1)
1488 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1489 defer close(gotpanic)
1490 defer func() {
1491 if e := recover(); e != nil {
1492 got := fmt.Sprintf("%T, %v", e, e)
1493 want := "string, invalid WriteHeader code 0"
1494 if got != want {
1495 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
1496 }
1497 gotpanic <- true
1498
1499
1500
1501
1502 w.WriteHeader(503)
1503 }
1504 }()
1505 w.WriteHeader(0)
1506 }))
1507 res, err := cst.c.Get(cst.ts.URL)
1508 if err != nil {
1509 t.Fatal(err)
1510 }
1511 if res.StatusCode != 503 {
1512 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
1513 }
1514 if !<-gotpanic {
1515 t.Error("expected panic in handler")
1516 }
1517 }
1518
1519
1520
1521 func TestWriteHeaderNoCodeCheck(t *testing.T) {
1522 run(t, func(t *testing.T, mode testMode) {
1523 testWriteHeaderAfterWrite(t, mode, false)
1524 })
1525 }
1526 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
1527 testWriteHeaderAfterWrite(t, http1Mode, true)
1528 }
1529 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
1530 var errorLog lockedBytesBuffer
1531 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1532 if hijack {
1533 conn, _, _ := w.(Hijacker).Hijack()
1534 defer conn.Close()
1535 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
1536 w.WriteHeader(0)
1537 conn.Write([]byte("bar"))
1538 return
1539 }
1540 io.WriteString(w, "foo")
1541 w.(Flusher).Flush()
1542 w.WriteHeader(0)
1543 io.WriteString(w, "bar")
1544 }), func(ts *httptest.Server) {
1545 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1546 })
1547 res, err := cst.c.Get(cst.ts.URL)
1548 if err != nil {
1549 t.Fatal(err)
1550 }
1551 defer res.Body.Close()
1552 body, err := io.ReadAll(res.Body)
1553 if err != nil {
1554 t.Fatal(err)
1555 }
1556 if got, want := string(body), "foobar"; got != want {
1557 t.Errorf("got = %q; want %q", got, want)
1558 }
1559
1560
1561 if mode == http2Mode {
1562
1563
1564 return
1565 }
1566 gotLog := strings.TrimSpace(errorLog.String())
1567 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1568 if hijack {
1569 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1570 }
1571 if !strings.HasPrefix(gotLog, wantLog) {
1572 t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
1573 }
1574 }
1575
1576 func TestBidiStreamReverseProxy(t *testing.T) {
1577 run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
1578 }
1579 func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
1580 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1581 if _, err := io.Copy(w, r.Body); err != nil {
1582 log.Printf("bidi backend copy: %v", err)
1583 }
1584 }))
1585
1586 backURL, err := url.Parse(backend.ts.URL)
1587 if err != nil {
1588 t.Fatal(err)
1589 }
1590 rp := httputil.NewSingleHostReverseProxy(backURL)
1591 rp.Transport = backend.tr
1592 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1593 rp.ServeHTTP(w, r)
1594 }))
1595
1596 bodyRes := make(chan any, 1)
1597 pr, pw := io.Pipe()
1598 req, _ := NewRequest("PUT", proxy.ts.URL, pr)
1599 const size = 4 << 20
1600 go func() {
1601 h := sha1.New()
1602 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
1603 go pw.Close()
1604 if err != nil {
1605 bodyRes <- err
1606 } else {
1607 bodyRes <- h
1608 }
1609 }()
1610 res, err := backend.c.Do(req)
1611 if err != nil {
1612 t.Fatal(err)
1613 }
1614 defer res.Body.Close()
1615 hgot := sha1.New()
1616 n, err := io.Copy(hgot, res.Body)
1617 if err != nil {
1618 t.Fatal(err)
1619 }
1620 if n != size {
1621 t.Fatalf("got %d bytes; want %d", n, size)
1622 }
1623 select {
1624 case v := <-bodyRes:
1625 switch v := v.(type) {
1626 default:
1627 t.Fatalf("body copy: %v", err)
1628 case hash.Hash:
1629 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
1630 t.Errorf("written bytes didn't match received bytes")
1631 }
1632 }
1633 case <-time.After(10 * time.Second):
1634 t.Fatal("timeout")
1635 }
1636
1637 }
1638
1639
1640 func TestH12_WebSocketUpgrade(t *testing.T) {
1641 h12Compare{
1642 Handler: func(w ResponseWriter, r *Request) {
1643 h := w.Header()
1644 h.Set("Foo", "bar")
1645 },
1646 ReqFunc: func(c *Client, url string) (*Response, error) {
1647 req, _ := NewRequest("GET", url, nil)
1648 req.Header.Set("Connection", "Upgrade")
1649 req.Header.Set("Upgrade", "WebSocket")
1650 return c.Do(req)
1651 },
1652 EarlyCheckResponse: func(proto string, res *Response) {
1653 if res.Proto != "HTTP/1.1" {
1654 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
1655 }
1656 res.Proto = "HTTP/IGNORE"
1657 },
1658 }.run(t)
1659 }
1660
1661 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
1662 func testIdentityTransferEncoding(t *testing.T, mode testMode) {
1663 const body = "body"
1664 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1665 gotBody, _ := io.ReadAll(r.Body)
1666 if got, want := string(gotBody), body; got != want {
1667 t.Errorf("got request body = %q; want %q", got, want)
1668 }
1669 w.Header().Set("Transfer-Encoding", "identity")
1670 w.WriteHeader(StatusOK)
1671 w.(Flusher).Flush()
1672 io.WriteString(w, body)
1673 }))
1674 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
1675 res, err := cst.c.Do(req)
1676 if err != nil {
1677 t.Fatal(err)
1678 }
1679 defer res.Body.Close()
1680 gotBody, err := io.ReadAll(res.Body)
1681 if err != nil {
1682 t.Fatal(err)
1683 }
1684 if got, want := string(gotBody), body; got != want {
1685 t.Errorf("got response body = %q; want %q", got, want)
1686 }
1687 }
1688
1689 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
1690 func testEarlyHintsRequest(t *testing.T, mode testMode) {
1691 var wg sync.WaitGroup
1692 wg.Add(1)
1693 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1694 h := w.Header()
1695
1696 h.Add("Content-Length", "123")
1697 h.Add("Link", "</style.css>; rel=preload; as=style")
1698 h.Add("Link", "</script.js>; rel=preload; as=script")
1699 w.WriteHeader(StatusEarlyHints)
1700
1701 wg.Wait()
1702
1703 h.Add("Link", "</foo.js>; rel=preload; as=script")
1704 w.WriteHeader(StatusEarlyHints)
1705
1706 w.Write([]byte("Hello"))
1707 }))
1708
1709 checkLinkHeaders := func(t *testing.T, expected, got []string) {
1710 t.Helper()
1711
1712 if len(expected) != len(got) {
1713 t.Errorf("got %d expected %d", len(got), len(expected))
1714 }
1715
1716 for i := range expected {
1717 if expected[i] != got[i] {
1718 t.Errorf("got %q expected %q", got[i], expected[i])
1719 }
1720 }
1721 }
1722
1723 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
1724 t.Helper()
1725
1726 for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
1727 if v, ok := header[h]; ok {
1728 t.Errorf("%s is %q; must not be sent", h, v)
1729 }
1730 }
1731 }
1732
1733 var respCounter uint8
1734 trace := &httptrace.ClientTrace{
1735 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1736 switch respCounter {
1737 case 0:
1738 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1739 checkExcludedHeaders(t, header)
1740
1741 wg.Done()
1742 case 1:
1743 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1744 checkExcludedHeaders(t, header)
1745
1746 default:
1747 t.Error("Unexpected 1xx response")
1748 }
1749
1750 respCounter++
1751
1752 return nil
1753 },
1754 }
1755 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
1756
1757 res, err := cst.c.Do(req)
1758 if err != nil {
1759 t.Fatal(err)
1760 }
1761 defer res.Body.Close()
1762
1763 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1764 if cl := res.Header.Get("Content-Length"); cl != "123" {
1765 t.Errorf("Content-Length is %q; want 123", cl)
1766 }
1767
1768 body, _ := io.ReadAll(res.Body)
1769 if string(body) != "Hello" {
1770 t.Errorf("Read body %q; want Hello", body)
1771 }
1772 }
1773
View as plain text