Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "io"
26 "log"
27 mrand "math/rand"
28 "net"
29 . "net/http"
30 "net/http/httptest"
31 "net/http/httptrace"
32 "net/http/httputil"
33 "net/http/internal/testcert"
34 "net/textproto"
35 "net/url"
36 "os"
37 "reflect"
38 "runtime"
39 "strconv"
40 "strings"
41 "sync"
42 "sync/atomic"
43 "testing"
44 "testing/iotest"
45 "time"
46
47 "golang.org/x/net/http/httpguts"
48 )
49
50
51
52
53
54 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
55 if r.FormValue("close") == "true" {
56 w.Header().Set("Connection", "close")
57 }
58 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
59 w.Write([]byte(r.RemoteAddr))
60
61
62
63 if c, ok := ResponseWriterConnForTesting(w); ok {
64 fmt.Fprintf(w, ", %T %p", c, c)
65 }
66 })
67
68
69 type testCloseConn struct {
70 net.Conn
71 set *testConnSet
72 }
73
74 func (c *testCloseConn) Close() error {
75 c.set.remove(c)
76 return c.Conn.Close()
77 }
78
79
80
81 type testConnSet struct {
82 t *testing.T
83 mu sync.Mutex
84 closed map[net.Conn]bool
85 list []net.Conn
86 }
87
88 func (tcs *testConnSet) insert(c net.Conn) {
89 tcs.mu.Lock()
90 defer tcs.mu.Unlock()
91 tcs.closed[c] = false
92 tcs.list = append(tcs.list, c)
93 }
94
95 func (tcs *testConnSet) remove(c net.Conn) {
96 tcs.mu.Lock()
97 defer tcs.mu.Unlock()
98 tcs.closed[c] = true
99 }
100
101
102 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
103 connSet := &testConnSet{
104 t: t,
105 closed: make(map[net.Conn]bool),
106 }
107 dial := func(n, addr string) (net.Conn, error) {
108 c, err := net.Dial(n, addr)
109 if err != nil {
110 return nil, err
111 }
112 tc := &testCloseConn{c, connSet}
113 connSet.insert(tc)
114 return tc, nil
115 }
116 return connSet, dial
117 }
118
119 func (tcs *testConnSet) check(t *testing.T) {
120 tcs.mu.Lock()
121 defer tcs.mu.Unlock()
122 for i := 4; i >= 0; i-- {
123 for i, c := range tcs.list {
124 if tcs.closed[c] {
125 continue
126 }
127 if i != 0 {
128
129
130 tcs.mu.Unlock()
131 time.Sleep(50 * time.Millisecond)
132 tcs.mu.Lock()
133 continue
134 }
135 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
136 }
137 }
138 }
139
140 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
141 func testReuseRequest(t *testing.T, mode testMode) {
142 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
143 w.Write([]byte("{}"))
144 })).ts
145
146 c := ts.Client()
147 req, _ := NewRequest("GET", ts.URL, nil)
148 res, err := c.Do(req)
149 if err != nil {
150 t.Fatal(err)
151 }
152 err = res.Body.Close()
153 if err != nil {
154 t.Fatal(err)
155 }
156
157 res, err = c.Do(req)
158 if err != nil {
159 t.Fatal(err)
160 }
161 err = res.Body.Close()
162 if err != nil {
163 t.Fatal(err)
164 }
165 }
166
167
168
169 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
170 func testTransportKeepAlives(t *testing.T, mode testMode) {
171 ts := newClientServerTest(t, mode, hostPortHandler).ts
172
173 c := ts.Client()
174 for _, disableKeepAlive := range []bool{false, true} {
175 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
176 fetch := func(n int) string {
177 res, err := c.Get(ts.URL)
178 if err != nil {
179 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
180 }
181 body, err := io.ReadAll(res.Body)
182 if err != nil {
183 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
184 }
185 return string(body)
186 }
187
188 body1 := fetch(1)
189 body2 := fetch(2)
190
191 bodiesDiffer := body1 != body2
192 if bodiesDiffer != disableKeepAlive {
193 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
194 disableKeepAlive, bodiesDiffer, body1, body2)
195 }
196 }
197 }
198
199 func TestTransportConnectionCloseOnResponse(t *testing.T) {
200 run(t, testTransportConnectionCloseOnResponse)
201 }
202 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
203 ts := newClientServerTest(t, mode, hostPortHandler).ts
204
205 connSet, testDial := makeTestDial(t)
206
207 c := ts.Client()
208 tr := c.Transport.(*Transport)
209 tr.Dial = testDial
210
211 for _, connectionClose := range []bool{false, true} {
212 fetch := func(n int) string {
213 req := new(Request)
214 var err error
215 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
216 if err != nil {
217 t.Fatalf("URL parse error: %v", err)
218 }
219 req.Method = "GET"
220 req.Proto = "HTTP/1.1"
221 req.ProtoMajor = 1
222 req.ProtoMinor = 1
223
224 res, err := c.Do(req)
225 if err != nil {
226 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
227 }
228 defer res.Body.Close()
229 body, err := io.ReadAll(res.Body)
230 if err != nil {
231 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
232 }
233 return string(body)
234 }
235
236 body1 := fetch(1)
237 body2 := fetch(2)
238 bodiesDiffer := body1 != body2
239 if bodiesDiffer != connectionClose {
240 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
241 connectionClose, bodiesDiffer, body1, body2)
242 }
243
244 tr.CloseIdleConnections()
245 }
246
247 connSet.check(t)
248 }
249
250
251
252
253
254
255
256 func TestTransportConnectionCloseOnRequest(t *testing.T) {
257 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
258 }
259 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
260 ts := newClientServerTest(t, mode, hostPortHandler).ts
261
262 connSet, testDial := makeTestDial(t)
263
264 c := ts.Client()
265 tr := c.Transport.(*Transport)
266 tr.Dial = testDial
267 for _, reqClose := range []bool{false, true} {
268 fetch := func(n int) string {
269 req := new(Request)
270 var err error
271 req.URL, err = url.Parse(ts.URL)
272 if err != nil {
273 t.Fatalf("URL parse error: %v", err)
274 }
275 req.Method = "GET"
276 req.Proto = "HTTP/1.1"
277 req.ProtoMajor = 1
278 req.ProtoMinor = 1
279 req.Close = reqClose
280
281 res, err := c.Do(req)
282 if err != nil {
283 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
284 }
285 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
286 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
287 reqClose, got, !reqClose)
288 }
289 body, err := io.ReadAll(res.Body)
290 if err != nil {
291 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
292 }
293 return string(body)
294 }
295
296 body1 := fetch(1)
297 body2 := fetch(2)
298
299 got := 1
300 if body1 != body2 {
301 got++
302 }
303 want := 1
304 if reqClose {
305 want = 2
306 }
307 if got != want {
308 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
309 reqClose, got, want, body1, body2)
310 }
311
312 tr.CloseIdleConnections()
313 }
314
315 connSet.check(t)
316 }
317
318
319
320
321 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
322 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
323 }
324 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
325 ts := newClientServerTest(t, mode, hostPortHandler).ts
326
327 c := ts.Client()
328 c.Transport.(*Transport).DisableKeepAlives = true
329
330 res, err := c.Get(ts.URL)
331 if err != nil {
332 t.Fatal(err)
333 }
334 res.Body.Close()
335 if res.Header.Get("X-Saw-Close") != "true" {
336 t.Errorf("handler didn't see Connection: close ")
337 }
338 }
339
340
341
342 func TestTransportRespectRequestWantsClose(t *testing.T) {
343 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
344 }
345 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
346 tests := []struct {
347 disableKeepAlives bool
348 close bool
349 }{
350 {disableKeepAlives: false, close: false},
351 {disableKeepAlives: false, close: true},
352 {disableKeepAlives: true, close: false},
353 {disableKeepAlives: true, close: true},
354 }
355
356 for _, tc := range tests {
357 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
358 func(t *testing.T) {
359 ts := newClientServerTest(t, mode, hostPortHandler).ts
360
361 c := ts.Client()
362 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
363 req, err := NewRequest("GET", ts.URL, nil)
364 if err != nil {
365 t.Fatal(err)
366 }
367 count := 0
368 trace := &httptrace.ClientTrace{
369 WroteHeaderField: func(key string, field []string) {
370 if key != "Connection" {
371 return
372 }
373 if httpguts.HeaderValuesContainsToken(field, "close") {
374 count += 1
375 }
376 },
377 }
378 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
379 req.Close = tc.close
380 res, err := c.Do(req)
381 if err != nil {
382 t.Fatal(err)
383 }
384 defer res.Body.Close()
385 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
386 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
387 }
388 })
389 }
390
391 }
392
393 func TestTransportIdleCacheKeys(t *testing.T) {
394 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
395 }
396 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
397 ts := newClientServerTest(t, mode, hostPortHandler).ts
398 c := ts.Client()
399 tr := c.Transport.(*Transport)
400
401 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
402 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
403 }
404
405 resp, err := c.Get(ts.URL)
406 if err != nil {
407 t.Error(err)
408 }
409 io.ReadAll(resp.Body)
410
411 keys := tr.IdleConnKeysForTesting()
412 if e, g := 1, len(keys); e != g {
413 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
414 }
415
416 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
417 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
418 }
419
420 tr.CloseIdleConnections()
421 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
422 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
423 }
424 }
425
426
427
428 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
429 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
430 const msg = "foobar"
431
432 var addrSeen map[string]int
433 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
434 addrSeen[r.RemoteAddr]++
435 if r.URL.Path == "/chunked/" {
436 w.WriteHeader(200)
437 w.(Flusher).Flush()
438 } else {
439 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
440 w.WriteHeader(200)
441 }
442 w.Write([]byte(msg))
443 })).ts
444
445 for pi, path := range []string{"/content-length/", "/chunked/"} {
446 wantLen := []int{len(msg), -1}[pi]
447 addrSeen = make(map[string]int)
448 for i := 0; i < 3; i++ {
449 res, err := ts.Client().Get(ts.URL + path)
450 if err != nil {
451 t.Errorf("Get %s: %v", path, err)
452 continue
453 }
454
455
456
457
458
459 defer res.Body.Close()
460
461 if res.ContentLength != int64(wantLen) {
462 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
463 }
464 got, err := io.ReadAll(res.Body)
465 if string(got) != msg || err != nil {
466 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
467 }
468 }
469 if len(addrSeen) != 1 {
470 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
471 }
472 }
473 }
474
475 func TestTransportMaxPerHostIdleConns(t *testing.T) {
476 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
477 }
478 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
479 stop := make(chan struct{})
480 defer close(stop)
481
482 resch := make(chan string)
483 gotReq := make(chan bool)
484 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
485 gotReq <- true
486 var msg string
487 select {
488 case <-stop:
489 return
490 case msg = <-resch:
491 }
492 _, err := w.Write([]byte(msg))
493 if err != nil {
494 t.Errorf("Write: %v", err)
495 return
496 }
497 })).ts
498
499 c := ts.Client()
500 tr := c.Transport.(*Transport)
501 maxIdleConnsPerHost := 2
502 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
503
504
505
506 donech := make(chan bool)
507 doReq := func() {
508 defer func() {
509 select {
510 case <-stop:
511 return
512 case donech <- t.Failed():
513 }
514 }()
515 resp, err := c.Get(ts.URL)
516 if err != nil {
517 t.Error(err)
518 return
519 }
520 if _, err := io.ReadAll(resp.Body); err != nil {
521 t.Errorf("ReadAll: %v", err)
522 return
523 }
524 }
525 go doReq()
526 <-gotReq
527 go doReq()
528 <-gotReq
529 go doReq()
530 <-gotReq
531
532 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
533 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
534 }
535
536 resch <- "res1"
537 <-donech
538 keys := tr.IdleConnKeysForTesting()
539 if e, g := 1, len(keys); e != g {
540 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
541 }
542 addr := ts.Listener.Addr().String()
543 cacheKey := "|http|" + addr
544 if keys[0] != cacheKey {
545 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
546 }
547 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
548 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
549 }
550
551 resch <- "res2"
552 <-donech
553 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
554 t.Errorf("after second response, idle conns = %d; want %d", g, w)
555 }
556
557 resch <- "res3"
558 <-donech
559 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
560 t.Errorf("after third response, idle conns = %d; want %d", g, w)
561 }
562 }
563
564 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
565 run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
566 }
567 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
568 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
569 _, err := w.Write([]byte("foo"))
570 if err != nil {
571 t.Fatalf("Write: %v", err)
572 }
573 })).ts
574 c := ts.Client()
575 tr := c.Transport.(*Transport)
576 dialStarted := make(chan struct{})
577 stallDial := make(chan struct{})
578 tr.Dial = func(network, addr string) (net.Conn, error) {
579 dialStarted <- struct{}{}
580 <-stallDial
581 return net.Dial(network, addr)
582 }
583
584 tr.DisableKeepAlives = true
585 tr.MaxConnsPerHost = 1
586
587 preDial := make(chan struct{})
588 reqComplete := make(chan struct{})
589 doReq := func(reqId string) {
590 req, _ := NewRequest("GET", ts.URL, nil)
591 trace := &httptrace.ClientTrace{
592 GetConn: func(hostPort string) {
593 preDial <- struct{}{}
594 },
595 }
596 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
597 resp, err := tr.RoundTrip(req)
598 if err != nil {
599 t.Errorf("unexpected error for request %s: %v", reqId, err)
600 }
601 _, err = io.ReadAll(resp.Body)
602 if err != nil {
603 t.Errorf("unexpected error for request %s: %v", reqId, err)
604 }
605 reqComplete <- struct{}{}
606 }
607
608 go doReq("req1")
609 <-preDial
610 <-dialStarted
611
612
613 go doReq("req2")
614 <-preDial
615 select {
616 case <-dialStarted:
617 t.Error("req2 dial started while req1 dial in progress")
618 return
619 default:
620 }
621
622
623 stallDial <- struct{}{}
624 <-reqComplete
625
626
627 <-dialStarted
628 stallDial <- struct{}{}
629 <-reqComplete
630 }
631
632 func TestTransportMaxConnsPerHost(t *testing.T) {
633 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
634 }
635 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
636 CondSkipHTTP2(t)
637
638 h := HandlerFunc(func(w ResponseWriter, r *Request) {
639 _, err := w.Write([]byte("foo"))
640 if err != nil {
641 t.Fatalf("Write: %v", err)
642 }
643 })
644
645 ts := newClientServerTest(t, mode, h).ts
646 c := ts.Client()
647 tr := c.Transport.(*Transport)
648 tr.MaxConnsPerHost = 1
649
650 mu := sync.Mutex{}
651 var conns []net.Conn
652 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
653 tr.Dial = func(network, addr string) (net.Conn, error) {
654 atomic.AddInt32(&dialCnt, 1)
655 c, err := net.Dial(network, addr)
656 mu.Lock()
657 defer mu.Unlock()
658 conns = append(conns, c)
659 return c, err
660 }
661
662 doReq := func() {
663 trace := &httptrace.ClientTrace{
664 GotConn: func(connInfo httptrace.GotConnInfo) {
665 if !connInfo.Reused {
666 atomic.AddInt32(&gotConnCnt, 1)
667 }
668 },
669 TLSHandshakeStart: func() {
670 atomic.AddInt32(&tlsHandshakeCnt, 1)
671 },
672 }
673 req, _ := NewRequest("GET", ts.URL, nil)
674 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
675
676 resp, err := c.Do(req)
677 if err != nil {
678 t.Fatalf("request failed: %v", err)
679 }
680 defer resp.Body.Close()
681 _, err = io.ReadAll(resp.Body)
682 if err != nil {
683 t.Fatalf("read body failed: %v", err)
684 }
685 }
686
687 wg := sync.WaitGroup{}
688 for i := 0; i < 10; i++ {
689 wg.Add(1)
690 go func() {
691 defer wg.Done()
692 doReq()
693 }()
694 }
695 wg.Wait()
696
697 expected := int32(tr.MaxConnsPerHost)
698 if dialCnt != expected {
699 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
700 }
701 if gotConnCnt != expected {
702 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
703 }
704 if ts.TLS != nil && tlsHandshakeCnt != expected {
705 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
706 }
707
708 if t.Failed() {
709 t.FailNow()
710 }
711
712 mu.Lock()
713 for _, c := range conns {
714 c.Close()
715 }
716 conns = nil
717 mu.Unlock()
718 tr.CloseIdleConnections()
719
720 doReq()
721 expected++
722 if dialCnt != expected {
723 t.Errorf("round 2: too many dials: %d", dialCnt)
724 }
725 if gotConnCnt != expected {
726 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
727 }
728 if ts.TLS != nil && tlsHandshakeCnt != expected {
729 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
730 }
731 }
732
733 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
734 run(t, testTransportMaxConnsPerHostDialCancellation,
735 testNotParallel,
736 []testMode{http1Mode, https1Mode, http2Mode},
737 )
738 }
739
740 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
741 CondSkipHTTP2(t)
742
743 h := HandlerFunc(func(w ResponseWriter, r *Request) {
744 _, err := w.Write([]byte("foo"))
745 if err != nil {
746 t.Fatalf("Write: %v", err)
747 }
748 })
749
750 cst := newClientServerTest(t, mode, h)
751 defer cst.close()
752 ts := cst.ts
753 c := ts.Client()
754 tr := c.Transport.(*Transport)
755 tr.MaxConnsPerHost = 1
756
757
758 ctx, cancel := context.WithCancel(context.Background())
759 defer cancel()
760 SetPendingDialHooks(cancel, nil)
761 defer SetPendingDialHooks(nil, nil)
762
763 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
764 _, err := c.Do(req)
765 if !errors.Is(err, context.Canceled) {
766 t.Errorf("expected error %v, got %v", context.Canceled, err)
767 }
768
769
770 SetPendingDialHooks(nil, nil)
771 req, _ = NewRequest("GET", ts.URL, nil)
772 resp, err := c.Do(req)
773 if err != nil {
774 t.Fatalf("request failed: %v", err)
775 }
776 defer resp.Body.Close()
777 _, err = io.ReadAll(resp.Body)
778 if err != nil {
779 t.Fatalf("read body failed: %v", err)
780 }
781 }
782
783 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
784 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
785 }
786 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
787 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
788 io.WriteString(w, r.RemoteAddr)
789 })).ts
790
791 c := ts.Client()
792 tr := c.Transport.(*Transport)
793
794 doReq := func(name string) {
795
796
797 res, err := c.Post(ts.URL, "", nil)
798 if err != nil {
799 t.Fatalf("%s: %v", name, err)
800 }
801 if res.StatusCode != 200 {
802 t.Fatalf("%s: %v", name, res.Status)
803 }
804 defer res.Body.Close()
805 slurp, err := io.ReadAll(res.Body)
806 if err != nil {
807 t.Fatalf("%s: %v", name, err)
808 }
809 t.Logf("%s: ok (%q)", name, slurp)
810 }
811
812 doReq("first")
813 keys1 := tr.IdleConnKeysForTesting()
814
815 ts.CloseClientConnections()
816
817 var keys2 []string
818 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
819 keys2 = tr.IdleConnKeysForTesting()
820 if len(keys2) != 0 {
821 if d > 0 {
822 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
823 }
824 return false
825 }
826 return true
827 })
828
829 doReq("second")
830 }
831
832
833
834 func TestTransportServerClosingUnexpectedly(t *testing.T) {
835 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
836 }
837 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
838 ts := newClientServerTest(t, mode, hostPortHandler).ts
839 c := ts.Client()
840
841 fetch := func(n, retries int) string {
842 condFatalf := func(format string, arg ...any) {
843 if retries <= 0 {
844 t.Fatalf(format, arg...)
845 }
846 t.Logf("retrying shortly after expected error: "+format, arg...)
847 time.Sleep(time.Second / time.Duration(retries))
848 }
849 for retries >= 0 {
850 retries--
851 res, err := c.Get(ts.URL)
852 if err != nil {
853 condFatalf("error in req #%d, GET: %v", n, err)
854 continue
855 }
856 body, err := io.ReadAll(res.Body)
857 if err != nil {
858 condFatalf("error in req #%d, ReadAll: %v", n, err)
859 continue
860 }
861 res.Body.Close()
862 return string(body)
863 }
864 panic("unreachable")
865 }
866
867 body1 := fetch(1, 0)
868 body2 := fetch(2, 0)
869
870
871
872
873
874
875
876
877 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
878
879 body3 := fetch(3, 5)
880
881 if body1 != body2 {
882 t.Errorf("expected body1 and body2 to be equal")
883 }
884 if body2 == body3 {
885 t.Errorf("expected body2 and body3 to be different")
886 }
887 }
888
889
890
891 func TestStressSurpriseServerCloses(t *testing.T) {
892 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
893 }
894 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
895 if testing.Short() {
896 t.Skip("skipping test in short mode")
897 }
898 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
899 w.Header().Set("Content-Length", "5")
900 w.Header().Set("Content-Type", "text/plain")
901 w.Write([]byte("Hello"))
902 w.(Flusher).Flush()
903 conn, buf, _ := w.(Hijacker).Hijack()
904 buf.Flush()
905 conn.Close()
906 })).ts
907 c := ts.Client()
908
909
910
911
912
913
914
915 const (
916 numClients = 20
917 reqsPerClient = 25
918 )
919 var wg sync.WaitGroup
920 wg.Add(numClients * reqsPerClient)
921 for i := 0; i < numClients; i++ {
922 go func() {
923 for i := 0; i < reqsPerClient; i++ {
924 res, err := c.Get(ts.URL)
925 if err == nil {
926
927
928
929
930
931
932 res.Body.Close()
933 }
934 wg.Done()
935 }
936 }()
937 }
938
939
940 wg.Wait()
941 }
942
943
944
945 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
946 func testTransportHeadResponses(t *testing.T, mode testMode) {
947 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
948 if r.Method != "HEAD" {
949 panic("expected HEAD; got " + r.Method)
950 }
951 w.Header().Set("Content-Length", "123")
952 w.WriteHeader(200)
953 })).ts
954 c := ts.Client()
955
956 for i := 0; i < 2; i++ {
957 res, err := c.Head(ts.URL)
958 if err != nil {
959 t.Errorf("error on loop %d: %v", i, err)
960 continue
961 }
962 if e, g := "123", res.Header.Get("Content-Length"); e != g {
963 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
964 }
965 if e, g := int64(123), res.ContentLength; e != g {
966 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
967 }
968 if all, err := io.ReadAll(res.Body); err != nil {
969 t.Errorf("loop %d: Body ReadAll: %v", i, err)
970 } else if len(all) != 0 {
971 t.Errorf("Bogus body %q", all)
972 }
973 }
974 }
975
976
977
978 func TestTransportHeadChunkedResponse(t *testing.T) {
979 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
980 }
981 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
982 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
983 if r.Method != "HEAD" {
984 panic("expected HEAD; got " + r.Method)
985 }
986 w.Header().Set("Transfer-Encoding", "chunked")
987 w.Header().Set("x-client-ipport", r.RemoteAddr)
988 w.WriteHeader(200)
989 })).ts
990 c := ts.Client()
991
992
993
994 didRead := make(chan bool)
995 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
996 defer SetReadLoopBeforeNextReadHook(nil)
997
998 res1, err := c.Head(ts.URL)
999 <-didRead
1000
1001 if err != nil {
1002 t.Fatalf("request 1 error: %v", err)
1003 }
1004
1005 res2, err := c.Head(ts.URL)
1006 <-didRead
1007
1008 if err != nil {
1009 t.Fatalf("request 2 error: %v", err)
1010 }
1011 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1012 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1013 }
1014 }
1015
1016 var roundTripTests = []struct {
1017 accept string
1018 expectAccept string
1019 compressed bool
1020 }{
1021
1022 {"", "gzip", false},
1023
1024 {"foo", "foo", false},
1025
1026 {"gzip", "gzip", true},
1027 }
1028
1029
1030 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
1031 func testRoundTripGzip(t *testing.T, mode testMode) {
1032 const responseBody = "test response body"
1033 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1034 accept := req.Header.Get("Accept-Encoding")
1035 if expect := req.FormValue("expect_accept"); accept != expect {
1036 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1037 req.FormValue("testnum"), accept, expect)
1038 }
1039 if accept == "gzip" {
1040 rw.Header().Set("Content-Encoding", "gzip")
1041 gz := gzip.NewWriter(rw)
1042 gz.Write([]byte(responseBody))
1043 gz.Close()
1044 } else {
1045 rw.Header().Set("Content-Encoding", accept)
1046 rw.Write([]byte(responseBody))
1047 }
1048 })).ts
1049 tr := ts.Client().Transport.(*Transport)
1050
1051 for i, test := range roundTripTests {
1052
1053 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1054 if test.accept != "" {
1055 req.Header.Set("Accept-Encoding", test.accept)
1056 }
1057 res, err := tr.RoundTrip(req)
1058 if err != nil {
1059 t.Errorf("%d. RoundTrip: %v", i, err)
1060 continue
1061 }
1062 var body []byte
1063 if test.compressed {
1064 var r *gzip.Reader
1065 r, err = gzip.NewReader(res.Body)
1066 if err != nil {
1067 t.Errorf("%d. gzip NewReader: %v", i, err)
1068 continue
1069 }
1070 body, err = io.ReadAll(r)
1071 res.Body.Close()
1072 } else {
1073 body, err = io.ReadAll(res.Body)
1074 }
1075 if err != nil {
1076 t.Errorf("%d. Error: %q", i, err)
1077 continue
1078 }
1079 if g, e := string(body), responseBody; g != e {
1080 t.Errorf("%d. body = %q; want %q", i, g, e)
1081 }
1082 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1083 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1084 }
1085 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1086 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1087 }
1088 }
1089
1090 }
1091
1092 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1093 func testTransportGzip(t *testing.T, mode testMode) {
1094 if mode == http2Mode {
1095 t.Skip("https://go.dev/issue/56020")
1096 }
1097 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1098 const nRandBytes = 1024 * 1024
1099 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1100 if req.Method == "HEAD" {
1101 if g := req.Header.Get("Accept-Encoding"); g != "" {
1102 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1103 }
1104 return
1105 }
1106 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1107 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1108 }
1109 rw.Header().Set("Content-Encoding", "gzip")
1110
1111 var w io.Writer = rw
1112 var buf bytes.Buffer
1113 if req.FormValue("chunked") == "0" {
1114 w = &buf
1115 defer io.Copy(rw, &buf)
1116 defer func() {
1117 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1118 }()
1119 }
1120 gz := gzip.NewWriter(w)
1121 gz.Write([]byte(testString))
1122 if req.FormValue("body") == "large" {
1123 io.CopyN(gz, rand.Reader, nRandBytes)
1124 }
1125 gz.Close()
1126 })).ts
1127 c := ts.Client()
1128
1129 for _, chunked := range []string{"1", "0"} {
1130
1131 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1132 if err != nil {
1133 t.Fatalf("large get: %v", err)
1134 }
1135 buf := make([]byte, len(testString))
1136 n, err := io.ReadFull(res.Body, buf)
1137 if err != nil {
1138 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1139 }
1140 if e, g := testString, string(buf); e != g {
1141 t.Errorf("partial read got %q, expected %q", g, e)
1142 }
1143 res.Body.Close()
1144
1145 n, err = res.Body.Read(buf)
1146 if n != 0 || err == nil {
1147 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1148 }
1149
1150
1151 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1152 if err != nil {
1153 t.Fatal(err)
1154 }
1155 body, err := io.ReadAll(res.Body)
1156 if err != nil {
1157 t.Fatal(err)
1158 }
1159 if g, e := string(body), testString; g != e {
1160 t.Fatalf("body = %q; want %q", g, e)
1161 }
1162 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1163 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1164 }
1165
1166
1167 n, err = res.Body.Read(buf)
1168 if n != 0 || err == nil {
1169 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1170 }
1171 res.Body.Close()
1172 n, err = res.Body.Read(buf)
1173 if n != 0 || err == nil {
1174 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1175 }
1176 }
1177
1178
1179 res, err := c.Head(ts.URL)
1180 if err != nil {
1181 t.Fatalf("Head: %v", err)
1182 }
1183 if res.StatusCode != 200 {
1184 t.Errorf("Head status=%d; want=200", res.StatusCode)
1185 }
1186 }
1187
1188
1189
1190 type transport100ContinueTest struct {
1191 t *testing.T
1192
1193 reqdone chan struct{}
1194 resp *Response
1195 respErr error
1196
1197 conn net.Conn
1198 reader *bufio.Reader
1199 }
1200
1201 const transport100ContinueTestBody = "request body"
1202
1203
1204
1205 func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
1206 ln := newLocalListener(t)
1207 defer ln.Close()
1208
1209 test := &transport100ContinueTest{
1210 t: t,
1211 reqdone: make(chan struct{}),
1212 }
1213
1214 tr := &Transport{
1215 ExpectContinueTimeout: timeout,
1216 }
1217 go func() {
1218 defer close(test.reqdone)
1219 body := strings.NewReader(transport100ContinueTestBody)
1220 req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
1221 req.Header.Set("Expect", "100-continue")
1222 req.ContentLength = int64(len(transport100ContinueTestBody))
1223 test.resp, test.respErr = tr.RoundTrip(req)
1224 test.resp.Body.Close()
1225 }()
1226
1227 c, err := ln.Accept()
1228 if err != nil {
1229 t.Fatalf("Accept: %v", err)
1230 }
1231 t.Cleanup(func() {
1232 c.Close()
1233 })
1234 br := bufio.NewReader(c)
1235 _, err = ReadRequest(br)
1236 if err != nil {
1237 t.Fatalf("ReadRequest: %v", err)
1238 }
1239 test.conn = c
1240 test.reader = br
1241 t.Cleanup(func() {
1242 <-test.reqdone
1243 tr.CloseIdleConnections()
1244 got, _ := io.ReadAll(test.reader)
1245 if len(got) > 0 {
1246 t.Fatalf("Transport sent unexpected bytes: %q", got)
1247 }
1248 })
1249
1250 return test
1251 }
1252
1253
1254 func (test *transport100ContinueTest) respond(lines ...string) {
1255 for _, line := range lines {
1256 if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
1257 test.t.Fatalf("Write: %v", err)
1258 }
1259 }
1260 if _, err := test.conn.Write([]byte("\r\n")); err != nil {
1261 test.t.Fatalf("Write: %v", err)
1262 }
1263 }
1264
1265
1266 func (test *transport100ContinueTest) wantBodySent() {
1267 got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
1268 if err != nil {
1269 test.t.Fatalf("unexpected error reading body: %v", err)
1270 }
1271 if got, want := string(got), transport100ContinueTestBody; got != want {
1272 test.t.Fatalf("unexpected body: got %q, want %q", got, want)
1273 }
1274 }
1275
1276
1277 func (test *transport100ContinueTest) wantRequestDone(want int) {
1278 <-test.reqdone
1279 if test.respErr != nil {
1280 test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
1281 }
1282 if got := test.resp.StatusCode; got != want {
1283 test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
1284 }
1285 }
1286
1287 func TestTransportExpect100ContinueSent(t *testing.T) {
1288 test := newTransport100ContinueTest(t, 1*time.Hour)
1289
1290 test.respond("HTTP/1.1 100 Continue")
1291 test.wantBodySent()
1292 test.respond("HTTP/1.1 200", "Content-Length: 0")
1293 test.wantRequestDone(200)
1294 }
1295
1296 func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
1297 test := newTransport100ContinueTest(t, 1*time.Hour)
1298
1299 test.respond("HTTP/1.1 200", "Content-Length: 0")
1300 test.wantBodySent()
1301 test.wantRequestDone(200)
1302 }
1303
1304 func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
1305 test := newTransport100ContinueTest(t, 1*time.Hour)
1306
1307 test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
1308 test.wantRequestDone(200)
1309 }
1310
1311 func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
1312 test := newTransport100ContinueTest(t, 1*time.Hour)
1313
1314 test.respond("HTTP/1.1 500", "Content-Length: 0")
1315 test.wantBodySent()
1316 test.wantRequestDone(500)
1317 }
1318
1319 func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
1320 test := newTransport100ContinueTest(t, 5*time.Millisecond)
1321 test.wantBodySent()
1322 test.respond("HTTP/1.1 200", "Content-Length: 0")
1323 test.wantRequestDone(200)
1324 }
1325
1326 func TestSOCKS5Proxy(t *testing.T) {
1327 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1328 }
1329 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1330 ch := make(chan string, 1)
1331 l := newLocalListener(t)
1332 defer l.Close()
1333 defer close(ch)
1334 proxy := func(t *testing.T) {
1335 s, err := l.Accept()
1336 if err != nil {
1337 t.Errorf("socks5 proxy Accept(): %v", err)
1338 return
1339 }
1340 defer s.Close()
1341 var buf [22]byte
1342 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1343 t.Errorf("socks5 proxy initial read: %v", err)
1344 return
1345 }
1346 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1347 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1348 return
1349 }
1350 if _, err := s.Write([]byte{5, 0}); err != nil {
1351 t.Errorf("socks5 proxy initial write: %v", err)
1352 return
1353 }
1354 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1355 t.Errorf("socks5 proxy second read: %v", err)
1356 return
1357 }
1358 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1359 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1360 return
1361 }
1362 var ipLen int
1363 switch buf[3] {
1364 case 1:
1365 ipLen = net.IPv4len
1366 case 4:
1367 ipLen = net.IPv6len
1368 default:
1369 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1370 return
1371 }
1372 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1373 t.Errorf("socks5 proxy address read: %v", err)
1374 return
1375 }
1376 ip := net.IP(buf[4 : ipLen+4])
1377 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1378 copy(buf[:3], []byte{5, 0, 0})
1379 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1380 t.Errorf("socks5 proxy connect write: %v", err)
1381 return
1382 }
1383 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1384
1385
1386 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1387 targetConn, err := net.Dial("tcp", targetHost)
1388 if err != nil {
1389 t.Errorf("net.Dial failed")
1390 return
1391 }
1392 go io.Copy(targetConn, s)
1393 io.Copy(s, targetConn)
1394 targetConn.Close()
1395 }
1396
1397 pu, err := url.Parse("socks5://" + l.Addr().String())
1398 if err != nil {
1399 t.Fatal(err)
1400 }
1401
1402 sentinelHeader := "X-Sentinel"
1403 sentinelValue := "12345"
1404 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1405 w.Header().Set(sentinelHeader, sentinelValue)
1406 })
1407 for _, useTLS := range []bool{false, true} {
1408 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1409 ts := newClientServerTest(t, mode, h).ts
1410 go proxy(t)
1411 c := ts.Client()
1412 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1413 r, err := c.Head(ts.URL)
1414 if err != nil {
1415 t.Fatal(err)
1416 }
1417 if r.Header.Get(sentinelHeader) != sentinelValue {
1418 t.Errorf("Failed to retrieve sentinel value")
1419 }
1420 got := <-ch
1421 ts.Close()
1422 tsu, err := url.Parse(ts.URL)
1423 if err != nil {
1424 t.Fatal(err)
1425 }
1426 want := "proxy for " + tsu.Host
1427 if got != want {
1428 t.Errorf("got %q, want %q", got, want)
1429 }
1430 })
1431 }
1432 }
1433
1434 func TestTransportProxy(t *testing.T) {
1435 defer afterTest(t)
1436 testCases := []struct{ siteMode, proxyMode testMode }{
1437 {http1Mode, http1Mode},
1438 {http1Mode, https1Mode},
1439 {https1Mode, http1Mode},
1440 {https1Mode, https1Mode},
1441 }
1442 for _, testCase := range testCases {
1443 siteMode := testCase.siteMode
1444 proxyMode := testCase.proxyMode
1445 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1446 siteCh := make(chan *Request, 1)
1447 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1448 siteCh <- r
1449 })
1450 proxyCh := make(chan *Request, 1)
1451 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1452 proxyCh <- r
1453
1454 if r.Method == "CONNECT" {
1455 hijacker, ok := w.(Hijacker)
1456 if !ok {
1457 t.Errorf("hijack not allowed")
1458 return
1459 }
1460 clientConn, _, err := hijacker.Hijack()
1461 if err != nil {
1462 t.Errorf("hijacking failed")
1463 return
1464 }
1465 res := &Response{
1466 StatusCode: StatusOK,
1467 Proto: "HTTP/1.1",
1468 ProtoMajor: 1,
1469 ProtoMinor: 1,
1470 Header: make(Header),
1471 }
1472
1473 targetConn, err := net.Dial("tcp", r.URL.Host)
1474 if err != nil {
1475 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1476 return
1477 }
1478
1479 if err := res.Write(clientConn); err != nil {
1480 t.Errorf("Writing 200 OK failed: %v", err)
1481 return
1482 }
1483
1484 go io.Copy(targetConn, clientConn)
1485 go func() {
1486 io.Copy(clientConn, targetConn)
1487 targetConn.Close()
1488 }()
1489 }
1490 })
1491 ts := newClientServerTest(t, siteMode, h1).ts
1492 proxy := newClientServerTest(t, proxyMode, h2).ts
1493
1494 pu, err := url.Parse(proxy.URL)
1495 if err != nil {
1496 t.Fatal(err)
1497 }
1498
1499
1500
1501
1502 c := proxy.Client()
1503 if siteMode == https1Mode {
1504 c = ts.Client()
1505 }
1506
1507 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1508 if _, err := c.Head(ts.URL); err != nil {
1509 t.Error(err)
1510 }
1511 got := <-proxyCh
1512 c.Transport.(*Transport).CloseIdleConnections()
1513 ts.Close()
1514 proxy.Close()
1515 if siteMode == https1Mode {
1516
1517 if got.Method != "CONNECT" {
1518 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1519 }
1520 gotHost := got.URL.Host
1521 pu, err := url.Parse(ts.URL)
1522 if err != nil {
1523 t.Fatal("Invalid site URL")
1524 }
1525 if wantHost := pu.Host; gotHost != wantHost {
1526 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1527 }
1528
1529
1530 next := <-siteCh
1531 if next.Method != "HEAD" {
1532 t.Errorf("Wrong method at destination: %s", next.Method)
1533 }
1534 if nextURL := next.URL.String(); nextURL != "/" {
1535 t.Errorf("Wrong URL at destination: %s", nextURL)
1536 }
1537 } else {
1538 if got.Method != "HEAD" {
1539 t.Errorf("Wrong method for destination: %q", got.Method)
1540 }
1541 gotURL := got.URL.String()
1542 wantURL := ts.URL + "/"
1543 if gotURL != wantURL {
1544 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1545 }
1546 }
1547 })
1548 }
1549 }
1550
1551 func TestOnProxyConnectResponse(t *testing.T) {
1552
1553 var tcases = []struct {
1554 proxyStatusCode int
1555 err error
1556 }{
1557 {
1558 StatusOK,
1559 nil,
1560 },
1561 {
1562 StatusForbidden,
1563 errors.New("403"),
1564 },
1565 }
1566 for _, tcase := range tcases {
1567 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1568
1569 })
1570
1571 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1572
1573 if r.Method == "CONNECT" {
1574 if tcase.proxyStatusCode != StatusOK {
1575 w.WriteHeader(tcase.proxyStatusCode)
1576 return
1577 }
1578 hijacker, ok := w.(Hijacker)
1579 if !ok {
1580 t.Errorf("hijack not allowed")
1581 return
1582 }
1583 clientConn, _, err := hijacker.Hijack()
1584 if err != nil {
1585 t.Errorf("hijacking failed")
1586 return
1587 }
1588 res := &Response{
1589 StatusCode: StatusOK,
1590 Proto: "HTTP/1.1",
1591 ProtoMajor: 1,
1592 ProtoMinor: 1,
1593 Header: make(Header),
1594 }
1595
1596 targetConn, err := net.Dial("tcp", r.URL.Host)
1597 if err != nil {
1598 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1599 return
1600 }
1601
1602 if err := res.Write(clientConn); err != nil {
1603 t.Errorf("Writing 200 OK failed: %v", err)
1604 return
1605 }
1606
1607 go io.Copy(targetConn, clientConn)
1608 go func() {
1609 io.Copy(clientConn, targetConn)
1610 targetConn.Close()
1611 }()
1612 }
1613 })
1614 ts := newClientServerTest(t, https1Mode, h1).ts
1615 proxy := newClientServerTest(t, https1Mode, h2).ts
1616
1617 pu, err := url.Parse(proxy.URL)
1618 if err != nil {
1619 t.Fatal(err)
1620 }
1621
1622 c := proxy.Client()
1623
1624 var (
1625 dials atomic.Int32
1626 closes atomic.Int32
1627 )
1628 c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1629 conn, err := net.Dial(network, addr)
1630 if err != nil {
1631 return nil, err
1632 }
1633 dials.Add(1)
1634 return noteCloseConn{
1635 Conn: conn,
1636 closeFunc: func() {
1637 closes.Add(1)
1638 },
1639 }, nil
1640 }
1641
1642 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1643 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1644 if proxyURL.String() != pu.String() {
1645 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1646 }
1647
1648 if "https://"+connectReq.URL.String() != ts.URL {
1649 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1650 }
1651 return tcase.err
1652 }
1653 wantCloses := int32(0)
1654 if _, err := c.Head(ts.URL); err != nil {
1655 wantCloses = 1
1656 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1657 t.Errorf("got %v, want %v", err, tcase.err)
1658 }
1659 } else {
1660 if tcase.err != nil {
1661 t.Errorf("got %v, want nil", err)
1662 }
1663 }
1664 if got, want := dials.Load(), int32(1); got != want {
1665 t.Errorf("got %v dials, want %v", got, want)
1666 }
1667
1668 if got, want := closes.Load(), wantCloses; got != want {
1669 t.Errorf("got %v closes, want %v", got, want)
1670 }
1671 }
1672 }
1673
1674
1675
1676 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1677 cancelc := make(chan struct{})
1678 SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
1679 ctx, cancel := context.WithCancel(ctx)
1680 go func() {
1681 select {
1682 case <-cancelc:
1683 case <-ctx.Done():
1684 }
1685 cancel()
1686 }()
1687 return ctx, cancel
1688 })
1689
1690 defer afterTest(t)
1691
1692 ln := newLocalListener(t)
1693 defer ln.Close()
1694 listenerDone := make(chan struct{})
1695 go func() {
1696 defer close(listenerDone)
1697 c, err := ln.Accept()
1698 if err != nil {
1699 t.Errorf("Accept: %v", err)
1700 return
1701 }
1702 defer c.Close()
1703
1704 br := bufio.NewReader(c)
1705 cr, err := ReadRequest(br)
1706 if err != nil {
1707 t.Errorf("proxy server failed to read CONNECT request")
1708 return
1709 }
1710 if cr.Method != "CONNECT" {
1711 t.Errorf("unexpected method %q", cr.Method)
1712 return
1713 }
1714
1715
1716
1717
1718 close(cancelc)
1719 var buf [1]byte
1720 _, err = br.Read(buf[:])
1721 if err != io.EOF {
1722 t.Errorf("proxy server Read err = %v; want EOF", err)
1723 }
1724 return
1725 }()
1726
1727 c := &Client{
1728 Transport: &Transport{
1729 Proxy: func(*Request) (*url.URL, error) {
1730 return url.Parse("http://" + ln.Addr().String())
1731 },
1732 },
1733 }
1734 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1735 if err != nil {
1736 t.Fatal(err)
1737 }
1738 _, err = c.Do(req)
1739 if err == nil {
1740 t.Errorf("unexpected Get success")
1741 }
1742
1743
1744
1745
1746 <-listenerDone
1747 }
1748
1749
1750 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1751 defer afterTest(t)
1752
1753 var errDial = errors.New("some dial error")
1754
1755 tr := &Transport{
1756 Proxy: func(*Request) (*url.URL, error) {
1757 return url.Parse("http://proxy.fake.tld/")
1758 },
1759 Dial: func(string, string) (net.Conn, error) {
1760 return nil, errDial
1761 },
1762 }
1763 defer tr.CloseIdleConnections()
1764
1765 c := &Client{Transport: tr}
1766 req, _ := NewRequest("GET", "http://fake.tld", nil)
1767 res, err := c.Do(req)
1768 if err == nil {
1769 res.Body.Close()
1770 t.Fatal("wanted a non-nil error")
1771 }
1772
1773 uerr, ok := err.(*url.Error)
1774 if !ok {
1775 t.Fatalf("got %T, want *url.Error", err)
1776 }
1777 oe, ok := uerr.Err.(*net.OpError)
1778 if !ok {
1779 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1780 }
1781 want := &net.OpError{
1782 Op: "proxyconnect",
1783 Net: "tcp",
1784 Err: errDial,
1785 }
1786 if !reflect.DeepEqual(oe, want) {
1787 t.Errorf("Got error %#v; want %#v", oe, want)
1788 }
1789 }
1790
1791
1792
1793
1794
1795 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1796 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1797 }
1798 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1799 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1800 defer proxy.Close()
1801 c := proxy.Client()
1802
1803 tr := c.Transport.(*Transport)
1804 tr.Proxy = func(*Request) (*url.URL, error) {
1805 u, _ := url.Parse(proxy.URL)
1806 u.User = url.UserPassword("aladdin", "opensesame")
1807 return u, nil
1808 }
1809 h := tr.ProxyConnectHeader
1810 if h == nil {
1811 h = make(Header)
1812 }
1813 tr.ProxyConnectHeader = h.Clone()
1814
1815 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1816 if err != nil {
1817 t.Fatal(err)
1818 }
1819 _, err = c.Do(req)
1820 if err == nil {
1821 t.Errorf("unexpected Get success")
1822 }
1823
1824 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1825 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1826 }
1827 }
1828
1829
1830
1831
1832
1833 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
1834 func testTransportGzipRecursive(t *testing.T, mode testMode) {
1835 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1836 w.Header().Set("Content-Encoding", "gzip")
1837 w.Write(rgz)
1838 })).ts
1839
1840 c := ts.Client()
1841 res, err := c.Get(ts.URL)
1842 if err != nil {
1843 t.Fatal(err)
1844 }
1845 body, err := io.ReadAll(res.Body)
1846 if err != nil {
1847 t.Fatal(err)
1848 }
1849 if !bytes.Equal(body, rgz) {
1850 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
1851 body, rgz)
1852 }
1853 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1854 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1855 }
1856 }
1857
1858
1859
1860 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
1861 func testTransportGzipShort(t *testing.T, mode testMode) {
1862 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1863 w.Header().Set("Content-Encoding", "gzip")
1864 w.Write([]byte{0x1f, 0x8b})
1865 })).ts
1866
1867 c := ts.Client()
1868 res, err := c.Get(ts.URL)
1869 if err != nil {
1870 t.Fatal(err)
1871 }
1872 defer res.Body.Close()
1873 _, err = io.ReadAll(res.Body)
1874 if err == nil {
1875 t.Fatal("Expect an error from reading a body.")
1876 }
1877 if err != io.ErrUnexpectedEOF {
1878 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
1879 }
1880 }
1881
1882
1883 func waitNumGoroutine(nmax int) int {
1884 nfinal := runtime.NumGoroutine()
1885 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
1886 time.Sleep(50 * time.Millisecond)
1887 runtime.GC()
1888 nfinal = runtime.NumGoroutine()
1889 }
1890 return nfinal
1891 }
1892
1893
1894 func TestTransportPersistConnLeak(t *testing.T) {
1895 run(t, testTransportPersistConnLeak, testNotParallel)
1896 }
1897 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
1898 if mode == http2Mode {
1899 t.Skip("flaky in HTTP/2")
1900 }
1901
1902
1903 const numReq = 25
1904 gotReqCh := make(chan bool, numReq)
1905 unblockCh := make(chan bool, numReq)
1906 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1907 gotReqCh <- true
1908 <-unblockCh
1909 w.Header().Set("Content-Length", "0")
1910 w.WriteHeader(204)
1911 })).ts
1912 c := ts.Client()
1913 tr := c.Transport.(*Transport)
1914
1915 n0 := runtime.NumGoroutine()
1916
1917 didReqCh := make(chan bool, numReq)
1918 failed := make(chan bool, numReq)
1919 for i := 0; i < numReq; i++ {
1920 go func() {
1921 res, err := c.Get(ts.URL)
1922 didReqCh <- true
1923 if err != nil {
1924 t.Logf("client fetch error: %v", err)
1925 failed <- true
1926 return
1927 }
1928 res.Body.Close()
1929 }()
1930 }
1931
1932
1933 for i := 0; i < numReq; i++ {
1934 select {
1935 case <-gotReqCh:
1936
1937 case <-failed:
1938
1939
1940 }
1941 }
1942
1943 nhigh := runtime.NumGoroutine()
1944
1945
1946 close(unblockCh)
1947
1948
1949 for i := 0; i < numReq; i++ {
1950 <-didReqCh
1951 }
1952
1953 tr.CloseIdleConnections()
1954 nfinal := waitNumGoroutine(n0 + 5)
1955
1956 growth := nfinal - n0
1957
1958
1959
1960 if int(growth) > 5 {
1961 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1962 t.Error("too many new goroutines")
1963 }
1964 }
1965
1966
1967
1968 func TestTransportPersistConnLeakShortBody(t *testing.T) {
1969 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
1970 }
1971 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
1972 if mode == http2Mode {
1973 t.Skip("flaky in HTTP/2")
1974 }
1975
1976
1977 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1978 })).ts
1979 c := ts.Client()
1980 tr := c.Transport.(*Transport)
1981
1982 n0 := runtime.NumGoroutine()
1983 body := []byte("Hello")
1984 for i := 0; i < 20; i++ {
1985 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1986 if err != nil {
1987 t.Fatal(err)
1988 }
1989 req.ContentLength = int64(len(body) - 2)
1990 _, err = c.Do(req)
1991 if err == nil {
1992 t.Fatal("Expect an error from writing too long of a body.")
1993 }
1994 }
1995 nhigh := runtime.NumGoroutine()
1996 tr.CloseIdleConnections()
1997 nfinal := waitNumGoroutine(n0 + 5)
1998
1999 growth := nfinal - n0
2000
2001
2002
2003 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2004 if int(growth) > 5 {
2005 t.Error("too many new goroutines")
2006 }
2007 }
2008
2009
2010 type countedConn struct {
2011 net.Conn
2012 }
2013
2014
2015 type countingDialer struct {
2016 dialer net.Dialer
2017 mu sync.Mutex
2018 total, live int64
2019 }
2020
2021 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
2022 conn, err := d.dialer.DialContext(ctx, network, address)
2023 if err != nil {
2024 return nil, err
2025 }
2026
2027 counted := new(countedConn)
2028 counted.Conn = conn
2029
2030 d.mu.Lock()
2031 defer d.mu.Unlock()
2032 d.total++
2033 d.live++
2034
2035 runtime.SetFinalizer(counted, d.decrement)
2036 return counted, nil
2037 }
2038
2039 func (d *countingDialer) decrement(*countedConn) {
2040 d.mu.Lock()
2041 defer d.mu.Unlock()
2042 d.live--
2043 }
2044
2045 func (d *countingDialer) Read() (total, live int64) {
2046 d.mu.Lock()
2047 defer d.mu.Unlock()
2048 return d.total, d.live
2049 }
2050
2051 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
2052 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
2053 }
2054 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
2055 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2056
2057 conn, _, err := w.(Hijacker).Hijack()
2058 if err != nil {
2059 t.Errorf("Hijack failed unexpectedly: %v", err)
2060 return
2061 }
2062 conn.Close()
2063 })).ts
2064
2065 var d countingDialer
2066 c := ts.Client()
2067 c.Transport.(*Transport).DialContext = d.DialContext
2068
2069 body := []byte("Hello")
2070 for i := 0; ; i++ {
2071 total, live := d.Read()
2072 if live < total {
2073 break
2074 }
2075 if i >= 1<<12 {
2076 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
2077 }
2078
2079 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2080 if err != nil {
2081 t.Fatal(err)
2082 }
2083 _, err = c.Do(req)
2084 if err == nil {
2085 t.Fatal("expected broken connection")
2086 }
2087
2088 runtime.GC()
2089 }
2090 }
2091
2092 type countedContext struct {
2093 context.Context
2094 }
2095
2096 type contextCounter struct {
2097 mu sync.Mutex
2098 live int64
2099 }
2100
2101 func (cc *contextCounter) Track(ctx context.Context) context.Context {
2102 counted := new(countedContext)
2103 counted.Context = ctx
2104 cc.mu.Lock()
2105 defer cc.mu.Unlock()
2106 cc.live++
2107 runtime.SetFinalizer(counted, cc.decrement)
2108 return counted
2109 }
2110
2111 func (cc *contextCounter) decrement(*countedContext) {
2112 cc.mu.Lock()
2113 defer cc.mu.Unlock()
2114 cc.live--
2115 }
2116
2117 func (cc *contextCounter) Read() (live int64) {
2118 cc.mu.Lock()
2119 defer cc.mu.Unlock()
2120 return cc.live
2121 }
2122
2123 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2124 run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
2125 }
2126 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2127 if mode == http2Mode {
2128 t.Skip("https://go.dev/issue/56021")
2129 }
2130
2131 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2132 runtime.Gosched()
2133 w.WriteHeader(StatusOK)
2134 })).ts
2135
2136 c := ts.Client()
2137 c.Transport.(*Transport).MaxConnsPerHost = 1
2138
2139 ctx := context.Background()
2140 body := []byte("Hello")
2141 doPosts := func(cc *contextCounter) {
2142 var wg sync.WaitGroup
2143 for n := 64; n > 0; n-- {
2144 wg.Add(1)
2145 go func() {
2146 defer wg.Done()
2147
2148 ctx := cc.Track(ctx)
2149 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2150 if err != nil {
2151 t.Error(err)
2152 }
2153
2154 _, err = c.Do(req.WithContext(ctx))
2155 if err != nil {
2156 t.Errorf("Do failed with error: %v", err)
2157 }
2158 }()
2159 }
2160 wg.Wait()
2161 }
2162
2163 var initialCC contextCounter
2164 doPosts(&initialCC)
2165
2166
2167
2168
2169 var flushCC contextCounter
2170 for i := 0; ; i++ {
2171 live := initialCC.Read()
2172 if live == 0 {
2173 break
2174 }
2175 if i >= 100 {
2176 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2177 }
2178 doPosts(&flushCC)
2179 runtime.GC()
2180 }
2181 }
2182
2183
2184 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2185 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2186 var tr *Transport
2187
2188 unblockCh := make(chan bool, 1)
2189 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2190 <-unblockCh
2191 tr.CloseIdleConnections()
2192 })).ts
2193 c := ts.Client()
2194 tr = c.Transport.(*Transport)
2195
2196 didreq := make(chan bool)
2197 go func() {
2198 res, err := c.Get(ts.URL)
2199 if err != nil {
2200 t.Error(err)
2201 } else {
2202 res.Body.Close()
2203 }
2204 didreq <- true
2205 }()
2206 unblockCh <- true
2207 <-didreq
2208 }
2209
2210
2211
2212
2213
2214 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2215 func testIssue3644(t *testing.T, mode testMode) {
2216 const numFoos = 5000
2217 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2218 w.Header().Set("Connection", "close")
2219 for i := 0; i < numFoos; i++ {
2220 w.Write([]byte("foo "))
2221 }
2222 })).ts
2223 c := ts.Client()
2224 res, err := c.Get(ts.URL)
2225 if err != nil {
2226 t.Fatal(err)
2227 }
2228 defer res.Body.Close()
2229 bs, err := io.ReadAll(res.Body)
2230 if err != nil {
2231 t.Fatal(err)
2232 }
2233 if len(bs) != numFoos*len("foo ") {
2234 t.Errorf("unexpected response length")
2235 }
2236 }
2237
2238
2239
2240 func TestIssue3595(t *testing.T) {
2241
2242 run(t, testIssue3595, testNotParallel)
2243 }
2244 func testIssue3595(t *testing.T, mode testMode) {
2245 runTimeSensitiveTest(t, []time.Duration{
2246 1 * time.Millisecond,
2247 5 * time.Millisecond,
2248 10 * time.Millisecond,
2249 50 * time.Millisecond,
2250 100 * time.Millisecond,
2251 500 * time.Millisecond,
2252 time.Second,
2253 5 * time.Second,
2254 }, func(t *testing.T, timeout time.Duration) error {
2255 SetRSTAvoidanceDelay(t, timeout)
2256 t.Logf("set RST avoidance delay to %v", timeout)
2257
2258 const deniedMsg = "sorry, denied."
2259 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2260 Error(w, deniedMsg, StatusUnauthorized)
2261 }))
2262
2263
2264 defer cst.close()
2265 ts := cst.ts
2266 c := ts.Client()
2267
2268 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2269 if err != nil {
2270 return fmt.Errorf("Post: %v", err)
2271 }
2272 got, err := io.ReadAll(res.Body)
2273 if err != nil {
2274 return fmt.Errorf("Body ReadAll: %v", err)
2275 }
2276 t.Logf("server response:\n%s", got)
2277 if !strings.Contains(string(got), deniedMsg) {
2278
2279
2280 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2281 }
2282 return nil
2283 })
2284 }
2285
2286
2287
2288 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2289 func testChunkedNoContent(t *testing.T, mode testMode) {
2290 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2291 w.WriteHeader(StatusNoContent)
2292 })).ts
2293
2294 c := ts.Client()
2295 for _, closeBody := range []bool{true, false} {
2296 const n = 4
2297 for i := 1; i <= n; i++ {
2298 res, err := c.Get(ts.URL)
2299 if err != nil {
2300 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2301 } else {
2302 if closeBody {
2303 res.Body.Close()
2304 }
2305 }
2306 }
2307 }
2308 }
2309
2310 func TestTransportConcurrency(t *testing.T) {
2311 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2312 }
2313 func testTransportConcurrency(t *testing.T, mode testMode) {
2314
2315 maxProcs, numReqs := 16, 500
2316 if testing.Short() {
2317 maxProcs, numReqs = 4, 50
2318 }
2319 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2320 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2321 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2322 })).ts
2323
2324 var wg sync.WaitGroup
2325 wg.Add(numReqs)
2326
2327
2328
2329
2330
2331
2332
2333 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2334 defer SetPendingDialHooks(nil, nil)
2335
2336 c := ts.Client()
2337 reqs := make(chan string)
2338 defer close(reqs)
2339
2340 for i := 0; i < maxProcs*2; i++ {
2341 go func() {
2342 for req := range reqs {
2343 res, err := c.Get(ts.URL + "/?echo=" + req)
2344 if err != nil {
2345 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2346
2347
2348 t.Logf("error on req %s: %v", req, err)
2349 t.Logf("(see https://go.dev/issue/52168)")
2350 } else {
2351 t.Errorf("error on req %s: %v", req, err)
2352 }
2353 wg.Done()
2354 continue
2355 }
2356 all, err := io.ReadAll(res.Body)
2357 if err != nil {
2358 t.Errorf("read error on req %s: %v", req, err)
2359 } else if string(all) != req {
2360 t.Errorf("body of req %s = %q; want %q", req, all, req)
2361 }
2362 res.Body.Close()
2363 wg.Done()
2364 }
2365 }()
2366 }
2367 for i := 0; i < numReqs; i++ {
2368 reqs <- fmt.Sprintf("request-%d", i)
2369 }
2370 wg.Wait()
2371 }
2372
2373 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2374 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2375 mux := NewServeMux()
2376 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2377 io.Copy(w, neverEnding('a'))
2378 })
2379 ts := newClientServerTest(t, mode, mux).ts
2380
2381 connc := make(chan net.Conn, 1)
2382 c := ts.Client()
2383 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2384 conn, err := net.Dial(n, addr)
2385 if err != nil {
2386 return nil, err
2387 }
2388 select {
2389 case connc <- conn:
2390 default:
2391 }
2392 return conn, nil
2393 }
2394
2395 res, err := c.Get(ts.URL + "/get")
2396 if err != nil {
2397 t.Fatalf("Error issuing GET: %v", err)
2398 }
2399 defer res.Body.Close()
2400
2401 conn := <-connc
2402 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2403 _, err = io.Copy(io.Discard, res.Body)
2404 if err == nil {
2405 t.Errorf("Unexpected successful copy")
2406 }
2407 }
2408
2409 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2410 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2411 }
2412 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2413 const debug = false
2414 mux := NewServeMux()
2415 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2416 io.Copy(w, neverEnding('a'))
2417 })
2418 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2419 defer r.Body.Close()
2420 io.Copy(io.Discard, r.Body)
2421 })
2422 ts := newClientServerTest(t, mode, mux).ts
2423 timeout := 100 * time.Millisecond
2424
2425 c := ts.Client()
2426 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2427 conn, err := net.Dial(n, addr)
2428 if err != nil {
2429 return nil, err
2430 }
2431 conn.SetDeadline(time.Now().Add(timeout))
2432 if debug {
2433 conn = NewLoggingConn("client", conn)
2434 }
2435 return conn, nil
2436 }
2437
2438 getFailed := false
2439 nRuns := 5
2440 if testing.Short() {
2441 nRuns = 1
2442 }
2443 for i := 0; i < nRuns; i++ {
2444 if debug {
2445 println("run", i+1, "of", nRuns)
2446 }
2447 sres, err := c.Get(ts.URL + "/get")
2448 if err != nil {
2449 if !getFailed {
2450
2451 getFailed = true
2452 t.Logf("increasing timeout")
2453 i--
2454 timeout *= 10
2455 continue
2456 }
2457 t.Errorf("Error issuing GET: %v", err)
2458 break
2459 }
2460 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2461 _, err = c.Do(req)
2462 if err == nil {
2463 sres.Body.Close()
2464 t.Errorf("Unexpected successful PUT")
2465 break
2466 }
2467 sres.Body.Close()
2468 }
2469 if debug {
2470 println("tests complete; waiting for handlers to finish")
2471 }
2472 ts.Close()
2473 }
2474
2475 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2476 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2477 if testing.Short() {
2478 t.Skip("skipping timeout test in -short mode")
2479 }
2480
2481 timeout := 2 * time.Millisecond
2482 retry := true
2483 for retry && !t.Failed() {
2484 var srvWG sync.WaitGroup
2485 inHandler := make(chan bool, 1)
2486 mux := NewServeMux()
2487 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2488 inHandler <- true
2489 srvWG.Done()
2490 })
2491 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2492 inHandler <- true
2493 <-r.Context().Done()
2494 srvWG.Done()
2495 })
2496 ts := newClientServerTest(t, mode, mux).ts
2497
2498 c := ts.Client()
2499 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2500
2501 retry = false
2502 srvWG.Add(3)
2503 tests := []struct {
2504 path string
2505 wantTimeout bool
2506 }{
2507 {path: "/fast"},
2508 {path: "/slow", wantTimeout: true},
2509 {path: "/fast"},
2510 }
2511 for i, tt := range tests {
2512 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2513 req = req.WithT(t)
2514 res, err := c.Do(req)
2515 <-inHandler
2516 if err != nil {
2517 uerr, ok := err.(*url.Error)
2518 if !ok {
2519 t.Errorf("error is not a url.Error; got: %#v", err)
2520 continue
2521 }
2522 nerr, ok := uerr.Err.(net.Error)
2523 if !ok {
2524 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2525 continue
2526 }
2527 if !nerr.Timeout() {
2528 t.Errorf("want timeout error; got: %q", nerr)
2529 continue
2530 }
2531 if !tt.wantTimeout {
2532 if !retry {
2533
2534 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2535 timeout *= 2
2536 retry = true
2537 }
2538 }
2539 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2540 t.Errorf("%d. unexpected error: %v", i, err)
2541 }
2542 continue
2543 }
2544 if tt.wantTimeout {
2545 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2546 continue
2547 }
2548 if res.StatusCode != 200 {
2549 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2550 }
2551 }
2552
2553 srvWG.Wait()
2554 ts.Close()
2555 }
2556 }
2557
2558
2559 type cancelTest struct {
2560 mode testMode
2561 newReq func(req *Request) *Request
2562 cancel func(tr *Transport, req *Request)
2563 checkErr func(when string, err error)
2564 }
2565
2566
2567 func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2568 t.Run("TransportCancel", func(t *testing.T) {
2569 f(t, cancelTest{
2570 mode: mode,
2571 newReq: func(req *Request) *Request {
2572 return req
2573 },
2574 cancel: func(tr *Transport, req *Request) {
2575 tr.CancelRequest(req)
2576 },
2577 checkErr: func(when string, err error) {
2578 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2579 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2580 }
2581 },
2582 })
2583 })
2584 }
2585
2586
2587 func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2588 var cancelOnce sync.Once
2589 cancelc := make(chan struct{})
2590 f(t, cancelTest{
2591 mode: mode,
2592 newReq: func(req *Request) *Request {
2593 req.Cancel = cancelc
2594 return req
2595 },
2596 cancel: func(tr *Transport, req *Request) {
2597 cancelOnce.Do(func() {
2598 close(cancelc)
2599 })
2600 },
2601 checkErr: func(when string, err error) {
2602 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2603 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2604 }
2605 },
2606 })
2607 }
2608
2609
2610 func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2611 ctx, cancel := context.WithCancel(context.Background())
2612 f(t, cancelTest{
2613 mode: mode,
2614 newReq: func(req *Request) *Request {
2615 return req.WithContext(ctx)
2616 },
2617 cancel: func(tr *Transport, req *Request) {
2618 cancel()
2619 },
2620 checkErr: func(when string, err error) {
2621 if !errors.Is(err, context.Canceled) {
2622 t.Errorf("%v error = %v, want context.Canceled", when, err)
2623 }
2624 },
2625 })
2626 }
2627
2628 func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
2629 run(t, func(t *testing.T, mode testMode) {
2630 if mode == http1Mode {
2631 t.Run("TransportCancel", func(t *testing.T) {
2632 runCancelTestTransport(t, mode, f)
2633 })
2634 }
2635 t.Run("RequestCancel", func(t *testing.T) {
2636 runCancelTestChannel(t, mode, f)
2637 })
2638 t.Run("ContextCancel", func(t *testing.T) {
2639 runCancelTestContext(t, mode, f)
2640 })
2641 }, opts...)
2642 }
2643
2644 func TestTransportCancelRequest(t *testing.T) {
2645 runCancelTest(t, testTransportCancelRequest)
2646 }
2647 func testTransportCancelRequest(t *testing.T, test cancelTest) {
2648 if testing.Short() {
2649 t.Skip("skipping test in -short mode")
2650 }
2651
2652 const msg = "Hello"
2653 unblockc := make(chan bool)
2654 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2655 io.WriteString(w, msg)
2656 w.(Flusher).Flush()
2657 <-unblockc
2658 })).ts
2659 defer close(unblockc)
2660
2661 c := ts.Client()
2662 tr := c.Transport.(*Transport)
2663
2664 req, _ := NewRequest("GET", ts.URL, nil)
2665 req = test.newReq(req)
2666 res, err := c.Do(req)
2667 if err != nil {
2668 t.Fatal(err)
2669 }
2670 body := make([]byte, len(msg))
2671 n, _ := io.ReadFull(res.Body, body)
2672 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2673 t.Errorf("Body = %q; want %q", body[:n], msg)
2674 }
2675 test.cancel(tr, req)
2676
2677 tail, err := io.ReadAll(res.Body)
2678 res.Body.Close()
2679 test.checkErr("Body.Read", err)
2680 if len(tail) > 0 {
2681 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2682 }
2683
2684
2685
2686 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2687 n := tr.NumPendingRequestsForTesting()
2688 if n > 0 {
2689 if d > 0 {
2690 t.Logf("pending requests = %d after %v (want 0)", n, d)
2691 }
2692 return false
2693 }
2694 return true
2695 })
2696 }
2697
2698 func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
2699 if testing.Short() {
2700 t.Skip("skipping test in -short mode")
2701 }
2702 unblockc := make(chan bool)
2703 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2704 <-unblockc
2705 })).ts
2706 defer close(unblockc)
2707
2708 c := ts.Client()
2709 tr := c.Transport.(*Transport)
2710
2711 donec := make(chan bool)
2712 req, _ := NewRequest("GET", ts.URL, body)
2713 req = test.newReq(req)
2714 go func() {
2715 defer close(donec)
2716 c.Do(req)
2717 }()
2718
2719 unblockc <- true
2720 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2721 test.cancel(tr, req)
2722 select {
2723 case <-donec:
2724 return true
2725 default:
2726 if d > 0 {
2727 t.Logf("Do of canceled request has not returned after %v", d)
2728 }
2729 return false
2730 }
2731 })
2732 }
2733
2734 func TestTransportCancelRequestInDo(t *testing.T) {
2735 runCancelTest(t, func(t *testing.T, test cancelTest) {
2736 testTransportCancelRequestInDo(t, test, nil)
2737 })
2738 }
2739
2740 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2741 runCancelTest(t, func(t *testing.T, test cancelTest) {
2742 testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
2743 })
2744 }
2745
2746 func TestTransportCancelRequestInDial(t *testing.T) {
2747 runCancelTest(t, testTransportCancelRequestInDial)
2748 }
2749 func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
2750 defer afterTest(t)
2751 if testing.Short() {
2752 t.Skip("skipping test in -short mode")
2753 }
2754 var logbuf strings.Builder
2755 eventLog := log.New(&logbuf, "", 0)
2756
2757 unblockDial := make(chan bool)
2758 defer close(unblockDial)
2759
2760 inDial := make(chan bool)
2761 tr := &Transport{
2762 Dial: func(network, addr string) (net.Conn, error) {
2763 eventLog.Println("dial: blocking")
2764 if !<-inDial {
2765 return nil, errors.New("main Test goroutine exited")
2766 }
2767 <-unblockDial
2768 return nil, errors.New("nope")
2769 },
2770 }
2771 cl := &Client{Transport: tr}
2772 gotres := make(chan bool)
2773 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2774 req = test.newReq(req)
2775 go func() {
2776 _, err := cl.Do(req)
2777 eventLog.Printf("Get error = %v", err != nil)
2778 test.checkErr("Get", err)
2779 gotres <- true
2780 }()
2781
2782 inDial <- true
2783
2784 eventLog.Printf("canceling")
2785 test.cancel(tr, req)
2786 test.cancel(tr, req)
2787
2788 if d, ok := t.Deadline(); ok {
2789
2790
2791 timeout := time.Until(d) * 19 / 20
2792 timer := time.AfterFunc(timeout, func() {
2793 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2794 })
2795 defer timer.Stop()
2796 }
2797 <-gotres
2798
2799 got := logbuf.String()
2800 want := `dial: blocking
2801 canceling
2802 Get error = true
2803 `
2804 if got != want {
2805 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2806 }
2807 }
2808
2809
2810 func TestTransportCancelRequestWithBody(t *testing.T) {
2811 runCancelTest(t, testTransportCancelRequestWithBody)
2812 }
2813 func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
2814 if testing.Short() {
2815 t.Skip("skipping test in -short mode")
2816 }
2817
2818 const msg = "Hello"
2819 unblockc := make(chan struct{})
2820 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2821 io.WriteString(w, msg)
2822 w.(Flusher).Flush()
2823 <-unblockc
2824 })).ts
2825 defer close(unblockc)
2826
2827 c := ts.Client()
2828 tr := c.Transport.(*Transport)
2829
2830 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
2831 req = test.newReq(req)
2832
2833 res, err := c.Do(req)
2834 if err != nil {
2835 t.Fatal(err)
2836 }
2837 body := make([]byte, len(msg))
2838 n, _ := io.ReadFull(res.Body, body)
2839 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2840 t.Errorf("Body = %q; want %q", body[:n], msg)
2841 }
2842 test.cancel(tr, req)
2843
2844 tail, err := io.ReadAll(res.Body)
2845 res.Body.Close()
2846 test.checkErr("Body.Read", err)
2847 if len(tail) > 0 {
2848 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2849 }
2850
2851
2852
2853 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2854 n := tr.NumPendingRequestsForTesting()
2855 if n > 0 {
2856 if d > 0 {
2857 t.Logf("pending requests = %d after %v (want 0)", n, d)
2858 }
2859 return false
2860 }
2861 return true
2862 })
2863 }
2864
2865 func TestTransportCancelRequestBeforeDo(t *testing.T) {
2866
2867 run(t, func(t *testing.T, mode testMode) {
2868 t.Run("RequestCancel", func(t *testing.T) {
2869 runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
2870 })
2871 t.Run("ContextCancel", func(t *testing.T) {
2872 runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
2873 })
2874 })
2875 }
2876 func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
2877 unblockc := make(chan bool)
2878 cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2879 <-unblockc
2880 }))
2881 defer close(unblockc)
2882
2883 c := cst.ts.Client()
2884
2885 req, _ := NewRequest("GET", cst.ts.URL, nil)
2886 req = test.newReq(req)
2887 test.cancel(cst.tr, req)
2888
2889 _, err := c.Do(req)
2890 test.checkErr("Do", err)
2891 }
2892
2893
2894 func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
2895 runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
2896 }
2897 func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
2898 defer afterTest(t)
2899
2900 serverConnCh := make(chan net.Conn, 1)
2901 tr := &Transport{
2902 Dial: func(network, addr string) (net.Conn, error) {
2903 cc, sc := net.Pipe()
2904 serverConnCh <- sc
2905 return cc, nil
2906 },
2907 }
2908 defer tr.CloseIdleConnections()
2909 errc := make(chan error, 1)
2910 req, _ := NewRequest("GET", "http://example.com/", nil)
2911 req = test.newReq(req)
2912 go func() {
2913 _, err := tr.RoundTrip(req)
2914 errc <- err
2915 }()
2916
2917 sc := <-serverConnCh
2918 verb := make([]byte, 3)
2919 if _, err := io.ReadFull(sc, verb); err != nil {
2920 t.Errorf("Error reading HTTP verb from server: %v", err)
2921 }
2922 if string(verb) != "GET" {
2923 t.Errorf("server received %q; want GET", verb)
2924 }
2925 defer sc.Close()
2926
2927 test.cancel(tr, req)
2928
2929 err := <-errc
2930 if err == nil {
2931 t.Fatalf("unexpected success from RoundTrip")
2932 }
2933 test.checkErr("RoundTrip", err)
2934 }
2935
2936
2937
2938
2939 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
2940 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
2941 writeErr := make(chan error, 1)
2942 msg := []byte("young\n")
2943 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2944 for {
2945 _, err := w.Write(msg)
2946 if err != nil {
2947 writeErr <- err
2948 return
2949 }
2950 w.(Flusher).Flush()
2951 }
2952 })).ts
2953
2954 c := ts.Client()
2955 tr := c.Transport.(*Transport)
2956
2957 req, _ := NewRequest("GET", ts.URL, nil)
2958 defer tr.CancelRequest(req)
2959
2960 res, err := c.Do(req)
2961 if err != nil {
2962 t.Fatal(err)
2963 }
2964
2965 const repeats = 3
2966 buf := make([]byte, len(msg)*repeats)
2967 want := bytes.Repeat(msg, repeats)
2968
2969 _, err = io.ReadFull(res.Body, buf)
2970 if err != nil {
2971 t.Fatal(err)
2972 }
2973 if !bytes.Equal(buf, want) {
2974 t.Fatalf("read %q; want %q", buf, want)
2975 }
2976
2977 if err := res.Body.Close(); err != nil {
2978 t.Errorf("Close = %v", err)
2979 }
2980
2981 if err := <-writeErr; err == nil {
2982 t.Errorf("expected non-nil write error")
2983 }
2984 }
2985
2986 type fooProto struct{}
2987
2988 func (fooProto) RoundTrip(req *Request) (*Response, error) {
2989 res := &Response{
2990 Status: "200 OK",
2991 StatusCode: 200,
2992 Header: make(Header),
2993 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
2994 }
2995 return res, nil
2996 }
2997
2998 func TestTransportAltProto(t *testing.T) {
2999 defer afterTest(t)
3000 tr := &Transport{}
3001 c := &Client{Transport: tr}
3002 tr.RegisterProtocol("foo", fooProto{})
3003 res, err := c.Get("foo://bar.com/path")
3004 if err != nil {
3005 t.Fatal(err)
3006 }
3007 bodyb, err := io.ReadAll(res.Body)
3008 if err != nil {
3009 t.Fatal(err)
3010 }
3011 body := string(bodyb)
3012 if e := "You wanted foo://bar.com/path"; body != e {
3013 t.Errorf("got response %q, want %q", body, e)
3014 }
3015 }
3016
3017 func TestTransportNoHost(t *testing.T) {
3018 defer afterTest(t)
3019 tr := &Transport{}
3020 _, err := tr.RoundTrip(&Request{
3021 Header: make(Header),
3022 URL: &url.URL{
3023 Scheme: "http",
3024 },
3025 })
3026 want := "http: no Host in request URL"
3027 if got := fmt.Sprint(err); got != want {
3028 t.Errorf("error = %v; want %q", err, want)
3029 }
3030 }
3031
3032
3033 func TestTransportEmptyMethod(t *testing.T) {
3034 req, _ := NewRequest("GET", "http://foo.com/", nil)
3035 req.Method = ""
3036 got, err := httputil.DumpRequestOut(req, false)
3037 if err != nil {
3038 t.Fatal(err)
3039 }
3040 if !strings.Contains(string(got), "GET ") {
3041 t.Fatalf("expected substring 'GET '; got: %s", got)
3042 }
3043 }
3044
3045 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
3046 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
3047 mux := NewServeMux()
3048 fooGate := make(chan bool, 1)
3049 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
3050 w.Header().Set("foo-ipport", r.RemoteAddr)
3051 w.(Flusher).Flush()
3052 <-fooGate
3053 })
3054 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
3055 w.Header().Set("bar-ipport", r.RemoteAddr)
3056 })
3057 ts := newClientServerTest(t, mode, mux).ts
3058
3059 dialGate := make(chan bool, 1)
3060 dialing := make(chan bool)
3061 c := ts.Client()
3062 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
3063 for {
3064 select {
3065 case ok := <-dialGate:
3066 if !ok {
3067 return nil, errors.New("manually closed")
3068 }
3069 return net.Dial(n, addr)
3070 case dialing <- true:
3071 }
3072 }
3073 }
3074 defer close(dialGate)
3075
3076 dialGate <- true
3077 fooRes, err := c.Get(ts.URL + "/foo")
3078 if err != nil {
3079 t.Fatal(err)
3080 }
3081 fooAddr := fooRes.Header.Get("foo-ipport")
3082 if fooAddr == "" {
3083 t.Fatal("No addr on /foo request")
3084 }
3085
3086 fooDone := make(chan struct{})
3087 go func() {
3088
3089
3090
3091
3092 if mode == http2Mode {
3093
3094
3095
3096
3097 select {
3098 case <-dialing:
3099 t.Errorf("unexpected second Dial in HTTP/2 mode")
3100 case <-time.After(10 * time.Millisecond):
3101 }
3102 } else {
3103 <-dialing
3104 }
3105 fooGate <- true
3106 io.Copy(io.Discard, fooRes.Body)
3107 fooRes.Body.Close()
3108 close(fooDone)
3109 }()
3110 defer func() {
3111 <-fooDone
3112 }()
3113
3114 barRes, err := c.Get(ts.URL + "/bar")
3115 if err != nil {
3116 t.Fatal(err)
3117 }
3118 barAddr := barRes.Header.Get("bar-ipport")
3119 if barAddr != fooAddr {
3120 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3121 }
3122 barRes.Body.Close()
3123 }
3124
3125
3126 func TestTransportReading100Continue(t *testing.T) {
3127 defer afterTest(t)
3128
3129 const numReqs = 5
3130 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3131 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3132
3133 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3134 defer w.Close()
3135 defer r.Close()
3136 br := bufio.NewReader(r)
3137 n := 0
3138 for {
3139 n++
3140 req, err := ReadRequest(br)
3141 if err == io.EOF {
3142 return
3143 }
3144 if err != nil {
3145 t.Error(err)
3146 return
3147 }
3148 slurp, err := io.ReadAll(req.Body)
3149 if err != nil {
3150 t.Errorf("Server request body slurp: %v", err)
3151 return
3152 }
3153 id := req.Header.Get("Request-Id")
3154 resCode := req.Header.Get("X-Want-Response-Code")
3155 if resCode == "" {
3156 resCode = "100 Continue"
3157 if string(slurp) != reqBody(n) {
3158 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3159 }
3160 }
3161 body := fmt.Sprintf("Response number %d", n)
3162 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3163 Date: Thu, 28 Feb 2013 17:55:41 GMT
3164
3165 HTTP/1.1 200 OK
3166 Content-Type: text/html
3167 Echo-Request-Id: %s
3168 Content-Length: %d
3169
3170 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3171 w.Write(v)
3172 if id == reqID(numReqs) {
3173 return
3174 }
3175 }
3176
3177 }
3178
3179 tr := &Transport{
3180 Dial: func(n, addr string) (net.Conn, error) {
3181 sr, sw := io.Pipe()
3182 cr, cw := io.Pipe()
3183 conn := &rwTestConn{
3184 Reader: cr,
3185 Writer: sw,
3186 closeFunc: func() error {
3187 sw.Close()
3188 cw.Close()
3189 return nil
3190 },
3191 }
3192 go send100Response(cw, sr)
3193 return conn, nil
3194 },
3195 DisableKeepAlives: false,
3196 }
3197 defer tr.CloseIdleConnections()
3198 c := &Client{Transport: tr}
3199
3200 testResponse := func(req *Request, name string, wantCode int) {
3201 t.Helper()
3202 res, err := c.Do(req)
3203 if err != nil {
3204 t.Fatalf("%s: Do: %v", name, err)
3205 }
3206 if res.StatusCode != wantCode {
3207 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3208 }
3209 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3210 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3211 }
3212 _, err = io.ReadAll(res.Body)
3213 if err != nil {
3214 t.Fatalf("%s: Slurp error: %v", name, err)
3215 }
3216 }
3217
3218
3219 for i := 1; i <= numReqs; i++ {
3220 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3221 req.Header.Set("Request-Id", reqID(i))
3222 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3223 }
3224 }
3225
3226
3227
3228 func TestTransportIgnore1xxResponses(t *testing.T) {
3229 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3230 }
3231 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3232 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3233 conn, buf, _ := w.(Hijacker).Hijack()
3234 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3235 buf.Flush()
3236 conn.Close()
3237 }))
3238 cst.tr.DisableKeepAlives = true
3239
3240 var got strings.Builder
3241
3242 req, _ := NewRequest("GET", cst.ts.URL, nil)
3243 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3244 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3245 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3246 return nil
3247 },
3248 }))
3249 res, err := cst.c.Do(req)
3250 if err != nil {
3251 t.Fatal(err)
3252 }
3253 defer res.Body.Close()
3254
3255 res.Write(&got)
3256 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3257 if got.String() != want {
3258 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3259 }
3260 }
3261
3262 func TestTransportLimits1xxResponses(t *testing.T) {
3263 run(t, testTransportLimits1xxResponses, []testMode{http1Mode})
3264 }
3265 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3266 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3267 conn, buf, _ := w.(Hijacker).Hijack()
3268 for i := 0; i < 10; i++ {
3269 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
3270 }
3271 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3272 buf.Flush()
3273 conn.Close()
3274 }))
3275 cst.tr.DisableKeepAlives = true
3276
3277 res, err := cst.c.Get(cst.ts.URL)
3278 if res != nil {
3279 defer res.Body.Close()
3280 }
3281 got := fmt.Sprint(err)
3282 wantSub := "too many 1xx informational responses"
3283 if !strings.Contains(got, wantSub) {
3284 t.Errorf("Get error = %v; want substring %q", err, wantSub)
3285 }
3286 }
3287
3288
3289
3290 func TestTransportTreat101Terminal(t *testing.T) {
3291 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3292 }
3293 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3294 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3295 conn, buf, _ := w.(Hijacker).Hijack()
3296 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3297 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3298 buf.Flush()
3299 conn.Close()
3300 }))
3301 res, err := cst.c.Get(cst.ts.URL)
3302 if err != nil {
3303 t.Fatal(err)
3304 }
3305 defer res.Body.Close()
3306 if res.StatusCode != StatusSwitchingProtocols {
3307 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3308 }
3309 }
3310
3311 type proxyFromEnvTest struct {
3312 req string
3313
3314 env string
3315 httpsenv string
3316 noenv string
3317 reqmeth string
3318
3319 want string
3320 wanterr error
3321 }
3322
3323 func (t proxyFromEnvTest) String() string {
3324 var buf strings.Builder
3325 space := func() {
3326 if buf.Len() > 0 {
3327 buf.WriteByte(' ')
3328 }
3329 }
3330 if t.env != "" {
3331 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3332 }
3333 if t.httpsenv != "" {
3334 space()
3335 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3336 }
3337 if t.noenv != "" {
3338 space()
3339 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3340 }
3341 if t.reqmeth != "" {
3342 space()
3343 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3344 }
3345 req := "http://example.com"
3346 if t.req != "" {
3347 req = t.req
3348 }
3349 space()
3350 fmt.Fprintf(&buf, "req=%q", req)
3351 return strings.TrimSpace(buf.String())
3352 }
3353
3354 var proxyFromEnvTests = []proxyFromEnvTest{
3355 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3356 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3357 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3358 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3359 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3360 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3361 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3362 {env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
3363
3364
3365 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3366
3367 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3368 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3369
3370
3371
3372 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3373 want: "<nil>",
3374 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3375
3376 {want: "<nil>"},
3377
3378 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3379 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3380 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3381 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3382 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3383 }
3384
3385 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3386 t.Helper()
3387 reqURL := tt.req
3388 if reqURL == "" {
3389 reqURL = "http://example.com"
3390 }
3391 req, _ := NewRequest("GET", reqURL, nil)
3392 url, err := proxyForRequest(req)
3393 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3394 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3395 return
3396 }
3397 if got := fmt.Sprintf("%s", url); got != tt.want {
3398 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3399 }
3400 }
3401
3402 func TestProxyFromEnvironment(t *testing.T) {
3403 ResetProxyEnv()
3404 defer ResetProxyEnv()
3405 for _, tt := range proxyFromEnvTests {
3406 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3407 os.Setenv("HTTP_PROXY", tt.env)
3408 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3409 os.Setenv("NO_PROXY", tt.noenv)
3410 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3411 ResetCachedEnvironment()
3412 return ProxyFromEnvironment(req)
3413 })
3414 }
3415 }
3416
3417 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3418 ResetProxyEnv()
3419 defer ResetProxyEnv()
3420 for _, tt := range proxyFromEnvTests {
3421 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3422 os.Setenv("http_proxy", tt.env)
3423 os.Setenv("https_proxy", tt.httpsenv)
3424 os.Setenv("no_proxy", tt.noenv)
3425 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3426 ResetCachedEnvironment()
3427 return ProxyFromEnvironment(req)
3428 })
3429 }
3430 }
3431
3432 func TestIdleConnChannelLeak(t *testing.T) {
3433 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3434 }
3435 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3436
3437 var mu sync.Mutex
3438 var n int
3439
3440 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3441 mu.Lock()
3442 n++
3443 mu.Unlock()
3444 })).ts
3445
3446 const nReqs = 5
3447 didRead := make(chan bool, nReqs)
3448 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3449 defer SetReadLoopBeforeNextReadHook(nil)
3450
3451 c := ts.Client()
3452 tr := c.Transport.(*Transport)
3453 tr.Dial = func(netw, addr string) (net.Conn, error) {
3454 return net.Dial(netw, ts.Listener.Addr().String())
3455 }
3456
3457
3458 for _, disableKeep := range []bool{true, false} {
3459 tr.DisableKeepAlives = disableKeep
3460 for i := 0; i < nReqs; i++ {
3461 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3462 if err != nil {
3463 t.Fatal(err)
3464 }
3465
3466
3467
3468
3469
3470 }
3471
3472
3473
3474
3475
3476
3477
3478 for i := 0; i < nReqs; i++ {
3479 <-didRead
3480 }
3481
3482 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3483 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3484 }
3485 }
3486 }
3487
3488
3489
3490
3491 func TestTransportClosesRequestBody(t *testing.T) {
3492 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3493 }
3494 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3495 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3496 io.Copy(io.Discard, r.Body)
3497 })).ts
3498
3499 c := ts.Client()
3500
3501 closes := 0
3502
3503 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3504 if err != nil {
3505 t.Fatal(err)
3506 }
3507 res.Body.Close()
3508 if closes != 1 {
3509 t.Errorf("closes = %d; want 1", closes)
3510 }
3511 }
3512
3513 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3514 defer afterTest(t)
3515 if testing.Short() {
3516 t.Skip("skipping in short mode")
3517 }
3518 ln := newLocalListener(t)
3519 defer ln.Close()
3520 testdonec := make(chan struct{})
3521 defer close(testdonec)
3522
3523 go func() {
3524 c, err := ln.Accept()
3525 if err != nil {
3526 t.Error(err)
3527 return
3528 }
3529 <-testdonec
3530 c.Close()
3531 }()
3532
3533 tr := &Transport{
3534 Dial: func(_, _ string) (net.Conn, error) {
3535 return net.Dial("tcp", ln.Addr().String())
3536 },
3537 TLSHandshakeTimeout: 250 * time.Millisecond,
3538 }
3539 cl := &Client{Transport: tr}
3540 _, err := cl.Get("https://dummy.tld/")
3541 if err == nil {
3542 t.Error("expected error")
3543 return
3544 }
3545 ue, ok := err.(*url.Error)
3546 if !ok {
3547 t.Errorf("expected url.Error; got %#v", err)
3548 return
3549 }
3550 ne, ok := ue.Err.(net.Error)
3551 if !ok {
3552 t.Errorf("expected net.Error; got %#v", err)
3553 return
3554 }
3555 if !ne.Timeout() {
3556 t.Errorf("expected timeout error; got %v", err)
3557 }
3558 if !strings.Contains(err.Error(), "handshake timeout") {
3559 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3560 }
3561 }
3562
3563
3564 func TestTLSServerClosesConnection(t *testing.T) {
3565 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3566 }
3567 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3568 closedc := make(chan bool, 1)
3569 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3570 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3571 conn, _, _ := w.(Hijacker).Hijack()
3572 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3573 conn.Close()
3574 closedc <- true
3575 return
3576 }
3577 fmt.Fprintf(w, "hello")
3578 })).ts
3579
3580 c := ts.Client()
3581 tr := c.Transport.(*Transport)
3582
3583 var nSuccess = 0
3584 var errs []error
3585 const trials = 20
3586 for i := 0; i < trials; i++ {
3587 tr.CloseIdleConnections()
3588 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3589 if err != nil {
3590 t.Fatal(err)
3591 }
3592 <-closedc
3593 slurp, err := io.ReadAll(res.Body)
3594 if err != nil {
3595 t.Fatal(err)
3596 }
3597 if string(slurp) != "foo" {
3598 t.Errorf("Got %q, want foo", slurp)
3599 }
3600
3601
3602
3603 res, err = c.Get(ts.URL + "/")
3604 if err != nil {
3605 errs = append(errs, err)
3606 continue
3607 }
3608 slurp, err = io.ReadAll(res.Body)
3609 if err != nil {
3610 errs = append(errs, err)
3611 continue
3612 }
3613 nSuccess++
3614 }
3615 if nSuccess > 0 {
3616 t.Logf("successes = %d of %d", nSuccess, trials)
3617 } else {
3618 t.Errorf("All runs failed:")
3619 }
3620 for _, err := range errs {
3621 t.Logf(" err: %v", err)
3622 }
3623 }
3624
3625
3626
3627
3628 type byteFromChanReader chan byte
3629
3630 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3631 if len(p) == 0 {
3632 return
3633 }
3634 b, ok := <-c
3635 if !ok {
3636 return 0, io.EOF
3637 }
3638 p[0] = b
3639 return 1, nil
3640 }
3641
3642
3643
3644
3645
3646
3647
3648 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3649 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3650 }
3651 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3652 defer func(d time.Duration) {
3653 *MaxWriteWaitBeforeConnReuse = d
3654 }(*MaxWriteWaitBeforeConnReuse)
3655 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3656 var sconn struct {
3657 sync.Mutex
3658 c net.Conn
3659 }
3660 var getOkay bool
3661 var copying sync.WaitGroup
3662 closeConn := func() {
3663 sconn.Lock()
3664 defer sconn.Unlock()
3665 if sconn.c != nil {
3666 sconn.c.Close()
3667 sconn.c = nil
3668 if !getOkay {
3669 t.Logf("Closed server connection")
3670 }
3671 }
3672 }
3673 defer func() {
3674 closeConn()
3675 copying.Wait()
3676 }()
3677
3678 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3679 if r.Method == "GET" {
3680 io.WriteString(w, "bar")
3681 return
3682 }
3683 conn, _, _ := w.(Hijacker).Hijack()
3684 sconn.Lock()
3685 sconn.c = conn
3686 sconn.Unlock()
3687 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3688
3689 copying.Add(1)
3690 go func() {
3691 io.Copy(io.Discard, conn)
3692 copying.Done()
3693 }()
3694 })).ts
3695 c := ts.Client()
3696
3697 const bodySize = 256 << 10
3698 finalBit := make(byteFromChanReader, 1)
3699 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3700 req.ContentLength = bodySize
3701 res, err := c.Do(req)
3702 if err := wantBody(res, err, "foo"); err != nil {
3703 t.Errorf("POST response: %v", err)
3704 }
3705
3706 res, err = c.Get(ts.URL)
3707 if err := wantBody(res, err, "bar"); err != nil {
3708 t.Errorf("GET response: %v", err)
3709 return
3710 }
3711 getOkay = true
3712 finalBit <- 'x'
3713 close(finalBit)
3714 }
3715
3716
3717
3718 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3719 func testTransportIssue10457(t *testing.T, mode testMode) {
3720 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3721
3722
3723
3724
3725
3726 conn, _, _ := w.(Hijacker).Hijack()
3727 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3728 conn.Close()
3729 })).ts
3730 c := ts.Client()
3731
3732 res, err := c.Get(ts.URL)
3733 if err != nil {
3734 t.Fatalf("Get: %v", err)
3735 }
3736 defer res.Body.Close()
3737
3738
3739
3740
3741 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3742 t.Errorf("Foo header = %q; want %q", got, want)
3743 }
3744 }
3745
3746 type closerFunc func() error
3747
3748 func (f closerFunc) Close() error { return f() }
3749
3750 type writerFuncConn struct {
3751 net.Conn
3752 write func(p []byte) (n int, err error)
3753 }
3754
3755 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769 func TestRetryRequestsOnError(t *testing.T) {
3770 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3771 }
3772 func testRetryRequestsOnError(t *testing.T, mode testMode) {
3773 newRequest := func(method, urlStr string, body io.Reader) *Request {
3774 req, err := NewRequest(method, urlStr, body)
3775 if err != nil {
3776 t.Fatal(err)
3777 }
3778 return req
3779 }
3780
3781 testCases := []struct {
3782 name string
3783 failureN int
3784 failureErr error
3785
3786
3787
3788 req func() *Request
3789 reqString string
3790 }{
3791 {
3792 name: "IdempotentNoBodySomeWritten",
3793
3794
3795 failureN: 1,
3796
3797 failureErr: ExportErrServerClosedIdle,
3798 req: func() *Request {
3799 return newRequest("GET", "http://fake.golang", nil)
3800 },
3801 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3802 },
3803 {
3804 name: "IdempotentGetBodySomeWritten",
3805
3806
3807 failureN: 1,
3808
3809 failureErr: ExportErrServerClosedIdle,
3810 req: func() *Request {
3811 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
3812 },
3813 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3814 },
3815 {
3816 name: "NothingWrittenNoBody",
3817
3818
3819 failureN: 0,
3820 failureErr: errors.New("second write fails"),
3821 req: func() *Request {
3822 return newRequest("DELETE", "http://fake.golang", nil)
3823 },
3824 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3825 },
3826 {
3827 name: "NothingWrittenGetBody",
3828
3829
3830 failureN: 0,
3831 failureErr: errors.New("second write fails"),
3832
3833
3834 req: func() *Request {
3835 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
3836 },
3837 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3838 },
3839 }
3840
3841 for _, tc := range testCases {
3842 t.Run(tc.name, func(t *testing.T) {
3843 var (
3844 mu sync.Mutex
3845 logbuf strings.Builder
3846 )
3847 logf := func(format string, args ...any) {
3848 mu.Lock()
3849 defer mu.Unlock()
3850 fmt.Fprintf(&logbuf, format, args...)
3851 logbuf.WriteByte('\n')
3852 }
3853
3854 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3855 logf("Handler")
3856 w.Header().Set("X-Status", "ok")
3857 })).ts
3858
3859 var writeNumAtomic int32
3860 c := ts.Client()
3861 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
3862 logf("Dial")
3863 c, err := net.Dial(network, ts.Listener.Addr().String())
3864 if err != nil {
3865 logf("Dial error: %v", err)
3866 return nil, err
3867 }
3868 return &writerFuncConn{
3869 Conn: c,
3870 write: func(p []byte) (n int, err error) {
3871 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
3872 logf("intentional write failure")
3873 return tc.failureN, tc.failureErr
3874 }
3875 logf("Write(%q)", p)
3876 return c.Write(p)
3877 },
3878 }, nil
3879 }
3880
3881 SetRoundTripRetried(func() {
3882 logf("Retried.")
3883 })
3884 defer SetRoundTripRetried(nil)
3885
3886 for i := 0; i < 3; i++ {
3887 t0 := time.Now()
3888 req := tc.req()
3889 res, err := c.Do(req)
3890 if err != nil {
3891 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
3892 mu.Lock()
3893 got := logbuf.String()
3894 mu.Unlock()
3895 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
3896 }
3897 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
3898 }
3899 res.Body.Close()
3900 if res.Request != req {
3901 t.Errorf("Response.Request != original request; want identical Request")
3902 }
3903 }
3904
3905 mu.Lock()
3906 got := logbuf.String()
3907 mu.Unlock()
3908 want := fmt.Sprintf(`Dial
3909 Write("%s")
3910 Handler
3911 intentional write failure
3912 Retried.
3913 Dial
3914 Write("%s")
3915 Handler
3916 Write("%s")
3917 Handler
3918 `, tc.reqString, tc.reqString, tc.reqString)
3919 if got != want {
3920 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
3921 }
3922 })
3923 }
3924 }
3925
3926
3927 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
3928 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
3929 readBody := make(chan error, 1)
3930 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3931 _, err := io.ReadAll(r.Body)
3932 readBody <- err
3933 })).ts
3934 c := ts.Client()
3935 fakeErr := errors.New("fake error")
3936 didClose := make(chan bool, 1)
3937 req, _ := NewRequest("POST", ts.URL, struct {
3938 io.Reader
3939 io.Closer
3940 }{
3941 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
3942 closerFunc(func() error {
3943 select {
3944 case didClose <- true:
3945 default:
3946 }
3947 return nil
3948 }),
3949 })
3950 res, err := c.Do(req)
3951 if res != nil {
3952 defer res.Body.Close()
3953 }
3954 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
3955 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
3956 }
3957 if err := <-readBody; err == nil {
3958 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
3959 }
3960 select {
3961 case <-didClose:
3962 default:
3963 t.Errorf("didn't see Body.Close")
3964 }
3965 }
3966
3967 func TestTransportDialTLS(t *testing.T) {
3968 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
3969 }
3970 func testTransportDialTLS(t *testing.T, mode testMode) {
3971 var mu sync.Mutex
3972 var gotReq, didDial bool
3973
3974 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3975 mu.Lock()
3976 gotReq = true
3977 mu.Unlock()
3978 })).ts
3979 c := ts.Client()
3980 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
3981 mu.Lock()
3982 didDial = true
3983 mu.Unlock()
3984 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
3985 if err != nil {
3986 return nil, err
3987 }
3988 return c, c.Handshake()
3989 }
3990
3991 res, err := c.Get(ts.URL)
3992 if err != nil {
3993 t.Fatal(err)
3994 }
3995 res.Body.Close()
3996 mu.Lock()
3997 if !gotReq {
3998 t.Error("didn't get request")
3999 }
4000 if !didDial {
4001 t.Error("didn't use dial hook")
4002 }
4003 }
4004
4005 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
4006 func testTransportDialContext(t *testing.T, mode testMode) {
4007 ctxKey := "some-key"
4008 ctxValue := "some-value"
4009 var (
4010 mu sync.Mutex
4011 gotReq bool
4012 gotCtxValue any
4013 )
4014
4015 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4016 mu.Lock()
4017 gotReq = true
4018 mu.Unlock()
4019 })).ts
4020 c := ts.Client()
4021 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4022 mu.Lock()
4023 gotCtxValue = ctx.Value(ctxKey)
4024 mu.Unlock()
4025 return net.Dial(netw, addr)
4026 }
4027
4028 req, err := NewRequest("GET", ts.URL, nil)
4029 if err != nil {
4030 t.Fatal(err)
4031 }
4032 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4033 res, err := c.Do(req.WithContext(ctx))
4034 if err != nil {
4035 t.Fatal(err)
4036 }
4037 res.Body.Close()
4038 mu.Lock()
4039 if !gotReq {
4040 t.Error("didn't get request")
4041 }
4042 if got, want := gotCtxValue, ctxValue; got != want {
4043 t.Errorf("got context with value %v, want %v", got, want)
4044 }
4045 }
4046
4047 func TestTransportDialTLSContext(t *testing.T) {
4048 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
4049 }
4050 func testTransportDialTLSContext(t *testing.T, mode testMode) {
4051 ctxKey := "some-key"
4052 ctxValue := "some-value"
4053 var (
4054 mu sync.Mutex
4055 gotReq bool
4056 gotCtxValue any
4057 )
4058
4059 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4060 mu.Lock()
4061 gotReq = true
4062 mu.Unlock()
4063 })).ts
4064 c := ts.Client()
4065 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4066 mu.Lock()
4067 gotCtxValue = ctx.Value(ctxKey)
4068 mu.Unlock()
4069 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4070 if err != nil {
4071 return nil, err
4072 }
4073 return c, c.HandshakeContext(ctx)
4074 }
4075
4076 req, err := NewRequest("GET", ts.URL, nil)
4077 if err != nil {
4078 t.Fatal(err)
4079 }
4080 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4081 res, err := c.Do(req.WithContext(ctx))
4082 if err != nil {
4083 t.Fatal(err)
4084 }
4085 res.Body.Close()
4086 mu.Lock()
4087 if !gotReq {
4088 t.Error("didn't get request")
4089 }
4090 if got, want := gotCtxValue, ctxValue; got != want {
4091 t.Errorf("got context with value %v, want %v", got, want)
4092 }
4093 }
4094
4095
4096
4097 func TestRoundTripReturnsProxyError(t *testing.T) {
4098 badProxy := func(*Request) (*url.URL, error) {
4099 return nil, errors.New("errorMessage")
4100 }
4101
4102 tr := &Transport{Proxy: badProxy}
4103
4104 req, _ := NewRequest("GET", "http://example.com", nil)
4105
4106 _, err := tr.RoundTrip(req)
4107
4108 if err == nil {
4109 t.Error("Expected proxy error to be returned by RoundTrip")
4110 }
4111 }
4112
4113
4114 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
4115 tr := &Transport{}
4116 wantIdle := func(when string, n int) bool {
4117 got := tr.IdleConnCountForTesting("http", "example.com")
4118 if got == n {
4119 return true
4120 }
4121 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4122 return false
4123 }
4124 wantIdle("start", 0)
4125 if !tr.PutIdleTestConn("http", "example.com") {
4126 t.Fatal("put failed")
4127 }
4128 if !tr.PutIdleTestConn("http", "example.com") {
4129 t.Fatal("second put failed")
4130 }
4131 wantIdle("after put", 2)
4132 tr.CloseIdleConnections()
4133 if !tr.IsIdleForTesting() {
4134 t.Error("should be idle after CloseIdleConnections")
4135 }
4136 wantIdle("after close idle", 0)
4137 if tr.PutIdleTestConn("http", "example.com") {
4138 t.Fatal("put didn't fail")
4139 }
4140 wantIdle("after second put", 0)
4141
4142 tr.QueueForIdleConnForTesting()
4143 if tr.IsIdleForTesting() {
4144 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4145 }
4146 if !tr.PutIdleTestConn("http", "example.com") {
4147 t.Fatal("after re-activation")
4148 }
4149 wantIdle("after final put", 1)
4150 }
4151
4152
4153
4154 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4155 tr := &Transport{}
4156 wantIdle := func(when string, n int) bool {
4157 got := tr.IdleConnCountForTesting("https", "example.com:443")
4158 if got == n {
4159 return true
4160 }
4161 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4162 return false
4163 }
4164 wantIdle("start", 0)
4165 alt := funcRoundTripper(func() {})
4166 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4167 t.Fatal("put failed")
4168 }
4169 wantIdle("after put", 1)
4170 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4171 GotConn: func(httptrace.GotConnInfo) {
4172
4173 t.Error("GotConn called")
4174 },
4175 })
4176 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4177 _, err := tr.RoundTrip(req)
4178 if err != errFakeRoundTrip {
4179 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4180 }
4181 wantIdle("after round trip", 1)
4182 }
4183
4184 func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
4185 run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
4186 }
4187 func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
4188 if testing.Short() {
4189 t.Skip("skipping in short mode")
4190 }
4191
4192 timeout := 1 * time.Millisecond
4193 retry := true
4194 for retry {
4195 trFunc := func(tr *Transport) {
4196 tr.MaxConnsPerHost = 1
4197 tr.MaxIdleConnsPerHost = 1
4198 tr.IdleConnTimeout = timeout
4199 }
4200 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
4201
4202 retry = false
4203 tooShort := func(err error) bool {
4204 if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
4205 return false
4206 }
4207 if !retry {
4208 t.Helper()
4209 t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout)
4210 timeout *= 2
4211 retry = true
4212 cst.close()
4213 }
4214 return true
4215 }
4216
4217 if _, err := cst.c.Get(cst.ts.URL); err != nil {
4218 if tooShort(err) {
4219 continue
4220 }
4221 t.Fatalf("got error: %s", err)
4222 }
4223
4224 time.Sleep(10 * timeout)
4225 if _, err := cst.c.Get(cst.ts.URL); err != nil {
4226 if tooShort(err) {
4227 continue
4228 }
4229 t.Fatalf("got error: %s", err)
4230 }
4231 }
4232 }
4233
4234
4235
4236
4237
4238 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4239 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4240 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4241 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4242 t.Error("Transport advertised gzip support in the Accept header")
4243 }
4244 if r.Header.Get("Range") == "" {
4245 t.Error("no Range in request")
4246 }
4247 })).ts
4248 c := ts.Client()
4249
4250 req, _ := NewRequest("GET", ts.URL, nil)
4251 req.Header.Set("Range", "bytes=7-11")
4252 res, err := c.Do(req)
4253 if err != nil {
4254 t.Fatal(err)
4255 }
4256 res.Body.Close()
4257 }
4258
4259
4260 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4261 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4262 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4263
4264 var b [1024]byte
4265 w.Write(b[:])
4266 })).ts
4267 tr := ts.Client().Transport.(*Transport)
4268
4269 req, err := NewRequest("GET", ts.URL, nil)
4270 if err != nil {
4271 t.Fatal(err)
4272 }
4273 res, err := tr.RoundTrip(req)
4274 if err != nil {
4275 t.Fatal(err)
4276 }
4277
4278
4279
4280 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4281 t.Fatal(err)
4282 }
4283
4284 req2, err := NewRequest("GET", ts.URL, nil)
4285 if err != nil {
4286 t.Fatal(err)
4287 }
4288 tr.CancelRequest(req)
4289 res, err = tr.RoundTrip(req2)
4290 if err != nil {
4291 t.Fatal(err)
4292 }
4293 res.Body.Close()
4294 }
4295
4296
4297 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4298 run(t, testTransportContentEncodingCaseInsensitive)
4299 }
4300 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4301 for _, ce := range []string{"gzip", "GZIP"} {
4302 ce := ce
4303 t.Run(ce, func(t *testing.T) {
4304 const encodedString = "Hello Gopher"
4305 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4306 w.Header().Set("Content-Encoding", ce)
4307 gz := gzip.NewWriter(w)
4308 gz.Write([]byte(encodedString))
4309 gz.Close()
4310 })).ts
4311
4312 res, err := ts.Client().Get(ts.URL)
4313 if err != nil {
4314 t.Fatal(err)
4315 }
4316
4317 body, err := io.ReadAll(res.Body)
4318 res.Body.Close()
4319 if err != nil {
4320 t.Fatal(err)
4321 }
4322
4323 if string(body) != encodedString {
4324 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4325 }
4326 })
4327 }
4328 }
4329
4330
4331 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4332 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4333 }
4334 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4335 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4336 func(tr *Transport) {
4337 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4338
4339 return &funcConn{
4340 read: func([]byte) (int, error) {
4341 return 0, errors.New("error")
4342 },
4343 write: func([]byte) (int, error) {
4344 return 0, errors.New("error")
4345 },
4346 }, nil
4347 }
4348 },
4349 ).ts
4350
4351
4352
4353
4354
4355 SetEnterRoundTripHook(func() {
4356 time.Sleep(1 * time.Millisecond)
4357 })
4358 defer SetEnterRoundTripHook(nil)
4359 var closes int
4360 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4361 if err == nil {
4362 t.Fatalf("expected request to fail, but it did not")
4363 }
4364 if closes != 1 {
4365 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4366 }
4367 }
4368
4369
4370
4371
4372 type logWritesConn struct {
4373 net.Conn
4374
4375 w io.Writer
4376
4377 rch <-chan io.Reader
4378 r io.Reader
4379
4380 mu sync.Mutex
4381 writes []string
4382 }
4383
4384 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4385 c.mu.Lock()
4386 defer c.mu.Unlock()
4387 c.writes = append(c.writes, string(p))
4388 return c.w.Write(p)
4389 }
4390
4391 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4392 if c.r == nil {
4393 c.r = <-c.rch
4394 }
4395 return c.r.Read(p)
4396 }
4397
4398 func (c *logWritesConn) Close() error { return nil }
4399
4400
4401 func TestTransportFlushesBodyChunks(t *testing.T) {
4402 defer afterTest(t)
4403 resBody := make(chan io.Reader, 1)
4404 connr, connw := io.Pipe()
4405 lw := &logWritesConn{
4406 rch: resBody,
4407 w: connw,
4408 }
4409 tr := &Transport{
4410 Dial: func(network, addr string) (net.Conn, error) {
4411 return lw, nil
4412 },
4413 }
4414 bodyr, bodyw := io.Pipe()
4415 go func() {
4416 defer bodyw.Close()
4417 for i := 0; i < 3; i++ {
4418 fmt.Fprintf(bodyw, "num%d\n", i)
4419 }
4420 }()
4421 resc := make(chan *Response)
4422 go func() {
4423 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4424 req.Header.Set("User-Agent", "x")
4425 res, err := tr.RoundTrip(req)
4426 if err != nil {
4427 t.Errorf("RoundTrip: %v", err)
4428 close(resc)
4429 return
4430 }
4431 resc <- res
4432
4433 }()
4434
4435 req, err := ReadRequest(bufio.NewReader(connr))
4436 if err != nil {
4437 t.Fatal(err)
4438 }
4439 io.Copy(io.Discard, req.Body)
4440
4441
4442 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4443 res, ok := <-resc
4444 if !ok {
4445 return
4446 }
4447 defer res.Body.Close()
4448
4449 want := []string{
4450 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4451 "5\r\nnum0\n\r\n",
4452 "5\r\nnum1\n\r\n",
4453 "5\r\nnum2\n\r\n",
4454 "0\r\n\r\n",
4455 }
4456 if !reflect.DeepEqual(lw.writes, want) {
4457 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4458 }
4459 }
4460
4461
4462 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4463 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4464 gotReq := make(chan struct{})
4465 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4466 close(gotReq)
4467 }))
4468
4469 pr, pw := io.Pipe()
4470 req, err := NewRequest("POST", cst.ts.URL, pr)
4471 if err != nil {
4472 t.Fatal(err)
4473 }
4474 gotRes := make(chan struct{})
4475 go func() {
4476 defer close(gotRes)
4477 res, err := cst.tr.RoundTrip(req)
4478 if err != nil {
4479 t.Error(err)
4480 return
4481 }
4482 res.Body.Close()
4483 }()
4484
4485 <-gotReq
4486 pw.Close()
4487 <-gotRes
4488 }
4489
4490 type wgReadCloser struct {
4491 io.Reader
4492 wg *sync.WaitGroup
4493 closed bool
4494 }
4495
4496 func (c *wgReadCloser) Close() error {
4497 if c.closed {
4498 return net.ErrClosed
4499 }
4500 c.closed = true
4501 c.wg.Done()
4502 return nil
4503 }
4504
4505
4506 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4507
4508 run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
4509 }
4510 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4511 if testing.Short() {
4512 t.Skip("skipping in short mode")
4513 }
4514
4515 runTimeSensitiveTest(t, []time.Duration{
4516 1 * time.Millisecond,
4517 5 * time.Millisecond,
4518 10 * time.Millisecond,
4519 50 * time.Millisecond,
4520 100 * time.Millisecond,
4521 500 * time.Millisecond,
4522 time.Second,
4523 5 * time.Second,
4524 }, func(t *testing.T, timeout time.Duration) error {
4525 SetRSTAvoidanceDelay(t, timeout)
4526 t.Logf("set RST avoidance delay to %v", timeout)
4527
4528 const contentLengthLimit = 1024 * 1024
4529 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4530 if r.ContentLength >= contentLengthLimit {
4531 w.WriteHeader(StatusBadRequest)
4532 r.Body.Close()
4533 return
4534 }
4535 w.WriteHeader(StatusOK)
4536 }))
4537
4538
4539 defer cst.close()
4540 ts := cst.ts
4541 c := ts.Client()
4542
4543 count := 100
4544
4545 bigBody := strings.Repeat("a", contentLengthLimit*2)
4546 var wg sync.WaitGroup
4547 defer wg.Wait()
4548 getBody := func() (io.ReadCloser, error) {
4549 wg.Add(1)
4550 body := &wgReadCloser{
4551 Reader: strings.NewReader(bigBody),
4552 wg: &wg,
4553 }
4554 return body, nil
4555 }
4556
4557 for i := 0; i < count; i++ {
4558 reqBody, _ := getBody()
4559 req, err := NewRequest("PUT", ts.URL, reqBody)
4560 if err != nil {
4561 reqBody.Close()
4562 t.Fatal(err)
4563 }
4564 req.ContentLength = int64(len(bigBody))
4565 req.GetBody = getBody
4566
4567 resp, err := c.Do(req)
4568 if err != nil {
4569 return fmt.Errorf("Do %d: %v", i, err)
4570 } else {
4571 resp.Body.Close()
4572 if resp.StatusCode != 400 {
4573 t.Errorf("Expected status code 400, got %v", resp.Status)
4574 }
4575 }
4576 }
4577 return nil
4578 })
4579 }
4580
4581 func TestTransportAutomaticHTTP2(t *testing.T) {
4582 testTransportAutoHTTP(t, &Transport{}, true)
4583 }
4584
4585 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4586 testTransportAutoHTTP(t, &Transport{
4587 ForceAttemptHTTP2: true,
4588 TLSClientConfig: new(tls.Config),
4589 }, true)
4590 }
4591
4592
4593 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4594 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4595 }
4596
4597 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4598 testTransportAutoHTTP(t, &Transport{
4599 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4600 }, false)
4601 }
4602
4603 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4604 testTransportAutoHTTP(t, &Transport{
4605 TLSClientConfig: new(tls.Config),
4606 }, false)
4607 }
4608
4609 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4610 testTransportAutoHTTP(t, &Transport{
4611 ExpectContinueTimeout: 1 * time.Second,
4612 }, true)
4613 }
4614
4615 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4616 var d net.Dialer
4617 testTransportAutoHTTP(t, &Transport{
4618 Dial: d.Dial,
4619 }, false)
4620 }
4621
4622 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4623 var d net.Dialer
4624 testTransportAutoHTTP(t, &Transport{
4625 DialContext: d.DialContext,
4626 }, false)
4627 }
4628
4629 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4630 testTransportAutoHTTP(t, &Transport{
4631 DialTLS: func(network, addr string) (net.Conn, error) {
4632 panic("unused")
4633 },
4634 }, false)
4635 }
4636
4637 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4638 CondSkipHTTP2(t)
4639 _, err := tr.RoundTrip(new(Request))
4640 if err == nil {
4641 t.Error("expected error from RoundTrip")
4642 }
4643 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4644 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4645 }
4646 }
4647
4648
4649
4650
4651
4652
4653
4654
4655 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4656 run(t, testTransportReuseConnEmptyResponseBody)
4657 }
4658 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4659 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4660 w.Header().Set("X-Addr", r.RemoteAddr)
4661
4662 }))
4663 n := 100
4664 if testing.Short() {
4665 n = 10
4666 }
4667 var firstAddr string
4668 for i := 0; i < n; i++ {
4669 res, err := cst.c.Get(cst.ts.URL)
4670 if err != nil {
4671 log.Fatal(err)
4672 }
4673 addr := res.Header.Get("X-Addr")
4674 if i == 0 {
4675 firstAddr = addr
4676 } else if addr != firstAddr {
4677 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
4678 }
4679 res.Body.Close()
4680 }
4681 }
4682
4683
4684 func TestNoCrashReturningTransportAltConn(t *testing.T) {
4685 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4686 if err != nil {
4687 t.Fatal(err)
4688 }
4689 ln := newLocalListener(t)
4690 defer ln.Close()
4691
4692 var wg sync.WaitGroup
4693 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
4694 defer SetPendingDialHooks(nil, nil)
4695
4696 testDone := make(chan struct{})
4697 defer close(testDone)
4698 go func() {
4699 tln := tls.NewListener(ln, &tls.Config{
4700 NextProtos: []string{"foo"},
4701 Certificates: []tls.Certificate{cert},
4702 })
4703 sc, err := tln.Accept()
4704 if err != nil {
4705 t.Error(err)
4706 return
4707 }
4708 if err := sc.(*tls.Conn).Handshake(); err != nil {
4709 t.Error(err)
4710 return
4711 }
4712 <-testDone
4713 sc.Close()
4714 }()
4715
4716 addr := ln.Addr().String()
4717
4718 req, _ := NewRequest("GET", "https://fake.tld/", nil)
4719 cancel := make(chan struct{})
4720 req.Cancel = cancel
4721
4722 doReturned := make(chan bool, 1)
4723 madeRoundTripper := make(chan bool, 1)
4724
4725 tr := &Transport{
4726 DisableKeepAlives: true,
4727 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
4728 "foo": func(authority string, c *tls.Conn) RoundTripper {
4729 madeRoundTripper <- true
4730 return funcRoundTripper(func() {
4731 t.Error("foo RoundTripper should not be called")
4732 })
4733 },
4734 },
4735 Dial: func(_, _ string) (net.Conn, error) {
4736 panic("shouldn't be called")
4737 },
4738 DialTLS: func(_, _ string) (net.Conn, error) {
4739 tc, err := tls.Dial("tcp", addr, &tls.Config{
4740 InsecureSkipVerify: true,
4741 NextProtos: []string{"foo"},
4742 })
4743 if err != nil {
4744 return nil, err
4745 }
4746 if err := tc.Handshake(); err != nil {
4747 return nil, err
4748 }
4749 close(cancel)
4750 <-doReturned
4751 return tc, nil
4752 },
4753 }
4754 c := &Client{Transport: tr}
4755
4756 _, err = c.Do(req)
4757 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
4758 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
4759 }
4760
4761 doReturned <- true
4762 <-madeRoundTripper
4763 wg.Wait()
4764 }
4765
4766 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
4767 run(t, func(t *testing.T, mode testMode) {
4768 testTransportReuseConnection_Gzip(t, mode, true)
4769 })
4770 }
4771
4772 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
4773 run(t, func(t *testing.T, mode testMode) {
4774 testTransportReuseConnection_Gzip(t, mode, false)
4775 })
4776 }
4777
4778
4779 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
4780 addr := make(chan string, 2)
4781 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4782 addr <- r.RemoteAddr
4783 w.Header().Set("Content-Encoding", "gzip")
4784 if chunked {
4785 w.(Flusher).Flush()
4786 }
4787 w.Write(rgz)
4788 })).ts
4789 c := ts.Client()
4790
4791 trace := &httptrace.ClientTrace{
4792 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
4793 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
4794 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
4795 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
4796 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
4797 }
4798 ctx := httptrace.WithClientTrace(context.Background(), trace)
4799
4800 for i := 0; i < 2; i++ {
4801 req, _ := NewRequest("GET", ts.URL, nil)
4802 req = req.WithContext(ctx)
4803 res, err := c.Do(req)
4804 if err != nil {
4805 t.Fatal(err)
4806 }
4807 buf := make([]byte, len(rgz))
4808 if n, err := io.ReadFull(res.Body, buf); err != nil {
4809 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
4810 }
4811
4812
4813
4814 }
4815 a1, a2 := <-addr, <-addr
4816 if a1 != a2 {
4817 t.Fatalf("didn't reuse connection")
4818 }
4819 }
4820
4821 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
4822 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
4823 if mode == http2Mode {
4824 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
4825 }
4826 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4827 if r.URL.Path == "/long" {
4828 w.Header().Set("Long", strings.Repeat("a", 1<<20))
4829 }
4830 })).ts
4831 c := ts.Client()
4832 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
4833
4834 if res, err := c.Get(ts.URL); err != nil {
4835 t.Fatal(err)
4836 } else {
4837 res.Body.Close()
4838 }
4839
4840 res, err := c.Get(ts.URL + "/long")
4841 if err == nil {
4842 defer res.Body.Close()
4843 var n int64
4844 for k, vv := range res.Header {
4845 for _, v := range vv {
4846 n += int64(len(k)) + int64(len(v))
4847 }
4848 }
4849 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
4850 }
4851 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
4852 t.Errorf("got error: %v; want %q", err, want)
4853 }
4854 }
4855
4856 func TestTransportEventTrace(t *testing.T) {
4857 run(t, func(t *testing.T, mode testMode) {
4858 testTransportEventTrace(t, mode, false)
4859 }, testNotParallel)
4860 }
4861
4862
4863 func TestTransportEventTrace_NoHooks(t *testing.T) {
4864 run(t, func(t *testing.T, mode testMode) {
4865 testTransportEventTrace(t, mode, true)
4866 }, testNotParallel)
4867 }
4868
4869 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
4870 const resBody = "some body"
4871 gotWroteReqEvent := make(chan struct{}, 500)
4872 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4873 if r.Method == "GET" {
4874
4875 return
4876 }
4877 if _, err := io.ReadAll(r.Body); err != nil {
4878 t.Error(err)
4879 }
4880 if !noHooks {
4881 <-gotWroteReqEvent
4882 }
4883 io.WriteString(w, resBody)
4884 }), func(tr *Transport) {
4885 if tr.TLSClientConfig != nil {
4886 tr.TLSClientConfig.InsecureSkipVerify = true
4887 }
4888 })
4889 defer cst.close()
4890
4891 cst.tr.ExpectContinueTimeout = 1 * time.Second
4892
4893 var mu sync.Mutex
4894 var buf strings.Builder
4895 logf := func(format string, args ...any) {
4896 mu.Lock()
4897 defer mu.Unlock()
4898 fmt.Fprintf(&buf, format, args...)
4899 buf.WriteByte('\n')
4900 }
4901
4902 addrStr := cst.ts.Listener.Addr().String()
4903 ip, port, err := net.SplitHostPort(addrStr)
4904 if err != nil {
4905 t.Fatal(err)
4906 }
4907
4908
4909 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
4910 if host != "dns-is-faked.golang" {
4911 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
4912 return nil, nil
4913 }
4914 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
4915 })
4916
4917 body := "some body"
4918 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
4919 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
4920 trace := &httptrace.ClientTrace{
4921 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
4922 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
4923 GotFirstResponseByte: func() { logf("first response byte") },
4924 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
4925 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
4926 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
4927 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
4928 ConnectDone: func(network, addr string, err error) {
4929 if err != nil {
4930 t.Errorf("ConnectDone: %v", err)
4931 }
4932 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
4933 },
4934 WroteHeaderField: func(key string, value []string) {
4935 logf("WroteHeaderField: %s: %v", key, value)
4936 },
4937 WroteHeaders: func() {
4938 logf("WroteHeaders")
4939 },
4940 Wait100Continue: func() { logf("Wait100Continue") },
4941 Got100Continue: func() { logf("Got100Continue") },
4942 WroteRequest: func(e httptrace.WroteRequestInfo) {
4943 logf("WroteRequest: %+v", e)
4944 gotWroteReqEvent <- struct{}{}
4945 },
4946 }
4947 if mode == http2Mode {
4948 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
4949 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
4950 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
4951 }
4952 }
4953 if noHooks {
4954
4955 *trace = httptrace.ClientTrace{}
4956 }
4957 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4958
4959 req.Header.Set("Expect", "100-continue")
4960 res, err := cst.c.Do(req)
4961 if err != nil {
4962 t.Fatal(err)
4963 }
4964 logf("got roundtrip.response")
4965 slurp, err := io.ReadAll(res.Body)
4966 if err != nil {
4967 t.Fatal(err)
4968 }
4969 logf("consumed body")
4970 if string(slurp) != resBody || res.StatusCode != 200 {
4971 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
4972 }
4973 res.Body.Close()
4974
4975 if noHooks {
4976
4977
4978
4979 return
4980 }
4981
4982 mu.Lock()
4983 got := buf.String()
4984 mu.Unlock()
4985
4986 wantOnce := func(sub string) {
4987 if strings.Count(got, sub) != 1 {
4988 t.Errorf("expected substring %q exactly once in output.", sub)
4989 }
4990 }
4991 wantOnceOrMore := func(sub string) {
4992 if strings.Count(got, sub) == 0 {
4993 t.Errorf("expected substring %q at least once in output.", sub)
4994 }
4995 }
4996 wantOnce("Getting conn for dns-is-faked.golang:" + port)
4997 wantOnce("DNS start: {Host:dns-is-faked.golang}")
4998 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
4999 wantOnce("got conn: {")
5000 wantOnceOrMore("Connecting to tcp " + addrStr)
5001 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
5002 wantOnce("Reused:false WasIdle:false IdleTime:0s")
5003 wantOnce("first response byte")
5004 if mode == http2Mode {
5005 wantOnce("tls handshake start")
5006 wantOnce("tls handshake done")
5007 } else {
5008 wantOnce("PutIdleConn = <nil>")
5009 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
5010
5011
5012 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
5013 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
5014 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
5015 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
5016 }
5017 wantOnce("WroteHeaders")
5018 wantOnce("Wait100Continue")
5019 wantOnce("Got100Continue")
5020 wantOnce("WroteRequest: {Err:<nil>}")
5021 if strings.Contains(got, " to udp ") {
5022 t.Errorf("should not see UDP (DNS) connections")
5023 }
5024 if t.Failed() {
5025 t.Errorf("Output:\n%s", got)
5026 }
5027
5028
5029 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
5030 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5031 res, err = cst.c.Do(req)
5032 if err != nil {
5033 t.Fatal(err)
5034 }
5035 if res.StatusCode != 200 {
5036 t.Fatal(res.Status)
5037 }
5038 res.Body.Close()
5039
5040 mu.Lock()
5041 got = buf.String()
5042 mu.Unlock()
5043
5044 sub := "Getting conn for dns-is-faked.golang:"
5045 if gotn, want := strings.Count(got, sub), 2; gotn != want {
5046 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
5047 }
5048
5049 }
5050
5051 func TestTransportEventTraceTLSVerify(t *testing.T) {
5052 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
5053 }
5054 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
5055 var mu sync.Mutex
5056 var buf strings.Builder
5057 logf := func(format string, args ...any) {
5058 mu.Lock()
5059 defer mu.Unlock()
5060 fmt.Fprintf(&buf, format, args...)
5061 buf.WriteByte('\n')
5062 }
5063
5064 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5065 t.Error("Unexpected request")
5066 }), func(ts *httptest.Server) {
5067 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
5068 logf("%s", p)
5069 return len(p), nil
5070 }), "", 0)
5071 }).ts
5072
5073 certpool := x509.NewCertPool()
5074 certpool.AddCert(ts.Certificate())
5075
5076 c := &Client{Transport: &Transport{
5077 TLSClientConfig: &tls.Config{
5078 ServerName: "dns-is-faked.golang",
5079 RootCAs: certpool,
5080 },
5081 }}
5082
5083 trace := &httptrace.ClientTrace{
5084 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
5085 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5086 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
5087 },
5088 }
5089
5090 req, _ := NewRequest("GET", ts.URL, nil)
5091 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5092 _, err := c.Do(req)
5093 if err == nil {
5094 t.Error("Expected request to fail TLS verification")
5095 }
5096
5097 mu.Lock()
5098 got := buf.String()
5099 mu.Unlock()
5100
5101 wantOnce := func(sub string) {
5102 if strings.Count(got, sub) != 1 {
5103 t.Errorf("expected substring %q exactly once in output.", sub)
5104 }
5105 }
5106
5107 wantOnce("TLSHandshakeStart")
5108 wantOnce("TLSHandshakeDone")
5109 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5110
5111 if t.Failed() {
5112 t.Errorf("Output:\n%s", got)
5113 }
5114 }
5115
5116 var (
5117 isDNSHijackedOnce sync.Once
5118 isDNSHijacked bool
5119 )
5120
5121 func skipIfDNSHijacked(t *testing.T) {
5122
5123
5124
5125 isDNSHijackedOnce.Do(func() {
5126 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5127 isDNSHijacked = len(addrs) != 0
5128 })
5129 if isDNSHijacked {
5130 t.Skip("skipping; test requires non-hijacking DNS server")
5131 }
5132 }
5133
5134 func TestTransportEventTraceRealDNS(t *testing.T) {
5135 skipIfDNSHijacked(t)
5136 defer afterTest(t)
5137 tr := &Transport{}
5138 defer tr.CloseIdleConnections()
5139 c := &Client{Transport: tr}
5140
5141 var mu sync.Mutex
5142 var buf strings.Builder
5143 logf := func(format string, args ...any) {
5144 mu.Lock()
5145 defer mu.Unlock()
5146 fmt.Fprintf(&buf, format, args...)
5147 buf.WriteByte('\n')
5148 }
5149
5150 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5151 trace := &httptrace.ClientTrace{
5152 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5153 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5154 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5155 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5156 }
5157 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5158
5159 resp, err := c.Do(req)
5160 if err == nil {
5161 resp.Body.Close()
5162 t.Fatal("expected error during DNS lookup")
5163 }
5164
5165 mu.Lock()
5166 got := buf.String()
5167 mu.Unlock()
5168
5169 wantSub := func(sub string) {
5170 if !strings.Contains(got, sub) {
5171 t.Errorf("expected substring %q in output.", sub)
5172 }
5173 }
5174 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5175 wantSub("DNSDone: {Addrs:[] Err:")
5176 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5177 t.Errorf("should not see Connect events")
5178 }
5179 if t.Failed() {
5180 t.Errorf("Output:\n%s", got)
5181 }
5182 }
5183
5184
5185 func TestTransportRejectsAlphaPort(t *testing.T) {
5186 res, err := Get("http://dummy.tld:123foo/bar")
5187 if err == nil {
5188 res.Body.Close()
5189 t.Fatal("unexpected success")
5190 }
5191 ue, ok := err.(*url.Error)
5192 if !ok {
5193 t.Fatalf("got %#v; want *url.Error", err)
5194 }
5195 got := ue.Err.Error()
5196 want := `invalid port ":123foo" after host`
5197 if got != want {
5198 t.Errorf("got error %q; want %q", got, want)
5199 }
5200 }
5201
5202
5203
5204 func TestTLSHandshakeTrace(t *testing.T) {
5205 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5206 }
5207 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5208 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5209
5210 var mu sync.Mutex
5211 var start, done bool
5212 trace := &httptrace.ClientTrace{
5213 TLSHandshakeStart: func() {
5214 mu.Lock()
5215 defer mu.Unlock()
5216 start = true
5217 },
5218 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5219 mu.Lock()
5220 defer mu.Unlock()
5221 done = true
5222 if err != nil {
5223 t.Fatal("Expected error to be nil but was:", err)
5224 }
5225 },
5226 }
5227
5228 c := ts.Client()
5229 req, err := NewRequest("GET", ts.URL, nil)
5230 if err != nil {
5231 t.Fatal("Unable to construct test request:", err)
5232 }
5233 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5234
5235 r, err := c.Do(req)
5236 if err != nil {
5237 t.Fatal("Unexpected error making request:", err)
5238 }
5239 r.Body.Close()
5240 mu.Lock()
5241 defer mu.Unlock()
5242 if !start {
5243 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5244 }
5245 if !done {
5246 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5247 }
5248 }
5249
5250 func TestTransportMaxIdleConns(t *testing.T) {
5251 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5252 }
5253 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5254 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5255
5256 })).ts
5257 c := ts.Client()
5258 tr := c.Transport.(*Transport)
5259 tr.MaxIdleConns = 4
5260
5261 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5262 if err != nil {
5263 t.Fatal(err)
5264 }
5265 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5266 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5267 })
5268
5269 hitHost := func(n int) {
5270 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5271 req = req.WithContext(ctx)
5272 res, err := c.Do(req)
5273 if err != nil {
5274 t.Fatal(err)
5275 }
5276 res.Body.Close()
5277 }
5278 for i := 0; i < 4; i++ {
5279 hitHost(i)
5280 }
5281 want := []string{
5282 "|http|host-0.dns-is-faked.golang:" + port,
5283 "|http|host-1.dns-is-faked.golang:" + port,
5284 "|http|host-2.dns-is-faked.golang:" + port,
5285 "|http|host-3.dns-is-faked.golang:" + port,
5286 }
5287 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5288 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5289 }
5290
5291
5292 hitHost(4)
5293 want = []string{
5294 "|http|host-1.dns-is-faked.golang:" + port,
5295 "|http|host-2.dns-is-faked.golang:" + port,
5296 "|http|host-3.dns-is-faked.golang:" + port,
5297 "|http|host-4.dns-is-faked.golang:" + port,
5298 }
5299 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5300 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5301 }
5302 }
5303
5304 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5305 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5306 if testing.Short() {
5307 t.Skip("skipping in short mode")
5308 }
5309
5310 timeout := 1 * time.Millisecond
5311 timeoutLoop:
5312 for {
5313 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5314
5315 }))
5316 tr := cst.tr
5317 tr.IdleConnTimeout = timeout
5318 defer tr.CloseIdleConnections()
5319 c := &Client{Transport: tr}
5320
5321 idleConns := func() []string {
5322 if mode == http2Mode {
5323 return tr.IdleConnStrsForTesting_h2()
5324 } else {
5325 return tr.IdleConnStrsForTesting()
5326 }
5327 }
5328
5329 var conn string
5330 doReq := func(n int) (timeoutOk bool) {
5331 req, _ := NewRequest("GET", cst.ts.URL, nil)
5332 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5333 PutIdleConn: func(err error) {
5334 if err != nil {
5335 t.Errorf("failed to keep idle conn: %v", err)
5336 }
5337 },
5338 }))
5339 res, err := c.Do(req)
5340 if err != nil {
5341 if strings.Contains(err.Error(), "use of closed network connection") {
5342 t.Logf("req %v: connection closed prematurely", n)
5343 return false
5344 }
5345 }
5346 res.Body.Close()
5347 conns := idleConns()
5348 if len(conns) != 1 {
5349 if len(conns) == 0 {
5350 t.Logf("req %v: no idle conns", n)
5351 return false
5352 }
5353 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5354 }
5355 if conn == "" {
5356 conn = conns[0]
5357 }
5358 if conn != conns[0] {
5359 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5360 return false
5361 }
5362 return true
5363 }
5364 for i := 0; i < 3; i++ {
5365 if !doReq(i) {
5366 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5367 timeout *= 2
5368 cst.close()
5369 continue timeoutLoop
5370 }
5371 time.Sleep(timeout / 2)
5372 }
5373
5374 waitCondition(t, timeout/2, func(d time.Duration) bool {
5375 if got := idleConns(); len(got) != 0 {
5376 if d >= timeout*3/2 {
5377 t.Logf("after %v, idle conns = %q", d, got)
5378 }
5379 return false
5380 }
5381 return true
5382 })
5383 break
5384 }
5385 }
5386
5387
5388
5389
5390
5391
5392
5393
5394
5395
5396
5397
5398 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5399 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5400 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5401
5402 }))
5403
5404 ctx, cancel := context.WithCancel(context.Background())
5405 defer cancel()
5406
5407 sawDoErr := make(chan bool, 1)
5408 testDone := make(chan struct{})
5409 defer close(testDone)
5410
5411 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5412 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5413 c, err := tls.Dial(network, addr, &tls.Config{
5414 InsecureSkipVerify: true,
5415 NextProtos: []string{"h2"},
5416 })
5417 if err != nil {
5418 t.Error(err)
5419 return nil, err
5420 }
5421 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5422 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5423 c.Close()
5424 return nil, errors.New("bogus")
5425 }
5426
5427 cancel()
5428
5429 select {
5430 case <-sawDoErr:
5431 case <-testDone:
5432 }
5433 return c, nil
5434 }
5435
5436 req, _ := NewRequest("GET", cst.ts.URL, nil)
5437 req = req.WithContext(ctx)
5438 res, err := cst.c.Do(req)
5439 if err == nil {
5440 res.Body.Close()
5441 t.Fatal("unexpected success")
5442 }
5443 sawDoErr <- true
5444
5445
5446 time.Sleep(cst.tr.IdleConnTimeout * 10)
5447 }
5448
5449 type funcConn struct {
5450 net.Conn
5451 read func([]byte) (int, error)
5452 write func([]byte) (int, error)
5453 }
5454
5455 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5456 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5457 func (c funcConn) Close() error { return nil }
5458
5459
5460
5461 func TestTransportReturnsPeekError(t *testing.T) {
5462 errValue := errors.New("specific error value")
5463
5464 wrote := make(chan struct{})
5465 var wroteOnce sync.Once
5466
5467 tr := &Transport{
5468 Dial: func(network, addr string) (net.Conn, error) {
5469 c := funcConn{
5470 read: func([]byte) (int, error) {
5471 <-wrote
5472 return 0, errValue
5473 },
5474 write: func(p []byte) (int, error) {
5475 wroteOnce.Do(func() { close(wrote) })
5476 return len(p), nil
5477 },
5478 }
5479 return c, nil
5480 },
5481 }
5482 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5483 if err != errValue {
5484 t.Errorf("error = %#v; want %v", err, errValue)
5485 }
5486 }
5487
5488
5489 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5490 func testTransportIDNA(t *testing.T, mode testMode) {
5491 const uniDomain = "гофер.го"
5492 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5493
5494 var port string
5495 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5496 want := punyDomain + ":" + port
5497 if r.Host != want {
5498 t.Errorf("Host header = %q; want %q", r.Host, want)
5499 }
5500 if mode == http2Mode {
5501 if r.TLS == nil {
5502 t.Errorf("r.TLS == nil")
5503 } else if r.TLS.ServerName != punyDomain {
5504 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5505 }
5506 }
5507 w.Header().Set("Hit-Handler", "1")
5508 }), func(tr *Transport) {
5509 if tr.TLSClientConfig != nil {
5510 tr.TLSClientConfig.InsecureSkipVerify = true
5511 }
5512 })
5513
5514 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5515 if err != nil {
5516 t.Fatal(err)
5517 }
5518
5519
5520 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5521 if host != punyDomain {
5522 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5523 return nil, nil
5524 }
5525 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5526 })
5527
5528 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5529 trace := &httptrace.ClientTrace{
5530 GetConn: func(hostPort string) {
5531 want := net.JoinHostPort(punyDomain, port)
5532 if hostPort != want {
5533 t.Errorf("getting conn for %q; want %q", hostPort, want)
5534 }
5535 },
5536 DNSStart: func(e httptrace.DNSStartInfo) {
5537 if e.Host != punyDomain {
5538 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5539 }
5540 },
5541 }
5542 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5543
5544 res, err := cst.tr.RoundTrip(req)
5545 if err != nil {
5546 t.Fatal(err)
5547 }
5548 defer res.Body.Close()
5549 if res.Header.Get("Hit-Handler") != "1" {
5550 out, err := httputil.DumpResponse(res, true)
5551 if err != nil {
5552 t.Fatal(err)
5553 }
5554 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5555 }
5556 }
5557
5558
5559 func TestTransportProxyConnectHeader(t *testing.T) {
5560 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5561 }
5562 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5563 reqc := make(chan *Request, 1)
5564 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5565 if r.Method != "CONNECT" {
5566 t.Errorf("method = %q; want CONNECT", r.Method)
5567 }
5568 reqc <- r
5569 c, _, err := w.(Hijacker).Hijack()
5570 if err != nil {
5571 t.Errorf("Hijack: %v", err)
5572 return
5573 }
5574 c.Close()
5575 })).ts
5576
5577 c := ts.Client()
5578 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5579 return url.Parse(ts.URL)
5580 }
5581 c.Transport.(*Transport).ProxyConnectHeader = Header{
5582 "User-Agent": {"foo"},
5583 "Other": {"bar"},
5584 }
5585
5586 res, err := c.Get("https://dummy.tld/")
5587 if err == nil {
5588 res.Body.Close()
5589 t.Errorf("unexpected success")
5590 }
5591
5592 r := <-reqc
5593 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5594 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5595 }
5596 if got, want := r.Header.Get("Other"), "bar"; got != want {
5597 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5598 }
5599 }
5600
5601 func TestTransportProxyGetConnectHeader(t *testing.T) {
5602 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5603 }
5604 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5605 reqc := make(chan *Request, 1)
5606 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5607 if r.Method != "CONNECT" {
5608 t.Errorf("method = %q; want CONNECT", r.Method)
5609 }
5610 reqc <- r
5611 c, _, err := w.(Hijacker).Hijack()
5612 if err != nil {
5613 t.Errorf("Hijack: %v", err)
5614 return
5615 }
5616 c.Close()
5617 })).ts
5618
5619 c := ts.Client()
5620 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5621 return url.Parse(ts.URL)
5622 }
5623
5624 c.Transport.(*Transport).ProxyConnectHeader = Header{
5625 "User-Agent": {"foo"},
5626 "Other": {"bar"},
5627 }
5628 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5629 return Header{
5630 "User-Agent": {"foo2"},
5631 "Other": {"bar2"},
5632 }, nil
5633 }
5634
5635 res, err := c.Get("https://dummy.tld/")
5636 if err == nil {
5637 res.Body.Close()
5638 t.Errorf("unexpected success")
5639 }
5640
5641 r := <-reqc
5642 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5643 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5644 }
5645 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5646 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5647 }
5648 }
5649
5650 var errFakeRoundTrip = errors.New("fake roundtrip")
5651
5652 type funcRoundTripper func()
5653
5654 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5655 fn()
5656 return nil, errFakeRoundTrip
5657 }
5658
5659 func wantBody(res *Response, err error, want string) error {
5660 if err != nil {
5661 return err
5662 }
5663 slurp, err := io.ReadAll(res.Body)
5664 if err != nil {
5665 return fmt.Errorf("error reading body: %v", err)
5666 }
5667 if string(slurp) != want {
5668 return fmt.Errorf("body = %q; want %q", slurp, want)
5669 }
5670 if err := res.Body.Close(); err != nil {
5671 return fmt.Errorf("body Close = %v", err)
5672 }
5673 return nil
5674 }
5675
5676 func newLocalListener(t *testing.T) net.Listener {
5677 ln, err := net.Listen("tcp", "127.0.0.1:0")
5678 if err != nil {
5679 ln, err = net.Listen("tcp6", "[::1]:0")
5680 }
5681 if err != nil {
5682 t.Fatal(err)
5683 }
5684 return ln
5685 }
5686
5687 type countCloseReader struct {
5688 n *int
5689 io.Reader
5690 }
5691
5692 func (cr countCloseReader) Close() error {
5693 (*cr.n)++
5694 return nil
5695 }
5696
5697
5698 var rgz = []byte{
5699 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
5700 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
5701 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
5702 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
5703 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
5704 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
5705 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
5706 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
5707 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
5708 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
5709 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
5710 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
5711 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
5712 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
5713 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5714 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
5715 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
5716 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
5717 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
5718 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5719 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
5720 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
5721 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
5722 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
5723 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
5724 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
5725 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
5726 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
5727 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
5728 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5729 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5730 0x00, 0x00,
5731 }
5732
5733
5734
5735 func TestMissingStatusNoPanic(t *testing.T) {
5736 t.Parallel()
5737
5738 const want = "unknown status code"
5739
5740 ln := newLocalListener(t)
5741 addr := ln.Addr().String()
5742 done := make(chan bool)
5743 fullAddrURL := fmt.Sprintf("http://%s", addr)
5744 raw := "HTTP/1.1 400\r\n" +
5745 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
5746 "Content-Type: text/html; charset=utf-8\r\n" +
5747 "Content-Length: 10\r\n" +
5748 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
5749 "Vary: Accept-Encoding\r\n\r\n" +
5750 "Aloha Olaa"
5751
5752 go func() {
5753 defer close(done)
5754
5755 conn, _ := ln.Accept()
5756 if conn != nil {
5757 io.WriteString(conn, raw)
5758 io.ReadAll(conn)
5759 conn.Close()
5760 }
5761 }()
5762
5763 proxyURL, err := url.Parse(fullAddrURL)
5764 if err != nil {
5765 t.Fatalf("proxyURL: %v", err)
5766 }
5767
5768 tr := &Transport{Proxy: ProxyURL(proxyURL)}
5769
5770 req, _ := NewRequest("GET", "https://golang.org/", nil)
5771 res, err, panicked := doFetchCheckPanic(tr, req)
5772 if panicked {
5773 t.Error("panicked, expecting an error")
5774 }
5775 if res != nil && res.Body != nil {
5776 io.Copy(io.Discard, res.Body)
5777 res.Body.Close()
5778 }
5779
5780 if err == nil || !strings.Contains(err.Error(), want) {
5781 t.Errorf("got=%v want=%q", err, want)
5782 }
5783
5784 ln.Close()
5785 <-done
5786 }
5787
5788 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
5789 defer func() {
5790 if r := recover(); r != nil {
5791 panicked = true
5792 }
5793 }()
5794 res, err = tr.RoundTrip(req)
5795 return
5796 }
5797
5798
5799
5800 func TestNoBodyOnChunked304Response(t *testing.T) {
5801 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
5802 }
5803 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
5804 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5805 conn, buf, _ := w.(Hijacker).Hijack()
5806 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
5807 buf.Flush()
5808 conn.Close()
5809 }))
5810
5811
5812
5813
5814
5815 cst.tr.DisableKeepAlives = true
5816
5817 res, err := cst.c.Get(cst.ts.URL)
5818 if err != nil {
5819 t.Fatal(err)
5820 }
5821
5822 if res.Body != NoBody {
5823 t.Errorf("Unexpected body on 304 response")
5824 }
5825 }
5826
5827 type funcWriter func([]byte) (int, error)
5828
5829 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
5830
5831 type doneContext struct {
5832 context.Context
5833 err error
5834 }
5835
5836 func (doneContext) Done() <-chan struct{} {
5837 c := make(chan struct{})
5838 close(c)
5839 return c
5840 }
5841
5842 func (d doneContext) Err() error { return d.err }
5843
5844
5845 func TestTransportCheckContextDoneEarly(t *testing.T) {
5846 tr := &Transport{}
5847 req, _ := NewRequest("GET", "http://fake.example/", nil)
5848 wantErr := errors.New("some error")
5849 req = req.WithContext(doneContext{context.Background(), wantErr})
5850 _, err := tr.RoundTrip(req)
5851 if err != wantErr {
5852 t.Errorf("error = %v; want %v", err, wantErr)
5853 }
5854 }
5855
5856
5857
5858
5859
5860
5861 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
5862 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
5863 }
5864 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
5865 timeout := 1 * time.Millisecond
5866 for {
5867 inHandler := make(chan bool)
5868 cancelHandler := make(chan struct{})
5869 handlerDone := make(chan bool)
5870 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5871 <-r.Context().Done()
5872
5873 select {
5874 case <-cancelHandler:
5875 return
5876 case inHandler <- true:
5877 }
5878 defer func() { handlerDone <- true }()
5879
5880
5881 conn, _, err := w.(Hijacker).Hijack()
5882 if err != nil {
5883 t.Error(err)
5884 return
5885 }
5886 n, err := conn.Read([]byte{0})
5887 if n != 0 || err != io.EOF {
5888 t.Errorf("unexpected Read result: %v, %v", n, err)
5889 }
5890 conn.Close()
5891 }))
5892
5893 cst.c.Timeout = timeout
5894
5895 _, err := cst.c.Get(cst.ts.URL)
5896 if err == nil {
5897 close(cancelHandler)
5898 t.Fatal("unexpected Get success")
5899 }
5900
5901 tooSlow := time.NewTimer(timeout * 10)
5902 select {
5903 case <-tooSlow.C:
5904
5905
5906
5907 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
5908 close(cancelHandler)
5909 cst.close()
5910 timeout *= 2
5911 continue
5912 case <-inHandler:
5913 tooSlow.Stop()
5914 <-handlerDone
5915 }
5916 break
5917 }
5918 }
5919
5920
5921
5922
5923
5924
5925 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
5926 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
5927 }
5928 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
5929 inHandler := make(chan bool)
5930 cancelHandler := make(chan struct{})
5931 handlerDone := make(chan bool)
5932 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5933 w.Header().Set("Content-Length", "100")
5934 w.(Flusher).Flush()
5935
5936 select {
5937 case <-cancelHandler:
5938 return
5939 case inHandler <- true:
5940 }
5941 defer func() { handlerDone <- true }()
5942
5943 conn, _, err := w.(Hijacker).Hijack()
5944 if err != nil {
5945 t.Error(err)
5946 return
5947 }
5948 conn.Write([]byte("foo"))
5949
5950 n, err := conn.Read([]byte{0})
5951
5952
5953
5954
5955
5956 if n != 0 || err == nil {
5957 t.Errorf("unexpected Read result: %v, %v", n, err)
5958 }
5959 conn.Close()
5960 }))
5961
5962
5963
5964
5965
5966 cst.c.Timeout = 24 * time.Hour
5967 req, _ := NewRequest("GET", cst.ts.URL, nil)
5968 cancelReq := make(chan struct{})
5969 req.Cancel = cancelReq
5970
5971 res, err := cst.c.Do(req)
5972 if err != nil {
5973 close(cancelHandler)
5974 t.Fatalf("Get error: %v", err)
5975 }
5976
5977
5978
5979
5980 close(cancelReq)
5981 got, err := io.ReadAll(res.Body)
5982 if err == nil {
5983 t.Errorf("unexpected success; read %q, nil", got)
5984 }
5985
5986
5987 <-inHandler
5988 <-handlerDone
5989 }
5990
5991 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
5992 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
5993 }
5994 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
5995 done := make(chan struct{})
5996 defer close(done)
5997 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5998 conn, _, err := w.(Hijacker).Hijack()
5999 if err != nil {
6000 t.Error(err)
6001 return
6002 }
6003 defer conn.Close()
6004 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
6005 bs := bufio.NewScanner(conn)
6006 bs.Scan()
6007 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
6008 <-done
6009 }))
6010
6011 req, _ := NewRequest("GET", cst.ts.URL, nil)
6012 req.Header.Set("Upgrade", "foo")
6013 req.Header.Set("Connection", "upgrade")
6014 res, err := cst.c.Do(req)
6015 if err != nil {
6016 t.Fatal(err)
6017 }
6018 if res.StatusCode != 101 {
6019 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
6020 }
6021 rwc, ok := res.Body.(io.ReadWriteCloser)
6022 if !ok {
6023 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
6024 }
6025 defer rwc.Close()
6026 bs := bufio.NewScanner(rwc)
6027 if !bs.Scan() {
6028 t.Fatalf("expected readable input")
6029 }
6030 if got, want := bs.Text(), "Some buffered data"; got != want {
6031 t.Errorf("read %q; want %q", got, want)
6032 }
6033 io.WriteString(rwc, "echo\n")
6034 if !bs.Scan() {
6035 t.Fatalf("expected another line")
6036 }
6037 if got, want := bs.Text(), "ECHO"; got != want {
6038 t.Errorf("read %q; want %q", got, want)
6039 }
6040 }
6041
6042 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
6043 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
6044 const target = "backend:443"
6045 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6046 if r.Method != "CONNECT" {
6047 t.Errorf("unexpected method %q", r.Method)
6048 w.WriteHeader(500)
6049 return
6050 }
6051 if r.RequestURI != target {
6052 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
6053 w.WriteHeader(500)
6054 return
6055 }
6056 nc, brw, err := w.(Hijacker).Hijack()
6057 if err != nil {
6058 t.Error(err)
6059 return
6060 }
6061 defer nc.Close()
6062 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6063
6064 for {
6065 line, err := brw.ReadString('\n')
6066 if err != nil {
6067 if err != io.EOF {
6068 t.Error(err)
6069 }
6070 return
6071 }
6072 io.WriteString(brw, strings.ToUpper(line))
6073 brw.Flush()
6074 }
6075 }))
6076 pr, pw := io.Pipe()
6077 defer pw.Close()
6078 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
6079 if err != nil {
6080 t.Fatal(err)
6081 }
6082 req.URL.Opaque = target
6083 res, err := cst.c.Do(req)
6084 if err != nil {
6085 t.Fatal(err)
6086 }
6087 defer res.Body.Close()
6088 if res.StatusCode != 200 {
6089 t.Fatalf("status code = %d; want 200", res.StatusCode)
6090 }
6091 br := bufio.NewReader(res.Body)
6092 for _, str := range []string{"foo", "bar", "baz"} {
6093 fmt.Fprintf(pw, "%s\n", str)
6094 got, err := br.ReadString('\n')
6095 if err != nil {
6096 t.Fatal(err)
6097 }
6098 got = strings.TrimSpace(got)
6099 want := strings.ToUpper(str)
6100 if got != want {
6101 t.Fatalf("got %q; want %q", got, want)
6102 }
6103 }
6104 }
6105
6106 func TestTransportRequestReplayable(t *testing.T) {
6107 someBody := io.NopCloser(strings.NewReader(""))
6108 tests := []struct {
6109 name string
6110 req *Request
6111 want bool
6112 }{
6113 {
6114 name: "GET",
6115 req: &Request{Method: "GET"},
6116 want: true,
6117 },
6118 {
6119 name: "GET_http.NoBody",
6120 req: &Request{Method: "GET", Body: NoBody},
6121 want: true,
6122 },
6123 {
6124 name: "GET_body",
6125 req: &Request{Method: "GET", Body: someBody},
6126 want: false,
6127 },
6128 {
6129 name: "POST",
6130 req: &Request{Method: "POST"},
6131 want: false,
6132 },
6133 {
6134 name: "POST_idempotency-key",
6135 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6136 want: true,
6137 },
6138 {
6139 name: "POST_x-idempotency-key",
6140 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6141 want: true,
6142 },
6143 {
6144 name: "POST_body",
6145 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6146 want: false,
6147 },
6148 }
6149 for _, tt := range tests {
6150 t.Run(tt.name, func(t *testing.T) {
6151 got := tt.req.ExportIsReplayable()
6152 if got != tt.want {
6153 t.Errorf("replyable = %v; want %v", got, tt.want)
6154 }
6155 })
6156 }
6157 }
6158
6159
6160
6161 type testMockTCPConn struct {
6162 *net.TCPConn
6163
6164 ReadFromCalled bool
6165 }
6166
6167 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6168 c.ReadFromCalled = true
6169 return c.TCPConn.ReadFrom(r)
6170 }
6171
6172 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6173 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6174 nBytes := int64(1 << 10)
6175 newFileFunc := func() (r io.Reader, done func(), err error) {
6176 f, err := os.CreateTemp("", "net-http-newfilefunc")
6177 if err != nil {
6178 return nil, nil, err
6179 }
6180
6181
6182 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6183 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6184 }
6185 if _, err := f.Seek(0, 0); err != nil {
6186 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6187 }
6188
6189 done = func() {
6190 f.Close()
6191 os.Remove(f.Name())
6192 }
6193
6194 return f, done, nil
6195 }
6196
6197 newBufferFunc := func() (io.Reader, func(), error) {
6198 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6199 }
6200
6201 cases := []struct {
6202 name string
6203 readerFunc func() (io.Reader, func(), error)
6204 contentLength int64
6205 expectedReadFrom bool
6206 }{
6207 {
6208 name: "file, length",
6209 readerFunc: newFileFunc,
6210 contentLength: nBytes,
6211 expectedReadFrom: true,
6212 },
6213 {
6214 name: "file, no length",
6215 readerFunc: newFileFunc,
6216 },
6217 {
6218 name: "file, negative length",
6219 readerFunc: newFileFunc,
6220 contentLength: -1,
6221 },
6222 {
6223 name: "buffer",
6224 contentLength: nBytes,
6225 readerFunc: newBufferFunc,
6226 },
6227 {
6228 name: "buffer, no length",
6229 readerFunc: newBufferFunc,
6230 },
6231 {
6232 name: "buffer, length -1",
6233 contentLength: -1,
6234 readerFunc: newBufferFunc,
6235 },
6236 }
6237
6238 for _, tc := range cases {
6239 t.Run(tc.name, func(t *testing.T) {
6240 r, cleanup, err := tc.readerFunc()
6241 if err != nil {
6242 t.Fatal(err)
6243 }
6244 defer cleanup()
6245
6246 tConn := &testMockTCPConn{}
6247 trFunc := func(tr *Transport) {
6248 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6249 var d net.Dialer
6250 conn, err := d.DialContext(ctx, network, addr)
6251 if err != nil {
6252 return nil, err
6253 }
6254
6255 tcpConn, ok := conn.(*net.TCPConn)
6256 if !ok {
6257 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6258 }
6259
6260 tConn.TCPConn = tcpConn
6261 return tConn, nil
6262 }
6263 }
6264
6265 cst := newClientServerTest(
6266 t,
6267 mode,
6268 HandlerFunc(func(w ResponseWriter, r *Request) {
6269 io.Copy(io.Discard, r.Body)
6270 r.Body.Close()
6271 w.WriteHeader(200)
6272 }),
6273 trFunc,
6274 )
6275
6276 req, err := NewRequest("PUT", cst.ts.URL, r)
6277 if err != nil {
6278 t.Fatal(err)
6279 }
6280 req.ContentLength = tc.contentLength
6281 req.Header.Set("Content-Type", "application/octet-stream")
6282 resp, err := cst.c.Do(req)
6283 if err != nil {
6284 t.Fatal(err)
6285 }
6286 defer resp.Body.Close()
6287 if resp.StatusCode != 200 {
6288 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6289 }
6290
6291 expectedReadFrom := tc.expectedReadFrom
6292 if mode != http1Mode {
6293 expectedReadFrom = false
6294 }
6295 if !tConn.ReadFromCalled && expectedReadFrom {
6296 t.Fatalf("did not call ReadFrom")
6297 }
6298
6299 if tConn.ReadFromCalled && !expectedReadFrom {
6300 t.Fatalf("ReadFrom was unexpectedly invoked")
6301 }
6302 })
6303 }
6304 }
6305
6306 func TestTransportClone(t *testing.T) {
6307 tr := &Transport{
6308 Proxy: func(*Request) (*url.URL, error) { panic("") },
6309 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6310 return nil
6311 },
6312 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6313 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6314 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6315 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6316 TLSClientConfig: new(tls.Config),
6317 TLSHandshakeTimeout: time.Second,
6318 DisableKeepAlives: true,
6319 DisableCompression: true,
6320 MaxIdleConns: 1,
6321 MaxIdleConnsPerHost: 1,
6322 MaxConnsPerHost: 1,
6323 IdleConnTimeout: time.Second,
6324 ResponseHeaderTimeout: time.Second,
6325 ExpectContinueTimeout: time.Second,
6326 ProxyConnectHeader: Header{},
6327 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6328 MaxResponseHeaderBytes: 1,
6329 ForceAttemptHTTP2: true,
6330 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6331 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6332 },
6333 ReadBufferSize: 1,
6334 WriteBufferSize: 1,
6335 }
6336 tr2 := tr.Clone()
6337 rv := reflect.ValueOf(tr2).Elem()
6338 rt := rv.Type()
6339 for i := 0; i < rt.NumField(); i++ {
6340 sf := rt.Field(i)
6341 if !token.IsExported(sf.Name) {
6342 continue
6343 }
6344 if rv.Field(i).IsZero() {
6345 t.Errorf("cloned field t2.%s is zero", sf.Name)
6346 }
6347 }
6348
6349 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6350 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6351 }
6352
6353
6354 tr = new(Transport)
6355 tr2 = tr.Clone()
6356 if tr2.TLSNextProto != nil {
6357 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6358 }
6359 }
6360
6361 func TestIs408(t *testing.T) {
6362 tests := []struct {
6363 in string
6364 want bool
6365 }{
6366 {"HTTP/1.0 408", true},
6367 {"HTTP/1.1 408", true},
6368 {"HTTP/1.8 408", true},
6369 {"HTTP/2.0 408", false},
6370 {"HTTP/1.1 408 ", true},
6371 {"HTTP/1.1 40", false},
6372 {"http/1.0 408", false},
6373 {"HTTP/1-1 408", false},
6374 }
6375 for _, tt := range tests {
6376 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6377 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6378 }
6379 }
6380 }
6381
6382 func TestTransportIgnores408(t *testing.T) {
6383 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6384 }
6385 func testTransportIgnores408(t *testing.T, mode testMode) {
6386
6387 defer log.SetOutput(log.Writer())
6388
6389 var logout strings.Builder
6390 log.SetOutput(&logout)
6391
6392 const target = "backend:443"
6393
6394 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6395 nc, _, err := w.(Hijacker).Hijack()
6396 if err != nil {
6397 t.Error(err)
6398 return
6399 }
6400 defer nc.Close()
6401 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6402 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6403 }))
6404 req, err := NewRequest("GET", cst.ts.URL, nil)
6405 if err != nil {
6406 t.Fatal(err)
6407 }
6408 res, err := cst.c.Do(req)
6409 if err != nil {
6410 t.Fatal(err)
6411 }
6412 slurp, err := io.ReadAll(res.Body)
6413 if err != nil {
6414 t.Fatal(err)
6415 }
6416 if err != nil {
6417 t.Fatal(err)
6418 }
6419 if string(slurp) != "ok" {
6420 t.Fatalf("got %q; want ok", slurp)
6421 }
6422
6423 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6424 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6425 if d > 0 {
6426 t.Logf("%v idle conns still present after %v", n, d)
6427 }
6428 return false
6429 }
6430 return true
6431 })
6432 if got := logout.String(); got != "" {
6433 t.Fatalf("expected no log output; got: %s", got)
6434 }
6435 }
6436
6437 func TestInvalidHeaderResponse(t *testing.T) {
6438 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6439 }
6440 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6441 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6442 conn, buf, _ := w.(Hijacker).Hijack()
6443 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6444 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6445 "Content-Type: text/html; charset=utf-8\r\n" +
6446 "Content-Length: 0\r\n" +
6447 "Foo : bar\r\n\r\n"))
6448 buf.Flush()
6449 conn.Close()
6450 }))
6451 res, err := cst.c.Get(cst.ts.URL)
6452 if err != nil {
6453 t.Fatal(err)
6454 }
6455 defer res.Body.Close()
6456 if v := res.Header.Get("Foo"); v != "" {
6457 t.Errorf(`unexpected "Foo" header: %q`, v)
6458 }
6459 if v := res.Header.Get("Foo "); v != "bar" {
6460 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6461 }
6462 }
6463
6464 type bodyCloser bool
6465
6466 func (bc *bodyCloser) Close() error {
6467 *bc = true
6468 return nil
6469 }
6470 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6471 return 0, io.EOF
6472 }
6473
6474
6475
6476 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6477 run(t, testTransportClosesBodyOnInvalidRequests)
6478 }
6479 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6480 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6481 t.Errorf("Should not have been invoked")
6482 })).ts
6483
6484 u, _ := url.Parse(cst.URL)
6485
6486 tests := []struct {
6487 name string
6488 req *Request
6489 wantErr string
6490 }{
6491 {
6492 name: "invalid method",
6493 req: &Request{
6494 Method: " ",
6495 URL: u,
6496 },
6497 wantErr: `invalid method " "`,
6498 },
6499 {
6500 name: "nil URL",
6501 req: &Request{
6502 Method: "GET",
6503 },
6504 wantErr: `nil Request.URL`,
6505 },
6506 {
6507 name: "invalid header key",
6508 req: &Request{
6509 Method: "GET",
6510 Header: Header{"💡": {"emoji"}},
6511 URL: u,
6512 },
6513 wantErr: `invalid header field name "💡"`,
6514 },
6515 {
6516 name: "invalid header value",
6517 req: &Request{
6518 Method: "POST",
6519 Header: Header{"key": {"\x19"}},
6520 URL: u,
6521 },
6522 wantErr: `invalid header field value for "key"`,
6523 },
6524 {
6525 name: "non HTTP(s) scheme",
6526 req: &Request{
6527 Method: "POST",
6528 URL: &url.URL{Scheme: "faux"},
6529 },
6530 wantErr: `unsupported protocol scheme "faux"`,
6531 },
6532 {
6533 name: "no Host in URL",
6534 req: &Request{
6535 Method: "POST",
6536 URL: &url.URL{Scheme: "http"},
6537 },
6538 wantErr: `no Host in request URL`,
6539 },
6540 }
6541
6542 for _, tt := range tests {
6543 t.Run(tt.name, func(t *testing.T) {
6544 var bc bodyCloser
6545 req := tt.req
6546 req.Body = &bc
6547 _, err := cst.Client().Do(tt.req)
6548 if err == nil {
6549 t.Fatal("Expected an error")
6550 }
6551 if !bc {
6552 t.Fatal("Expected body to have been closed")
6553 }
6554 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6555 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6556 }
6557 })
6558 }
6559 }
6560
6561
6562
6563 type breakableConn struct {
6564 net.Conn
6565 *brokenState
6566 }
6567
6568 type brokenState struct {
6569 sync.Mutex
6570 broken bool
6571 }
6572
6573 func (w *breakableConn) Write(b []byte) (n int, err error) {
6574 w.Lock()
6575 defer w.Unlock()
6576 if w.broken {
6577 return 0, errors.New("some write error")
6578 }
6579 return w.Conn.Write(b)
6580 }
6581
6582
6583 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6584 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6585 }
6586 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6587 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6588
6589 var brokenState brokenState
6590
6591 const numReqs = 5
6592 var numDials, gotConns uint32
6593
6594 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6595 atomic.AddUint32(&numDials, 1)
6596 c, err := net.Dial(netw, addr)
6597 if err != nil {
6598 t.Errorf("unexpected Dial error: %v", err)
6599 return nil, err
6600 }
6601 return &breakableConn{c, &brokenState}, err
6602 }
6603
6604 for i := 1; i <= numReqs; i++ {
6605 brokenState.Lock()
6606 brokenState.broken = false
6607 brokenState.Unlock()
6608
6609
6610
6611
6612 doBreak := i != numReqs
6613
6614 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6615 GotConn: func(info httptrace.GotConnInfo) {
6616 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6617 atomic.AddUint32(&gotConns, 1)
6618 },
6619 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6620 brokenState.Lock()
6621 defer brokenState.Unlock()
6622 if doBreak {
6623 brokenState.broken = true
6624 }
6625 },
6626 })
6627 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6628 if err != nil {
6629 t.Fatal(err)
6630 }
6631 _, err = cst.c.Do(req)
6632 if doBreak != (err != nil) {
6633 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6634 }
6635 }
6636 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6637 t.Errorf("GotConn calls = %v; want %v", got, want)
6638 }
6639 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6640 t.Errorf("Dials = %v; want %v", got, want)
6641 }
6642 }
6643
6644
6645
6646
6647
6648 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6649 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6650 }
6651 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6652 CondSkipHTTP2(t)
6653
6654 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6655 _, err := w.Write([]byte("foo"))
6656 if err != nil {
6657 t.Fatalf("Write: %v", err)
6658 }
6659 })
6660
6661 ts := newClientServerTest(t, mode, h).ts
6662
6663 c := ts.Client()
6664 tr := c.Transport.(*Transport)
6665 tr.MaxConnsPerHost = 1
6666
6667 errCh := make(chan error, 300)
6668 doReq := func() {
6669 resp, err := c.Get(ts.URL)
6670 if err != nil {
6671 errCh <- fmt.Errorf("request failed: %v", err)
6672 return
6673 }
6674 defer resp.Body.Close()
6675 _, err = io.ReadAll(resp.Body)
6676 if err != nil {
6677 errCh <- fmt.Errorf("read body failed: %v", err)
6678 }
6679 }
6680
6681 var wg sync.WaitGroup
6682 for i := 0; i < 300; i++ {
6683 wg.Add(1)
6684 go func() {
6685 defer wg.Done()
6686 doReq()
6687 }()
6688 }
6689 wg.Wait()
6690 close(errCh)
6691
6692 for err := range errCh {
6693 t.Errorf("error occurred: %v", err)
6694 }
6695 }
6696
6697
6698
6699
6700 func TestAltProtoCancellation(t *testing.T) {
6701 defer afterTest(t)
6702 tr := &Transport{}
6703 c := &Client{
6704 Transport: tr,
6705 Timeout: time.Millisecond,
6706 }
6707 tr.RegisterProtocol("cancel", cancelProto{})
6708 _, err := c.Get("cancel://bar.com/path")
6709 if err == nil {
6710 t.Error("request unexpectedly succeeded")
6711 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
6712 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
6713 }
6714 }
6715
6716 var errCancelProto = errors.New("canceled as expected")
6717
6718 type cancelProto struct{}
6719
6720 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
6721 <-req.Cancel
6722 return nil, errCancelProto
6723 }
6724
6725 type roundTripFunc func(r *Request) (*Response, error)
6726
6727 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
6728
6729
6730 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
6731 func testIssue32441(t *testing.T, mode testMode) {
6732 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6733 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6734 t.Error("body length is zero")
6735 }
6736 })).ts
6737 c := ts.Client()
6738 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
6739
6740 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6741 t.Error("body length is zero during round trip")
6742 }
6743 return nil, ErrSkipAltProtocol
6744 }))
6745 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
6746 t.Error(err)
6747 }
6748 }
6749
6750
6751
6752 func TestTransportRejectsSignInContentLength(t *testing.T) {
6753 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
6754 }
6755 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
6756 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6757 w.Header().Set("Content-Length", "+3")
6758 w.Write([]byte("abc"))
6759 })).ts
6760
6761 c := cst.Client()
6762 res, err := c.Get(cst.URL)
6763 if err == nil || res != nil {
6764 t.Fatal("Expected a non-nil error and a nil http.Response")
6765 }
6766 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
6767 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
6768 }
6769 }
6770
6771
6772 type dumpConn struct {
6773 io.Writer
6774 io.Reader
6775 }
6776
6777 func (c *dumpConn) Close() error { return nil }
6778 func (c *dumpConn) LocalAddr() net.Addr { return nil }
6779 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
6780 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
6781 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
6782 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
6783
6784
6785
6786 type delegateReader struct {
6787 c chan io.Reader
6788 r io.Reader
6789 }
6790
6791 func (r *delegateReader) Read(p []byte) (int, error) {
6792 if r.r == nil {
6793 var ok bool
6794 if r.r, ok = <-r.c; !ok {
6795 return 0, errors.New("delegate closed")
6796 }
6797 }
6798 return r.r.Read(p)
6799 }
6800
6801 func testTransportRace(req *Request) {
6802 save := req.Body
6803 pr, pw := io.Pipe()
6804 defer pr.Close()
6805 defer pw.Close()
6806 dr := &delegateReader{c: make(chan io.Reader)}
6807
6808 t := &Transport{
6809 Dial: func(net, addr string) (net.Conn, error) {
6810 return &dumpConn{pw, dr}, nil
6811 },
6812 }
6813 defer t.CloseIdleConnections()
6814
6815 quitReadCh := make(chan struct{})
6816
6817 go func() {
6818 defer close(quitReadCh)
6819
6820 req, err := ReadRequest(bufio.NewReader(pr))
6821 if err == nil {
6822
6823
6824 io.Copy(io.Discard, req.Body)
6825 req.Body.Close()
6826 }
6827 select {
6828 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
6829 case quitReadCh <- struct{}{}:
6830
6831 close(dr.c)
6832 }
6833 }()
6834
6835 t.RoundTrip(req)
6836
6837
6838
6839 pw.Close()
6840 <-quitReadCh
6841
6842 req.Body = save
6843 }
6844
6845
6846
6847
6848
6849 func TestErrorWriteLoopRace(t *testing.T) {
6850 if testing.Short() {
6851 return
6852 }
6853 t.Parallel()
6854 for i := 0; i < 1000; i++ {
6855 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
6856 ctx, cancel := context.WithTimeout(context.Background(), delay)
6857 defer cancel()
6858
6859 r := bytes.NewBuffer(make([]byte, 10000))
6860 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
6861 if err != nil {
6862 t.Fatal(err)
6863 }
6864
6865 testTransportRace(req)
6866 }
6867 }
6868
6869
6870
6871
6872 func TestCancelRequestWhenSharingConnection(t *testing.T) {
6873 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
6874 }
6875 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
6876 reqc := make(chan chan struct{}, 2)
6877 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
6878 ch := make(chan struct{}, 1)
6879 reqc <- ch
6880 <-ch
6881 w.Header().Add("Content-Length", "0")
6882 })).ts
6883
6884 client := ts.Client()
6885 transport := client.Transport.(*Transport)
6886 transport.MaxIdleConns = 1
6887 transport.MaxConnsPerHost = 1
6888
6889 var wg sync.WaitGroup
6890
6891 wg.Add(1)
6892 putidlec := make(chan chan struct{}, 1)
6893 reqerrc := make(chan error, 1)
6894 go func() {
6895 defer wg.Done()
6896 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6897 PutIdleConn: func(error) {
6898
6899
6900 ch := make(chan struct{})
6901 putidlec <- ch
6902 close(putidlec)
6903 <-ch
6904 },
6905 })
6906 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
6907 res, err := client.Do(req)
6908 if err != nil {
6909 reqerrc <- err
6910 } else {
6911 res.Body.Close()
6912 }
6913 }()
6914
6915
6916
6917 select {
6918 case err := <-reqerrc:
6919 t.Fatalf("request 1: got err %v, want nil", err)
6920 case r1c := <-reqc:
6921 close(r1c)
6922 }
6923 var idlec chan struct{}
6924 select {
6925 case err := <-reqerrc:
6926 t.Fatalf("request 1: got err %v, want nil", err)
6927 case idlec = <-putidlec:
6928 }
6929
6930 wg.Add(1)
6931 cancelctx, cancel := context.WithCancel(context.Background())
6932 go func() {
6933 defer wg.Done()
6934 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
6935 res, err := client.Do(req)
6936 if err == nil {
6937 res.Body.Close()
6938 }
6939 if !errors.Is(err, context.Canceled) {
6940 t.Errorf("request 2: got err %v, want Canceled", err)
6941 }
6942
6943
6944 close(idlec)
6945 }()
6946
6947
6948
6949 r2c := <-reqc
6950 cancel()
6951
6952 <-idlec
6953
6954 close(r2c)
6955 wg.Wait()
6956 }
6957
6958 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
6959 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
6960 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6961 go io.Copy(io.Discard, req.Body)
6962 panic(ErrAbortHandler)
6963 })).ts
6964
6965 var wg sync.WaitGroup
6966 for i := 0; i < 2; i++ {
6967 wg.Add(1)
6968 go func() {
6969 defer wg.Done()
6970 for j := 0; j < 10; j++ {
6971 const reqLen = 6 * 1024 * 1024
6972 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
6973 req.ContentLength = reqLen
6974 resp, _ := ts.Client().Transport.RoundTrip(req)
6975 if resp != nil {
6976 resp.Body.Close()
6977 }
6978 }
6979 }()
6980 }
6981 wg.Wait()
6982 }
6983
6984 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
6985 func testRequestSanitization(t *testing.T, mode testMode) {
6986 if mode == http2Mode {
6987
6988 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
6989 }
6990 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6991 if h, ok := req.Header["X-Evil"]; ok {
6992 t.Errorf("request has X-Evil header: %q", h)
6993 }
6994 })).ts
6995 req, _ := NewRequest("GET", ts.URL, nil)
6996 req.Host = "go.dev\r\nX-Evil:evil"
6997 resp, _ := ts.Client().Do(req)
6998 if resp != nil {
6999 resp.Body.Close()
7000 }
7001 }
7002
7003 func TestProxyAuthHeader(t *testing.T) {
7004
7005 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
7006 }
7007 func testProxyAuthHeader(t *testing.T, mode testMode) {
7008 const username = "u"
7009 const password = "@/?!"
7010 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7011
7012
7013 var r2 Request
7014 r2.Header = Header{
7015 "Authorization": req.Header["Proxy-Authorization"],
7016 }
7017 gotuser, gotpass, ok := r2.BasicAuth()
7018 if !ok || gotuser != username || gotpass != password {
7019 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
7020 }
7021 }))
7022 u, err := url.Parse(cst.ts.URL)
7023 if err != nil {
7024 t.Fatal(err)
7025 }
7026 u.User = url.UserPassword(username, password)
7027 t.Setenv("HTTP_PROXY", u.String())
7028 cst.tr.Proxy = ProxyURL(u)
7029 resp, err := cst.c.Get("http://_/")
7030 if err != nil {
7031 t.Fatal(err)
7032 }
7033 resp.Body.Close()
7034 }
7035
7036
7037 func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
7038 ln := newLocalListener(t)
7039 addr := ln.Addr().String()
7040
7041 done := make(chan struct{})
7042 go func() {
7043 conn, err := ln.Accept()
7044 if err != nil {
7045 t.Errorf("ln.Accept: %v", err)
7046 return
7047 }
7048
7049
7050 if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
7051 t.Errorf("conn.Read: %v", err)
7052 return
7053 }
7054 io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
7055 <-done
7056 conn.Close()
7057 }()
7058
7059 didRead := make(chan bool)
7060 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
7061 defer SetReadLoopBeforeNextReadHook(nil)
7062
7063 tr := &Transport{}
7064
7065
7066 req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
7067 if err != nil {
7068 t.Fatalf("NewRequest: %v", err)
7069 }
7070
7071 resp, err := tr.RoundTrip(req)
7072 if err != nil {
7073 t.Fatalf("tr.RoundTrip: %v", err)
7074 }
7075
7076 close(done)
7077
7078
7079
7080 <-didRead
7081
7082 resp.Body.Close()
7083
7084
7085
7086 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
7087 n := tr.NumPendingRequestsForTesting()
7088 if n > 0 {
7089 if d > 0 {
7090 t.Logf("pending requests = %d after %v (want 0)", n, d)
7091 }
7092 return false
7093 }
7094 return true
7095 })
7096 }
7097
7098 func TestValidateClientRequestTrailers(t *testing.T) {
7099 run(t, testValidateClientRequestTrailers)
7100 }
7101
7102 func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
7103 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7104 rw.Write([]byte("Hello"))
7105 })).ts
7106
7107 cases := []struct {
7108 trailer Header
7109 wantErr string
7110 }{
7111 {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
7112 {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
7113 }
7114
7115 for i, tt := range cases {
7116 testName := fmt.Sprintf("%s%d", mode, i)
7117 t.Run(testName, func(t *testing.T) {
7118 req, err := NewRequest("GET", cst.URL, nil)
7119 if err != nil {
7120 t.Fatal(err)
7121 }
7122 req.Trailer = tt.trailer
7123 res, err := cst.Client().Do(req)
7124 if err == nil {
7125 t.Fatal("Expected an error")
7126 }
7127 if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
7128 t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
7129 }
7130 if res != nil {
7131 t.Fatal("Unexpected non-nil response")
7132 }
7133 })
7134 }
7135 }
7136
View as plain text