Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "io"
26 "log"
27 mrand "math/rand"
28 "net"
29 . "net/http"
30 "net/http/httptest"
31 "net/http/httptrace"
32 "net/http/httputil"
33 "net/http/internal/testcert"
34 "net/textproto"
35 "net/url"
36 "os"
37 "reflect"
38 "runtime"
39 "slices"
40 "strconv"
41 "strings"
42 "sync"
43 "sync/atomic"
44 "testing"
45 "testing/iotest"
46 "time"
47
48 "golang.org/x/net/http/httpguts"
49 )
50
51
52
53
54
55 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
56 if r.FormValue("close") == "true" {
57 w.Header().Set("Connection", "close")
58 }
59 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
60 w.Write([]byte(r.RemoteAddr))
61
62
63
64 if c, ok := ResponseWriterConnForTesting(w); ok {
65 fmt.Fprintf(w, ", %T %p", c, c)
66 }
67 })
68
69
70 type testCloseConn struct {
71 net.Conn
72 set *testConnSet
73 }
74
75 func (c *testCloseConn) Close() error {
76 c.set.remove(c)
77 return c.Conn.Close()
78 }
79
80
81
82 type testConnSet struct {
83 t *testing.T
84 mu sync.Mutex
85 closed map[net.Conn]bool
86 list []net.Conn
87 }
88
89 func (tcs *testConnSet) insert(c net.Conn) {
90 tcs.mu.Lock()
91 defer tcs.mu.Unlock()
92 tcs.closed[c] = false
93 tcs.list = append(tcs.list, c)
94 }
95
96 func (tcs *testConnSet) remove(c net.Conn) {
97 tcs.mu.Lock()
98 defer tcs.mu.Unlock()
99 tcs.closed[c] = true
100 }
101
102
103 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
104 connSet := &testConnSet{
105 t: t,
106 closed: make(map[net.Conn]bool),
107 }
108 dial := func(n, addr string) (net.Conn, error) {
109 c, err := net.Dial(n, addr)
110 if err != nil {
111 return nil, err
112 }
113 tc := &testCloseConn{c, connSet}
114 connSet.insert(tc)
115 return tc, nil
116 }
117 return connSet, dial
118 }
119
120 func (tcs *testConnSet) check(t *testing.T) {
121 tcs.mu.Lock()
122 defer tcs.mu.Unlock()
123 for i := 4; i >= 0; i-- {
124 for i, c := range tcs.list {
125 if tcs.closed[c] {
126 continue
127 }
128 if i != 0 {
129
130
131 tcs.mu.Unlock()
132 time.Sleep(50 * time.Millisecond)
133 tcs.mu.Lock()
134 continue
135 }
136 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
137 }
138 }
139 }
140
141 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
142 func testReuseRequest(t *testing.T, mode testMode) {
143 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
144 w.Write([]byte("{}"))
145 })).ts
146
147 c := ts.Client()
148 req, _ := NewRequest("GET", ts.URL, nil)
149 res, err := c.Do(req)
150 if err != nil {
151 t.Fatal(err)
152 }
153 err = res.Body.Close()
154 if err != nil {
155 t.Fatal(err)
156 }
157
158 res, err = c.Do(req)
159 if err != nil {
160 t.Fatal(err)
161 }
162 err = res.Body.Close()
163 if err != nil {
164 t.Fatal(err)
165 }
166 }
167
168
169
170 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
171 func testTransportKeepAlives(t *testing.T, mode testMode) {
172 ts := newClientServerTest(t, mode, hostPortHandler).ts
173
174 c := ts.Client()
175 for _, disableKeepAlive := range []bool{false, true} {
176 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
177 fetch := func(n int) string {
178 res, err := c.Get(ts.URL)
179 if err != nil {
180 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
181 }
182 body, err := io.ReadAll(res.Body)
183 if err != nil {
184 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
185 }
186 return string(body)
187 }
188
189 body1 := fetch(1)
190 body2 := fetch(2)
191
192 bodiesDiffer := body1 != body2
193 if bodiesDiffer != disableKeepAlive {
194 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
195 disableKeepAlive, bodiesDiffer, body1, body2)
196 }
197 }
198 }
199
200 func TestTransportConnectionCloseOnResponse(t *testing.T) {
201 run(t, testTransportConnectionCloseOnResponse)
202 }
203 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
204 ts := newClientServerTest(t, mode, hostPortHandler).ts
205
206 connSet, testDial := makeTestDial(t)
207
208 c := ts.Client()
209 tr := c.Transport.(*Transport)
210 tr.Dial = testDial
211
212 for _, connectionClose := range []bool{false, true} {
213 fetch := func(n int) string {
214 req := new(Request)
215 var err error
216 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
217 if err != nil {
218 t.Fatalf("URL parse error: %v", err)
219 }
220 req.Method = "GET"
221 req.Proto = "HTTP/1.1"
222 req.ProtoMajor = 1
223 req.ProtoMinor = 1
224
225 res, err := c.Do(req)
226 if err != nil {
227 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
228 }
229 defer res.Body.Close()
230 body, err := io.ReadAll(res.Body)
231 if err != nil {
232 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
233 }
234 return string(body)
235 }
236
237 body1 := fetch(1)
238 body2 := fetch(2)
239 bodiesDiffer := body1 != body2
240 if bodiesDiffer != connectionClose {
241 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
242 connectionClose, bodiesDiffer, body1, body2)
243 }
244
245 tr.CloseIdleConnections()
246 }
247
248 connSet.check(t)
249 }
250
251
252
253
254
255
256
257 func TestTransportConnectionCloseOnRequest(t *testing.T) {
258 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
259 }
260 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
261 ts := newClientServerTest(t, mode, hostPortHandler).ts
262
263 connSet, testDial := makeTestDial(t)
264
265 c := ts.Client()
266 tr := c.Transport.(*Transport)
267 tr.Dial = testDial
268 for _, reqClose := range []bool{false, true} {
269 fetch := func(n int) string {
270 req := new(Request)
271 var err error
272 req.URL, err = url.Parse(ts.URL)
273 if err != nil {
274 t.Fatalf("URL parse error: %v", err)
275 }
276 req.Method = "GET"
277 req.Proto = "HTTP/1.1"
278 req.ProtoMajor = 1
279 req.ProtoMinor = 1
280 req.Close = reqClose
281
282 res, err := c.Do(req)
283 if err != nil {
284 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
285 }
286 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
287 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
288 reqClose, got, !reqClose)
289 }
290 body, err := io.ReadAll(res.Body)
291 if err != nil {
292 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
293 }
294 return string(body)
295 }
296
297 body1 := fetch(1)
298 body2 := fetch(2)
299
300 got := 1
301 if body1 != body2 {
302 got++
303 }
304 want := 1
305 if reqClose {
306 want = 2
307 }
308 if got != want {
309 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
310 reqClose, got, want, body1, body2)
311 }
312
313 tr.CloseIdleConnections()
314 }
315
316 connSet.check(t)
317 }
318
319
320
321
322 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
323 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
324 }
325 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
326 ts := newClientServerTest(t, mode, hostPortHandler).ts
327
328 c := ts.Client()
329 c.Transport.(*Transport).DisableKeepAlives = true
330
331 res, err := c.Get(ts.URL)
332 if err != nil {
333 t.Fatal(err)
334 }
335 res.Body.Close()
336 if res.Header.Get("X-Saw-Close") != "true" {
337 t.Errorf("handler didn't see Connection: close ")
338 }
339 }
340
341
342
343 func TestTransportRespectRequestWantsClose(t *testing.T) {
344 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
345 }
346 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
347 tests := []struct {
348 disableKeepAlives bool
349 close bool
350 }{
351 {disableKeepAlives: false, close: false},
352 {disableKeepAlives: false, close: true},
353 {disableKeepAlives: true, close: false},
354 {disableKeepAlives: true, close: true},
355 }
356
357 for _, tc := range tests {
358 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
359 func(t *testing.T) {
360 ts := newClientServerTest(t, mode, hostPortHandler).ts
361
362 c := ts.Client()
363 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
364 req, err := NewRequest("GET", ts.URL, nil)
365 if err != nil {
366 t.Fatal(err)
367 }
368 count := 0
369 trace := &httptrace.ClientTrace{
370 WroteHeaderField: func(key string, field []string) {
371 if key != "Connection" {
372 return
373 }
374 if httpguts.HeaderValuesContainsToken(field, "close") {
375 count += 1
376 }
377 },
378 }
379 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
380 req.Close = tc.close
381 res, err := c.Do(req)
382 if err != nil {
383 t.Fatal(err)
384 }
385 defer res.Body.Close()
386 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
387 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
388 }
389 })
390 }
391
392 }
393
394 func TestTransportIdleCacheKeys(t *testing.T) {
395 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
396 }
397 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
398 ts := newClientServerTest(t, mode, hostPortHandler).ts
399 c := ts.Client()
400 tr := c.Transport.(*Transport)
401
402 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
403 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
404 }
405
406 resp, err := c.Get(ts.URL)
407 if err != nil {
408 t.Error(err)
409 }
410 io.ReadAll(resp.Body)
411
412 keys := tr.IdleConnKeysForTesting()
413 if e, g := 1, len(keys); e != g {
414 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
415 }
416
417 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
418 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
419 }
420
421 tr.CloseIdleConnections()
422 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
423 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
424 }
425 }
426
427
428
429 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
430 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
431 const msg = "foobar"
432
433 var addrSeen map[string]int
434 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
435 addrSeen[r.RemoteAddr]++
436 if r.URL.Path == "/chunked/" {
437 w.WriteHeader(200)
438 w.(Flusher).Flush()
439 } else {
440 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
441 w.WriteHeader(200)
442 }
443 w.Write([]byte(msg))
444 })).ts
445
446 for pi, path := range []string{"/content-length/", "/chunked/"} {
447 wantLen := []int{len(msg), -1}[pi]
448 addrSeen = make(map[string]int)
449 for i := 0; i < 3; i++ {
450 res, err := ts.Client().Get(ts.URL + path)
451 if err != nil {
452 t.Errorf("Get %s: %v", path, err)
453 continue
454 }
455
456
457
458
459
460 defer res.Body.Close()
461
462 if res.ContentLength != int64(wantLen) {
463 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
464 }
465 got, err := io.ReadAll(res.Body)
466 if string(got) != msg || err != nil {
467 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
468 }
469 }
470 if len(addrSeen) != 1 {
471 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
472 }
473 }
474 }
475
476 func TestTransportMaxPerHostIdleConns(t *testing.T) {
477 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
478 }
479 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
480 stop := make(chan struct{})
481 defer close(stop)
482
483 resch := make(chan string)
484 gotReq := make(chan bool)
485 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
486 gotReq <- true
487 var msg string
488 select {
489 case <-stop:
490 return
491 case msg = <-resch:
492 }
493 _, err := w.Write([]byte(msg))
494 if err != nil {
495 t.Errorf("Write: %v", err)
496 return
497 }
498 })).ts
499
500 c := ts.Client()
501 tr := c.Transport.(*Transport)
502 maxIdleConnsPerHost := 2
503 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
504
505
506
507 donech := make(chan bool)
508 doReq := func() {
509 defer func() {
510 select {
511 case <-stop:
512 return
513 case donech <- t.Failed():
514 }
515 }()
516 resp, err := c.Get(ts.URL)
517 if err != nil {
518 t.Error(err)
519 return
520 }
521 if _, err := io.ReadAll(resp.Body); err != nil {
522 t.Errorf("ReadAll: %v", err)
523 return
524 }
525 }
526 go doReq()
527 <-gotReq
528 go doReq()
529 <-gotReq
530 go doReq()
531 <-gotReq
532
533 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
534 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
535 }
536
537 resch <- "res1"
538 <-donech
539 keys := tr.IdleConnKeysForTesting()
540 if e, g := 1, len(keys); e != g {
541 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
542 }
543 addr := ts.Listener.Addr().String()
544 cacheKey := "|http|" + addr
545 if keys[0] != cacheKey {
546 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
547 }
548 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
549 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
550 }
551
552 resch <- "res2"
553 <-donech
554 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
555 t.Errorf("after second response, idle conns = %d; want %d", g, w)
556 }
557
558 resch <- "res3"
559 <-donech
560 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
561 t.Errorf("after third response, idle conns = %d; want %d", g, w)
562 }
563 }
564
565 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
566 run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
567 }
568 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
569 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
570 _, err := w.Write([]byte("foo"))
571 if err != nil {
572 t.Fatalf("Write: %v", err)
573 }
574 })).ts
575 c := ts.Client()
576 tr := c.Transport.(*Transport)
577 dialStarted := make(chan struct{})
578 stallDial := make(chan struct{})
579 tr.Dial = func(network, addr string) (net.Conn, error) {
580 dialStarted <- struct{}{}
581 <-stallDial
582 return net.Dial(network, addr)
583 }
584
585 tr.DisableKeepAlives = true
586 tr.MaxConnsPerHost = 1
587
588 preDial := make(chan struct{})
589 reqComplete := make(chan struct{})
590 doReq := func(reqId string) {
591 req, _ := NewRequest("GET", ts.URL, nil)
592 trace := &httptrace.ClientTrace{
593 GetConn: func(hostPort string) {
594 preDial <- struct{}{}
595 },
596 }
597 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
598 resp, err := tr.RoundTrip(req)
599 if err != nil {
600 t.Errorf("unexpected error for request %s: %v", reqId, err)
601 }
602 _, err = io.ReadAll(resp.Body)
603 if err != nil {
604 t.Errorf("unexpected error for request %s: %v", reqId, err)
605 }
606 reqComplete <- struct{}{}
607 }
608
609 go doReq("req1")
610 <-preDial
611 <-dialStarted
612
613
614 go doReq("req2")
615 <-preDial
616 select {
617 case <-dialStarted:
618 t.Error("req2 dial started while req1 dial in progress")
619 return
620 default:
621 }
622
623
624 stallDial <- struct{}{}
625 <-reqComplete
626
627
628 <-dialStarted
629 stallDial <- struct{}{}
630 <-reqComplete
631 }
632
633 func TestTransportMaxConnsPerHost(t *testing.T) {
634 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
635 }
636 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
637 CondSkipHTTP2(t)
638
639 h := HandlerFunc(func(w ResponseWriter, r *Request) {
640 _, err := w.Write([]byte("foo"))
641 if err != nil {
642 t.Fatalf("Write: %v", err)
643 }
644 })
645
646 ts := newClientServerTest(t, mode, h).ts
647 c := ts.Client()
648 tr := c.Transport.(*Transport)
649 tr.MaxConnsPerHost = 1
650
651 mu := sync.Mutex{}
652 var conns []net.Conn
653 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
654 tr.Dial = func(network, addr string) (net.Conn, error) {
655 atomic.AddInt32(&dialCnt, 1)
656 c, err := net.Dial(network, addr)
657 mu.Lock()
658 defer mu.Unlock()
659 conns = append(conns, c)
660 return c, err
661 }
662
663 doReq := func() {
664 trace := &httptrace.ClientTrace{
665 GotConn: func(connInfo httptrace.GotConnInfo) {
666 if !connInfo.Reused {
667 atomic.AddInt32(&gotConnCnt, 1)
668 }
669 },
670 TLSHandshakeStart: func() {
671 atomic.AddInt32(&tlsHandshakeCnt, 1)
672 },
673 }
674 req, _ := NewRequest("GET", ts.URL, nil)
675 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
676
677 resp, err := c.Do(req)
678 if err != nil {
679 t.Fatalf("request failed: %v", err)
680 }
681 defer resp.Body.Close()
682 _, err = io.ReadAll(resp.Body)
683 if err != nil {
684 t.Fatalf("read body failed: %v", err)
685 }
686 }
687
688 wg := sync.WaitGroup{}
689 for i := 0; i < 10; i++ {
690 wg.Add(1)
691 go func() {
692 defer wg.Done()
693 doReq()
694 }()
695 }
696 wg.Wait()
697
698 expected := int32(tr.MaxConnsPerHost)
699 if dialCnt != expected {
700 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
701 }
702 if gotConnCnt != expected {
703 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
704 }
705 if ts.TLS != nil && tlsHandshakeCnt != expected {
706 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
707 }
708
709 if t.Failed() {
710 t.FailNow()
711 }
712
713 mu.Lock()
714 for _, c := range conns {
715 c.Close()
716 }
717 conns = nil
718 mu.Unlock()
719 tr.CloseIdleConnections()
720
721 doReq()
722 expected++
723 if dialCnt != expected {
724 t.Errorf("round 2: too many dials: %d", dialCnt)
725 }
726 if gotConnCnt != expected {
727 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
728 }
729 if ts.TLS != nil && tlsHandshakeCnt != expected {
730 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
731 }
732 }
733
734 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
735 run(t, testTransportMaxConnsPerHostDialCancellation,
736 testNotParallel,
737 []testMode{http1Mode, https1Mode, http2Mode},
738 )
739 }
740
741 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
742 CondSkipHTTP2(t)
743
744 h := HandlerFunc(func(w ResponseWriter, r *Request) {
745 _, err := w.Write([]byte("foo"))
746 if err != nil {
747 t.Fatalf("Write: %v", err)
748 }
749 })
750
751 cst := newClientServerTest(t, mode, h)
752 defer cst.close()
753 ts := cst.ts
754 c := ts.Client()
755 tr := c.Transport.(*Transport)
756 tr.MaxConnsPerHost = 1
757
758
759 ctx, cancel := context.WithCancel(context.Background())
760 defer cancel()
761 SetPendingDialHooks(cancel, nil)
762 defer SetPendingDialHooks(nil, nil)
763
764 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
765 _, err := c.Do(req)
766 if !errors.Is(err, context.Canceled) {
767 t.Errorf("expected error %v, got %v", context.Canceled, err)
768 }
769
770
771 SetPendingDialHooks(nil, nil)
772 req, _ = NewRequest("GET", ts.URL, nil)
773 resp, err := c.Do(req)
774 if err != nil {
775 t.Fatalf("request failed: %v", err)
776 }
777 defer resp.Body.Close()
778 _, err = io.ReadAll(resp.Body)
779 if err != nil {
780 t.Fatalf("read body failed: %v", err)
781 }
782 }
783
784 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
785 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
786 }
787 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
788 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
789 io.WriteString(w, r.RemoteAddr)
790 })).ts
791
792 c := ts.Client()
793 tr := c.Transport.(*Transport)
794
795 doReq := func(name string) {
796
797
798 res, err := c.Post(ts.URL, "", nil)
799 if err != nil {
800 t.Fatalf("%s: %v", name, err)
801 }
802 if res.StatusCode != 200 {
803 t.Fatalf("%s: %v", name, res.Status)
804 }
805 defer res.Body.Close()
806 slurp, err := io.ReadAll(res.Body)
807 if err != nil {
808 t.Fatalf("%s: %v", name, err)
809 }
810 t.Logf("%s: ok (%q)", name, slurp)
811 }
812
813 doReq("first")
814 keys1 := tr.IdleConnKeysForTesting()
815
816 ts.CloseClientConnections()
817
818 var keys2 []string
819 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
820 keys2 = tr.IdleConnKeysForTesting()
821 if len(keys2) != 0 {
822 if d > 0 {
823 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
824 }
825 return false
826 }
827 return true
828 })
829
830 doReq("second")
831 }
832
833
834
835 func TestTransportServerClosingUnexpectedly(t *testing.T) {
836 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
837 }
838 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
839 ts := newClientServerTest(t, mode, hostPortHandler).ts
840 c := ts.Client()
841
842 fetch := func(n, retries int) string {
843 condFatalf := func(format string, arg ...any) {
844 if retries <= 0 {
845 t.Fatalf(format, arg...)
846 }
847 t.Logf("retrying shortly after expected error: "+format, arg...)
848 time.Sleep(time.Second / time.Duration(retries))
849 }
850 for retries >= 0 {
851 retries--
852 res, err := c.Get(ts.URL)
853 if err != nil {
854 condFatalf("error in req #%d, GET: %v", n, err)
855 continue
856 }
857 body, err := io.ReadAll(res.Body)
858 if err != nil {
859 condFatalf("error in req #%d, ReadAll: %v", n, err)
860 continue
861 }
862 res.Body.Close()
863 return string(body)
864 }
865 panic("unreachable")
866 }
867
868 body1 := fetch(1, 0)
869 body2 := fetch(2, 0)
870
871
872
873
874
875
876
877
878 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
879
880 body3 := fetch(3, 5)
881
882 if body1 != body2 {
883 t.Errorf("expected body1 and body2 to be equal")
884 }
885 if body2 == body3 {
886 t.Errorf("expected body2 and body3 to be different")
887 }
888 }
889
890
891
892 func TestStressSurpriseServerCloses(t *testing.T) {
893 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
894 }
895 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
896 if testing.Short() {
897 t.Skip("skipping test in short mode")
898 }
899 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
900 w.Header().Set("Content-Length", "5")
901 w.Header().Set("Content-Type", "text/plain")
902 w.Write([]byte("Hello"))
903 w.(Flusher).Flush()
904 conn, buf, _ := w.(Hijacker).Hijack()
905 buf.Flush()
906 conn.Close()
907 })).ts
908 c := ts.Client()
909
910
911
912
913
914
915
916 const (
917 numClients = 20
918 reqsPerClient = 25
919 )
920 var wg sync.WaitGroup
921 wg.Add(numClients * reqsPerClient)
922 for i := 0; i < numClients; i++ {
923 go func() {
924 for i := 0; i < reqsPerClient; i++ {
925 res, err := c.Get(ts.URL)
926 if err == nil {
927
928
929
930
931
932
933 res.Body.Close()
934 }
935 wg.Done()
936 }
937 }()
938 }
939
940
941 wg.Wait()
942 }
943
944
945
946 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
947 func testTransportHeadResponses(t *testing.T, mode testMode) {
948 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
949 if r.Method != "HEAD" {
950 panic("expected HEAD; got " + r.Method)
951 }
952 w.Header().Set("Content-Length", "123")
953 w.WriteHeader(200)
954 })).ts
955 c := ts.Client()
956
957 for i := 0; i < 2; i++ {
958 res, err := c.Head(ts.URL)
959 if err != nil {
960 t.Errorf("error on loop %d: %v", i, err)
961 continue
962 }
963 if e, g := "123", res.Header.Get("Content-Length"); e != g {
964 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
965 }
966 if e, g := int64(123), res.ContentLength; e != g {
967 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
968 }
969 if all, err := io.ReadAll(res.Body); err != nil {
970 t.Errorf("loop %d: Body ReadAll: %v", i, err)
971 } else if len(all) != 0 {
972 t.Errorf("Bogus body %q", all)
973 }
974 }
975 }
976
977
978
979 func TestTransportHeadChunkedResponse(t *testing.T) {
980 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
981 }
982 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
983 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
984 if r.Method != "HEAD" {
985 panic("expected HEAD; got " + r.Method)
986 }
987 w.Header().Set("Transfer-Encoding", "chunked")
988 w.Header().Set("x-client-ipport", r.RemoteAddr)
989 w.WriteHeader(200)
990 })).ts
991 c := ts.Client()
992
993
994
995 didRead := make(chan bool)
996 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
997 defer SetReadLoopBeforeNextReadHook(nil)
998
999 res1, err := c.Head(ts.URL)
1000 <-didRead
1001
1002 if err != nil {
1003 t.Fatalf("request 1 error: %v", err)
1004 }
1005
1006 res2, err := c.Head(ts.URL)
1007 <-didRead
1008
1009 if err != nil {
1010 t.Fatalf("request 2 error: %v", err)
1011 }
1012 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1013 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1014 }
1015 }
1016
1017 var roundTripTests = []struct {
1018 accept string
1019 expectAccept string
1020 compressed bool
1021 }{
1022
1023 {"", "gzip", false},
1024
1025 {"foo", "foo", false},
1026
1027 {"gzip", "gzip", true},
1028 }
1029
1030
1031 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
1032 func testRoundTripGzip(t *testing.T, mode testMode) {
1033 const responseBody = "test response body"
1034 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1035 accept := req.Header.Get("Accept-Encoding")
1036 if expect := req.FormValue("expect_accept"); accept != expect {
1037 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1038 req.FormValue("testnum"), accept, expect)
1039 }
1040 if accept == "gzip" {
1041 rw.Header().Set("Content-Encoding", "gzip")
1042 gz := gzip.NewWriter(rw)
1043 gz.Write([]byte(responseBody))
1044 gz.Close()
1045 } else {
1046 rw.Header().Set("Content-Encoding", accept)
1047 rw.Write([]byte(responseBody))
1048 }
1049 })).ts
1050 tr := ts.Client().Transport.(*Transport)
1051
1052 for i, test := range roundTripTests {
1053
1054 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1055 if test.accept != "" {
1056 req.Header.Set("Accept-Encoding", test.accept)
1057 }
1058 res, err := tr.RoundTrip(req)
1059 if err != nil {
1060 t.Errorf("%d. RoundTrip: %v", i, err)
1061 continue
1062 }
1063 var body []byte
1064 if test.compressed {
1065 var r *gzip.Reader
1066 r, err = gzip.NewReader(res.Body)
1067 if err != nil {
1068 t.Errorf("%d. gzip NewReader: %v", i, err)
1069 continue
1070 }
1071 body, err = io.ReadAll(r)
1072 res.Body.Close()
1073 } else {
1074 body, err = io.ReadAll(res.Body)
1075 }
1076 if err != nil {
1077 t.Errorf("%d. Error: %q", i, err)
1078 continue
1079 }
1080 if g, e := string(body), responseBody; g != e {
1081 t.Errorf("%d. body = %q; want %q", i, g, e)
1082 }
1083 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1084 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1085 }
1086 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1087 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1088 }
1089 }
1090
1091 }
1092
1093 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1094 func testTransportGzip(t *testing.T, mode testMode) {
1095 if mode == http2Mode {
1096 t.Skip("https://go.dev/issue/56020")
1097 }
1098 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1099 const nRandBytes = 1024 * 1024
1100 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1101 if req.Method == "HEAD" {
1102 if g := req.Header.Get("Accept-Encoding"); g != "" {
1103 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1104 }
1105 return
1106 }
1107 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1108 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1109 }
1110 rw.Header().Set("Content-Encoding", "gzip")
1111
1112 var w io.Writer = rw
1113 var buf bytes.Buffer
1114 if req.FormValue("chunked") == "0" {
1115 w = &buf
1116 defer io.Copy(rw, &buf)
1117 defer func() {
1118 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1119 }()
1120 }
1121 gz := gzip.NewWriter(w)
1122 gz.Write([]byte(testString))
1123 if req.FormValue("body") == "large" {
1124 io.CopyN(gz, rand.Reader, nRandBytes)
1125 }
1126 gz.Close()
1127 })).ts
1128 c := ts.Client()
1129
1130 for _, chunked := range []string{"1", "0"} {
1131
1132 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1133 if err != nil {
1134 t.Fatalf("large get: %v", err)
1135 }
1136 buf := make([]byte, len(testString))
1137 n, err := io.ReadFull(res.Body, buf)
1138 if err != nil {
1139 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1140 }
1141 if e, g := testString, string(buf); e != g {
1142 t.Errorf("partial read got %q, expected %q", g, e)
1143 }
1144 res.Body.Close()
1145
1146 n, err = res.Body.Read(buf)
1147 if n != 0 || err == nil {
1148 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1149 }
1150
1151
1152 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1153 if err != nil {
1154 t.Fatal(err)
1155 }
1156 body, err := io.ReadAll(res.Body)
1157 if err != nil {
1158 t.Fatal(err)
1159 }
1160 if g, e := string(body), testString; g != e {
1161 t.Fatalf("body = %q; want %q", g, e)
1162 }
1163 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1164 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1165 }
1166
1167
1168 n, err = res.Body.Read(buf)
1169 if n != 0 || err == nil {
1170 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1171 }
1172 res.Body.Close()
1173 n, err = res.Body.Read(buf)
1174 if n != 0 || err == nil {
1175 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1176 }
1177 }
1178
1179
1180 res, err := c.Head(ts.URL)
1181 if err != nil {
1182 t.Fatalf("Head: %v", err)
1183 }
1184 if res.StatusCode != 200 {
1185 t.Errorf("Head status=%d; want=200", res.StatusCode)
1186 }
1187 }
1188
1189
1190
1191 type transport100ContinueTest struct {
1192 t *testing.T
1193
1194 reqdone chan struct{}
1195 resp *Response
1196 respErr error
1197
1198 conn net.Conn
1199 reader *bufio.Reader
1200 }
1201
1202 const transport100ContinueTestBody = "request body"
1203
1204
1205
1206 func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
1207 ln := newLocalListener(t)
1208 defer ln.Close()
1209
1210 test := &transport100ContinueTest{
1211 t: t,
1212 reqdone: make(chan struct{}),
1213 }
1214
1215 tr := &Transport{
1216 ExpectContinueTimeout: timeout,
1217 }
1218 go func() {
1219 defer close(test.reqdone)
1220 body := strings.NewReader(transport100ContinueTestBody)
1221 req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
1222 req.Header.Set("Expect", "100-continue")
1223 req.ContentLength = int64(len(transport100ContinueTestBody))
1224 test.resp, test.respErr = tr.RoundTrip(req)
1225 test.resp.Body.Close()
1226 }()
1227
1228 c, err := ln.Accept()
1229 if err != nil {
1230 t.Fatalf("Accept: %v", err)
1231 }
1232 t.Cleanup(func() {
1233 c.Close()
1234 })
1235 br := bufio.NewReader(c)
1236 _, err = ReadRequest(br)
1237 if err != nil {
1238 t.Fatalf("ReadRequest: %v", err)
1239 }
1240 test.conn = c
1241 test.reader = br
1242 t.Cleanup(func() {
1243 <-test.reqdone
1244 tr.CloseIdleConnections()
1245 got, _ := io.ReadAll(test.reader)
1246 if len(got) > 0 {
1247 t.Fatalf("Transport sent unexpected bytes: %q", got)
1248 }
1249 })
1250
1251 return test
1252 }
1253
1254
1255 func (test *transport100ContinueTest) respond(lines ...string) {
1256 for _, line := range lines {
1257 if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
1258 test.t.Fatalf("Write: %v", err)
1259 }
1260 }
1261 if _, err := test.conn.Write([]byte("\r\n")); err != nil {
1262 test.t.Fatalf("Write: %v", err)
1263 }
1264 }
1265
1266
1267 func (test *transport100ContinueTest) wantBodySent() {
1268 got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
1269 if err != nil {
1270 test.t.Fatalf("unexpected error reading body: %v", err)
1271 }
1272 if got, want := string(got), transport100ContinueTestBody; got != want {
1273 test.t.Fatalf("unexpected body: got %q, want %q", got, want)
1274 }
1275 }
1276
1277
1278 func (test *transport100ContinueTest) wantRequestDone(want int) {
1279 <-test.reqdone
1280 if test.respErr != nil {
1281 test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
1282 }
1283 if got := test.resp.StatusCode; got != want {
1284 test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
1285 }
1286 }
1287
1288 func TestTransportExpect100ContinueSent(t *testing.T) {
1289 test := newTransport100ContinueTest(t, 1*time.Hour)
1290
1291 test.respond("HTTP/1.1 100 Continue")
1292 test.wantBodySent()
1293 test.respond("HTTP/1.1 200", "Content-Length: 0")
1294 test.wantRequestDone(200)
1295 }
1296
1297 func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
1298 test := newTransport100ContinueTest(t, 1*time.Hour)
1299
1300 test.respond("HTTP/1.1 200", "Content-Length: 0")
1301 test.wantBodySent()
1302 test.wantRequestDone(200)
1303 }
1304
1305 func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
1306 test := newTransport100ContinueTest(t, 1*time.Hour)
1307
1308 test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
1309 test.wantRequestDone(200)
1310 }
1311
1312 func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
1313 test := newTransport100ContinueTest(t, 1*time.Hour)
1314
1315 test.respond("HTTP/1.1 500", "Content-Length: 0")
1316 test.wantBodySent()
1317 test.wantRequestDone(500)
1318 }
1319
1320 func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
1321 test := newTransport100ContinueTest(t, 5*time.Millisecond)
1322 test.wantBodySent()
1323 test.respond("HTTP/1.1 200", "Content-Length: 0")
1324 test.wantRequestDone(200)
1325 }
1326
1327 func TestSOCKS5Proxy(t *testing.T) {
1328 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1329 }
1330 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1331 ch := make(chan string, 1)
1332 l := newLocalListener(t)
1333 defer l.Close()
1334 defer close(ch)
1335 proxy := func(t *testing.T) {
1336 s, err := l.Accept()
1337 if err != nil {
1338 t.Errorf("socks5 proxy Accept(): %v", err)
1339 return
1340 }
1341 defer s.Close()
1342 var buf [22]byte
1343 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1344 t.Errorf("socks5 proxy initial read: %v", err)
1345 return
1346 }
1347 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1348 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1349 return
1350 }
1351 if _, err := s.Write([]byte{5, 0}); err != nil {
1352 t.Errorf("socks5 proxy initial write: %v", err)
1353 return
1354 }
1355 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1356 t.Errorf("socks5 proxy second read: %v", err)
1357 return
1358 }
1359 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1360 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1361 return
1362 }
1363 var ipLen int
1364 switch buf[3] {
1365 case 1:
1366 ipLen = net.IPv4len
1367 case 4:
1368 ipLen = net.IPv6len
1369 default:
1370 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1371 return
1372 }
1373 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1374 t.Errorf("socks5 proxy address read: %v", err)
1375 return
1376 }
1377 ip := net.IP(buf[4 : ipLen+4])
1378 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1379 copy(buf[:3], []byte{5, 0, 0})
1380 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1381 t.Errorf("socks5 proxy connect write: %v", err)
1382 return
1383 }
1384 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1385
1386
1387 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1388 targetConn, err := net.Dial("tcp", targetHost)
1389 if err != nil {
1390 t.Errorf("net.Dial failed")
1391 return
1392 }
1393 go io.Copy(targetConn, s)
1394 io.Copy(s, targetConn)
1395 targetConn.Close()
1396 }
1397
1398 pu, err := url.Parse("socks5://" + l.Addr().String())
1399 if err != nil {
1400 t.Fatal(err)
1401 }
1402
1403 sentinelHeader := "X-Sentinel"
1404 sentinelValue := "12345"
1405 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1406 w.Header().Set(sentinelHeader, sentinelValue)
1407 })
1408 for _, useTLS := range []bool{false, true} {
1409 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1410 ts := newClientServerTest(t, mode, h).ts
1411 go proxy(t)
1412 c := ts.Client()
1413 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1414 r, err := c.Head(ts.URL)
1415 if err != nil {
1416 t.Fatal(err)
1417 }
1418 if r.Header.Get(sentinelHeader) != sentinelValue {
1419 t.Errorf("Failed to retrieve sentinel value")
1420 }
1421 got := <-ch
1422 ts.Close()
1423 tsu, err := url.Parse(ts.URL)
1424 if err != nil {
1425 t.Fatal(err)
1426 }
1427 want := "proxy for " + tsu.Host
1428 if got != want {
1429 t.Errorf("got %q, want %q", got, want)
1430 }
1431 })
1432 }
1433 }
1434
1435 func TestTransportProxy(t *testing.T) {
1436 defer afterTest(t)
1437 testCases := []struct{ siteMode, proxyMode testMode }{
1438 {http1Mode, http1Mode},
1439 {http1Mode, https1Mode},
1440 {https1Mode, http1Mode},
1441 {https1Mode, https1Mode},
1442 }
1443 for _, testCase := range testCases {
1444 siteMode := testCase.siteMode
1445 proxyMode := testCase.proxyMode
1446 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1447 siteCh := make(chan *Request, 1)
1448 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1449 siteCh <- r
1450 })
1451 proxyCh := make(chan *Request, 1)
1452 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1453 proxyCh <- r
1454
1455 if r.Method == "CONNECT" {
1456 hijacker, ok := w.(Hijacker)
1457 if !ok {
1458 t.Errorf("hijack not allowed")
1459 return
1460 }
1461 clientConn, _, err := hijacker.Hijack()
1462 if err != nil {
1463 t.Errorf("hijacking failed")
1464 return
1465 }
1466 res := &Response{
1467 StatusCode: StatusOK,
1468 Proto: "HTTP/1.1",
1469 ProtoMajor: 1,
1470 ProtoMinor: 1,
1471 Header: make(Header),
1472 }
1473
1474 targetConn, err := net.Dial("tcp", r.URL.Host)
1475 if err != nil {
1476 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1477 return
1478 }
1479
1480 if err := res.Write(clientConn); err != nil {
1481 t.Errorf("Writing 200 OK failed: %v", err)
1482 return
1483 }
1484
1485 go io.Copy(targetConn, clientConn)
1486 go func() {
1487 io.Copy(clientConn, targetConn)
1488 targetConn.Close()
1489 }()
1490 }
1491 })
1492 ts := newClientServerTest(t, siteMode, h1).ts
1493 proxy := newClientServerTest(t, proxyMode, h2).ts
1494
1495 pu, err := url.Parse(proxy.URL)
1496 if err != nil {
1497 t.Fatal(err)
1498 }
1499
1500
1501
1502
1503 c := proxy.Client()
1504 if siteMode == https1Mode {
1505 c = ts.Client()
1506 }
1507
1508 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1509 if _, err := c.Head(ts.URL); err != nil {
1510 t.Error(err)
1511 }
1512 got := <-proxyCh
1513 c.Transport.(*Transport).CloseIdleConnections()
1514 ts.Close()
1515 proxy.Close()
1516 if siteMode == https1Mode {
1517
1518 if got.Method != "CONNECT" {
1519 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1520 }
1521 gotHost := got.URL.Host
1522 pu, err := url.Parse(ts.URL)
1523 if err != nil {
1524 t.Fatal("Invalid site URL")
1525 }
1526 if wantHost := pu.Host; gotHost != wantHost {
1527 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1528 }
1529
1530
1531 next := <-siteCh
1532 if next.Method != "HEAD" {
1533 t.Errorf("Wrong method at destination: %s", next.Method)
1534 }
1535 if nextURL := next.URL.String(); nextURL != "/" {
1536 t.Errorf("Wrong URL at destination: %s", nextURL)
1537 }
1538 } else {
1539 if got.Method != "HEAD" {
1540 t.Errorf("Wrong method for destination: %q", got.Method)
1541 }
1542 gotURL := got.URL.String()
1543 wantURL := ts.URL + "/"
1544 if gotURL != wantURL {
1545 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1546 }
1547 }
1548 })
1549 }
1550 }
1551
1552 func TestOnProxyConnectResponse(t *testing.T) {
1553
1554 var tcases = []struct {
1555 proxyStatusCode int
1556 err error
1557 }{
1558 {
1559 StatusOK,
1560 nil,
1561 },
1562 {
1563 StatusForbidden,
1564 errors.New("403"),
1565 },
1566 }
1567 for _, tcase := range tcases {
1568 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1569
1570 })
1571
1572 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1573
1574 if r.Method == "CONNECT" {
1575 if tcase.proxyStatusCode != StatusOK {
1576 w.WriteHeader(tcase.proxyStatusCode)
1577 return
1578 }
1579 hijacker, ok := w.(Hijacker)
1580 if !ok {
1581 t.Errorf("hijack not allowed")
1582 return
1583 }
1584 clientConn, _, err := hijacker.Hijack()
1585 if err != nil {
1586 t.Errorf("hijacking failed")
1587 return
1588 }
1589 res := &Response{
1590 StatusCode: StatusOK,
1591 Proto: "HTTP/1.1",
1592 ProtoMajor: 1,
1593 ProtoMinor: 1,
1594 Header: make(Header),
1595 }
1596
1597 targetConn, err := net.Dial("tcp", r.URL.Host)
1598 if err != nil {
1599 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1600 return
1601 }
1602
1603 if err := res.Write(clientConn); err != nil {
1604 t.Errorf("Writing 200 OK failed: %v", err)
1605 return
1606 }
1607
1608 go io.Copy(targetConn, clientConn)
1609 go func() {
1610 io.Copy(clientConn, targetConn)
1611 targetConn.Close()
1612 }()
1613 }
1614 })
1615 ts := newClientServerTest(t, https1Mode, h1).ts
1616 proxy := newClientServerTest(t, https1Mode, h2).ts
1617
1618 pu, err := url.Parse(proxy.URL)
1619 if err != nil {
1620 t.Fatal(err)
1621 }
1622
1623 c := proxy.Client()
1624
1625 var (
1626 dials atomic.Int32
1627 closes atomic.Int32
1628 )
1629 c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1630 conn, err := net.Dial(network, addr)
1631 if err != nil {
1632 return nil, err
1633 }
1634 dials.Add(1)
1635 return noteCloseConn{
1636 Conn: conn,
1637 closeFunc: func() {
1638 closes.Add(1)
1639 },
1640 }, nil
1641 }
1642
1643 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1644 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1645 if proxyURL.String() != pu.String() {
1646 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1647 }
1648
1649 if "https://"+connectReq.URL.String() != ts.URL {
1650 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1651 }
1652 return tcase.err
1653 }
1654 wantCloses := int32(0)
1655 if _, err := c.Head(ts.URL); err != nil {
1656 wantCloses = 1
1657 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1658 t.Errorf("got %v, want %v", err, tcase.err)
1659 }
1660 } else {
1661 if tcase.err != nil {
1662 t.Errorf("got %v, want nil", err)
1663 }
1664 }
1665 if got, want := dials.Load(), int32(1); got != want {
1666 t.Errorf("got %v dials, want %v", got, want)
1667 }
1668
1669 if got, want := closes.Load(), wantCloses; got != want {
1670 t.Errorf("got %v closes, want %v", got, want)
1671 }
1672 }
1673 }
1674
1675
1676
1677 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1678 cancelc := make(chan struct{})
1679 SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
1680 ctx, cancel := context.WithCancel(ctx)
1681 go func() {
1682 select {
1683 case <-cancelc:
1684 case <-ctx.Done():
1685 }
1686 cancel()
1687 }()
1688 return ctx, cancel
1689 })
1690
1691 defer afterTest(t)
1692
1693 ln := newLocalListener(t)
1694 defer ln.Close()
1695 listenerDone := make(chan struct{})
1696 go func() {
1697 defer close(listenerDone)
1698 c, err := ln.Accept()
1699 if err != nil {
1700 t.Errorf("Accept: %v", err)
1701 return
1702 }
1703 defer c.Close()
1704
1705 br := bufio.NewReader(c)
1706 cr, err := ReadRequest(br)
1707 if err != nil {
1708 t.Errorf("proxy server failed to read CONNECT request")
1709 return
1710 }
1711 if cr.Method != "CONNECT" {
1712 t.Errorf("unexpected method %q", cr.Method)
1713 return
1714 }
1715
1716
1717
1718
1719 close(cancelc)
1720 var buf [1]byte
1721 _, err = br.Read(buf[:])
1722 if err != io.EOF {
1723 t.Errorf("proxy server Read err = %v; want EOF", err)
1724 }
1725 return
1726 }()
1727
1728 c := &Client{
1729 Transport: &Transport{
1730 Proxy: func(*Request) (*url.URL, error) {
1731 return url.Parse("http://" + ln.Addr().String())
1732 },
1733 },
1734 }
1735 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1736 if err != nil {
1737 t.Fatal(err)
1738 }
1739 _, err = c.Do(req)
1740 if err == nil {
1741 t.Errorf("unexpected Get success")
1742 }
1743
1744
1745
1746
1747 <-listenerDone
1748 }
1749
1750
1751 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1752 defer afterTest(t)
1753
1754 var errDial = errors.New("some dial error")
1755
1756 tr := &Transport{
1757 Proxy: func(*Request) (*url.URL, error) {
1758 return url.Parse("http://proxy.fake.tld/")
1759 },
1760 Dial: func(string, string) (net.Conn, error) {
1761 return nil, errDial
1762 },
1763 }
1764 defer tr.CloseIdleConnections()
1765
1766 c := &Client{Transport: tr}
1767 req, _ := NewRequest("GET", "http://fake.tld", nil)
1768 res, err := c.Do(req)
1769 if err == nil {
1770 res.Body.Close()
1771 t.Fatal("wanted a non-nil error")
1772 }
1773
1774 uerr, ok := err.(*url.Error)
1775 if !ok {
1776 t.Fatalf("got %T, want *url.Error", err)
1777 }
1778 oe, ok := uerr.Err.(*net.OpError)
1779 if !ok {
1780 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1781 }
1782 want := &net.OpError{
1783 Op: "proxyconnect",
1784 Net: "tcp",
1785 Err: errDial,
1786 }
1787 if !reflect.DeepEqual(oe, want) {
1788 t.Errorf("Got error %#v; want %#v", oe, want)
1789 }
1790 }
1791
1792
1793
1794
1795
1796 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1797 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1798 }
1799 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1800 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1801 defer proxy.Close()
1802 c := proxy.Client()
1803
1804 tr := c.Transport.(*Transport)
1805 tr.Proxy = func(*Request) (*url.URL, error) {
1806 u, _ := url.Parse(proxy.URL)
1807 u.User = url.UserPassword("aladdin", "opensesame")
1808 return u, nil
1809 }
1810 h := tr.ProxyConnectHeader
1811 if h == nil {
1812 h = make(Header)
1813 }
1814 tr.ProxyConnectHeader = h.Clone()
1815
1816 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1817 if err != nil {
1818 t.Fatal(err)
1819 }
1820 _, err = c.Do(req)
1821 if err == nil {
1822 t.Errorf("unexpected Get success")
1823 }
1824
1825 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1826 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1827 }
1828 }
1829
1830
1831
1832
1833
1834 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
1835 func testTransportGzipRecursive(t *testing.T, mode testMode) {
1836 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1837 w.Header().Set("Content-Encoding", "gzip")
1838 w.Write(rgz)
1839 })).ts
1840
1841 c := ts.Client()
1842 res, err := c.Get(ts.URL)
1843 if err != nil {
1844 t.Fatal(err)
1845 }
1846 body, err := io.ReadAll(res.Body)
1847 if err != nil {
1848 t.Fatal(err)
1849 }
1850 if !bytes.Equal(body, rgz) {
1851 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
1852 body, rgz)
1853 }
1854 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1855 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1856 }
1857 }
1858
1859
1860
1861 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
1862 func testTransportGzipShort(t *testing.T, mode testMode) {
1863 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1864 w.Header().Set("Content-Encoding", "gzip")
1865 w.Write([]byte{0x1f, 0x8b})
1866 })).ts
1867
1868 c := ts.Client()
1869 res, err := c.Get(ts.URL)
1870 if err != nil {
1871 t.Fatal(err)
1872 }
1873 defer res.Body.Close()
1874 _, err = io.ReadAll(res.Body)
1875 if err == nil {
1876 t.Fatal("Expect an error from reading a body.")
1877 }
1878 if err != io.ErrUnexpectedEOF {
1879 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
1880 }
1881 }
1882
1883
1884 func waitNumGoroutine(nmax int) int {
1885 nfinal := runtime.NumGoroutine()
1886 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
1887 time.Sleep(50 * time.Millisecond)
1888 runtime.GC()
1889 nfinal = runtime.NumGoroutine()
1890 }
1891 return nfinal
1892 }
1893
1894
1895 func TestTransportPersistConnLeak(t *testing.T) {
1896 run(t, testTransportPersistConnLeak, testNotParallel)
1897 }
1898 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
1899 if mode == http2Mode {
1900 t.Skip("flaky in HTTP/2")
1901 }
1902
1903
1904 const numReq = 25
1905 gotReqCh := make(chan bool, numReq)
1906 unblockCh := make(chan bool, numReq)
1907 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1908 gotReqCh <- true
1909 <-unblockCh
1910 w.Header().Set("Content-Length", "0")
1911 w.WriteHeader(204)
1912 })).ts
1913 c := ts.Client()
1914 tr := c.Transport.(*Transport)
1915
1916 n0 := runtime.NumGoroutine()
1917
1918 didReqCh := make(chan bool, numReq)
1919 failed := make(chan bool, numReq)
1920 for i := 0; i < numReq; i++ {
1921 go func() {
1922 res, err := c.Get(ts.URL)
1923 didReqCh <- true
1924 if err != nil {
1925 t.Logf("client fetch error: %v", err)
1926 failed <- true
1927 return
1928 }
1929 res.Body.Close()
1930 }()
1931 }
1932
1933
1934 for i := 0; i < numReq; i++ {
1935 select {
1936 case <-gotReqCh:
1937
1938 case <-failed:
1939
1940
1941 }
1942 }
1943
1944 nhigh := runtime.NumGoroutine()
1945
1946
1947 close(unblockCh)
1948
1949
1950 for i := 0; i < numReq; i++ {
1951 <-didReqCh
1952 }
1953
1954 tr.CloseIdleConnections()
1955 nfinal := waitNumGoroutine(n0 + 5)
1956
1957 growth := nfinal - n0
1958
1959
1960
1961 if int(growth) > 5 {
1962 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1963 t.Error("too many new goroutines")
1964 }
1965 }
1966
1967
1968
1969 func TestTransportPersistConnLeakShortBody(t *testing.T) {
1970 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
1971 }
1972 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
1973 if mode == http2Mode {
1974 t.Skip("flaky in HTTP/2")
1975 }
1976
1977
1978 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1979 })).ts
1980 c := ts.Client()
1981 tr := c.Transport.(*Transport)
1982
1983 n0 := runtime.NumGoroutine()
1984 body := []byte("Hello")
1985 for i := 0; i < 20; i++ {
1986 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1987 if err != nil {
1988 t.Fatal(err)
1989 }
1990 req.ContentLength = int64(len(body) - 2)
1991 _, err = c.Do(req)
1992 if err == nil {
1993 t.Fatal("Expect an error from writing too long of a body.")
1994 }
1995 }
1996 nhigh := runtime.NumGoroutine()
1997 tr.CloseIdleConnections()
1998 nfinal := waitNumGoroutine(n0 + 5)
1999
2000 growth := nfinal - n0
2001
2002
2003
2004 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2005 if int(growth) > 5 {
2006 t.Error("too many new goroutines")
2007 }
2008 }
2009
2010
2011 type countedConn struct {
2012 net.Conn
2013 }
2014
2015
2016 type countingDialer struct {
2017 dialer net.Dialer
2018 mu sync.Mutex
2019 total, live int64
2020 }
2021
2022 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
2023 conn, err := d.dialer.DialContext(ctx, network, address)
2024 if err != nil {
2025 return nil, err
2026 }
2027
2028 counted := new(countedConn)
2029 counted.Conn = conn
2030
2031 d.mu.Lock()
2032 defer d.mu.Unlock()
2033 d.total++
2034 d.live++
2035
2036 runtime.SetFinalizer(counted, d.decrement)
2037 return counted, nil
2038 }
2039
2040 func (d *countingDialer) decrement(*countedConn) {
2041 d.mu.Lock()
2042 defer d.mu.Unlock()
2043 d.live--
2044 }
2045
2046 func (d *countingDialer) Read() (total, live int64) {
2047 d.mu.Lock()
2048 defer d.mu.Unlock()
2049 return d.total, d.live
2050 }
2051
2052 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
2053 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
2054 }
2055 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
2056 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2057
2058 conn, _, err := w.(Hijacker).Hijack()
2059 if err != nil {
2060 t.Errorf("Hijack failed unexpectedly: %v", err)
2061 return
2062 }
2063 conn.Close()
2064 })).ts
2065
2066 var d countingDialer
2067 c := ts.Client()
2068 c.Transport.(*Transport).DialContext = d.DialContext
2069
2070 body := []byte("Hello")
2071 for i := 0; ; i++ {
2072 total, live := d.Read()
2073 if live < total {
2074 break
2075 }
2076 if i >= 1<<12 {
2077 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
2078 }
2079
2080 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2081 if err != nil {
2082 t.Fatal(err)
2083 }
2084 _, err = c.Do(req)
2085 if err == nil {
2086 t.Fatal("expected broken connection")
2087 }
2088
2089 runtime.GC()
2090 }
2091 }
2092
2093 type countedContext struct {
2094 context.Context
2095 }
2096
2097 type contextCounter struct {
2098 mu sync.Mutex
2099 live int64
2100 }
2101
2102 func (cc *contextCounter) Track(ctx context.Context) context.Context {
2103 counted := new(countedContext)
2104 counted.Context = ctx
2105 cc.mu.Lock()
2106 defer cc.mu.Unlock()
2107 cc.live++
2108 runtime.SetFinalizer(counted, cc.decrement)
2109 return counted
2110 }
2111
2112 func (cc *contextCounter) decrement(*countedContext) {
2113 cc.mu.Lock()
2114 defer cc.mu.Unlock()
2115 cc.live--
2116 }
2117
2118 func (cc *contextCounter) Read() (live int64) {
2119 cc.mu.Lock()
2120 defer cc.mu.Unlock()
2121 return cc.live
2122 }
2123
2124 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2125 run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
2126 }
2127 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2128 if mode == http2Mode {
2129 t.Skip("https://go.dev/issue/56021")
2130 }
2131
2132 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2133 runtime.Gosched()
2134 w.WriteHeader(StatusOK)
2135 })).ts
2136
2137 c := ts.Client()
2138 c.Transport.(*Transport).MaxConnsPerHost = 1
2139
2140 ctx := context.Background()
2141 body := []byte("Hello")
2142 doPosts := func(cc *contextCounter) {
2143 var wg sync.WaitGroup
2144 for n := 64; n > 0; n-- {
2145 wg.Add(1)
2146 go func() {
2147 defer wg.Done()
2148
2149 ctx := cc.Track(ctx)
2150 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2151 if err != nil {
2152 t.Error(err)
2153 }
2154
2155 _, err = c.Do(req.WithContext(ctx))
2156 if err != nil {
2157 t.Errorf("Do failed with error: %v", err)
2158 }
2159 }()
2160 }
2161 wg.Wait()
2162 }
2163
2164 var initialCC contextCounter
2165 doPosts(&initialCC)
2166
2167
2168
2169
2170 var flushCC contextCounter
2171 for i := 0; ; i++ {
2172 live := initialCC.Read()
2173 if live == 0 {
2174 break
2175 }
2176 if i >= 100 {
2177 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2178 }
2179 doPosts(&flushCC)
2180 runtime.GC()
2181 }
2182 }
2183
2184
2185 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2186 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2187 var tr *Transport
2188
2189 unblockCh := make(chan bool, 1)
2190 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2191 <-unblockCh
2192 tr.CloseIdleConnections()
2193 })).ts
2194 c := ts.Client()
2195 tr = c.Transport.(*Transport)
2196
2197 didreq := make(chan bool)
2198 go func() {
2199 res, err := c.Get(ts.URL)
2200 if err != nil {
2201 t.Error(err)
2202 } else {
2203 res.Body.Close()
2204 }
2205 didreq <- true
2206 }()
2207 unblockCh <- true
2208 <-didreq
2209 }
2210
2211
2212
2213
2214
2215 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2216 func testIssue3644(t *testing.T, mode testMode) {
2217 const numFoos = 5000
2218 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2219 w.Header().Set("Connection", "close")
2220 for i := 0; i < numFoos; i++ {
2221 w.Write([]byte("foo "))
2222 }
2223 })).ts
2224 c := ts.Client()
2225 res, err := c.Get(ts.URL)
2226 if err != nil {
2227 t.Fatal(err)
2228 }
2229 defer res.Body.Close()
2230 bs, err := io.ReadAll(res.Body)
2231 if err != nil {
2232 t.Fatal(err)
2233 }
2234 if len(bs) != numFoos*len("foo ") {
2235 t.Errorf("unexpected response length")
2236 }
2237 }
2238
2239
2240
2241 func TestIssue3595(t *testing.T) {
2242
2243 run(t, testIssue3595, testNotParallel)
2244 }
2245 func testIssue3595(t *testing.T, mode testMode) {
2246 runTimeSensitiveTest(t, []time.Duration{
2247 1 * time.Millisecond,
2248 5 * time.Millisecond,
2249 10 * time.Millisecond,
2250 50 * time.Millisecond,
2251 100 * time.Millisecond,
2252 500 * time.Millisecond,
2253 time.Second,
2254 5 * time.Second,
2255 }, func(t *testing.T, timeout time.Duration) error {
2256 SetRSTAvoidanceDelay(t, timeout)
2257 t.Logf("set RST avoidance delay to %v", timeout)
2258
2259 const deniedMsg = "sorry, denied."
2260 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2261 Error(w, deniedMsg, StatusUnauthorized)
2262 }))
2263
2264
2265 defer cst.close()
2266 ts := cst.ts
2267 c := ts.Client()
2268
2269 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2270 if err != nil {
2271 return fmt.Errorf("Post: %v", err)
2272 }
2273 got, err := io.ReadAll(res.Body)
2274 if err != nil {
2275 return fmt.Errorf("Body ReadAll: %v", err)
2276 }
2277 t.Logf("server response:\n%s", got)
2278 if !strings.Contains(string(got), deniedMsg) {
2279
2280
2281 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2282 }
2283 return nil
2284 })
2285 }
2286
2287
2288
2289 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2290 func testChunkedNoContent(t *testing.T, mode testMode) {
2291 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2292 w.WriteHeader(StatusNoContent)
2293 })).ts
2294
2295 c := ts.Client()
2296 for _, closeBody := range []bool{true, false} {
2297 const n = 4
2298 for i := 1; i <= n; i++ {
2299 res, err := c.Get(ts.URL)
2300 if err != nil {
2301 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2302 } else {
2303 if closeBody {
2304 res.Body.Close()
2305 }
2306 }
2307 }
2308 }
2309 }
2310
2311 func TestTransportConcurrency(t *testing.T) {
2312 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2313 }
2314 func testTransportConcurrency(t *testing.T, mode testMode) {
2315
2316 maxProcs, numReqs := 16, 500
2317 if testing.Short() {
2318 maxProcs, numReqs = 4, 50
2319 }
2320 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2321 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2322 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2323 })).ts
2324
2325 var wg sync.WaitGroup
2326 wg.Add(numReqs)
2327
2328
2329
2330
2331
2332
2333
2334 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2335 defer SetPendingDialHooks(nil, nil)
2336
2337 c := ts.Client()
2338 reqs := make(chan string)
2339 defer close(reqs)
2340
2341 for i := 0; i < maxProcs*2; i++ {
2342 go func() {
2343 for req := range reqs {
2344 res, err := c.Get(ts.URL + "/?echo=" + req)
2345 if err != nil {
2346 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2347
2348
2349 t.Logf("error on req %s: %v", req, err)
2350 t.Logf("(see https://go.dev/issue/52168)")
2351 } else {
2352 t.Errorf("error on req %s: %v", req, err)
2353 }
2354 wg.Done()
2355 continue
2356 }
2357 all, err := io.ReadAll(res.Body)
2358 if err != nil {
2359 t.Errorf("read error on req %s: %v", req, err)
2360 } else if string(all) != req {
2361 t.Errorf("body of req %s = %q; want %q", req, all, req)
2362 }
2363 res.Body.Close()
2364 wg.Done()
2365 }
2366 }()
2367 }
2368 for i := 0; i < numReqs; i++ {
2369 reqs <- fmt.Sprintf("request-%d", i)
2370 }
2371 wg.Wait()
2372 }
2373
2374 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2375 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2376 mux := NewServeMux()
2377 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2378 io.Copy(w, neverEnding('a'))
2379 })
2380 ts := newClientServerTest(t, mode, mux).ts
2381
2382 connc := make(chan net.Conn, 1)
2383 c := ts.Client()
2384 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2385 conn, err := net.Dial(n, addr)
2386 if err != nil {
2387 return nil, err
2388 }
2389 select {
2390 case connc <- conn:
2391 default:
2392 }
2393 return conn, nil
2394 }
2395
2396 res, err := c.Get(ts.URL + "/get")
2397 if err != nil {
2398 t.Fatalf("Error issuing GET: %v", err)
2399 }
2400 defer res.Body.Close()
2401
2402 conn := <-connc
2403 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2404 _, err = io.Copy(io.Discard, res.Body)
2405 if err == nil {
2406 t.Errorf("Unexpected successful copy")
2407 }
2408 }
2409
2410 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2411 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2412 }
2413 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2414 const debug = false
2415 mux := NewServeMux()
2416 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2417 io.Copy(w, neverEnding('a'))
2418 })
2419 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2420 defer r.Body.Close()
2421 io.Copy(io.Discard, r.Body)
2422 })
2423 ts := newClientServerTest(t, mode, mux).ts
2424 timeout := 100 * time.Millisecond
2425
2426 c := ts.Client()
2427 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2428 conn, err := net.Dial(n, addr)
2429 if err != nil {
2430 return nil, err
2431 }
2432 conn.SetDeadline(time.Now().Add(timeout))
2433 if debug {
2434 conn = NewLoggingConn("client", conn)
2435 }
2436 return conn, nil
2437 }
2438
2439 getFailed := false
2440 nRuns := 5
2441 if testing.Short() {
2442 nRuns = 1
2443 }
2444 for i := 0; i < nRuns; i++ {
2445 if debug {
2446 println("run", i+1, "of", nRuns)
2447 }
2448 sres, err := c.Get(ts.URL + "/get")
2449 if err != nil {
2450 if !getFailed {
2451
2452 getFailed = true
2453 t.Logf("increasing timeout")
2454 i--
2455 timeout *= 10
2456 continue
2457 }
2458 t.Errorf("Error issuing GET: %v", err)
2459 break
2460 }
2461 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2462 _, err = c.Do(req)
2463 if err == nil {
2464 sres.Body.Close()
2465 t.Errorf("Unexpected successful PUT")
2466 break
2467 }
2468 sres.Body.Close()
2469 }
2470 if debug {
2471 println("tests complete; waiting for handlers to finish")
2472 }
2473 ts.Close()
2474 }
2475
2476 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2477 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2478 if testing.Short() {
2479 t.Skip("skipping timeout test in -short mode")
2480 }
2481
2482 timeout := 2 * time.Millisecond
2483 retry := true
2484 for retry && !t.Failed() {
2485 var srvWG sync.WaitGroup
2486 inHandler := make(chan bool, 1)
2487 mux := NewServeMux()
2488 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2489 inHandler <- true
2490 srvWG.Done()
2491 })
2492 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2493 inHandler <- true
2494 <-r.Context().Done()
2495 srvWG.Done()
2496 })
2497 ts := newClientServerTest(t, mode, mux).ts
2498
2499 c := ts.Client()
2500 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2501
2502 retry = false
2503 srvWG.Add(3)
2504 tests := []struct {
2505 path string
2506 wantTimeout bool
2507 }{
2508 {path: "/fast"},
2509 {path: "/slow", wantTimeout: true},
2510 {path: "/fast"},
2511 }
2512 for i, tt := range tests {
2513 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2514 req = req.WithT(t)
2515 res, err := c.Do(req)
2516 <-inHandler
2517 if err != nil {
2518 uerr, ok := err.(*url.Error)
2519 if !ok {
2520 t.Errorf("error is not a url.Error; got: %#v", err)
2521 continue
2522 }
2523 nerr, ok := uerr.Err.(net.Error)
2524 if !ok {
2525 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2526 continue
2527 }
2528 if !nerr.Timeout() {
2529 t.Errorf("want timeout error; got: %q", nerr)
2530 continue
2531 }
2532 if !tt.wantTimeout {
2533 if !retry {
2534
2535 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2536 timeout *= 2
2537 retry = true
2538 }
2539 }
2540 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2541 t.Errorf("%d. unexpected error: %v", i, err)
2542 }
2543 continue
2544 }
2545 if tt.wantTimeout {
2546 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2547 continue
2548 }
2549 if res.StatusCode != 200 {
2550 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2551 }
2552 }
2553
2554 srvWG.Wait()
2555 ts.Close()
2556 }
2557 }
2558
2559
2560 type cancelTest struct {
2561 mode testMode
2562 newReq func(req *Request) *Request
2563 cancel func(tr *Transport, req *Request)
2564 checkErr func(when string, err error)
2565 }
2566
2567
2568 func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2569 t.Run("TransportCancel", func(t *testing.T) {
2570 f(t, cancelTest{
2571 mode: mode,
2572 newReq: func(req *Request) *Request {
2573 return req
2574 },
2575 cancel: func(tr *Transport, req *Request) {
2576 tr.CancelRequest(req)
2577 },
2578 checkErr: func(when string, err error) {
2579 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2580 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2581 }
2582 },
2583 })
2584 })
2585 }
2586
2587
2588 func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2589 cancelc := make(chan struct{})
2590 cancelOnce := sync.OnceFunc(func() { close(cancelc) })
2591 f(t, cancelTest{
2592 mode: mode,
2593 newReq: func(req *Request) *Request {
2594 req.Cancel = cancelc
2595 return req
2596 },
2597 cancel: func(tr *Transport, req *Request) {
2598 cancelOnce()
2599 },
2600 checkErr: func(when string, err error) {
2601 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2602 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2603 }
2604 },
2605 })
2606 }
2607
2608
2609 func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2610 ctx, cancel := context.WithCancel(context.Background())
2611 f(t, cancelTest{
2612 mode: mode,
2613 newReq: func(req *Request) *Request {
2614 return req.WithContext(ctx)
2615 },
2616 cancel: func(tr *Transport, req *Request) {
2617 cancel()
2618 },
2619 checkErr: func(when string, err error) {
2620 if !errors.Is(err, context.Canceled) {
2621 t.Errorf("%v error = %v, want context.Canceled", when, err)
2622 }
2623 },
2624 })
2625 }
2626
2627 func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
2628 run(t, func(t *testing.T, mode testMode) {
2629 if mode == http1Mode {
2630 t.Run("TransportCancel", func(t *testing.T) {
2631 runCancelTestTransport(t, mode, f)
2632 })
2633 }
2634 t.Run("RequestCancel", func(t *testing.T) {
2635 runCancelTestChannel(t, mode, f)
2636 })
2637 t.Run("ContextCancel", func(t *testing.T) {
2638 runCancelTestContext(t, mode, f)
2639 })
2640 }, opts...)
2641 }
2642
2643 func TestTransportCancelRequest(t *testing.T) {
2644 runCancelTest(t, testTransportCancelRequest)
2645 }
2646 func testTransportCancelRequest(t *testing.T, test cancelTest) {
2647 if testing.Short() {
2648 t.Skip("skipping test in -short mode")
2649 }
2650
2651 const msg = "Hello"
2652 unblockc := make(chan bool)
2653 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2654 io.WriteString(w, msg)
2655 w.(Flusher).Flush()
2656 <-unblockc
2657 })).ts
2658 defer close(unblockc)
2659
2660 c := ts.Client()
2661 tr := c.Transport.(*Transport)
2662
2663 req, _ := NewRequest("GET", ts.URL, nil)
2664 req = test.newReq(req)
2665 res, err := c.Do(req)
2666 if err != nil {
2667 t.Fatal(err)
2668 }
2669 body := make([]byte, len(msg))
2670 n, _ := io.ReadFull(res.Body, body)
2671 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2672 t.Errorf("Body = %q; want %q", body[:n], msg)
2673 }
2674 test.cancel(tr, req)
2675
2676 tail, err := io.ReadAll(res.Body)
2677 res.Body.Close()
2678 test.checkErr("Body.Read", err)
2679 if len(tail) > 0 {
2680 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2681 }
2682
2683
2684
2685 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2686 n := tr.NumPendingRequestsForTesting()
2687 if n > 0 {
2688 if d > 0 {
2689 t.Logf("pending requests = %d after %v (want 0)", n, d)
2690 }
2691 return false
2692 }
2693 return true
2694 })
2695 }
2696
2697 func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
2698 if testing.Short() {
2699 t.Skip("skipping test in -short mode")
2700 }
2701 unblockc := make(chan bool)
2702 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2703 <-unblockc
2704 })).ts
2705 defer close(unblockc)
2706
2707 c := ts.Client()
2708 tr := c.Transport.(*Transport)
2709
2710 donec := make(chan bool)
2711 req, _ := NewRequest("GET", ts.URL, body)
2712 req = test.newReq(req)
2713 go func() {
2714 defer close(donec)
2715 c.Do(req)
2716 }()
2717
2718 unblockc <- true
2719 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2720 test.cancel(tr, req)
2721 select {
2722 case <-donec:
2723 return true
2724 default:
2725 if d > 0 {
2726 t.Logf("Do of canceled request has not returned after %v", d)
2727 }
2728 return false
2729 }
2730 })
2731 }
2732
2733 func TestTransportCancelRequestInDo(t *testing.T) {
2734 runCancelTest(t, func(t *testing.T, test cancelTest) {
2735 testTransportCancelRequestInDo(t, test, nil)
2736 })
2737 }
2738
2739 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2740 runCancelTest(t, func(t *testing.T, test cancelTest) {
2741 testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
2742 })
2743 }
2744
2745 func TestTransportCancelRequestInDial(t *testing.T) {
2746 runCancelTest(t, testTransportCancelRequestInDial)
2747 }
2748 func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
2749 defer afterTest(t)
2750 if testing.Short() {
2751 t.Skip("skipping test in -short mode")
2752 }
2753 var logbuf strings.Builder
2754 eventLog := log.New(&logbuf, "", 0)
2755
2756 unblockDial := make(chan bool)
2757 defer close(unblockDial)
2758
2759 inDial := make(chan bool)
2760 tr := &Transport{
2761 Dial: func(network, addr string) (net.Conn, error) {
2762 eventLog.Println("dial: blocking")
2763 if !<-inDial {
2764 return nil, errors.New("main Test goroutine exited")
2765 }
2766 <-unblockDial
2767 return nil, errors.New("nope")
2768 },
2769 }
2770 cl := &Client{Transport: tr}
2771 gotres := make(chan bool)
2772 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2773 req = test.newReq(req)
2774 go func() {
2775 _, err := cl.Do(req)
2776 eventLog.Printf("Get error = %v", err != nil)
2777 test.checkErr("Get", err)
2778 gotres <- true
2779 }()
2780
2781 inDial <- true
2782
2783 eventLog.Printf("canceling")
2784 test.cancel(tr, req)
2785 test.cancel(tr, req)
2786
2787 if d, ok := t.Deadline(); ok {
2788
2789
2790 timeout := time.Until(d) * 19 / 20
2791 timer := time.AfterFunc(timeout, func() {
2792 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2793 })
2794 defer timer.Stop()
2795 }
2796 <-gotres
2797
2798 got := logbuf.String()
2799 want := `dial: blocking
2800 canceling
2801 Get error = true
2802 `
2803 if got != want {
2804 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2805 }
2806 }
2807
2808
2809 func TestTransportCancelRequestWithBody(t *testing.T) {
2810 runCancelTest(t, testTransportCancelRequestWithBody)
2811 }
2812 func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
2813 if testing.Short() {
2814 t.Skip("skipping test in -short mode")
2815 }
2816
2817 const msg = "Hello"
2818 unblockc := make(chan struct{})
2819 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2820 io.WriteString(w, msg)
2821 w.(Flusher).Flush()
2822 <-unblockc
2823 })).ts
2824 defer close(unblockc)
2825
2826 c := ts.Client()
2827 tr := c.Transport.(*Transport)
2828
2829 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
2830 req = test.newReq(req)
2831
2832 res, err := c.Do(req)
2833 if err != nil {
2834 t.Fatal(err)
2835 }
2836 body := make([]byte, len(msg))
2837 n, _ := io.ReadFull(res.Body, body)
2838 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2839 t.Errorf("Body = %q; want %q", body[:n], msg)
2840 }
2841 test.cancel(tr, req)
2842
2843 tail, err := io.ReadAll(res.Body)
2844 res.Body.Close()
2845 test.checkErr("Body.Read", err)
2846 if len(tail) > 0 {
2847 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2848 }
2849
2850
2851
2852 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2853 n := tr.NumPendingRequestsForTesting()
2854 if n > 0 {
2855 if d > 0 {
2856 t.Logf("pending requests = %d after %v (want 0)", n, d)
2857 }
2858 return false
2859 }
2860 return true
2861 })
2862 }
2863
2864 func TestTransportCancelRequestBeforeDo(t *testing.T) {
2865
2866 run(t, func(t *testing.T, mode testMode) {
2867 t.Run("RequestCancel", func(t *testing.T) {
2868 runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
2869 })
2870 t.Run("ContextCancel", func(t *testing.T) {
2871 runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
2872 })
2873 })
2874 }
2875 func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
2876 unblockc := make(chan bool)
2877 cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2878 <-unblockc
2879 }))
2880 defer close(unblockc)
2881
2882 c := cst.ts.Client()
2883
2884 req, _ := NewRequest("GET", cst.ts.URL, nil)
2885 req = test.newReq(req)
2886 test.cancel(cst.tr, req)
2887
2888 _, err := c.Do(req)
2889 test.checkErr("Do", err)
2890 }
2891
2892
2893 func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
2894 runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
2895 }
2896 func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
2897 defer afterTest(t)
2898
2899 serverConnCh := make(chan net.Conn, 1)
2900 tr := &Transport{
2901 Dial: func(network, addr string) (net.Conn, error) {
2902 cc, sc := net.Pipe()
2903 serverConnCh <- sc
2904 return cc, nil
2905 },
2906 }
2907 defer tr.CloseIdleConnections()
2908 errc := make(chan error, 1)
2909 req, _ := NewRequest("GET", "http://example.com/", nil)
2910 req = test.newReq(req)
2911 go func() {
2912 _, err := tr.RoundTrip(req)
2913 errc <- err
2914 }()
2915
2916 sc := <-serverConnCh
2917 verb := make([]byte, 3)
2918 if _, err := io.ReadFull(sc, verb); err != nil {
2919 t.Errorf("Error reading HTTP verb from server: %v", err)
2920 }
2921 if string(verb) != "GET" {
2922 t.Errorf("server received %q; want GET", verb)
2923 }
2924 defer sc.Close()
2925
2926 test.cancel(tr, req)
2927
2928 err := <-errc
2929 if err == nil {
2930 t.Fatalf("unexpected success from RoundTrip")
2931 }
2932 test.checkErr("RoundTrip", err)
2933 }
2934
2935
2936
2937
2938 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
2939 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
2940 writeErr := make(chan error, 1)
2941 msg := []byte("young\n")
2942 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2943 for {
2944 _, err := w.Write(msg)
2945 if err != nil {
2946 writeErr <- err
2947 return
2948 }
2949 w.(Flusher).Flush()
2950 }
2951 })).ts
2952
2953 c := ts.Client()
2954 tr := c.Transport.(*Transport)
2955
2956 req, _ := NewRequest("GET", ts.URL, nil)
2957 defer tr.CancelRequest(req)
2958
2959 res, err := c.Do(req)
2960 if err != nil {
2961 t.Fatal(err)
2962 }
2963
2964 const repeats = 3
2965 buf := make([]byte, len(msg)*repeats)
2966 want := bytes.Repeat(msg, repeats)
2967
2968 _, err = io.ReadFull(res.Body, buf)
2969 if err != nil {
2970 t.Fatal(err)
2971 }
2972 if !bytes.Equal(buf, want) {
2973 t.Fatalf("read %q; want %q", buf, want)
2974 }
2975
2976 if err := res.Body.Close(); err != nil {
2977 t.Errorf("Close = %v", err)
2978 }
2979
2980 if err := <-writeErr; err == nil {
2981 t.Errorf("expected non-nil write error")
2982 }
2983 }
2984
2985 type fooProto struct{}
2986
2987 func (fooProto) RoundTrip(req *Request) (*Response, error) {
2988 res := &Response{
2989 Status: "200 OK",
2990 StatusCode: 200,
2991 Header: make(Header),
2992 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
2993 }
2994 return res, nil
2995 }
2996
2997 func TestTransportAltProto(t *testing.T) {
2998 defer afterTest(t)
2999 tr := &Transport{}
3000 c := &Client{Transport: tr}
3001 tr.RegisterProtocol("foo", fooProto{})
3002 res, err := c.Get("foo://bar.com/path")
3003 if err != nil {
3004 t.Fatal(err)
3005 }
3006 bodyb, err := io.ReadAll(res.Body)
3007 if err != nil {
3008 t.Fatal(err)
3009 }
3010 body := string(bodyb)
3011 if e := "You wanted foo://bar.com/path"; body != e {
3012 t.Errorf("got response %q, want %q", body, e)
3013 }
3014 }
3015
3016 func TestTransportNoHost(t *testing.T) {
3017 defer afterTest(t)
3018 tr := &Transport{}
3019 _, err := tr.RoundTrip(&Request{
3020 Header: make(Header),
3021 URL: &url.URL{
3022 Scheme: "http",
3023 },
3024 })
3025 want := "http: no Host in request URL"
3026 if got := fmt.Sprint(err); got != want {
3027 t.Errorf("error = %v; want %q", err, want)
3028 }
3029 }
3030
3031
3032 func TestTransportEmptyMethod(t *testing.T) {
3033 req, _ := NewRequest("GET", "http://foo.com/", nil)
3034 req.Method = ""
3035 got, err := httputil.DumpRequestOut(req, false)
3036 if err != nil {
3037 t.Fatal(err)
3038 }
3039 if !strings.Contains(string(got), "GET ") {
3040 t.Fatalf("expected substring 'GET '; got: %s", got)
3041 }
3042 }
3043
3044 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
3045 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
3046 mux := NewServeMux()
3047 fooGate := make(chan bool, 1)
3048 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
3049 w.Header().Set("foo-ipport", r.RemoteAddr)
3050 w.(Flusher).Flush()
3051 <-fooGate
3052 })
3053 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
3054 w.Header().Set("bar-ipport", r.RemoteAddr)
3055 })
3056 ts := newClientServerTest(t, mode, mux).ts
3057
3058 dialGate := make(chan bool, 1)
3059 dialing := make(chan bool)
3060 c := ts.Client()
3061 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
3062 for {
3063 select {
3064 case ok := <-dialGate:
3065 if !ok {
3066 return nil, errors.New("manually closed")
3067 }
3068 return net.Dial(n, addr)
3069 case dialing <- true:
3070 }
3071 }
3072 }
3073 defer close(dialGate)
3074
3075 dialGate <- true
3076 fooRes, err := c.Get(ts.URL + "/foo")
3077 if err != nil {
3078 t.Fatal(err)
3079 }
3080 fooAddr := fooRes.Header.Get("foo-ipport")
3081 if fooAddr == "" {
3082 t.Fatal("No addr on /foo request")
3083 }
3084
3085 fooDone := make(chan struct{})
3086 go func() {
3087
3088
3089
3090
3091 if mode == http2Mode {
3092
3093
3094
3095
3096 select {
3097 case <-dialing:
3098 t.Errorf("unexpected second Dial in HTTP/2 mode")
3099 case <-time.After(10 * time.Millisecond):
3100 }
3101 } else {
3102 <-dialing
3103 }
3104 fooGate <- true
3105 io.Copy(io.Discard, fooRes.Body)
3106 fooRes.Body.Close()
3107 close(fooDone)
3108 }()
3109 defer func() {
3110 <-fooDone
3111 }()
3112
3113 barRes, err := c.Get(ts.URL + "/bar")
3114 if err != nil {
3115 t.Fatal(err)
3116 }
3117 barAddr := barRes.Header.Get("bar-ipport")
3118 if barAddr != fooAddr {
3119 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3120 }
3121 barRes.Body.Close()
3122 }
3123
3124
3125 func TestTransportReading100Continue(t *testing.T) {
3126 defer afterTest(t)
3127
3128 const numReqs = 5
3129 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3130 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3131
3132 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3133 defer w.Close()
3134 defer r.Close()
3135 br := bufio.NewReader(r)
3136 n := 0
3137 for {
3138 n++
3139 req, err := ReadRequest(br)
3140 if err == io.EOF {
3141 return
3142 }
3143 if err != nil {
3144 t.Error(err)
3145 return
3146 }
3147 slurp, err := io.ReadAll(req.Body)
3148 if err != nil {
3149 t.Errorf("Server request body slurp: %v", err)
3150 return
3151 }
3152 id := req.Header.Get("Request-Id")
3153 resCode := req.Header.Get("X-Want-Response-Code")
3154 if resCode == "" {
3155 resCode = "100 Continue"
3156 if string(slurp) != reqBody(n) {
3157 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3158 }
3159 }
3160 body := fmt.Sprintf("Response number %d", n)
3161 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3162 Date: Thu, 28 Feb 2013 17:55:41 GMT
3163
3164 HTTP/1.1 200 OK
3165 Content-Type: text/html
3166 Echo-Request-Id: %s
3167 Content-Length: %d
3168
3169 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3170 w.Write(v)
3171 if id == reqID(numReqs) {
3172 return
3173 }
3174 }
3175
3176 }
3177
3178 tr := &Transport{
3179 Dial: func(n, addr string) (net.Conn, error) {
3180 sr, sw := io.Pipe()
3181 cr, cw := io.Pipe()
3182 conn := &rwTestConn{
3183 Reader: cr,
3184 Writer: sw,
3185 closeFunc: func() error {
3186 sw.Close()
3187 cw.Close()
3188 return nil
3189 },
3190 }
3191 go send100Response(cw, sr)
3192 return conn, nil
3193 },
3194 DisableKeepAlives: false,
3195 }
3196 defer tr.CloseIdleConnections()
3197 c := &Client{Transport: tr}
3198
3199 testResponse := func(req *Request, name string, wantCode int) {
3200 t.Helper()
3201 res, err := c.Do(req)
3202 if err != nil {
3203 t.Fatalf("%s: Do: %v", name, err)
3204 }
3205 if res.StatusCode != wantCode {
3206 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3207 }
3208 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3209 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3210 }
3211 _, err = io.ReadAll(res.Body)
3212 if err != nil {
3213 t.Fatalf("%s: Slurp error: %v", name, err)
3214 }
3215 }
3216
3217
3218 for i := 1; i <= numReqs; i++ {
3219 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3220 req.Header.Set("Request-Id", reqID(i))
3221 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3222 }
3223 }
3224
3225
3226
3227 func TestTransportIgnore1xxResponses(t *testing.T) {
3228 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3229 }
3230 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3231 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3232 conn, buf, _ := w.(Hijacker).Hijack()
3233 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3234 buf.Flush()
3235 conn.Close()
3236 }))
3237 cst.tr.DisableKeepAlives = true
3238
3239 var got strings.Builder
3240
3241 req, _ := NewRequest("GET", cst.ts.URL, nil)
3242 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3243 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3244 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3245 return nil
3246 },
3247 }))
3248 res, err := cst.c.Do(req)
3249 if err != nil {
3250 t.Fatal(err)
3251 }
3252 defer res.Body.Close()
3253
3254 res.Write(&got)
3255 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3256 if got.String() != want {
3257 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3258 }
3259 }
3260
3261 func TestTransportLimits1xxResponses(t *testing.T) { run(t, testTransportLimits1xxResponses) }
3262 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3263 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3264 w.Header().Add("X-Header", strings.Repeat("a", 100))
3265 for i := 0; i < 10; i++ {
3266 w.WriteHeader(123)
3267 }
3268 w.WriteHeader(204)
3269 }))
3270 cst.tr.DisableKeepAlives = true
3271 cst.tr.MaxResponseHeaderBytes = 1000
3272
3273 res, err := cst.c.Get(cst.ts.URL)
3274 if err == nil {
3275 res.Body.Close()
3276 t.Fatalf("RoundTrip succeeded; want error")
3277 }
3278 for _, want := range []string{
3279 "response headers exceeded",
3280 "too many 1xx",
3281 "header list too large",
3282 } {
3283 if strings.Contains(err.Error(), want) {
3284 return
3285 }
3286 }
3287 t.Errorf(`got error %q; want "response headers exceeded" or "too many 1xx"`, err)
3288 }
3289
3290 func TestTransportDoesNotLimitDelivered1xxResponses(t *testing.T) {
3291 run(t, testTransportDoesNotLimitDelivered1xxResponses)
3292 }
3293 func testTransportDoesNotLimitDelivered1xxResponses(t *testing.T, mode testMode) {
3294 if mode == http2Mode {
3295 t.Skip("skip until x/net/http2 updated")
3296 }
3297 const num1xx = 10
3298 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3299 w.Header().Add("X-Header", strings.Repeat("a", 100))
3300 for i := 0; i < 10; i++ {
3301 w.WriteHeader(123)
3302 }
3303 w.WriteHeader(204)
3304 }))
3305 cst.tr.DisableKeepAlives = true
3306 cst.tr.MaxResponseHeaderBytes = 1000
3307
3308 got1xx := 0
3309 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3310 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3311 got1xx++
3312 return nil
3313 },
3314 })
3315 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
3316 res, err := cst.c.Do(req)
3317 if err != nil {
3318 t.Fatal(err)
3319 }
3320 res.Body.Close()
3321 if got1xx != num1xx {
3322 t.Errorf("Got %v 1xx responses, want %x", got1xx, num1xx)
3323 }
3324 }
3325
3326
3327
3328 func TestTransportTreat101Terminal(t *testing.T) {
3329 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3330 }
3331 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3332 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3333 conn, buf, _ := w.(Hijacker).Hijack()
3334 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3335 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3336 buf.Flush()
3337 conn.Close()
3338 }))
3339 res, err := cst.c.Get(cst.ts.URL)
3340 if err != nil {
3341 t.Fatal(err)
3342 }
3343 defer res.Body.Close()
3344 if res.StatusCode != StatusSwitchingProtocols {
3345 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3346 }
3347 }
3348
3349 type proxyFromEnvTest struct {
3350 req string
3351
3352 env string
3353 httpsenv string
3354 noenv string
3355 reqmeth string
3356
3357 want string
3358 wanterr error
3359 }
3360
3361 func (t proxyFromEnvTest) String() string {
3362 var buf strings.Builder
3363 space := func() {
3364 if buf.Len() > 0 {
3365 buf.WriteByte(' ')
3366 }
3367 }
3368 if t.env != "" {
3369 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3370 }
3371 if t.httpsenv != "" {
3372 space()
3373 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3374 }
3375 if t.noenv != "" {
3376 space()
3377 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3378 }
3379 if t.reqmeth != "" {
3380 space()
3381 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3382 }
3383 req := "http://example.com"
3384 if t.req != "" {
3385 req = t.req
3386 }
3387 space()
3388 fmt.Fprintf(&buf, "req=%q", req)
3389 return strings.TrimSpace(buf.String())
3390 }
3391
3392 var proxyFromEnvTests = []proxyFromEnvTest{
3393 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3394 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3395 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3396 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3397 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3398 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3399 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3400 {env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
3401
3402
3403 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3404
3405 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3406 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3407
3408
3409
3410 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3411 want: "<nil>",
3412 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3413
3414 {want: "<nil>"},
3415
3416 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3417 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3418 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3419 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3420 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3421 }
3422
3423 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3424 t.Helper()
3425 reqURL := tt.req
3426 if reqURL == "" {
3427 reqURL = "http://example.com"
3428 }
3429 req, _ := NewRequest("GET", reqURL, nil)
3430 url, err := proxyForRequest(req)
3431 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3432 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3433 return
3434 }
3435 if got := fmt.Sprintf("%s", url); got != tt.want {
3436 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3437 }
3438 }
3439
3440 func TestProxyFromEnvironment(t *testing.T) {
3441 ResetProxyEnv()
3442 defer ResetProxyEnv()
3443 for _, tt := range proxyFromEnvTests {
3444 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3445 os.Setenv("HTTP_PROXY", tt.env)
3446 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3447 os.Setenv("NO_PROXY", tt.noenv)
3448 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3449 ResetCachedEnvironment()
3450 return ProxyFromEnvironment(req)
3451 })
3452 }
3453 }
3454
3455 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3456 ResetProxyEnv()
3457 defer ResetProxyEnv()
3458 for _, tt := range proxyFromEnvTests {
3459 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3460 os.Setenv("http_proxy", tt.env)
3461 os.Setenv("https_proxy", tt.httpsenv)
3462 os.Setenv("no_proxy", tt.noenv)
3463 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3464 ResetCachedEnvironment()
3465 return ProxyFromEnvironment(req)
3466 })
3467 }
3468 }
3469
3470 func TestIdleConnChannelLeak(t *testing.T) {
3471 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3472 }
3473 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3474
3475 var mu sync.Mutex
3476 var n int
3477
3478 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3479 mu.Lock()
3480 n++
3481 mu.Unlock()
3482 })).ts
3483
3484 const nReqs = 5
3485 didRead := make(chan bool, nReqs)
3486 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3487 defer SetReadLoopBeforeNextReadHook(nil)
3488
3489 c := ts.Client()
3490 tr := c.Transport.(*Transport)
3491 tr.Dial = func(netw, addr string) (net.Conn, error) {
3492 return net.Dial(netw, ts.Listener.Addr().String())
3493 }
3494
3495
3496 for _, disableKeep := range []bool{true, false} {
3497 tr.DisableKeepAlives = disableKeep
3498 for i := 0; i < nReqs; i++ {
3499 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3500 if err != nil {
3501 t.Fatal(err)
3502 }
3503
3504
3505
3506
3507
3508 }
3509
3510
3511
3512
3513
3514
3515
3516 for i := 0; i < nReqs; i++ {
3517 <-didRead
3518 }
3519
3520 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3521 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3522 }
3523 }
3524 }
3525
3526
3527
3528
3529 func TestTransportClosesRequestBody(t *testing.T) {
3530 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3531 }
3532 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3533 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3534 io.Copy(io.Discard, r.Body)
3535 })).ts
3536
3537 c := ts.Client()
3538
3539 closes := 0
3540
3541 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3542 if err != nil {
3543 t.Fatal(err)
3544 }
3545 res.Body.Close()
3546 if closes != 1 {
3547 t.Errorf("closes = %d; want 1", closes)
3548 }
3549 }
3550
3551 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3552 defer afterTest(t)
3553 if testing.Short() {
3554 t.Skip("skipping in short mode")
3555 }
3556 ln := newLocalListener(t)
3557 defer ln.Close()
3558 testdonec := make(chan struct{})
3559 defer close(testdonec)
3560
3561 go func() {
3562 c, err := ln.Accept()
3563 if err != nil {
3564 t.Error(err)
3565 return
3566 }
3567 <-testdonec
3568 c.Close()
3569 }()
3570
3571 tr := &Transport{
3572 Dial: func(_, _ string) (net.Conn, error) {
3573 return net.Dial("tcp", ln.Addr().String())
3574 },
3575 TLSHandshakeTimeout: 250 * time.Millisecond,
3576 }
3577 cl := &Client{Transport: tr}
3578 _, err := cl.Get("https://dummy.tld/")
3579 if err == nil {
3580 t.Error("expected error")
3581 return
3582 }
3583 ue, ok := err.(*url.Error)
3584 if !ok {
3585 t.Errorf("expected url.Error; got %#v", err)
3586 return
3587 }
3588 ne, ok := ue.Err.(net.Error)
3589 if !ok {
3590 t.Errorf("expected net.Error; got %#v", err)
3591 return
3592 }
3593 if !ne.Timeout() {
3594 t.Errorf("expected timeout error; got %v", err)
3595 }
3596 if !strings.Contains(err.Error(), "handshake timeout") {
3597 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3598 }
3599 }
3600
3601
3602 func TestTLSServerClosesConnection(t *testing.T) {
3603 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3604 }
3605 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3606 closedc := make(chan bool, 1)
3607 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3608 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3609 conn, _, _ := w.(Hijacker).Hijack()
3610 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3611 conn.Close()
3612 closedc <- true
3613 return
3614 }
3615 fmt.Fprintf(w, "hello")
3616 })).ts
3617
3618 c := ts.Client()
3619 tr := c.Transport.(*Transport)
3620
3621 var nSuccess = 0
3622 var errs []error
3623 const trials = 20
3624 for i := 0; i < trials; i++ {
3625 tr.CloseIdleConnections()
3626 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3627 if err != nil {
3628 t.Fatal(err)
3629 }
3630 <-closedc
3631 slurp, err := io.ReadAll(res.Body)
3632 if err != nil {
3633 t.Fatal(err)
3634 }
3635 if string(slurp) != "foo" {
3636 t.Errorf("Got %q, want foo", slurp)
3637 }
3638
3639
3640
3641 res, err = c.Get(ts.URL + "/")
3642 if err != nil {
3643 errs = append(errs, err)
3644 continue
3645 }
3646 slurp, err = io.ReadAll(res.Body)
3647 if err != nil {
3648 errs = append(errs, err)
3649 continue
3650 }
3651 nSuccess++
3652 }
3653 if nSuccess > 0 {
3654 t.Logf("successes = %d of %d", nSuccess, trials)
3655 } else {
3656 t.Errorf("All runs failed:")
3657 }
3658 for _, err := range errs {
3659 t.Logf(" err: %v", err)
3660 }
3661 }
3662
3663
3664
3665
3666 type byteFromChanReader chan byte
3667
3668 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3669 if len(p) == 0 {
3670 return
3671 }
3672 b, ok := <-c
3673 if !ok {
3674 return 0, io.EOF
3675 }
3676 p[0] = b
3677 return 1, nil
3678 }
3679
3680
3681
3682
3683
3684
3685
3686 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3687 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3688 }
3689 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3690 defer func(d time.Duration) {
3691 *MaxWriteWaitBeforeConnReuse = d
3692 }(*MaxWriteWaitBeforeConnReuse)
3693 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3694 var sconn struct {
3695 sync.Mutex
3696 c net.Conn
3697 }
3698 var getOkay bool
3699 var copying sync.WaitGroup
3700 closeConn := func() {
3701 sconn.Lock()
3702 defer sconn.Unlock()
3703 if sconn.c != nil {
3704 sconn.c.Close()
3705 sconn.c = nil
3706 if !getOkay {
3707 t.Logf("Closed server connection")
3708 }
3709 }
3710 }
3711 defer func() {
3712 closeConn()
3713 copying.Wait()
3714 }()
3715
3716 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3717 if r.Method == "GET" {
3718 io.WriteString(w, "bar")
3719 return
3720 }
3721 conn, _, _ := w.(Hijacker).Hijack()
3722 sconn.Lock()
3723 sconn.c = conn
3724 sconn.Unlock()
3725 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3726
3727 copying.Add(1)
3728 go func() {
3729 io.Copy(io.Discard, conn)
3730 copying.Done()
3731 }()
3732 })).ts
3733 c := ts.Client()
3734
3735 const bodySize = 256 << 10
3736 finalBit := make(byteFromChanReader, 1)
3737 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3738 req.ContentLength = bodySize
3739 res, err := c.Do(req)
3740 if err := wantBody(res, err, "foo"); err != nil {
3741 t.Errorf("POST response: %v", err)
3742 }
3743
3744 res, err = c.Get(ts.URL)
3745 if err := wantBody(res, err, "bar"); err != nil {
3746 t.Errorf("GET response: %v", err)
3747 return
3748 }
3749 getOkay = true
3750 finalBit <- 'x'
3751 close(finalBit)
3752 }
3753
3754
3755
3756 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3757 func testTransportIssue10457(t *testing.T, mode testMode) {
3758 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3759
3760
3761
3762
3763
3764 conn, _, _ := w.(Hijacker).Hijack()
3765 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3766 conn.Close()
3767 })).ts
3768 c := ts.Client()
3769
3770 res, err := c.Get(ts.URL)
3771 if err != nil {
3772 t.Fatalf("Get: %v", err)
3773 }
3774 defer res.Body.Close()
3775
3776
3777
3778
3779 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3780 t.Errorf("Foo header = %q; want %q", got, want)
3781 }
3782 }
3783
3784 type closerFunc func() error
3785
3786 func (f closerFunc) Close() error { return f() }
3787
3788 type writerFuncConn struct {
3789 net.Conn
3790 write func(p []byte) (n int, err error)
3791 }
3792
3793 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807 func TestRetryRequestsOnError(t *testing.T) {
3808 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3809 }
3810 func testRetryRequestsOnError(t *testing.T, mode testMode) {
3811 newRequest := func(method, urlStr string, body io.Reader) *Request {
3812 req, err := NewRequest(method, urlStr, body)
3813 if err != nil {
3814 t.Fatal(err)
3815 }
3816 return req
3817 }
3818
3819 testCases := []struct {
3820 name string
3821 failureN int
3822 failureErr error
3823
3824
3825
3826 req func() *Request
3827 reqString string
3828 }{
3829 {
3830 name: "IdempotentNoBodySomeWritten",
3831
3832
3833 failureN: 1,
3834
3835 failureErr: ExportErrServerClosedIdle,
3836 req: func() *Request {
3837 return newRequest("GET", "http://fake.golang", nil)
3838 },
3839 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3840 },
3841 {
3842 name: "IdempotentGetBodySomeWritten",
3843
3844
3845 failureN: 1,
3846
3847 failureErr: ExportErrServerClosedIdle,
3848 req: func() *Request {
3849 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
3850 },
3851 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3852 },
3853 {
3854 name: "NothingWrittenNoBody",
3855
3856
3857 failureN: 0,
3858 failureErr: errors.New("second write fails"),
3859 req: func() *Request {
3860 return newRequest("DELETE", "http://fake.golang", nil)
3861 },
3862 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3863 },
3864 {
3865 name: "NothingWrittenGetBody",
3866
3867
3868 failureN: 0,
3869 failureErr: errors.New("second write fails"),
3870
3871
3872 req: func() *Request {
3873 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
3874 },
3875 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3876 },
3877 }
3878
3879 for _, tc := range testCases {
3880 t.Run(tc.name, func(t *testing.T) {
3881 var (
3882 mu sync.Mutex
3883 logbuf strings.Builder
3884 )
3885 logf := func(format string, args ...any) {
3886 mu.Lock()
3887 defer mu.Unlock()
3888 fmt.Fprintf(&logbuf, format, args...)
3889 logbuf.WriteByte('\n')
3890 }
3891
3892 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3893 logf("Handler")
3894 w.Header().Set("X-Status", "ok")
3895 })).ts
3896
3897 var writeNumAtomic int32
3898 c := ts.Client()
3899 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
3900 logf("Dial")
3901 c, err := net.Dial(network, ts.Listener.Addr().String())
3902 if err != nil {
3903 logf("Dial error: %v", err)
3904 return nil, err
3905 }
3906 return &writerFuncConn{
3907 Conn: c,
3908 write: func(p []byte) (n int, err error) {
3909 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
3910 logf("intentional write failure")
3911 return tc.failureN, tc.failureErr
3912 }
3913 logf("Write(%q)", p)
3914 return c.Write(p)
3915 },
3916 }, nil
3917 }
3918
3919 SetRoundTripRetried(func() {
3920 logf("Retried.")
3921 })
3922 defer SetRoundTripRetried(nil)
3923
3924 for i := 0; i < 3; i++ {
3925 t0 := time.Now()
3926 req := tc.req()
3927 res, err := c.Do(req)
3928 if err != nil {
3929 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
3930 mu.Lock()
3931 got := logbuf.String()
3932 mu.Unlock()
3933 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
3934 }
3935 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
3936 }
3937 res.Body.Close()
3938 if res.Request != req {
3939 t.Errorf("Response.Request != original request; want identical Request")
3940 }
3941 }
3942
3943 mu.Lock()
3944 got := logbuf.String()
3945 mu.Unlock()
3946 want := fmt.Sprintf(`Dial
3947 Write("%s")
3948 Handler
3949 intentional write failure
3950 Retried.
3951 Dial
3952 Write("%s")
3953 Handler
3954 Write("%s")
3955 Handler
3956 `, tc.reqString, tc.reqString, tc.reqString)
3957 if got != want {
3958 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
3959 }
3960 })
3961 }
3962 }
3963
3964
3965 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
3966 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
3967 readBody := make(chan error, 1)
3968 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3969 _, err := io.ReadAll(r.Body)
3970 readBody <- err
3971 })).ts
3972 c := ts.Client()
3973 fakeErr := errors.New("fake error")
3974 didClose := make(chan bool, 1)
3975 req, _ := NewRequest("POST", ts.URL, struct {
3976 io.Reader
3977 io.Closer
3978 }{
3979 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
3980 closerFunc(func() error {
3981 select {
3982 case didClose <- true:
3983 default:
3984 }
3985 return nil
3986 }),
3987 })
3988 res, err := c.Do(req)
3989 if res != nil {
3990 defer res.Body.Close()
3991 }
3992 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
3993 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
3994 }
3995 if err := <-readBody; err == nil {
3996 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
3997 }
3998 select {
3999 case <-didClose:
4000 default:
4001 t.Errorf("didn't see Body.Close")
4002 }
4003 }
4004
4005 func TestTransportDialTLS(t *testing.T) {
4006 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
4007 }
4008 func testTransportDialTLS(t *testing.T, mode testMode) {
4009 var mu sync.Mutex
4010 var gotReq, didDial bool
4011
4012 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4013 mu.Lock()
4014 gotReq = true
4015 mu.Unlock()
4016 })).ts
4017 c := ts.Client()
4018 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
4019 mu.Lock()
4020 didDial = true
4021 mu.Unlock()
4022 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4023 if err != nil {
4024 return nil, err
4025 }
4026 return c, c.Handshake()
4027 }
4028
4029 res, err := c.Get(ts.URL)
4030 if err != nil {
4031 t.Fatal(err)
4032 }
4033 res.Body.Close()
4034 mu.Lock()
4035 if !gotReq {
4036 t.Error("didn't get request")
4037 }
4038 if !didDial {
4039 t.Error("didn't use dial hook")
4040 }
4041 }
4042
4043 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
4044 func testTransportDialContext(t *testing.T, mode testMode) {
4045 ctxKey := "some-key"
4046 ctxValue := "some-value"
4047 var (
4048 mu sync.Mutex
4049 gotReq bool
4050 gotCtxValue any
4051 )
4052
4053 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4054 mu.Lock()
4055 gotReq = true
4056 mu.Unlock()
4057 })).ts
4058 c := ts.Client()
4059 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4060 mu.Lock()
4061 gotCtxValue = ctx.Value(ctxKey)
4062 mu.Unlock()
4063 return net.Dial(netw, addr)
4064 }
4065
4066 req, err := NewRequest("GET", ts.URL, nil)
4067 if err != nil {
4068 t.Fatal(err)
4069 }
4070 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4071 res, err := c.Do(req.WithContext(ctx))
4072 if err != nil {
4073 t.Fatal(err)
4074 }
4075 res.Body.Close()
4076 mu.Lock()
4077 if !gotReq {
4078 t.Error("didn't get request")
4079 }
4080 if got, want := gotCtxValue, ctxValue; got != want {
4081 t.Errorf("got context with value %v, want %v", got, want)
4082 }
4083 }
4084
4085 func TestTransportDialTLSContext(t *testing.T) {
4086 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
4087 }
4088 func testTransportDialTLSContext(t *testing.T, mode testMode) {
4089 ctxKey := "some-key"
4090 ctxValue := "some-value"
4091 var (
4092 mu sync.Mutex
4093 gotReq bool
4094 gotCtxValue any
4095 )
4096
4097 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4098 mu.Lock()
4099 gotReq = true
4100 mu.Unlock()
4101 })).ts
4102 c := ts.Client()
4103 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4104 mu.Lock()
4105 gotCtxValue = ctx.Value(ctxKey)
4106 mu.Unlock()
4107 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4108 if err != nil {
4109 return nil, err
4110 }
4111 return c, c.HandshakeContext(ctx)
4112 }
4113
4114 req, err := NewRequest("GET", ts.URL, nil)
4115 if err != nil {
4116 t.Fatal(err)
4117 }
4118 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4119 res, err := c.Do(req.WithContext(ctx))
4120 if err != nil {
4121 t.Fatal(err)
4122 }
4123 res.Body.Close()
4124 mu.Lock()
4125 if !gotReq {
4126 t.Error("didn't get request")
4127 }
4128 if got, want := gotCtxValue, ctxValue; got != want {
4129 t.Errorf("got context with value %v, want %v", got, want)
4130 }
4131 }
4132
4133
4134
4135 func TestRoundTripReturnsProxyError(t *testing.T) {
4136 badProxy := func(*Request) (*url.URL, error) {
4137 return nil, errors.New("errorMessage")
4138 }
4139
4140 tr := &Transport{Proxy: badProxy}
4141
4142 req, _ := NewRequest("GET", "http://example.com", nil)
4143
4144 _, err := tr.RoundTrip(req)
4145
4146 if err == nil {
4147 t.Error("Expected proxy error to be returned by RoundTrip")
4148 }
4149 }
4150
4151
4152 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
4153 tr := &Transport{}
4154 wantIdle := func(when string, n int) bool {
4155 got := tr.IdleConnCountForTesting("http", "example.com")
4156 if got == n {
4157 return true
4158 }
4159 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4160 return false
4161 }
4162 wantIdle("start", 0)
4163 if !tr.PutIdleTestConn("http", "example.com") {
4164 t.Fatal("put failed")
4165 }
4166 if !tr.PutIdleTestConn("http", "example.com") {
4167 t.Fatal("second put failed")
4168 }
4169 wantIdle("after put", 2)
4170 tr.CloseIdleConnections()
4171 if !tr.IsIdleForTesting() {
4172 t.Error("should be idle after CloseIdleConnections")
4173 }
4174 wantIdle("after close idle", 0)
4175 if tr.PutIdleTestConn("http", "example.com") {
4176 t.Fatal("put didn't fail")
4177 }
4178 wantIdle("after second put", 0)
4179
4180 tr.QueueForIdleConnForTesting()
4181 if tr.IsIdleForTesting() {
4182 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4183 }
4184 if !tr.PutIdleTestConn("http", "example.com") {
4185 t.Fatal("after re-activation")
4186 }
4187 wantIdle("after final put", 1)
4188 }
4189
4190
4191
4192 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4193 tr := &Transport{}
4194 wantIdle := func(when string, n int) bool {
4195 got := tr.IdleConnCountForTesting("https", "example.com:443")
4196 if got == n {
4197 return true
4198 }
4199 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4200 return false
4201 }
4202 wantIdle("start", 0)
4203 alt := funcRoundTripper(func() {})
4204 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4205 t.Fatal("put failed")
4206 }
4207 wantIdle("after put", 1)
4208 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4209 GotConn: func(httptrace.GotConnInfo) {
4210
4211 t.Error("GotConn called")
4212 },
4213 })
4214 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4215 _, err := tr.RoundTrip(req)
4216 if err != errFakeRoundTrip {
4217 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4218 }
4219 wantIdle("after round trip", 1)
4220 }
4221
4222 func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
4223 run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
4224 }
4225 func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
4226 if testing.Short() {
4227 t.Skip("skipping in short mode")
4228 }
4229
4230 timeout := 1 * time.Millisecond
4231 retry := true
4232 for retry {
4233 trFunc := func(tr *Transport) {
4234 tr.MaxConnsPerHost = 1
4235 tr.MaxIdleConnsPerHost = 1
4236 tr.IdleConnTimeout = timeout
4237 }
4238 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
4239
4240 retry = false
4241 tooShort := func(err error) bool {
4242 if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
4243 return false
4244 }
4245 if !retry {
4246 t.Helper()
4247 t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout)
4248 timeout *= 2
4249 retry = true
4250 cst.close()
4251 }
4252 return true
4253 }
4254
4255 if _, err := cst.c.Get(cst.ts.URL); err != nil {
4256 if tooShort(err) {
4257 continue
4258 }
4259 t.Fatalf("got error: %s", err)
4260 }
4261
4262 time.Sleep(10 * timeout)
4263 if _, err := cst.c.Get(cst.ts.URL); err != nil {
4264 if tooShort(err) {
4265 continue
4266 }
4267 t.Fatalf("got error: %s", err)
4268 }
4269 }
4270 }
4271
4272
4273
4274
4275
4276 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4277 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4278 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4279 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4280 t.Error("Transport advertised gzip support in the Accept header")
4281 }
4282 if r.Header.Get("Range") == "" {
4283 t.Error("no Range in request")
4284 }
4285 })).ts
4286 c := ts.Client()
4287
4288 req, _ := NewRequest("GET", ts.URL, nil)
4289 req.Header.Set("Range", "bytes=7-11")
4290 res, err := c.Do(req)
4291 if err != nil {
4292 t.Fatal(err)
4293 }
4294 res.Body.Close()
4295 }
4296
4297
4298 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4299 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4300 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4301
4302 var b [1024]byte
4303 w.Write(b[:])
4304 })).ts
4305 tr := ts.Client().Transport.(*Transport)
4306
4307 req, err := NewRequest("GET", ts.URL, nil)
4308 if err != nil {
4309 t.Fatal(err)
4310 }
4311 res, err := tr.RoundTrip(req)
4312 if err != nil {
4313 t.Fatal(err)
4314 }
4315
4316
4317
4318 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4319 t.Fatal(err)
4320 }
4321
4322 req2, err := NewRequest("GET", ts.URL, nil)
4323 if err != nil {
4324 t.Fatal(err)
4325 }
4326 tr.CancelRequest(req)
4327 res, err = tr.RoundTrip(req2)
4328 if err != nil {
4329 t.Fatal(err)
4330 }
4331 res.Body.Close()
4332 }
4333
4334
4335 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4336 run(t, testTransportContentEncodingCaseInsensitive)
4337 }
4338 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4339 for _, ce := range []string{"gzip", "GZIP"} {
4340 ce := ce
4341 t.Run(ce, func(t *testing.T) {
4342 const encodedString = "Hello Gopher"
4343 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4344 w.Header().Set("Content-Encoding", ce)
4345 gz := gzip.NewWriter(w)
4346 gz.Write([]byte(encodedString))
4347 gz.Close()
4348 })).ts
4349
4350 res, err := ts.Client().Get(ts.URL)
4351 if err != nil {
4352 t.Fatal(err)
4353 }
4354
4355 body, err := io.ReadAll(res.Body)
4356 res.Body.Close()
4357 if err != nil {
4358 t.Fatal(err)
4359 }
4360
4361 if string(body) != encodedString {
4362 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4363 }
4364 })
4365 }
4366 }
4367
4368
4369 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4370 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4371 }
4372 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4373 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4374 func(tr *Transport) {
4375 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4376
4377 return &funcConn{
4378 read: func([]byte) (int, error) {
4379 return 0, errors.New("error")
4380 },
4381 write: func([]byte) (int, error) {
4382 return 0, errors.New("error")
4383 },
4384 }, nil
4385 }
4386 },
4387 ).ts
4388
4389
4390
4391
4392
4393 SetEnterRoundTripHook(func() {
4394 time.Sleep(1 * time.Millisecond)
4395 })
4396 defer SetEnterRoundTripHook(nil)
4397 var closes int
4398 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4399 if err == nil {
4400 t.Fatalf("expected request to fail, but it did not")
4401 }
4402 if closes != 1 {
4403 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4404 }
4405 }
4406
4407
4408
4409
4410 type logWritesConn struct {
4411 net.Conn
4412
4413 w io.Writer
4414
4415 rch <-chan io.Reader
4416 r io.Reader
4417
4418 mu sync.Mutex
4419 writes []string
4420 }
4421
4422 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4423 c.mu.Lock()
4424 defer c.mu.Unlock()
4425 c.writes = append(c.writes, string(p))
4426 return c.w.Write(p)
4427 }
4428
4429 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4430 if c.r == nil {
4431 c.r = <-c.rch
4432 }
4433 return c.r.Read(p)
4434 }
4435
4436 func (c *logWritesConn) Close() error { return nil }
4437
4438
4439 func TestTransportFlushesBodyChunks(t *testing.T) {
4440 defer afterTest(t)
4441 resBody := make(chan io.Reader, 1)
4442 connr, connw := io.Pipe()
4443 lw := &logWritesConn{
4444 rch: resBody,
4445 w: connw,
4446 }
4447 tr := &Transport{
4448 Dial: func(network, addr string) (net.Conn, error) {
4449 return lw, nil
4450 },
4451 }
4452 bodyr, bodyw := io.Pipe()
4453 go func() {
4454 defer bodyw.Close()
4455 for i := 0; i < 3; i++ {
4456 fmt.Fprintf(bodyw, "num%d\n", i)
4457 }
4458 }()
4459 resc := make(chan *Response)
4460 go func() {
4461 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4462 req.Header.Set("User-Agent", "x")
4463 res, err := tr.RoundTrip(req)
4464 if err != nil {
4465 t.Errorf("RoundTrip: %v", err)
4466 close(resc)
4467 return
4468 }
4469 resc <- res
4470
4471 }()
4472
4473 req, err := ReadRequest(bufio.NewReader(connr))
4474 if err != nil {
4475 t.Fatal(err)
4476 }
4477 io.Copy(io.Discard, req.Body)
4478
4479
4480 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4481 res, ok := <-resc
4482 if !ok {
4483 return
4484 }
4485 defer res.Body.Close()
4486
4487 want := []string{
4488 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4489 "5\r\nnum0\n\r\n",
4490 "5\r\nnum1\n\r\n",
4491 "5\r\nnum2\n\r\n",
4492 "0\r\n\r\n",
4493 }
4494 if !slices.Equal(lw.writes, want) {
4495 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4496 }
4497 }
4498
4499
4500 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4501 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4502 gotReq := make(chan struct{})
4503 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4504 close(gotReq)
4505 }))
4506
4507 pr, pw := io.Pipe()
4508 req, err := NewRequest("POST", cst.ts.URL, pr)
4509 if err != nil {
4510 t.Fatal(err)
4511 }
4512 gotRes := make(chan struct{})
4513 go func() {
4514 defer close(gotRes)
4515 res, err := cst.tr.RoundTrip(req)
4516 if err != nil {
4517 t.Error(err)
4518 return
4519 }
4520 res.Body.Close()
4521 }()
4522
4523 <-gotReq
4524 pw.Close()
4525 <-gotRes
4526 }
4527
4528 type wgReadCloser struct {
4529 io.Reader
4530 wg *sync.WaitGroup
4531 closed bool
4532 }
4533
4534 func (c *wgReadCloser) Close() error {
4535 if c.closed {
4536 return net.ErrClosed
4537 }
4538 c.closed = true
4539 c.wg.Done()
4540 return nil
4541 }
4542
4543
4544 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4545
4546 run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
4547 }
4548 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4549 if testing.Short() {
4550 t.Skip("skipping in short mode")
4551 }
4552
4553 runTimeSensitiveTest(t, []time.Duration{
4554 1 * time.Millisecond,
4555 5 * time.Millisecond,
4556 10 * time.Millisecond,
4557 50 * time.Millisecond,
4558 100 * time.Millisecond,
4559 500 * time.Millisecond,
4560 time.Second,
4561 5 * time.Second,
4562 }, func(t *testing.T, timeout time.Duration) error {
4563 SetRSTAvoidanceDelay(t, timeout)
4564 t.Logf("set RST avoidance delay to %v", timeout)
4565
4566 const contentLengthLimit = 1024 * 1024
4567 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4568 if r.ContentLength >= contentLengthLimit {
4569 w.WriteHeader(StatusBadRequest)
4570 r.Body.Close()
4571 return
4572 }
4573 w.WriteHeader(StatusOK)
4574 }))
4575
4576
4577 defer cst.close()
4578 ts := cst.ts
4579 c := ts.Client()
4580
4581 count := 100
4582
4583 bigBody := strings.Repeat("a", contentLengthLimit*2)
4584 var wg sync.WaitGroup
4585 defer wg.Wait()
4586 getBody := func() (io.ReadCloser, error) {
4587 wg.Add(1)
4588 body := &wgReadCloser{
4589 Reader: strings.NewReader(bigBody),
4590 wg: &wg,
4591 }
4592 return body, nil
4593 }
4594
4595 for i := 0; i < count; i++ {
4596 reqBody, _ := getBody()
4597 req, err := NewRequest("PUT", ts.URL, reqBody)
4598 if err != nil {
4599 reqBody.Close()
4600 t.Fatal(err)
4601 }
4602 req.ContentLength = int64(len(bigBody))
4603 req.GetBody = getBody
4604
4605 resp, err := c.Do(req)
4606 if err != nil {
4607 return fmt.Errorf("Do %d: %v", i, err)
4608 } else {
4609 resp.Body.Close()
4610 if resp.StatusCode != 400 {
4611 t.Errorf("Expected status code 400, got %v", resp.Status)
4612 }
4613 }
4614 }
4615 return nil
4616 })
4617 }
4618
4619 func TestTransportAutomaticHTTP2(t *testing.T) {
4620 testTransportAutoHTTP(t, &Transport{}, true)
4621 }
4622
4623 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4624 testTransportAutoHTTP(t, &Transport{
4625 ForceAttemptHTTP2: true,
4626 TLSClientConfig: new(tls.Config),
4627 }, true)
4628 }
4629
4630
4631 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4632 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4633 }
4634
4635 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4636 testTransportAutoHTTP(t, &Transport{
4637 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4638 }, false)
4639 }
4640
4641 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4642 testTransportAutoHTTP(t, &Transport{
4643 TLSClientConfig: new(tls.Config),
4644 }, false)
4645 }
4646
4647 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4648 testTransportAutoHTTP(t, &Transport{
4649 ExpectContinueTimeout: 1 * time.Second,
4650 }, true)
4651 }
4652
4653 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4654 var d net.Dialer
4655 testTransportAutoHTTP(t, &Transport{
4656 Dial: d.Dial,
4657 }, false)
4658 }
4659
4660 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4661 var d net.Dialer
4662 testTransportAutoHTTP(t, &Transport{
4663 DialContext: d.DialContext,
4664 }, false)
4665 }
4666
4667 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4668 testTransportAutoHTTP(t, &Transport{
4669 DialTLS: func(network, addr string) (net.Conn, error) {
4670 panic("unused")
4671 },
4672 }, false)
4673 }
4674
4675 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4676 CondSkipHTTP2(t)
4677 _, err := tr.RoundTrip(new(Request))
4678 if err == nil {
4679 t.Error("expected error from RoundTrip")
4680 }
4681 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4682 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4683 }
4684 }
4685
4686
4687
4688
4689
4690
4691
4692
4693 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4694 run(t, testTransportReuseConnEmptyResponseBody)
4695 }
4696 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4697 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4698 w.Header().Set("X-Addr", r.RemoteAddr)
4699
4700 }))
4701 n := 100
4702 if testing.Short() {
4703 n = 10
4704 }
4705 var firstAddr string
4706 for i := 0; i < n; i++ {
4707 res, err := cst.c.Get(cst.ts.URL)
4708 if err != nil {
4709 log.Fatal(err)
4710 }
4711 addr := res.Header.Get("X-Addr")
4712 if i == 0 {
4713 firstAddr = addr
4714 } else if addr != firstAddr {
4715 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
4716 }
4717 res.Body.Close()
4718 }
4719 }
4720
4721
4722 func TestNoCrashReturningTransportAltConn(t *testing.T) {
4723 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4724 if err != nil {
4725 t.Fatal(err)
4726 }
4727 ln := newLocalListener(t)
4728 defer ln.Close()
4729
4730 var wg sync.WaitGroup
4731 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
4732 defer SetPendingDialHooks(nil, nil)
4733
4734 testDone := make(chan struct{})
4735 defer close(testDone)
4736 go func() {
4737 tln := tls.NewListener(ln, &tls.Config{
4738 NextProtos: []string{"foo"},
4739 Certificates: []tls.Certificate{cert},
4740 })
4741 sc, err := tln.Accept()
4742 if err != nil {
4743 t.Error(err)
4744 return
4745 }
4746 if err := sc.(*tls.Conn).Handshake(); err != nil {
4747 t.Error(err)
4748 return
4749 }
4750 <-testDone
4751 sc.Close()
4752 }()
4753
4754 addr := ln.Addr().String()
4755
4756 req, _ := NewRequest("GET", "https://fake.tld/", nil)
4757 cancel := make(chan struct{})
4758 req.Cancel = cancel
4759
4760 doReturned := make(chan bool, 1)
4761 madeRoundTripper := make(chan bool, 1)
4762
4763 tr := &Transport{
4764 DisableKeepAlives: true,
4765 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
4766 "foo": func(authority string, c *tls.Conn) RoundTripper {
4767 madeRoundTripper <- true
4768 return funcRoundTripper(func() {
4769 t.Error("foo RoundTripper should not be called")
4770 })
4771 },
4772 },
4773 Dial: func(_, _ string) (net.Conn, error) {
4774 panic("shouldn't be called")
4775 },
4776 DialTLS: func(_, _ string) (net.Conn, error) {
4777 tc, err := tls.Dial("tcp", addr, &tls.Config{
4778 InsecureSkipVerify: true,
4779 NextProtos: []string{"foo"},
4780 })
4781 if err != nil {
4782 return nil, err
4783 }
4784 if err := tc.Handshake(); err != nil {
4785 return nil, err
4786 }
4787 close(cancel)
4788 <-doReturned
4789 return tc, nil
4790 },
4791 }
4792 c := &Client{Transport: tr}
4793
4794 _, err = c.Do(req)
4795 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
4796 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
4797 }
4798
4799 doReturned <- true
4800 <-madeRoundTripper
4801 wg.Wait()
4802 }
4803
4804 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
4805 run(t, func(t *testing.T, mode testMode) {
4806 testTransportReuseConnection_Gzip(t, mode, true)
4807 })
4808 }
4809
4810 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
4811 run(t, func(t *testing.T, mode testMode) {
4812 testTransportReuseConnection_Gzip(t, mode, false)
4813 })
4814 }
4815
4816
4817 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
4818 addr := make(chan string, 2)
4819 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4820 addr <- r.RemoteAddr
4821 w.Header().Set("Content-Encoding", "gzip")
4822 if chunked {
4823 w.(Flusher).Flush()
4824 }
4825 w.Write(rgz)
4826 })).ts
4827 c := ts.Client()
4828
4829 trace := &httptrace.ClientTrace{
4830 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
4831 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
4832 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
4833 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
4834 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
4835 }
4836 ctx := httptrace.WithClientTrace(context.Background(), trace)
4837
4838 for i := 0; i < 2; i++ {
4839 req, _ := NewRequest("GET", ts.URL, nil)
4840 req = req.WithContext(ctx)
4841 res, err := c.Do(req)
4842 if err != nil {
4843 t.Fatal(err)
4844 }
4845 buf := make([]byte, len(rgz))
4846 if n, err := io.ReadFull(res.Body, buf); err != nil {
4847 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
4848 }
4849
4850
4851
4852 }
4853 a1, a2 := <-addr, <-addr
4854 if a1 != a2 {
4855 t.Fatalf("didn't reuse connection")
4856 }
4857 }
4858
4859 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
4860 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
4861 if mode == http2Mode {
4862 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
4863 }
4864 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4865 if r.URL.Path == "/long" {
4866 w.Header().Set("Long", strings.Repeat("a", 1<<20))
4867 }
4868 })).ts
4869 c := ts.Client()
4870 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
4871
4872 if res, err := c.Get(ts.URL); err != nil {
4873 t.Fatal(err)
4874 } else {
4875 res.Body.Close()
4876 }
4877
4878 res, err := c.Get(ts.URL + "/long")
4879 if err == nil {
4880 defer res.Body.Close()
4881 var n int64
4882 for k, vv := range res.Header {
4883 for _, v := range vv {
4884 n += int64(len(k)) + int64(len(v))
4885 }
4886 }
4887 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
4888 }
4889 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
4890 t.Errorf("got error: %v; want %q", err, want)
4891 }
4892 }
4893
4894 func TestTransportEventTrace(t *testing.T) {
4895 run(t, func(t *testing.T, mode testMode) {
4896 testTransportEventTrace(t, mode, false)
4897 }, testNotParallel)
4898 }
4899
4900
4901 func TestTransportEventTrace_NoHooks(t *testing.T) {
4902 run(t, func(t *testing.T, mode testMode) {
4903 testTransportEventTrace(t, mode, true)
4904 }, testNotParallel)
4905 }
4906
4907 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
4908 const resBody = "some body"
4909 gotWroteReqEvent := make(chan struct{}, 500)
4910 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4911 if r.Method == "GET" {
4912
4913 return
4914 }
4915 if _, err := io.ReadAll(r.Body); err != nil {
4916 t.Error(err)
4917 }
4918 if !noHooks {
4919 <-gotWroteReqEvent
4920 }
4921 io.WriteString(w, resBody)
4922 }), func(tr *Transport) {
4923 if tr.TLSClientConfig != nil {
4924 tr.TLSClientConfig.InsecureSkipVerify = true
4925 }
4926 })
4927 defer cst.close()
4928
4929 cst.tr.ExpectContinueTimeout = 1 * time.Second
4930
4931 var mu sync.Mutex
4932 var buf strings.Builder
4933 logf := func(format string, args ...any) {
4934 mu.Lock()
4935 defer mu.Unlock()
4936 fmt.Fprintf(&buf, format, args...)
4937 buf.WriteByte('\n')
4938 }
4939
4940 addrStr := cst.ts.Listener.Addr().String()
4941 ip, port, err := net.SplitHostPort(addrStr)
4942 if err != nil {
4943 t.Fatal(err)
4944 }
4945
4946
4947 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
4948 if host != "dns-is-faked.golang" {
4949 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
4950 return nil, nil
4951 }
4952 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
4953 })
4954
4955 body := "some body"
4956 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
4957 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
4958 trace := &httptrace.ClientTrace{
4959 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
4960 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
4961 GotFirstResponseByte: func() { logf("first response byte") },
4962 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
4963 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
4964 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
4965 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
4966 ConnectDone: func(network, addr string, err error) {
4967 if err != nil {
4968 t.Errorf("ConnectDone: %v", err)
4969 }
4970 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
4971 },
4972 WroteHeaderField: func(key string, value []string) {
4973 logf("WroteHeaderField: %s: %v", key, value)
4974 },
4975 WroteHeaders: func() {
4976 logf("WroteHeaders")
4977 },
4978 Wait100Continue: func() { logf("Wait100Continue") },
4979 Got100Continue: func() { logf("Got100Continue") },
4980 WroteRequest: func(e httptrace.WroteRequestInfo) {
4981 logf("WroteRequest: %+v", e)
4982 gotWroteReqEvent <- struct{}{}
4983 },
4984 }
4985 if mode == http2Mode {
4986 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
4987 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
4988 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
4989 }
4990 }
4991 if noHooks {
4992
4993 *trace = httptrace.ClientTrace{}
4994 }
4995 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4996
4997 req.Header.Set("Expect", "100-continue")
4998 res, err := cst.c.Do(req)
4999 if err != nil {
5000 t.Fatal(err)
5001 }
5002 logf("got roundtrip.response")
5003 slurp, err := io.ReadAll(res.Body)
5004 if err != nil {
5005 t.Fatal(err)
5006 }
5007 logf("consumed body")
5008 if string(slurp) != resBody || res.StatusCode != 200 {
5009 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
5010 }
5011 res.Body.Close()
5012
5013 if noHooks {
5014
5015
5016
5017 return
5018 }
5019
5020 mu.Lock()
5021 got := buf.String()
5022 mu.Unlock()
5023
5024 wantOnce := func(sub string) {
5025 if strings.Count(got, sub) != 1 {
5026 t.Errorf("expected substring %q exactly once in output.", sub)
5027 }
5028 }
5029 wantOnceOrMore := func(sub string) {
5030 if strings.Count(got, sub) == 0 {
5031 t.Errorf("expected substring %q at least once in output.", sub)
5032 }
5033 }
5034 wantOnce("Getting conn for dns-is-faked.golang:" + port)
5035 wantOnce("DNS start: {Host:dns-is-faked.golang}")
5036 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
5037 wantOnce("got conn: {")
5038 wantOnceOrMore("Connecting to tcp " + addrStr)
5039 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
5040 wantOnce("Reused:false WasIdle:false IdleTime:0s")
5041 wantOnce("first response byte")
5042 if mode == http2Mode {
5043 wantOnce("tls handshake start")
5044 wantOnce("tls handshake done")
5045 } else {
5046 wantOnce("PutIdleConn = <nil>")
5047 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
5048
5049
5050 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
5051 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
5052 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
5053 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
5054 }
5055 wantOnce("WroteHeaders")
5056 wantOnce("Wait100Continue")
5057 wantOnce("Got100Continue")
5058 wantOnce("WroteRequest: {Err:<nil>}")
5059 if strings.Contains(got, " to udp ") {
5060 t.Errorf("should not see UDP (DNS) connections")
5061 }
5062 if t.Failed() {
5063 t.Errorf("Output:\n%s", got)
5064 }
5065
5066
5067 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
5068 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5069 res, err = cst.c.Do(req)
5070 if err != nil {
5071 t.Fatal(err)
5072 }
5073 if res.StatusCode != 200 {
5074 t.Fatal(res.Status)
5075 }
5076 res.Body.Close()
5077
5078 mu.Lock()
5079 got = buf.String()
5080 mu.Unlock()
5081
5082 sub := "Getting conn for dns-is-faked.golang:"
5083 if gotn, want := strings.Count(got, sub), 2; gotn != want {
5084 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
5085 }
5086
5087 }
5088
5089 func TestTransportEventTraceTLSVerify(t *testing.T) {
5090 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
5091 }
5092 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
5093 var mu sync.Mutex
5094 var buf strings.Builder
5095 logf := func(format string, args ...any) {
5096 mu.Lock()
5097 defer mu.Unlock()
5098 fmt.Fprintf(&buf, format, args...)
5099 buf.WriteByte('\n')
5100 }
5101
5102 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5103 t.Error("Unexpected request")
5104 }), func(ts *httptest.Server) {
5105 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
5106 logf("%s", p)
5107 return len(p), nil
5108 }), "", 0)
5109 }).ts
5110
5111 certpool := x509.NewCertPool()
5112 certpool.AddCert(ts.Certificate())
5113
5114 c := &Client{Transport: &Transport{
5115 TLSClientConfig: &tls.Config{
5116 ServerName: "dns-is-faked.golang",
5117 RootCAs: certpool,
5118 },
5119 }}
5120
5121 trace := &httptrace.ClientTrace{
5122 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
5123 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5124 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
5125 },
5126 }
5127
5128 req, _ := NewRequest("GET", ts.URL, nil)
5129 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5130 _, err := c.Do(req)
5131 if err == nil {
5132 t.Error("Expected request to fail TLS verification")
5133 }
5134
5135 mu.Lock()
5136 got := buf.String()
5137 mu.Unlock()
5138
5139 wantOnce := func(sub string) {
5140 if strings.Count(got, sub) != 1 {
5141 t.Errorf("expected substring %q exactly once in output.", sub)
5142 }
5143 }
5144
5145 wantOnce("TLSHandshakeStart")
5146 wantOnce("TLSHandshakeDone")
5147 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5148
5149 if t.Failed() {
5150 t.Errorf("Output:\n%s", got)
5151 }
5152 }
5153
5154 var isDNSHijacked = sync.OnceValue(func() bool {
5155 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5156 return len(addrs) != 0
5157 })
5158
5159 func skipIfDNSHijacked(t *testing.T) {
5160
5161
5162
5163 if isDNSHijacked() {
5164 t.Skip("skipping; test requires non-hijacking DNS server")
5165 }
5166 }
5167
5168 func TestTransportEventTraceRealDNS(t *testing.T) {
5169 skipIfDNSHijacked(t)
5170 defer afterTest(t)
5171 tr := &Transport{}
5172 defer tr.CloseIdleConnections()
5173 c := &Client{Transport: tr}
5174
5175 var mu sync.Mutex
5176 var buf strings.Builder
5177 logf := func(format string, args ...any) {
5178 mu.Lock()
5179 defer mu.Unlock()
5180 fmt.Fprintf(&buf, format, args...)
5181 buf.WriteByte('\n')
5182 }
5183
5184 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5185 trace := &httptrace.ClientTrace{
5186 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5187 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5188 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5189 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5190 }
5191 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5192
5193 resp, err := c.Do(req)
5194 if err == nil {
5195 resp.Body.Close()
5196 t.Fatal("expected error during DNS lookup")
5197 }
5198
5199 mu.Lock()
5200 got := buf.String()
5201 mu.Unlock()
5202
5203 wantSub := func(sub string) {
5204 if !strings.Contains(got, sub) {
5205 t.Errorf("expected substring %q in output.", sub)
5206 }
5207 }
5208 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5209 wantSub("DNSDone: {Addrs:[] Err:")
5210 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5211 t.Errorf("should not see Connect events")
5212 }
5213 if t.Failed() {
5214 t.Errorf("Output:\n%s", got)
5215 }
5216 }
5217
5218
5219 func TestTransportRejectsAlphaPort(t *testing.T) {
5220 res, err := Get("http://dummy.tld:123foo/bar")
5221 if err == nil {
5222 res.Body.Close()
5223 t.Fatal("unexpected success")
5224 }
5225 ue, ok := err.(*url.Error)
5226 if !ok {
5227 t.Fatalf("got %#v; want *url.Error", err)
5228 }
5229 got := ue.Err.Error()
5230 want := `invalid port ":123foo" after host`
5231 if got != want {
5232 t.Errorf("got error %q; want %q", got, want)
5233 }
5234 }
5235
5236
5237
5238 func TestTLSHandshakeTrace(t *testing.T) {
5239 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5240 }
5241 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5242 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5243
5244 var mu sync.Mutex
5245 var start, done bool
5246 trace := &httptrace.ClientTrace{
5247 TLSHandshakeStart: func() {
5248 mu.Lock()
5249 defer mu.Unlock()
5250 start = true
5251 },
5252 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5253 mu.Lock()
5254 defer mu.Unlock()
5255 done = true
5256 if err != nil {
5257 t.Fatal("Expected error to be nil but was:", err)
5258 }
5259 },
5260 }
5261
5262 c := ts.Client()
5263 req, err := NewRequest("GET", ts.URL, nil)
5264 if err != nil {
5265 t.Fatal("Unable to construct test request:", err)
5266 }
5267 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5268
5269 r, err := c.Do(req)
5270 if err != nil {
5271 t.Fatal("Unexpected error making request:", err)
5272 }
5273 r.Body.Close()
5274 mu.Lock()
5275 defer mu.Unlock()
5276 if !start {
5277 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5278 }
5279 if !done {
5280 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5281 }
5282 }
5283
5284 func TestTransportMaxIdleConns(t *testing.T) {
5285 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5286 }
5287 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5288 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5289
5290 })).ts
5291 c := ts.Client()
5292 tr := c.Transport.(*Transport)
5293 tr.MaxIdleConns = 4
5294
5295 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5296 if err != nil {
5297 t.Fatal(err)
5298 }
5299 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5300 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5301 })
5302
5303 hitHost := func(n int) {
5304 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5305 req = req.WithContext(ctx)
5306 res, err := c.Do(req)
5307 if err != nil {
5308 t.Fatal(err)
5309 }
5310 res.Body.Close()
5311 }
5312 for i := 0; i < 4; i++ {
5313 hitHost(i)
5314 }
5315 want := []string{
5316 "|http|host-0.dns-is-faked.golang:" + port,
5317 "|http|host-1.dns-is-faked.golang:" + port,
5318 "|http|host-2.dns-is-faked.golang:" + port,
5319 "|http|host-3.dns-is-faked.golang:" + port,
5320 }
5321 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5322 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5323 }
5324
5325
5326 hitHost(4)
5327 want = []string{
5328 "|http|host-1.dns-is-faked.golang:" + port,
5329 "|http|host-2.dns-is-faked.golang:" + port,
5330 "|http|host-3.dns-is-faked.golang:" + port,
5331 "|http|host-4.dns-is-faked.golang:" + port,
5332 }
5333 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5334 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5335 }
5336 }
5337
5338 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5339 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5340 if testing.Short() {
5341 t.Skip("skipping in short mode")
5342 }
5343
5344 timeout := 1 * time.Millisecond
5345 timeoutLoop:
5346 for {
5347 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5348
5349 }))
5350 tr := cst.tr
5351 tr.IdleConnTimeout = timeout
5352 defer tr.CloseIdleConnections()
5353 c := &Client{Transport: tr}
5354
5355 idleConns := func() []string {
5356 if mode == http2Mode {
5357 return tr.IdleConnStrsForTesting_h2()
5358 } else {
5359 return tr.IdleConnStrsForTesting()
5360 }
5361 }
5362
5363 var conn string
5364 doReq := func(n int) (timeoutOk bool) {
5365 req, _ := NewRequest("GET", cst.ts.URL, nil)
5366 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5367 PutIdleConn: func(err error) {
5368 if err != nil {
5369 t.Errorf("failed to keep idle conn: %v", err)
5370 }
5371 },
5372 }))
5373 res, err := c.Do(req)
5374 if err != nil {
5375 if strings.Contains(err.Error(), "use of closed network connection") {
5376 t.Logf("req %v: connection closed prematurely", n)
5377 return false
5378 }
5379 }
5380 res.Body.Close()
5381 conns := idleConns()
5382 if len(conns) != 1 {
5383 if len(conns) == 0 {
5384 t.Logf("req %v: no idle conns", n)
5385 return false
5386 }
5387 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5388 }
5389 if conn == "" {
5390 conn = conns[0]
5391 }
5392 if conn != conns[0] {
5393 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5394 return false
5395 }
5396 return true
5397 }
5398 for i := 0; i < 3; i++ {
5399 if !doReq(i) {
5400 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5401 timeout *= 2
5402 cst.close()
5403 continue timeoutLoop
5404 }
5405 time.Sleep(timeout / 2)
5406 }
5407
5408 waitCondition(t, timeout/2, func(d time.Duration) bool {
5409 if got := idleConns(); len(got) != 0 {
5410 if d >= timeout*3/2 {
5411 t.Logf("after %v, idle conns = %q", d, got)
5412 }
5413 return false
5414 }
5415 return true
5416 })
5417 break
5418 }
5419 }
5420
5421
5422
5423
5424
5425
5426
5427
5428
5429
5430
5431
5432 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5433 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5434 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5435
5436 }))
5437
5438 ctx, cancel := context.WithCancel(context.Background())
5439 defer cancel()
5440
5441 sawDoErr := make(chan bool, 1)
5442 testDone := make(chan struct{})
5443 defer close(testDone)
5444
5445 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5446 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5447 c, err := tls.Dial(network, addr, &tls.Config{
5448 InsecureSkipVerify: true,
5449 NextProtos: []string{"h2"},
5450 })
5451 if err != nil {
5452 t.Error(err)
5453 return nil, err
5454 }
5455 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5456 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5457 c.Close()
5458 return nil, errors.New("bogus")
5459 }
5460
5461 cancel()
5462
5463 select {
5464 case <-sawDoErr:
5465 case <-testDone:
5466 }
5467 return c, nil
5468 }
5469
5470 req, _ := NewRequest("GET", cst.ts.URL, nil)
5471 req = req.WithContext(ctx)
5472 res, err := cst.c.Do(req)
5473 if err == nil {
5474 res.Body.Close()
5475 t.Fatal("unexpected success")
5476 }
5477 sawDoErr <- true
5478
5479
5480 time.Sleep(cst.tr.IdleConnTimeout * 10)
5481 }
5482
5483 type funcConn struct {
5484 net.Conn
5485 read func([]byte) (int, error)
5486 write func([]byte) (int, error)
5487 }
5488
5489 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5490 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5491 func (c funcConn) Close() error { return nil }
5492
5493
5494
5495 func TestTransportReturnsPeekError(t *testing.T) {
5496 errValue := errors.New("specific error value")
5497
5498 wrote := make(chan struct{})
5499 wroteOnce := sync.OnceFunc(func() { close(wrote) })
5500
5501 tr := &Transport{
5502 Dial: func(network, addr string) (net.Conn, error) {
5503 c := funcConn{
5504 read: func([]byte) (int, error) {
5505 <-wrote
5506 return 0, errValue
5507 },
5508 write: func(p []byte) (int, error) {
5509 wroteOnce()
5510 return len(p), nil
5511 },
5512 }
5513 return c, nil
5514 },
5515 }
5516 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5517 if err != errValue {
5518 t.Errorf("error = %#v; want %v", err, errValue)
5519 }
5520 }
5521
5522
5523 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5524 func testTransportIDNA(t *testing.T, mode testMode) {
5525 const uniDomain = "гофер.го"
5526 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5527
5528 var port string
5529 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5530 want := punyDomain + ":" + port
5531 if r.Host != want {
5532 t.Errorf("Host header = %q; want %q", r.Host, want)
5533 }
5534 if mode == http2Mode {
5535 if r.TLS == nil {
5536 t.Errorf("r.TLS == nil")
5537 } else if r.TLS.ServerName != punyDomain {
5538 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5539 }
5540 }
5541 w.Header().Set("Hit-Handler", "1")
5542 }), func(tr *Transport) {
5543 if tr.TLSClientConfig != nil {
5544 tr.TLSClientConfig.InsecureSkipVerify = true
5545 }
5546 })
5547
5548 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5549 if err != nil {
5550 t.Fatal(err)
5551 }
5552
5553
5554 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5555 if host != punyDomain {
5556 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5557 return nil, nil
5558 }
5559 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5560 })
5561
5562 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5563 trace := &httptrace.ClientTrace{
5564 GetConn: func(hostPort string) {
5565 want := net.JoinHostPort(punyDomain, port)
5566 if hostPort != want {
5567 t.Errorf("getting conn for %q; want %q", hostPort, want)
5568 }
5569 },
5570 DNSStart: func(e httptrace.DNSStartInfo) {
5571 if e.Host != punyDomain {
5572 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5573 }
5574 },
5575 }
5576 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5577
5578 res, err := cst.tr.RoundTrip(req)
5579 if err != nil {
5580 t.Fatal(err)
5581 }
5582 defer res.Body.Close()
5583 if res.Header.Get("Hit-Handler") != "1" {
5584 out, err := httputil.DumpResponse(res, true)
5585 if err != nil {
5586 t.Fatal(err)
5587 }
5588 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5589 }
5590 }
5591
5592
5593 func TestTransportProxyConnectHeader(t *testing.T) {
5594 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5595 }
5596 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5597 reqc := make(chan *Request, 1)
5598 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5599 if r.Method != "CONNECT" {
5600 t.Errorf("method = %q; want CONNECT", r.Method)
5601 }
5602 reqc <- r
5603 c, _, err := w.(Hijacker).Hijack()
5604 if err != nil {
5605 t.Errorf("Hijack: %v", err)
5606 return
5607 }
5608 c.Close()
5609 })).ts
5610
5611 c := ts.Client()
5612 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5613 return url.Parse(ts.URL)
5614 }
5615 c.Transport.(*Transport).ProxyConnectHeader = Header{
5616 "User-Agent": {"foo"},
5617 "Other": {"bar"},
5618 }
5619
5620 res, err := c.Get("https://dummy.tld/")
5621 if err == nil {
5622 res.Body.Close()
5623 t.Errorf("unexpected success")
5624 }
5625
5626 r := <-reqc
5627 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5628 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5629 }
5630 if got, want := r.Header.Get("Other"), "bar"; got != want {
5631 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5632 }
5633 }
5634
5635 func TestTransportProxyGetConnectHeader(t *testing.T) {
5636 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5637 }
5638 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5639 reqc := make(chan *Request, 1)
5640 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5641 if r.Method != "CONNECT" {
5642 t.Errorf("method = %q; want CONNECT", r.Method)
5643 }
5644 reqc <- r
5645 c, _, err := w.(Hijacker).Hijack()
5646 if err != nil {
5647 t.Errorf("Hijack: %v", err)
5648 return
5649 }
5650 c.Close()
5651 })).ts
5652
5653 c := ts.Client()
5654 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5655 return url.Parse(ts.URL)
5656 }
5657
5658 c.Transport.(*Transport).ProxyConnectHeader = Header{
5659 "User-Agent": {"foo"},
5660 "Other": {"bar"},
5661 }
5662 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5663 return Header{
5664 "User-Agent": {"foo2"},
5665 "Other": {"bar2"},
5666 }, nil
5667 }
5668
5669 res, err := c.Get("https://dummy.tld/")
5670 if err == nil {
5671 res.Body.Close()
5672 t.Errorf("unexpected success")
5673 }
5674
5675 r := <-reqc
5676 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5677 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5678 }
5679 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5680 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5681 }
5682 }
5683
5684 var errFakeRoundTrip = errors.New("fake roundtrip")
5685
5686 type funcRoundTripper func()
5687
5688 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5689 fn()
5690 return nil, errFakeRoundTrip
5691 }
5692
5693 func wantBody(res *Response, err error, want string) error {
5694 if err != nil {
5695 return err
5696 }
5697 slurp, err := io.ReadAll(res.Body)
5698 if err != nil {
5699 return fmt.Errorf("error reading body: %v", err)
5700 }
5701 if string(slurp) != want {
5702 return fmt.Errorf("body = %q; want %q", slurp, want)
5703 }
5704 if err := res.Body.Close(); err != nil {
5705 return fmt.Errorf("body Close = %v", err)
5706 }
5707 return nil
5708 }
5709
5710 func newLocalListener(t *testing.T) net.Listener {
5711 ln, err := net.Listen("tcp", "127.0.0.1:0")
5712 if err != nil {
5713 ln, err = net.Listen("tcp6", "[::1]:0")
5714 }
5715 if err != nil {
5716 t.Fatal(err)
5717 }
5718 return ln
5719 }
5720
5721 type countCloseReader struct {
5722 n *int
5723 io.Reader
5724 }
5725
5726 func (cr countCloseReader) Close() error {
5727 (*cr.n)++
5728 return nil
5729 }
5730
5731
5732 var rgz = []byte{
5733 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
5734 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
5735 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
5736 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
5737 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
5738 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
5739 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
5740 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
5741 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
5742 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
5743 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
5744 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
5745 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
5746 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
5747 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5748 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
5749 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
5750 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
5751 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
5752 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5753 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
5754 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
5755 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
5756 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
5757 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
5758 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
5759 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
5760 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
5761 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
5762 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5763 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5764 0x00, 0x00,
5765 }
5766
5767
5768
5769 func TestMissingStatusNoPanic(t *testing.T) {
5770 t.Parallel()
5771
5772 const want = "unknown status code"
5773
5774 ln := newLocalListener(t)
5775 addr := ln.Addr().String()
5776 done := make(chan bool)
5777 fullAddrURL := fmt.Sprintf("http://%s", addr)
5778 raw := "HTTP/1.1 400\r\n" +
5779 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
5780 "Content-Type: text/html; charset=utf-8\r\n" +
5781 "Content-Length: 10\r\n" +
5782 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
5783 "Vary: Accept-Encoding\r\n\r\n" +
5784 "Aloha Olaa"
5785
5786 go func() {
5787 defer close(done)
5788
5789 conn, _ := ln.Accept()
5790 if conn != nil {
5791 io.WriteString(conn, raw)
5792 io.ReadAll(conn)
5793 conn.Close()
5794 }
5795 }()
5796
5797 proxyURL, err := url.Parse(fullAddrURL)
5798 if err != nil {
5799 t.Fatalf("proxyURL: %v", err)
5800 }
5801
5802 tr := &Transport{Proxy: ProxyURL(proxyURL)}
5803
5804 req, _ := NewRequest("GET", "https://golang.org/", nil)
5805 res, err, panicked := doFetchCheckPanic(tr, req)
5806 if panicked {
5807 t.Error("panicked, expecting an error")
5808 }
5809 if res != nil && res.Body != nil {
5810 io.Copy(io.Discard, res.Body)
5811 res.Body.Close()
5812 }
5813
5814 if err == nil || !strings.Contains(err.Error(), want) {
5815 t.Errorf("got=%v want=%q", err, want)
5816 }
5817
5818 ln.Close()
5819 <-done
5820 }
5821
5822 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
5823 defer func() {
5824 if r := recover(); r != nil {
5825 panicked = true
5826 }
5827 }()
5828 res, err = tr.RoundTrip(req)
5829 return
5830 }
5831
5832
5833
5834 func TestNoBodyOnChunked304Response(t *testing.T) {
5835 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
5836 }
5837 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
5838 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5839 conn, buf, _ := w.(Hijacker).Hijack()
5840 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
5841 buf.Flush()
5842 conn.Close()
5843 }))
5844
5845
5846
5847
5848
5849 cst.tr.DisableKeepAlives = true
5850
5851 res, err := cst.c.Get(cst.ts.URL)
5852 if err != nil {
5853 t.Fatal(err)
5854 }
5855
5856 if res.Body != NoBody {
5857 t.Errorf("Unexpected body on 304 response")
5858 }
5859 }
5860
5861 type funcWriter func([]byte) (int, error)
5862
5863 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
5864
5865 type doneContext struct {
5866 context.Context
5867 err error
5868 }
5869
5870 func (doneContext) Done() <-chan struct{} {
5871 c := make(chan struct{})
5872 close(c)
5873 return c
5874 }
5875
5876 func (d doneContext) Err() error { return d.err }
5877
5878
5879 func TestTransportCheckContextDoneEarly(t *testing.T) {
5880 tr := &Transport{}
5881 req, _ := NewRequest("GET", "http://fake.example/", nil)
5882 wantErr := errors.New("some error")
5883 req = req.WithContext(doneContext{context.Background(), wantErr})
5884 _, err := tr.RoundTrip(req)
5885 if err != wantErr {
5886 t.Errorf("error = %v; want %v", err, wantErr)
5887 }
5888 }
5889
5890
5891
5892
5893
5894
5895 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
5896 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
5897 }
5898 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
5899 timeout := 1 * time.Millisecond
5900 for {
5901 inHandler := make(chan bool)
5902 cancelHandler := make(chan struct{})
5903 handlerDone := make(chan bool)
5904 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5905 <-r.Context().Done()
5906
5907 select {
5908 case <-cancelHandler:
5909 return
5910 case inHandler <- true:
5911 }
5912 defer func() { handlerDone <- true }()
5913
5914
5915 conn, _, err := w.(Hijacker).Hijack()
5916 if err != nil {
5917 t.Error(err)
5918 return
5919 }
5920 n, err := conn.Read([]byte{0})
5921 if n != 0 || err != io.EOF {
5922 t.Errorf("unexpected Read result: %v, %v", n, err)
5923 }
5924 conn.Close()
5925 }))
5926
5927 cst.c.Timeout = timeout
5928
5929 _, err := cst.c.Get(cst.ts.URL)
5930 if err == nil {
5931 close(cancelHandler)
5932 t.Fatal("unexpected Get success")
5933 }
5934
5935 tooSlow := time.NewTimer(timeout * 10)
5936 select {
5937 case <-tooSlow.C:
5938
5939
5940
5941 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
5942 close(cancelHandler)
5943 cst.close()
5944 timeout *= 2
5945 continue
5946 case <-inHandler:
5947 tooSlow.Stop()
5948 <-handlerDone
5949 }
5950 break
5951 }
5952 }
5953
5954
5955
5956
5957
5958
5959 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
5960 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
5961 }
5962 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
5963 inHandler := make(chan bool)
5964 cancelHandler := make(chan struct{})
5965 handlerDone := make(chan bool)
5966 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5967 w.Header().Set("Content-Length", "100")
5968 w.(Flusher).Flush()
5969
5970 select {
5971 case <-cancelHandler:
5972 return
5973 case inHandler <- true:
5974 }
5975 defer func() { handlerDone <- true }()
5976
5977 conn, _, err := w.(Hijacker).Hijack()
5978 if err != nil {
5979 t.Error(err)
5980 return
5981 }
5982 conn.Write([]byte("foo"))
5983
5984 n, err := conn.Read([]byte{0})
5985
5986
5987
5988
5989
5990 if n != 0 || err == nil {
5991 t.Errorf("unexpected Read result: %v, %v", n, err)
5992 }
5993 conn.Close()
5994 }))
5995
5996
5997
5998
5999
6000 cst.c.Timeout = 24 * time.Hour
6001 req, _ := NewRequest("GET", cst.ts.URL, nil)
6002 cancelReq := make(chan struct{})
6003 req.Cancel = cancelReq
6004
6005 res, err := cst.c.Do(req)
6006 if err != nil {
6007 close(cancelHandler)
6008 t.Fatalf("Get error: %v", err)
6009 }
6010
6011
6012
6013
6014 close(cancelReq)
6015 got, err := io.ReadAll(res.Body)
6016 if err == nil {
6017 t.Errorf("unexpected success; read %q, nil", got)
6018 }
6019
6020
6021 <-inHandler
6022 <-handlerDone
6023 }
6024
6025 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
6026 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
6027 }
6028 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
6029 done := make(chan struct{})
6030 defer close(done)
6031 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6032 conn, _, err := w.(Hijacker).Hijack()
6033 if err != nil {
6034 t.Error(err)
6035 return
6036 }
6037 defer conn.Close()
6038 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
6039 bs := bufio.NewScanner(conn)
6040 bs.Scan()
6041 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
6042 <-done
6043 }))
6044
6045 req, _ := NewRequest("GET", cst.ts.URL, nil)
6046 req.Header.Set("Upgrade", "foo")
6047 req.Header.Set("Connection", "upgrade")
6048 res, err := cst.c.Do(req)
6049 if err != nil {
6050 t.Fatal(err)
6051 }
6052 if res.StatusCode != 101 {
6053 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
6054 }
6055 rwc, ok := res.Body.(io.ReadWriteCloser)
6056 if !ok {
6057 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
6058 }
6059 defer rwc.Close()
6060 bs := bufio.NewScanner(rwc)
6061 if !bs.Scan() {
6062 t.Fatalf("expected readable input")
6063 }
6064 if got, want := bs.Text(), "Some buffered data"; got != want {
6065 t.Errorf("read %q; want %q", got, want)
6066 }
6067 io.WriteString(rwc, "echo\n")
6068 if !bs.Scan() {
6069 t.Fatalf("expected another line")
6070 }
6071 if got, want := bs.Text(), "ECHO"; got != want {
6072 t.Errorf("read %q; want %q", got, want)
6073 }
6074 }
6075
6076 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
6077 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
6078 const target = "backend:443"
6079 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6080 if r.Method != "CONNECT" {
6081 t.Errorf("unexpected method %q", r.Method)
6082 w.WriteHeader(500)
6083 return
6084 }
6085 if r.RequestURI != target {
6086 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
6087 w.WriteHeader(500)
6088 return
6089 }
6090 nc, brw, err := w.(Hijacker).Hijack()
6091 if err != nil {
6092 t.Error(err)
6093 return
6094 }
6095 defer nc.Close()
6096 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6097
6098 for {
6099 line, err := brw.ReadString('\n')
6100 if err != nil {
6101 if err != io.EOF {
6102 t.Error(err)
6103 }
6104 return
6105 }
6106 io.WriteString(brw, strings.ToUpper(line))
6107 brw.Flush()
6108 }
6109 }))
6110 pr, pw := io.Pipe()
6111 defer pw.Close()
6112 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
6113 if err != nil {
6114 t.Fatal(err)
6115 }
6116 req.URL.Opaque = target
6117 res, err := cst.c.Do(req)
6118 if err != nil {
6119 t.Fatal(err)
6120 }
6121 defer res.Body.Close()
6122 if res.StatusCode != 200 {
6123 t.Fatalf("status code = %d; want 200", res.StatusCode)
6124 }
6125 br := bufio.NewReader(res.Body)
6126 for _, str := range []string{"foo", "bar", "baz"} {
6127 fmt.Fprintf(pw, "%s\n", str)
6128 got, err := br.ReadString('\n')
6129 if err != nil {
6130 t.Fatal(err)
6131 }
6132 got = strings.TrimSpace(got)
6133 want := strings.ToUpper(str)
6134 if got != want {
6135 t.Fatalf("got %q; want %q", got, want)
6136 }
6137 }
6138 }
6139
6140 func TestTransportRequestReplayable(t *testing.T) {
6141 someBody := io.NopCloser(strings.NewReader(""))
6142 tests := []struct {
6143 name string
6144 req *Request
6145 want bool
6146 }{
6147 {
6148 name: "GET",
6149 req: &Request{Method: "GET"},
6150 want: true,
6151 },
6152 {
6153 name: "GET_http.NoBody",
6154 req: &Request{Method: "GET", Body: NoBody},
6155 want: true,
6156 },
6157 {
6158 name: "GET_body",
6159 req: &Request{Method: "GET", Body: someBody},
6160 want: false,
6161 },
6162 {
6163 name: "POST",
6164 req: &Request{Method: "POST"},
6165 want: false,
6166 },
6167 {
6168 name: "POST_idempotency-key",
6169 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6170 want: true,
6171 },
6172 {
6173 name: "POST_x-idempotency-key",
6174 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6175 want: true,
6176 },
6177 {
6178 name: "POST_body",
6179 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6180 want: false,
6181 },
6182 }
6183 for _, tt := range tests {
6184 t.Run(tt.name, func(t *testing.T) {
6185 got := tt.req.ExportIsReplayable()
6186 if got != tt.want {
6187 t.Errorf("replyable = %v; want %v", got, tt.want)
6188 }
6189 })
6190 }
6191 }
6192
6193
6194
6195 type testMockTCPConn struct {
6196 *net.TCPConn
6197
6198 ReadFromCalled bool
6199 }
6200
6201 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6202 c.ReadFromCalled = true
6203 return c.TCPConn.ReadFrom(r)
6204 }
6205
6206 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6207 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6208 nBytes := int64(1 << 10)
6209 newFileFunc := func() (r io.Reader, done func(), err error) {
6210 f, err := os.CreateTemp("", "net-http-newfilefunc")
6211 if err != nil {
6212 return nil, nil, err
6213 }
6214
6215
6216 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6217 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6218 }
6219 if _, err := f.Seek(0, 0); err != nil {
6220 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6221 }
6222
6223 done = func() {
6224 f.Close()
6225 os.Remove(f.Name())
6226 }
6227
6228 return f, done, nil
6229 }
6230
6231 newBufferFunc := func() (io.Reader, func(), error) {
6232 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6233 }
6234
6235 cases := []struct {
6236 name string
6237 readerFunc func() (io.Reader, func(), error)
6238 contentLength int64
6239 expectedReadFrom bool
6240 }{
6241 {
6242 name: "file, length",
6243 readerFunc: newFileFunc,
6244 contentLength: nBytes,
6245 expectedReadFrom: true,
6246 },
6247 {
6248 name: "file, no length",
6249 readerFunc: newFileFunc,
6250 },
6251 {
6252 name: "file, negative length",
6253 readerFunc: newFileFunc,
6254 contentLength: -1,
6255 },
6256 {
6257 name: "buffer",
6258 contentLength: nBytes,
6259 readerFunc: newBufferFunc,
6260 },
6261 {
6262 name: "buffer, no length",
6263 readerFunc: newBufferFunc,
6264 },
6265 {
6266 name: "buffer, length -1",
6267 contentLength: -1,
6268 readerFunc: newBufferFunc,
6269 },
6270 }
6271
6272 for _, tc := range cases {
6273 t.Run(tc.name, func(t *testing.T) {
6274 r, cleanup, err := tc.readerFunc()
6275 if err != nil {
6276 t.Fatal(err)
6277 }
6278 defer cleanup()
6279
6280 tConn := &testMockTCPConn{}
6281 trFunc := func(tr *Transport) {
6282 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6283 var d net.Dialer
6284 conn, err := d.DialContext(ctx, network, addr)
6285 if err != nil {
6286 return nil, err
6287 }
6288
6289 tcpConn, ok := conn.(*net.TCPConn)
6290 if !ok {
6291 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6292 }
6293
6294 tConn.TCPConn = tcpConn
6295 return tConn, nil
6296 }
6297 }
6298
6299 cst := newClientServerTest(
6300 t,
6301 mode,
6302 HandlerFunc(func(w ResponseWriter, r *Request) {
6303 io.Copy(io.Discard, r.Body)
6304 r.Body.Close()
6305 w.WriteHeader(200)
6306 }),
6307 trFunc,
6308 )
6309
6310 req, err := NewRequest("PUT", cst.ts.URL, r)
6311 if err != nil {
6312 t.Fatal(err)
6313 }
6314 req.ContentLength = tc.contentLength
6315 req.Header.Set("Content-Type", "application/octet-stream")
6316 resp, err := cst.c.Do(req)
6317 if err != nil {
6318 t.Fatal(err)
6319 }
6320 defer resp.Body.Close()
6321 if resp.StatusCode != 200 {
6322 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6323 }
6324
6325 expectedReadFrom := tc.expectedReadFrom
6326 if mode != http1Mode {
6327 expectedReadFrom = false
6328 }
6329 if !tConn.ReadFromCalled && expectedReadFrom {
6330 t.Fatalf("did not call ReadFrom")
6331 }
6332
6333 if tConn.ReadFromCalled && !expectedReadFrom {
6334 t.Fatalf("ReadFrom was unexpectedly invoked")
6335 }
6336 })
6337 }
6338 }
6339
6340 func TestTransportClone(t *testing.T) {
6341 tr := &Transport{
6342 Proxy: func(*Request) (*url.URL, error) { panic("") },
6343 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6344 return nil
6345 },
6346 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6347 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6348 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6349 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6350 TLSClientConfig: new(tls.Config),
6351 TLSHandshakeTimeout: time.Second,
6352 DisableKeepAlives: true,
6353 DisableCompression: true,
6354 MaxIdleConns: 1,
6355 MaxIdleConnsPerHost: 1,
6356 MaxConnsPerHost: 1,
6357 IdleConnTimeout: time.Second,
6358 ResponseHeaderTimeout: time.Second,
6359 ExpectContinueTimeout: time.Second,
6360 ProxyConnectHeader: Header{},
6361 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6362 MaxResponseHeaderBytes: 1,
6363 ForceAttemptHTTP2: true,
6364 HTTP2: &HTTP2Config{MaxConcurrentStreams: 1},
6365 Protocols: &Protocols{},
6366 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6367 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6368 },
6369 ReadBufferSize: 1,
6370 WriteBufferSize: 1,
6371 }
6372 tr.Protocols.SetHTTP1(true)
6373 tr.Protocols.SetHTTP2(true)
6374 tr2 := tr.Clone()
6375 rv := reflect.ValueOf(tr2).Elem()
6376 rt := rv.Type()
6377 for i := 0; i < rt.NumField(); i++ {
6378 sf := rt.Field(i)
6379 if !token.IsExported(sf.Name) {
6380 continue
6381 }
6382 if rv.Field(i).IsZero() {
6383 t.Errorf("cloned field t2.%s is zero", sf.Name)
6384 }
6385 }
6386
6387 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6388 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6389 }
6390
6391
6392 tr = new(Transport)
6393 tr2 = tr.Clone()
6394 if tr2.TLSNextProto != nil {
6395 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6396 }
6397 }
6398
6399 func TestIs408(t *testing.T) {
6400 tests := []struct {
6401 in string
6402 want bool
6403 }{
6404 {"HTTP/1.0 408", true},
6405 {"HTTP/1.1 408", true},
6406 {"HTTP/1.8 408", true},
6407 {"HTTP/2.0 408", false},
6408 {"HTTP/1.1 408 ", true},
6409 {"HTTP/1.1 40", false},
6410 {"http/1.0 408", false},
6411 {"HTTP/1-1 408", false},
6412 }
6413 for _, tt := range tests {
6414 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6415 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6416 }
6417 }
6418 }
6419
6420 func TestTransportIgnores408(t *testing.T) {
6421 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6422 }
6423 func testTransportIgnores408(t *testing.T, mode testMode) {
6424
6425 defer log.SetOutput(log.Writer())
6426
6427 var logout strings.Builder
6428 log.SetOutput(&logout)
6429
6430 const target = "backend:443"
6431
6432 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6433 nc, _, err := w.(Hijacker).Hijack()
6434 if err != nil {
6435 t.Error(err)
6436 return
6437 }
6438 defer nc.Close()
6439 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6440 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6441 }))
6442 req, err := NewRequest("GET", cst.ts.URL, nil)
6443 if err != nil {
6444 t.Fatal(err)
6445 }
6446 res, err := cst.c.Do(req)
6447 if err != nil {
6448 t.Fatal(err)
6449 }
6450 slurp, err := io.ReadAll(res.Body)
6451 if err != nil {
6452 t.Fatal(err)
6453 }
6454 if err != nil {
6455 t.Fatal(err)
6456 }
6457 if string(slurp) != "ok" {
6458 t.Fatalf("got %q; want ok", slurp)
6459 }
6460
6461 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6462 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6463 if d > 0 {
6464 t.Logf("%v idle conns still present after %v", n, d)
6465 }
6466 return false
6467 }
6468 return true
6469 })
6470 if got := logout.String(); got != "" {
6471 t.Fatalf("expected no log output; got: %s", got)
6472 }
6473 }
6474
6475 func TestInvalidHeaderResponse(t *testing.T) {
6476 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6477 }
6478 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6479 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6480 conn, buf, _ := w.(Hijacker).Hijack()
6481 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6482 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6483 "Content-Type: text/html; charset=utf-8\r\n" +
6484 "Content-Length: 0\r\n" +
6485 "Foo : bar\r\n\r\n"))
6486 buf.Flush()
6487 conn.Close()
6488 }))
6489 res, err := cst.c.Get(cst.ts.URL)
6490 if err != nil {
6491 t.Fatal(err)
6492 }
6493 defer res.Body.Close()
6494 if v := res.Header.Get("Foo"); v != "" {
6495 t.Errorf(`unexpected "Foo" header: %q`, v)
6496 }
6497 if v := res.Header.Get("Foo "); v != "bar" {
6498 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6499 }
6500 }
6501
6502 type bodyCloser bool
6503
6504 func (bc *bodyCloser) Close() error {
6505 *bc = true
6506 return nil
6507 }
6508 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6509 return 0, io.EOF
6510 }
6511
6512
6513
6514 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6515 run(t, testTransportClosesBodyOnInvalidRequests)
6516 }
6517 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6518 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6519 t.Errorf("Should not have been invoked")
6520 })).ts
6521
6522 u, _ := url.Parse(cst.URL)
6523
6524 tests := []struct {
6525 name string
6526 req *Request
6527 wantErr string
6528 }{
6529 {
6530 name: "invalid method",
6531 req: &Request{
6532 Method: " ",
6533 URL: u,
6534 },
6535 wantErr: `invalid method " "`,
6536 },
6537 {
6538 name: "nil URL",
6539 req: &Request{
6540 Method: "GET",
6541 },
6542 wantErr: `nil Request.URL`,
6543 },
6544 {
6545 name: "invalid header key",
6546 req: &Request{
6547 Method: "GET",
6548 Header: Header{"💡": {"emoji"}},
6549 URL: u,
6550 },
6551 wantErr: `invalid header field name "💡"`,
6552 },
6553 {
6554 name: "invalid header value",
6555 req: &Request{
6556 Method: "POST",
6557 Header: Header{"key": {"\x19"}},
6558 URL: u,
6559 },
6560 wantErr: `invalid header field value for "key"`,
6561 },
6562 {
6563 name: "non HTTP(s) scheme",
6564 req: &Request{
6565 Method: "POST",
6566 URL: &url.URL{Scheme: "faux"},
6567 },
6568 wantErr: `unsupported protocol scheme "faux"`,
6569 },
6570 {
6571 name: "no Host in URL",
6572 req: &Request{
6573 Method: "POST",
6574 URL: &url.URL{Scheme: "http"},
6575 },
6576 wantErr: `no Host in request URL`,
6577 },
6578 }
6579
6580 for _, tt := range tests {
6581 t.Run(tt.name, func(t *testing.T) {
6582 var bc bodyCloser
6583 req := tt.req
6584 req.Body = &bc
6585 _, err := cst.Client().Do(tt.req)
6586 if err == nil {
6587 t.Fatal("Expected an error")
6588 }
6589 if !bc {
6590 t.Fatal("Expected body to have been closed")
6591 }
6592 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6593 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6594 }
6595 })
6596 }
6597 }
6598
6599
6600
6601 type breakableConn struct {
6602 net.Conn
6603 *brokenState
6604 }
6605
6606 type brokenState struct {
6607 sync.Mutex
6608 broken bool
6609 }
6610
6611 func (w *breakableConn) Write(b []byte) (n int, err error) {
6612 w.Lock()
6613 defer w.Unlock()
6614 if w.broken {
6615 return 0, errors.New("some write error")
6616 }
6617 return w.Conn.Write(b)
6618 }
6619
6620
6621 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6622 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6623 }
6624 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6625 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6626
6627 var brokenState brokenState
6628
6629 const numReqs = 5
6630 var numDials, gotConns uint32
6631
6632 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6633 atomic.AddUint32(&numDials, 1)
6634 c, err := net.Dial(netw, addr)
6635 if err != nil {
6636 t.Errorf("unexpected Dial error: %v", err)
6637 return nil, err
6638 }
6639 return &breakableConn{c, &brokenState}, err
6640 }
6641
6642 for i := 1; i <= numReqs; i++ {
6643 brokenState.Lock()
6644 brokenState.broken = false
6645 brokenState.Unlock()
6646
6647
6648
6649
6650 doBreak := i != numReqs
6651
6652 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6653 GotConn: func(info httptrace.GotConnInfo) {
6654 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6655 atomic.AddUint32(&gotConns, 1)
6656 },
6657 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6658 brokenState.Lock()
6659 defer brokenState.Unlock()
6660 if doBreak {
6661 brokenState.broken = true
6662 }
6663 },
6664 })
6665 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6666 if err != nil {
6667 t.Fatal(err)
6668 }
6669 _, err = cst.c.Do(req)
6670 if doBreak != (err != nil) {
6671 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6672 }
6673 }
6674 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6675 t.Errorf("GotConn calls = %v; want %v", got, want)
6676 }
6677 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6678 t.Errorf("Dials = %v; want %v", got, want)
6679 }
6680 }
6681
6682
6683
6684
6685
6686 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6687 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6688 }
6689 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6690 CondSkipHTTP2(t)
6691
6692 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6693 _, err := w.Write([]byte("foo"))
6694 if err != nil {
6695 t.Fatalf("Write: %v", err)
6696 }
6697 })
6698
6699 ts := newClientServerTest(t, mode, h).ts
6700
6701 c := ts.Client()
6702 tr := c.Transport.(*Transport)
6703 tr.MaxConnsPerHost = 1
6704
6705 errCh := make(chan error, 300)
6706 doReq := func() {
6707 resp, err := c.Get(ts.URL)
6708 if err != nil {
6709 errCh <- fmt.Errorf("request failed: %v", err)
6710 return
6711 }
6712 defer resp.Body.Close()
6713 _, err = io.ReadAll(resp.Body)
6714 if err != nil {
6715 errCh <- fmt.Errorf("read body failed: %v", err)
6716 }
6717 }
6718
6719 var wg sync.WaitGroup
6720 for i := 0; i < 300; i++ {
6721 wg.Add(1)
6722 go func() {
6723 defer wg.Done()
6724 doReq()
6725 }()
6726 }
6727 wg.Wait()
6728 close(errCh)
6729
6730 for err := range errCh {
6731 t.Errorf("error occurred: %v", err)
6732 }
6733 }
6734
6735
6736
6737
6738 func TestAltProtoCancellation(t *testing.T) {
6739 defer afterTest(t)
6740 tr := &Transport{}
6741 c := &Client{
6742 Transport: tr,
6743 Timeout: time.Millisecond,
6744 }
6745 tr.RegisterProtocol("cancel", cancelProto{})
6746 _, err := c.Get("cancel://bar.com/path")
6747 if err == nil {
6748 t.Error("request unexpectedly succeeded")
6749 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
6750 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
6751 }
6752 }
6753
6754 var errCancelProto = errors.New("canceled as expected")
6755
6756 type cancelProto struct{}
6757
6758 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
6759 <-req.Cancel
6760 return nil, errCancelProto
6761 }
6762
6763 type roundTripFunc func(r *Request) (*Response, error)
6764
6765 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
6766
6767
6768 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
6769 func testIssue32441(t *testing.T, mode testMode) {
6770 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6771 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6772 t.Error("body length is zero")
6773 }
6774 })).ts
6775 c := ts.Client()
6776 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
6777
6778 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6779 t.Error("body length is zero during round trip")
6780 }
6781 return nil, ErrSkipAltProtocol
6782 }))
6783 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
6784 t.Error(err)
6785 }
6786 }
6787
6788
6789
6790 func TestTransportRejectsSignInContentLength(t *testing.T) {
6791 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
6792 }
6793 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
6794 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6795 w.Header().Set("Content-Length", "+3")
6796 w.Write([]byte("abc"))
6797 })).ts
6798
6799 c := cst.Client()
6800 res, err := c.Get(cst.URL)
6801 if err == nil || res != nil {
6802 t.Fatal("Expected a non-nil error and a nil http.Response")
6803 }
6804 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
6805 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
6806 }
6807 }
6808
6809
6810 type dumpConn struct {
6811 io.Writer
6812 io.Reader
6813 }
6814
6815 func (c *dumpConn) Close() error { return nil }
6816 func (c *dumpConn) LocalAddr() net.Addr { return nil }
6817 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
6818 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
6819 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
6820 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
6821
6822
6823
6824 type delegateReader struct {
6825 c chan io.Reader
6826 r io.Reader
6827 }
6828
6829 func (r *delegateReader) Read(p []byte) (int, error) {
6830 if r.r == nil {
6831 var ok bool
6832 if r.r, ok = <-r.c; !ok {
6833 return 0, errors.New("delegate closed")
6834 }
6835 }
6836 return r.r.Read(p)
6837 }
6838
6839 func testTransportRace(req *Request) {
6840 save := req.Body
6841 pr, pw := io.Pipe()
6842 defer pr.Close()
6843 defer pw.Close()
6844 dr := &delegateReader{c: make(chan io.Reader)}
6845
6846 t := &Transport{
6847 Dial: func(net, addr string) (net.Conn, error) {
6848 return &dumpConn{pw, dr}, nil
6849 },
6850 }
6851 defer t.CloseIdleConnections()
6852
6853 quitReadCh := make(chan struct{})
6854
6855 go func() {
6856 defer close(quitReadCh)
6857
6858 req, err := ReadRequest(bufio.NewReader(pr))
6859 if err == nil {
6860
6861
6862 io.Copy(io.Discard, req.Body)
6863 req.Body.Close()
6864 }
6865 select {
6866 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
6867 case quitReadCh <- struct{}{}:
6868
6869 close(dr.c)
6870 }
6871 }()
6872
6873 t.RoundTrip(req)
6874
6875
6876
6877 pw.Close()
6878 <-quitReadCh
6879
6880 req.Body = save
6881 }
6882
6883
6884
6885
6886
6887 func TestErrorWriteLoopRace(t *testing.T) {
6888 if testing.Short() {
6889 return
6890 }
6891 t.Parallel()
6892 for i := 0; i < 1000; i++ {
6893 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
6894 ctx, cancel := context.WithTimeout(context.Background(), delay)
6895 defer cancel()
6896
6897 r := bytes.NewBuffer(make([]byte, 10000))
6898 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
6899 if err != nil {
6900 t.Fatal(err)
6901 }
6902
6903 testTransportRace(req)
6904 }
6905 }
6906
6907
6908
6909
6910 func TestCancelRequestWhenSharingConnection(t *testing.T) {
6911 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
6912 }
6913 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
6914 reqc := make(chan chan struct{}, 2)
6915 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
6916 ch := make(chan struct{}, 1)
6917 reqc <- ch
6918 <-ch
6919 w.Header().Add("Content-Length", "0")
6920 })).ts
6921
6922 client := ts.Client()
6923 transport := client.Transport.(*Transport)
6924 transport.MaxIdleConns = 1
6925 transport.MaxConnsPerHost = 1
6926
6927 var wg sync.WaitGroup
6928
6929 wg.Add(1)
6930 putidlec := make(chan chan struct{}, 1)
6931 reqerrc := make(chan error, 1)
6932 go func() {
6933 defer wg.Done()
6934 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6935 PutIdleConn: func(error) {
6936
6937
6938 ch := make(chan struct{})
6939 putidlec <- ch
6940 close(putidlec)
6941 <-ch
6942 },
6943 })
6944 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
6945 res, err := client.Do(req)
6946 if err != nil {
6947 reqerrc <- err
6948 } else {
6949 res.Body.Close()
6950 }
6951 }()
6952
6953
6954
6955 select {
6956 case err := <-reqerrc:
6957 t.Fatalf("request 1: got err %v, want nil", err)
6958 case r1c := <-reqc:
6959 close(r1c)
6960 }
6961 var idlec chan struct{}
6962 select {
6963 case err := <-reqerrc:
6964 t.Fatalf("request 1: got err %v, want nil", err)
6965 case idlec = <-putidlec:
6966 }
6967
6968 wg.Add(1)
6969 cancelctx, cancel := context.WithCancel(context.Background())
6970 go func() {
6971 defer wg.Done()
6972 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
6973 res, err := client.Do(req)
6974 if err == nil {
6975 res.Body.Close()
6976 }
6977 if !errors.Is(err, context.Canceled) {
6978 t.Errorf("request 2: got err %v, want Canceled", err)
6979 }
6980
6981
6982 close(idlec)
6983 }()
6984
6985
6986
6987 r2c := <-reqc
6988 cancel()
6989
6990 <-idlec
6991
6992 close(r2c)
6993 wg.Wait()
6994 }
6995
6996 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
6997 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
6998 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6999 go io.Copy(io.Discard, req.Body)
7000 panic(ErrAbortHandler)
7001 })).ts
7002
7003 var wg sync.WaitGroup
7004 for i := 0; i < 2; i++ {
7005 wg.Add(1)
7006 go func() {
7007 defer wg.Done()
7008 for j := 0; j < 10; j++ {
7009 const reqLen = 6 * 1024 * 1024
7010 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
7011 req.ContentLength = reqLen
7012 resp, _ := ts.Client().Transport.RoundTrip(req)
7013 if resp != nil {
7014 resp.Body.Close()
7015 }
7016 }
7017 }()
7018 }
7019 wg.Wait()
7020 }
7021
7022 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
7023 func testRequestSanitization(t *testing.T, mode testMode) {
7024 if mode == http2Mode {
7025
7026 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
7027 }
7028 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7029 if h, ok := req.Header["X-Evil"]; ok {
7030 t.Errorf("request has X-Evil header: %q", h)
7031 }
7032 })).ts
7033 req, _ := NewRequest("GET", ts.URL, nil)
7034 req.Host = "go.dev\r\nX-Evil:evil"
7035 resp, _ := ts.Client().Do(req)
7036 if resp != nil {
7037 resp.Body.Close()
7038 }
7039 }
7040
7041 func TestProxyAuthHeader(t *testing.T) {
7042
7043 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
7044 }
7045 func testProxyAuthHeader(t *testing.T, mode testMode) {
7046 const username = "u"
7047 const password = "@/?!"
7048 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7049
7050
7051 var r2 Request
7052 r2.Header = Header{
7053 "Authorization": req.Header["Proxy-Authorization"],
7054 }
7055 gotuser, gotpass, ok := r2.BasicAuth()
7056 if !ok || gotuser != username || gotpass != password {
7057 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
7058 }
7059 }))
7060 u, err := url.Parse(cst.ts.URL)
7061 if err != nil {
7062 t.Fatal(err)
7063 }
7064 u.User = url.UserPassword(username, password)
7065 t.Setenv("HTTP_PROXY", u.String())
7066 cst.tr.Proxy = ProxyURL(u)
7067 resp, err := cst.c.Get("http://_/")
7068 if err != nil {
7069 t.Fatal(err)
7070 }
7071 resp.Body.Close()
7072 }
7073
7074
7075 func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
7076 ln := newLocalListener(t)
7077 addr := ln.Addr().String()
7078
7079 done := make(chan struct{})
7080 go func() {
7081 conn, err := ln.Accept()
7082 if err != nil {
7083 t.Errorf("ln.Accept: %v", err)
7084 return
7085 }
7086
7087
7088 if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
7089 t.Errorf("conn.Read: %v", err)
7090 return
7091 }
7092 io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
7093 <-done
7094 conn.Close()
7095 }()
7096
7097 didRead := make(chan bool)
7098 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
7099 defer SetReadLoopBeforeNextReadHook(nil)
7100
7101 tr := &Transport{}
7102
7103
7104 req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
7105 if err != nil {
7106 t.Fatalf("NewRequest: %v", err)
7107 }
7108
7109 resp, err := tr.RoundTrip(req)
7110 if err != nil {
7111 t.Fatalf("tr.RoundTrip: %v", err)
7112 }
7113
7114 close(done)
7115
7116
7117
7118 <-didRead
7119
7120 resp.Body.Close()
7121
7122
7123
7124 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
7125 n := tr.NumPendingRequestsForTesting()
7126 if n > 0 {
7127 if d > 0 {
7128 t.Logf("pending requests = %d after %v (want 0)", n, d)
7129 }
7130 return false
7131 }
7132 return true
7133 })
7134 }
7135
7136 func TestValidateClientRequestTrailers(t *testing.T) {
7137 run(t, testValidateClientRequestTrailers)
7138 }
7139
7140 func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
7141 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7142 rw.Write([]byte("Hello"))
7143 })).ts
7144
7145 cases := []struct {
7146 trailer Header
7147 wantErr string
7148 }{
7149 {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
7150 {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
7151 }
7152
7153 for i, tt := range cases {
7154 testName := fmt.Sprintf("%s%d", mode, i)
7155 t.Run(testName, func(t *testing.T) {
7156 req, err := NewRequest("GET", cst.URL, nil)
7157 if err != nil {
7158 t.Fatal(err)
7159 }
7160 req.Trailer = tt.trailer
7161 res, err := cst.Client().Do(req)
7162 if err == nil {
7163 t.Fatal("Expected an error")
7164 }
7165 if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
7166 t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
7167 }
7168 if res != nil {
7169 t.Fatal("Unexpected non-nil response")
7170 }
7171 })
7172 }
7173 }
7174
7175 func TestTransportServerProtocols(t *testing.T) {
7176 CondSkipHTTP2(t)
7177 DefaultTransport.(*Transport).CloseIdleConnections()
7178
7179 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7180 if err != nil {
7181 t.Fatal(err)
7182 }
7183 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7184 if err != nil {
7185 t.Fatal(err)
7186 }
7187 certpool := x509.NewCertPool()
7188 certpool.AddCert(leafCert)
7189
7190 for _, test := range []struct {
7191 name string
7192 scheme string
7193 setup func(t *testing.T)
7194 transport func(*Transport)
7195 server func(*Server)
7196 want string
7197 }{{
7198 name: "http default",
7199 scheme: "http",
7200 want: "HTTP/1.1",
7201 }, {
7202 name: "https default",
7203 scheme: "https",
7204 transport: func(tr *Transport) {
7205
7206 },
7207 want: "HTTP/1.1",
7208 }, {
7209 name: "https transport protocols include HTTP2",
7210 scheme: "https",
7211 transport: func(tr *Transport) {
7212
7213
7214 tr.Protocols = &Protocols{}
7215 tr.Protocols.SetHTTP1(true)
7216 tr.Protocols.SetHTTP2(true)
7217 },
7218 want: "HTTP/2.0",
7219 }, {
7220 name: "https transport protocols only include HTTP1",
7221 scheme: "https",
7222 transport: func(tr *Transport) {
7223
7224 tr.Protocols = &Protocols{}
7225 tr.Protocols.SetHTTP1(true)
7226 },
7227 want: "HTTP/1.1",
7228 }, {
7229 name: "https transport ForceAttemptHTTP2",
7230 scheme: "https",
7231 transport: func(tr *Transport) {
7232
7233 tr.ForceAttemptHTTP2 = true
7234 },
7235 want: "HTTP/2.0",
7236 }, {
7237 name: "https transport protocols override TLSNextProto",
7238 scheme: "https",
7239 transport: func(tr *Transport) {
7240
7241
7242
7243 tr.Protocols = &Protocols{}
7244 tr.Protocols.SetHTTP1(true)
7245 tr.Protocols.SetHTTP2(true)
7246 tr.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{}
7247 },
7248 want: "HTTP/2.0",
7249 }, {
7250 name: "https server disables HTTP2 with TLSNextProto",
7251 scheme: "https",
7252 server: func(srv *Server) {
7253
7254
7255 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7256 },
7257 want: "HTTP/1.1",
7258 }, {
7259 name: "https server Protocols overrides empty TLSNextProto",
7260 scheme: "https",
7261 server: func(srv *Server) {
7262
7263
7264 srv.Protocols = &Protocols{}
7265 srv.Protocols.SetHTTP1(true)
7266 srv.Protocols.SetHTTP2(true)
7267 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7268 },
7269 want: "HTTP/2.0",
7270 }, {
7271 name: "https server protocols only include HTTP1",
7272 scheme: "https",
7273 server: func(srv *Server) {
7274 srv.Protocols = &Protocols{}
7275 srv.Protocols.SetHTTP1(true)
7276 },
7277 want: "HTTP/1.1",
7278 }, {
7279 name: "https server protocols include HTTP2",
7280 scheme: "https",
7281 server: func(srv *Server) {
7282 srv.Protocols = &Protocols{}
7283 srv.Protocols.SetHTTP1(true)
7284 srv.Protocols.SetHTTP2(true)
7285 },
7286 want: "HTTP/2.0",
7287 }, {
7288 name: "GODEBUG disables HTTP2 client",
7289 scheme: "https",
7290 setup: func(t *testing.T) {
7291 t.Setenv("GODEBUG", "http2client=0")
7292 },
7293 transport: func(tr *Transport) {
7294
7295
7296 tr.Protocols = &Protocols{}
7297 tr.Protocols.SetHTTP1(true)
7298 tr.Protocols.SetHTTP2(true)
7299 },
7300 want: "HTTP/1.1",
7301 }, {
7302 name: "GODEBUG disables HTTP2 server",
7303 scheme: "https",
7304 setup: func(t *testing.T) {
7305 t.Setenv("GODEBUG", "http2server=0")
7306 },
7307 transport: func(tr *Transport) {
7308
7309
7310 tr.Protocols = &Protocols{}
7311 tr.Protocols.SetHTTP1(true)
7312 tr.Protocols.SetHTTP2(true)
7313 },
7314 want: "HTTP/1.1",
7315 }} {
7316 t.Run(test.name, func(t *testing.T) {
7317
7318
7319 srv := &Server{
7320 TLSConfig: &tls.Config{
7321 Certificates: []tls.Certificate{cert},
7322 },
7323 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {
7324 w.Header().Set("X-Proto", req.Proto)
7325 }),
7326 }
7327 tr := &Transport{
7328 TLSClientConfig: &tls.Config{
7329 RootCAs: certpool,
7330 },
7331 }
7332
7333 if test.setup != nil {
7334 test.setup(t)
7335 }
7336 if test.server != nil {
7337 test.server(srv)
7338 }
7339 if test.transport != nil {
7340 test.transport(tr)
7341 } else {
7342 tr.Protocols = &Protocols{}
7343 tr.Protocols.SetHTTP1(true)
7344 tr.Protocols.SetHTTP2(true)
7345 }
7346
7347 listener := newLocalListener(t)
7348 srvc := make(chan error, 1)
7349 go func() {
7350 switch test.scheme {
7351 case "http":
7352 srvc <- srv.Serve(listener)
7353 case "https":
7354 srvc <- srv.ServeTLS(listener, "", "")
7355 }
7356 }()
7357 t.Cleanup(func() {
7358 srv.Close()
7359 <-srvc
7360 })
7361
7362 client := &Client{Transport: tr}
7363 resp, err := client.Get(test.scheme + "://" + listener.Addr().String())
7364 if err != nil {
7365 t.Fatal(err)
7366 }
7367 if got := resp.Header.Get("X-Proto"); got != test.want {
7368 t.Fatalf("request proto %q, want %q", got, test.want)
7369 }
7370 })
7371 }
7372 }
7373
View as plain text