Source file src/net/http/transport_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Tests for transport.go.
     6  //
     7  // More tests are in clientserver_test.go (for things testing both client & server for both
     8  // HTTP/1 and HTTP/2). This
     9  
    10  package http_test
    11  
    12  import (
    13  	"bufio"
    14  	"bytes"
    15  	"compress/gzip"
    16  	"context"
    17  	"crypto/rand"
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"encoding/binary"
    21  	"errors"
    22  	"fmt"
    23  	"go/token"
    24  	"internal/nettrace"
    25  	"io"
    26  	"log"
    27  	mrand "math/rand"
    28  	"net"
    29  	. "net/http"
    30  	"net/http/httptest"
    31  	"net/http/httptrace"
    32  	"net/http/httputil"
    33  	"net/http/internal/testcert"
    34  	"net/textproto"
    35  	"net/url"
    36  	"os"
    37  	"reflect"
    38  	"runtime"
    39  	"strconv"
    40  	"strings"
    41  	"sync"
    42  	"sync/atomic"
    43  	"testing"
    44  	"testing/iotest"
    45  	"time"
    46  
    47  	"golang.org/x/net/http/httpguts"
    48  )
    49  
    50  // TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
    51  // and then verify that the final 2 responses get errors back.
    52  
    53  // hostPortHandler writes back the client's "host:port".
    54  var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
    55  	if r.FormValue("close") == "true" {
    56  		w.Header().Set("Connection", "close")
    57  	}
    58  	w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
    59  	w.Write([]byte(r.RemoteAddr))
    60  
    61  	// Include the address of the net.Conn in addition to the RemoteAddr,
    62  	// in case kernels reuse source ports quickly (see Issue 52450)
    63  	if c, ok := ResponseWriterConnForTesting(w); ok {
    64  		fmt.Fprintf(w, ", %T %p", c, c)
    65  	}
    66  })
    67  
    68  // testCloseConn is a net.Conn tracked by a testConnSet.
    69  type testCloseConn struct {
    70  	net.Conn
    71  	set *testConnSet
    72  }
    73  
    74  func (c *testCloseConn) Close() error {
    75  	c.set.remove(c)
    76  	return c.Conn.Close()
    77  }
    78  
    79  // testConnSet tracks a set of TCP connections and whether they've
    80  // been closed.
    81  type testConnSet struct {
    82  	t      *testing.T
    83  	mu     sync.Mutex // guards closed and list
    84  	closed map[net.Conn]bool
    85  	list   []net.Conn // in order created
    86  }
    87  
    88  func (tcs *testConnSet) insert(c net.Conn) {
    89  	tcs.mu.Lock()
    90  	defer tcs.mu.Unlock()
    91  	tcs.closed[c] = false
    92  	tcs.list = append(tcs.list, c)
    93  }
    94  
    95  func (tcs *testConnSet) remove(c net.Conn) {
    96  	tcs.mu.Lock()
    97  	defer tcs.mu.Unlock()
    98  	tcs.closed[c] = true
    99  }
   100  
   101  // some tests use this to manage raw tcp connections for later inspection
   102  func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
   103  	connSet := &testConnSet{
   104  		t:      t,
   105  		closed: make(map[net.Conn]bool),
   106  	}
   107  	dial := func(n, addr string) (net.Conn, error) {
   108  		c, err := net.Dial(n, addr)
   109  		if err != nil {
   110  			return nil, err
   111  		}
   112  		tc := &testCloseConn{c, connSet}
   113  		connSet.insert(tc)
   114  		return tc, nil
   115  	}
   116  	return connSet, dial
   117  }
   118  
   119  func (tcs *testConnSet) check(t *testing.T) {
   120  	tcs.mu.Lock()
   121  	defer tcs.mu.Unlock()
   122  	for i := 4; i >= 0; i-- {
   123  		for i, c := range tcs.list {
   124  			if tcs.closed[c] {
   125  				continue
   126  			}
   127  			if i != 0 {
   128  				// TODO(bcmills): What is the Sleep here doing, and why is this
   129  				// Unlock/Sleep/Lock cycle needed at all?
   130  				tcs.mu.Unlock()
   131  				time.Sleep(50 * time.Millisecond)
   132  				tcs.mu.Lock()
   133  				continue
   134  			}
   135  			t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
   136  		}
   137  	}
   138  }
   139  
   140  func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
   141  func testReuseRequest(t *testing.T, mode testMode) {
   142  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   143  		w.Write([]byte("{}"))
   144  	})).ts
   145  
   146  	c := ts.Client()
   147  	req, _ := NewRequest("GET", ts.URL, nil)
   148  	res, err := c.Do(req)
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  	err = res.Body.Close()
   153  	if err != nil {
   154  		t.Fatal(err)
   155  	}
   156  
   157  	res, err = c.Do(req)
   158  	if err != nil {
   159  		t.Fatal(err)
   160  	}
   161  	err = res.Body.Close()
   162  	if err != nil {
   163  		t.Fatal(err)
   164  	}
   165  }
   166  
   167  // Two subsequent requests and verify their response is the same.
   168  // The response from the server is our own IP:port
   169  func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
   170  func testTransportKeepAlives(t *testing.T, mode testMode) {
   171  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   172  
   173  	c := ts.Client()
   174  	for _, disableKeepAlive := range []bool{false, true} {
   175  		c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
   176  		fetch := func(n int) string {
   177  			res, err := c.Get(ts.URL)
   178  			if err != nil {
   179  				t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
   180  			}
   181  			body, err := io.ReadAll(res.Body)
   182  			if err != nil {
   183  				t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
   184  			}
   185  			return string(body)
   186  		}
   187  
   188  		body1 := fetch(1)
   189  		body2 := fetch(2)
   190  
   191  		bodiesDiffer := body1 != body2
   192  		if bodiesDiffer != disableKeepAlive {
   193  			t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
   194  				disableKeepAlive, bodiesDiffer, body1, body2)
   195  		}
   196  	}
   197  }
   198  
   199  func TestTransportConnectionCloseOnResponse(t *testing.T) {
   200  	run(t, testTransportConnectionCloseOnResponse)
   201  }
   202  func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
   203  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   204  
   205  	connSet, testDial := makeTestDial(t)
   206  
   207  	c := ts.Client()
   208  	tr := c.Transport.(*Transport)
   209  	tr.Dial = testDial
   210  
   211  	for _, connectionClose := range []bool{false, true} {
   212  		fetch := func(n int) string {
   213  			req := new(Request)
   214  			var err error
   215  			req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
   216  			if err != nil {
   217  				t.Fatalf("URL parse error: %v", err)
   218  			}
   219  			req.Method = "GET"
   220  			req.Proto = "HTTP/1.1"
   221  			req.ProtoMajor = 1
   222  			req.ProtoMinor = 1
   223  
   224  			res, err := c.Do(req)
   225  			if err != nil {
   226  				t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
   227  			}
   228  			defer res.Body.Close()
   229  			body, err := io.ReadAll(res.Body)
   230  			if err != nil {
   231  				t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
   232  			}
   233  			return string(body)
   234  		}
   235  
   236  		body1 := fetch(1)
   237  		body2 := fetch(2)
   238  		bodiesDiffer := body1 != body2
   239  		if bodiesDiffer != connectionClose {
   240  			t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
   241  				connectionClose, bodiesDiffer, body1, body2)
   242  		}
   243  
   244  		tr.CloseIdleConnections()
   245  	}
   246  
   247  	connSet.check(t)
   248  }
   249  
   250  // TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse
   251  // an underlying TCP connection after making an http.Request with Request.Close set.
   252  //
   253  // It tests the behavior by making an HTTP request to a server which
   254  // describes the source connection it got (remote port number +
   255  // address of its net.Conn).
   256  func TestTransportConnectionCloseOnRequest(t *testing.T) {
   257  	run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
   258  }
   259  func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
   260  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   261  
   262  	connSet, testDial := makeTestDial(t)
   263  
   264  	c := ts.Client()
   265  	tr := c.Transport.(*Transport)
   266  	tr.Dial = testDial
   267  	for _, reqClose := range []bool{false, true} {
   268  		fetch := func(n int) string {
   269  			req := new(Request)
   270  			var err error
   271  			req.URL, err = url.Parse(ts.URL)
   272  			if err != nil {
   273  				t.Fatalf("URL parse error: %v", err)
   274  			}
   275  			req.Method = "GET"
   276  			req.Proto = "HTTP/1.1"
   277  			req.ProtoMajor = 1
   278  			req.ProtoMinor = 1
   279  			req.Close = reqClose
   280  
   281  			res, err := c.Do(req)
   282  			if err != nil {
   283  				t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
   284  			}
   285  			if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
   286  				t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
   287  					reqClose, got, !reqClose)
   288  			}
   289  			body, err := io.ReadAll(res.Body)
   290  			if err != nil {
   291  				t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
   292  			}
   293  			return string(body)
   294  		}
   295  
   296  		body1 := fetch(1)
   297  		body2 := fetch(2)
   298  
   299  		got := 1
   300  		if body1 != body2 {
   301  			got++
   302  		}
   303  		want := 1
   304  		if reqClose {
   305  			want = 2
   306  		}
   307  		if got != want {
   308  			t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
   309  				reqClose, got, want, body1, body2)
   310  		}
   311  
   312  		tr.CloseIdleConnections()
   313  	}
   314  
   315  	connSet.check(t)
   316  }
   317  
   318  // if the Transport's DisableKeepAlives is set, all requests should
   319  // send Connection: close.
   320  // HTTP/1-only (Connection: close doesn't exist in h2)
   321  func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
   322  	run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
   323  }
   324  func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
   325  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   326  
   327  	c := ts.Client()
   328  	c.Transport.(*Transport).DisableKeepAlives = true
   329  
   330  	res, err := c.Get(ts.URL)
   331  	if err != nil {
   332  		t.Fatal(err)
   333  	}
   334  	res.Body.Close()
   335  	if res.Header.Get("X-Saw-Close") != "true" {
   336  		t.Errorf("handler didn't see Connection: close ")
   337  	}
   338  }
   339  
   340  // Test that Transport only sends one "Connection: close", regardless of
   341  // how "close" was indicated.
   342  func TestTransportRespectRequestWantsClose(t *testing.T) {
   343  	run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
   344  }
   345  func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
   346  	tests := []struct {
   347  		disableKeepAlives bool
   348  		close             bool
   349  	}{
   350  		{disableKeepAlives: false, close: false},
   351  		{disableKeepAlives: false, close: true},
   352  		{disableKeepAlives: true, close: false},
   353  		{disableKeepAlives: true, close: true},
   354  	}
   355  
   356  	for _, tc := range tests {
   357  		t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
   358  			func(t *testing.T) {
   359  				ts := newClientServerTest(t, mode, hostPortHandler).ts
   360  
   361  				c := ts.Client()
   362  				c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
   363  				req, err := NewRequest("GET", ts.URL, nil)
   364  				if err != nil {
   365  					t.Fatal(err)
   366  				}
   367  				count := 0
   368  				trace := &httptrace.ClientTrace{
   369  					WroteHeaderField: func(key string, field []string) {
   370  						if key != "Connection" {
   371  							return
   372  						}
   373  						if httpguts.HeaderValuesContainsToken(field, "close") {
   374  							count += 1
   375  						}
   376  					},
   377  				}
   378  				req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   379  				req.Close = tc.close
   380  				res, err := c.Do(req)
   381  				if err != nil {
   382  					t.Fatal(err)
   383  				}
   384  				defer res.Body.Close()
   385  				if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
   386  					t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
   387  				}
   388  			})
   389  	}
   390  
   391  }
   392  
   393  func TestTransportIdleCacheKeys(t *testing.T) {
   394  	run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
   395  }
   396  func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
   397  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   398  	c := ts.Client()
   399  	tr := c.Transport.(*Transport)
   400  
   401  	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
   402  		t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
   403  	}
   404  
   405  	resp, err := c.Get(ts.URL)
   406  	if err != nil {
   407  		t.Error(err)
   408  	}
   409  	io.ReadAll(resp.Body)
   410  
   411  	keys := tr.IdleConnKeysForTesting()
   412  	if e, g := 1, len(keys); e != g {
   413  		t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
   414  	}
   415  
   416  	if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
   417  		t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
   418  	}
   419  
   420  	tr.CloseIdleConnections()
   421  	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
   422  		t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
   423  	}
   424  }
   425  
   426  // Tests that the HTTP transport re-uses connections when a client
   427  // reads to the end of a response Body without closing it.
   428  func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
   429  func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
   430  	const msg = "foobar"
   431  
   432  	var addrSeen map[string]int
   433  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   434  		addrSeen[r.RemoteAddr]++
   435  		if r.URL.Path == "/chunked/" {
   436  			w.WriteHeader(200)
   437  			w.(Flusher).Flush()
   438  		} else {
   439  			w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
   440  			w.WriteHeader(200)
   441  		}
   442  		w.Write([]byte(msg))
   443  	})).ts
   444  
   445  	for pi, path := range []string{"/content-length/", "/chunked/"} {
   446  		wantLen := []int{len(msg), -1}[pi]
   447  		addrSeen = make(map[string]int)
   448  		for i := 0; i < 3; i++ {
   449  			res, err := ts.Client().Get(ts.URL + path)
   450  			if err != nil {
   451  				t.Errorf("Get %s: %v", path, err)
   452  				continue
   453  			}
   454  			// We want to close this body eventually (before the
   455  			// defer afterTest at top runs), but not before the
   456  			// len(addrSeen) check at the bottom of this test,
   457  			// since Closing this early in the loop would risk
   458  			// making connections be re-used for the wrong reason.
   459  			defer res.Body.Close()
   460  
   461  			if res.ContentLength != int64(wantLen) {
   462  				t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
   463  			}
   464  			got, err := io.ReadAll(res.Body)
   465  			if string(got) != msg || err != nil {
   466  				t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
   467  			}
   468  		}
   469  		if len(addrSeen) != 1 {
   470  			t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
   471  		}
   472  	}
   473  }
   474  
   475  func TestTransportMaxPerHostIdleConns(t *testing.T) {
   476  	run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
   477  }
   478  func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
   479  	stop := make(chan struct{}) // stop marks the exit of main Test goroutine
   480  	defer close(stop)
   481  
   482  	resch := make(chan string)
   483  	gotReq := make(chan bool)
   484  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   485  		gotReq <- true
   486  		var msg string
   487  		select {
   488  		case <-stop:
   489  			return
   490  		case msg = <-resch:
   491  		}
   492  		_, err := w.Write([]byte(msg))
   493  		if err != nil {
   494  			t.Errorf("Write: %v", err)
   495  			return
   496  		}
   497  	})).ts
   498  
   499  	c := ts.Client()
   500  	tr := c.Transport.(*Transport)
   501  	maxIdleConnsPerHost := 2
   502  	tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
   503  
   504  	// Start 3 outstanding requests and wait for the server to get them.
   505  	// Their responses will hang until we write to resch, though.
   506  	donech := make(chan bool)
   507  	doReq := func() {
   508  		defer func() {
   509  			select {
   510  			case <-stop:
   511  				return
   512  			case donech <- t.Failed():
   513  			}
   514  		}()
   515  		resp, err := c.Get(ts.URL)
   516  		if err != nil {
   517  			t.Error(err)
   518  			return
   519  		}
   520  		if _, err := io.ReadAll(resp.Body); err != nil {
   521  			t.Errorf("ReadAll: %v", err)
   522  			return
   523  		}
   524  	}
   525  	go doReq()
   526  	<-gotReq
   527  	go doReq()
   528  	<-gotReq
   529  	go doReq()
   530  	<-gotReq
   531  
   532  	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
   533  		t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
   534  	}
   535  
   536  	resch <- "res1"
   537  	<-donech
   538  	keys := tr.IdleConnKeysForTesting()
   539  	if e, g := 1, len(keys); e != g {
   540  		t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
   541  	}
   542  	addr := ts.Listener.Addr().String()
   543  	cacheKey := "|http|" + addr
   544  	if keys[0] != cacheKey {
   545  		t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
   546  	}
   547  	if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
   548  		t.Errorf("after first response, expected %d idle conns; got %d", e, g)
   549  	}
   550  
   551  	resch <- "res2"
   552  	<-donech
   553  	if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
   554  		t.Errorf("after second response, idle conns = %d; want %d", g, w)
   555  	}
   556  
   557  	resch <- "res3"
   558  	<-donech
   559  	if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
   560  		t.Errorf("after third response, idle conns = %d; want %d", g, w)
   561  	}
   562  }
   563  
   564  func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
   565  	run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
   566  }
   567  func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
   568  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   569  		_, err := w.Write([]byte("foo"))
   570  		if err != nil {
   571  			t.Fatalf("Write: %v", err)
   572  		}
   573  	})).ts
   574  	c := ts.Client()
   575  	tr := c.Transport.(*Transport)
   576  	dialStarted := make(chan struct{})
   577  	stallDial := make(chan struct{})
   578  	tr.Dial = func(network, addr string) (net.Conn, error) {
   579  		dialStarted <- struct{}{}
   580  		<-stallDial
   581  		return net.Dial(network, addr)
   582  	}
   583  
   584  	tr.DisableKeepAlives = true
   585  	tr.MaxConnsPerHost = 1
   586  
   587  	preDial := make(chan struct{})
   588  	reqComplete := make(chan struct{})
   589  	doReq := func(reqId string) {
   590  		req, _ := NewRequest("GET", ts.URL, nil)
   591  		trace := &httptrace.ClientTrace{
   592  			GetConn: func(hostPort string) {
   593  				preDial <- struct{}{}
   594  			},
   595  		}
   596  		req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   597  		resp, err := tr.RoundTrip(req)
   598  		if err != nil {
   599  			t.Errorf("unexpected error for request %s: %v", reqId, err)
   600  		}
   601  		_, err = io.ReadAll(resp.Body)
   602  		if err != nil {
   603  			t.Errorf("unexpected error for request %s: %v", reqId, err)
   604  		}
   605  		reqComplete <- struct{}{}
   606  	}
   607  	// get req1 to dial-in-progress
   608  	go doReq("req1")
   609  	<-preDial
   610  	<-dialStarted
   611  
   612  	// get req2 to waiting on conns per host to go down below max
   613  	go doReq("req2")
   614  	<-preDial
   615  	select {
   616  	case <-dialStarted:
   617  		t.Error("req2 dial started while req1 dial in progress")
   618  		return
   619  	default:
   620  	}
   621  
   622  	// let req1 complete
   623  	stallDial <- struct{}{}
   624  	<-reqComplete
   625  
   626  	// let req2 complete
   627  	<-dialStarted
   628  	stallDial <- struct{}{}
   629  	<-reqComplete
   630  }
   631  
   632  func TestTransportMaxConnsPerHost(t *testing.T) {
   633  	run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
   634  }
   635  func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
   636  	CondSkipHTTP2(t)
   637  
   638  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
   639  		_, err := w.Write([]byte("foo"))
   640  		if err != nil {
   641  			t.Fatalf("Write: %v", err)
   642  		}
   643  	})
   644  
   645  	ts := newClientServerTest(t, mode, h).ts
   646  	c := ts.Client()
   647  	tr := c.Transport.(*Transport)
   648  	tr.MaxConnsPerHost = 1
   649  
   650  	mu := sync.Mutex{}
   651  	var conns []net.Conn
   652  	var dialCnt, gotConnCnt, tlsHandshakeCnt int32
   653  	tr.Dial = func(network, addr string) (net.Conn, error) {
   654  		atomic.AddInt32(&dialCnt, 1)
   655  		c, err := net.Dial(network, addr)
   656  		mu.Lock()
   657  		defer mu.Unlock()
   658  		conns = append(conns, c)
   659  		return c, err
   660  	}
   661  
   662  	doReq := func() {
   663  		trace := &httptrace.ClientTrace{
   664  			GotConn: func(connInfo httptrace.GotConnInfo) {
   665  				if !connInfo.Reused {
   666  					atomic.AddInt32(&gotConnCnt, 1)
   667  				}
   668  			},
   669  			TLSHandshakeStart: func() {
   670  				atomic.AddInt32(&tlsHandshakeCnt, 1)
   671  			},
   672  		}
   673  		req, _ := NewRequest("GET", ts.URL, nil)
   674  		req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   675  
   676  		resp, err := c.Do(req)
   677  		if err != nil {
   678  			t.Fatalf("request failed: %v", err)
   679  		}
   680  		defer resp.Body.Close()
   681  		_, err = io.ReadAll(resp.Body)
   682  		if err != nil {
   683  			t.Fatalf("read body failed: %v", err)
   684  		}
   685  	}
   686  
   687  	wg := sync.WaitGroup{}
   688  	for i := 0; i < 10; i++ {
   689  		wg.Add(1)
   690  		go func() {
   691  			defer wg.Done()
   692  			doReq()
   693  		}()
   694  	}
   695  	wg.Wait()
   696  
   697  	expected := int32(tr.MaxConnsPerHost)
   698  	if dialCnt != expected {
   699  		t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
   700  	}
   701  	if gotConnCnt != expected {
   702  		t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
   703  	}
   704  	if ts.TLS != nil && tlsHandshakeCnt != expected {
   705  		t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
   706  	}
   707  
   708  	if t.Failed() {
   709  		t.FailNow()
   710  	}
   711  
   712  	mu.Lock()
   713  	for _, c := range conns {
   714  		c.Close()
   715  	}
   716  	conns = nil
   717  	mu.Unlock()
   718  	tr.CloseIdleConnections()
   719  
   720  	doReq()
   721  	expected++
   722  	if dialCnt != expected {
   723  		t.Errorf("round 2: too many dials: %d", dialCnt)
   724  	}
   725  	if gotConnCnt != expected {
   726  		t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
   727  	}
   728  	if ts.TLS != nil && tlsHandshakeCnt != expected {
   729  		t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
   730  	}
   731  }
   732  
   733  func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
   734  	run(t, testTransportMaxConnsPerHostDialCancellation,
   735  		testNotParallel, // because test uses SetPendingDialHooks
   736  		[]testMode{http1Mode, https1Mode, http2Mode},
   737  	)
   738  }
   739  
   740  func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
   741  	CondSkipHTTP2(t)
   742  
   743  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
   744  		_, err := w.Write([]byte("foo"))
   745  		if err != nil {
   746  			t.Fatalf("Write: %v", err)
   747  		}
   748  	})
   749  
   750  	cst := newClientServerTest(t, mode, h)
   751  	defer cst.close()
   752  	ts := cst.ts
   753  	c := ts.Client()
   754  	tr := c.Transport.(*Transport)
   755  	tr.MaxConnsPerHost = 1
   756  
   757  	// This request is canceled when dial is queued, which preempts dialing.
   758  	ctx, cancel := context.WithCancel(context.Background())
   759  	defer cancel()
   760  	SetPendingDialHooks(cancel, nil)
   761  	defer SetPendingDialHooks(nil, nil)
   762  
   763  	req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
   764  	_, err := c.Do(req)
   765  	if !errors.Is(err, context.Canceled) {
   766  		t.Errorf("expected error %v, got %v", context.Canceled, err)
   767  	}
   768  
   769  	// This request should succeed.
   770  	SetPendingDialHooks(nil, nil)
   771  	req, _ = NewRequest("GET", ts.URL, nil)
   772  	resp, err := c.Do(req)
   773  	if err != nil {
   774  		t.Fatalf("request failed: %v", err)
   775  	}
   776  	defer resp.Body.Close()
   777  	_, err = io.ReadAll(resp.Body)
   778  	if err != nil {
   779  		t.Fatalf("read body failed: %v", err)
   780  	}
   781  }
   782  
   783  func TestTransportRemovesDeadIdleConnections(t *testing.T) {
   784  	run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
   785  }
   786  func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
   787  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   788  		io.WriteString(w, r.RemoteAddr)
   789  	})).ts
   790  
   791  	c := ts.Client()
   792  	tr := c.Transport.(*Transport)
   793  
   794  	doReq := func(name string) {
   795  		// Do a POST instead of a GET to prevent the Transport's
   796  		// idempotent request retry logic from kicking in...
   797  		res, err := c.Post(ts.URL, "", nil)
   798  		if err != nil {
   799  			t.Fatalf("%s: %v", name, err)
   800  		}
   801  		if res.StatusCode != 200 {
   802  			t.Fatalf("%s: %v", name, res.Status)
   803  		}
   804  		defer res.Body.Close()
   805  		slurp, err := io.ReadAll(res.Body)
   806  		if err != nil {
   807  			t.Fatalf("%s: %v", name, err)
   808  		}
   809  		t.Logf("%s: ok (%q)", name, slurp)
   810  	}
   811  
   812  	doReq("first")
   813  	keys1 := tr.IdleConnKeysForTesting()
   814  
   815  	ts.CloseClientConnections()
   816  
   817  	var keys2 []string
   818  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
   819  		keys2 = tr.IdleConnKeysForTesting()
   820  		if len(keys2) != 0 {
   821  			if d > 0 {
   822  				t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
   823  			}
   824  			return false
   825  		}
   826  		return true
   827  	})
   828  
   829  	doReq("second")
   830  }
   831  
   832  // Test that the Transport notices when a server hangs up on its
   833  // unexpectedly (a keep-alive connection is closed).
   834  func TestTransportServerClosingUnexpectedly(t *testing.T) {
   835  	run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
   836  }
   837  func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
   838  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   839  	c := ts.Client()
   840  
   841  	fetch := func(n, retries int) string {
   842  		condFatalf := func(format string, arg ...any) {
   843  			if retries <= 0 {
   844  				t.Fatalf(format, arg...)
   845  			}
   846  			t.Logf("retrying shortly after expected error: "+format, arg...)
   847  			time.Sleep(time.Second / time.Duration(retries))
   848  		}
   849  		for retries >= 0 {
   850  			retries--
   851  			res, err := c.Get(ts.URL)
   852  			if err != nil {
   853  				condFatalf("error in req #%d, GET: %v", n, err)
   854  				continue
   855  			}
   856  			body, err := io.ReadAll(res.Body)
   857  			if err != nil {
   858  				condFatalf("error in req #%d, ReadAll: %v", n, err)
   859  				continue
   860  			}
   861  			res.Body.Close()
   862  			return string(body)
   863  		}
   864  		panic("unreachable")
   865  	}
   866  
   867  	body1 := fetch(1, 0)
   868  	body2 := fetch(2, 0)
   869  
   870  	// Close all the idle connections in a way that's similar to
   871  	// the server hanging up on us. We don't use
   872  	// httptest.Server.CloseClientConnections because it's
   873  	// best-effort and stops blocking after 5 seconds. On a loaded
   874  	// machine running many tests concurrently it's possible for
   875  	// that method to be async and cause the body3 fetch below to
   876  	// run on an old connection. This function is synchronous.
   877  	ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
   878  
   879  	body3 := fetch(3, 5)
   880  
   881  	if body1 != body2 {
   882  		t.Errorf("expected body1 and body2 to be equal")
   883  	}
   884  	if body2 == body3 {
   885  		t.Errorf("expected body2 and body3 to be different")
   886  	}
   887  }
   888  
   889  // Test for https://golang.org/issue/2616 (appropriate issue number)
   890  // This fails pretty reliably with GOMAXPROCS=100 or something high.
   891  func TestStressSurpriseServerCloses(t *testing.T) {
   892  	run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
   893  }
   894  func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
   895  	if testing.Short() {
   896  		t.Skip("skipping test in short mode")
   897  	}
   898  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   899  		w.Header().Set("Content-Length", "5")
   900  		w.Header().Set("Content-Type", "text/plain")
   901  		w.Write([]byte("Hello"))
   902  		w.(Flusher).Flush()
   903  		conn, buf, _ := w.(Hijacker).Hijack()
   904  		buf.Flush()
   905  		conn.Close()
   906  	})).ts
   907  	c := ts.Client()
   908  
   909  	// Do a bunch of traffic from different goroutines. Send to activityc
   910  	// after each request completes, regardless of whether it failed.
   911  	// If these are too high, OS X exhausts its ephemeral ports
   912  	// and hangs waiting for them to transition TCP states. That's
   913  	// not what we want to test. TODO(bradfitz): use an io.Pipe
   914  	// dialer for this test instead?
   915  	const (
   916  		numClients    = 20
   917  		reqsPerClient = 25
   918  	)
   919  	var wg sync.WaitGroup
   920  	wg.Add(numClients * reqsPerClient)
   921  	for i := 0; i < numClients; i++ {
   922  		go func() {
   923  			for i := 0; i < reqsPerClient; i++ {
   924  				res, err := c.Get(ts.URL)
   925  				if err == nil {
   926  					// We expect errors since the server is
   927  					// hanging up on us after telling us to
   928  					// send more requests, so we don't
   929  					// actually care what the error is.
   930  					// But we want to close the body in cases
   931  					// where we won the race.
   932  					res.Body.Close()
   933  				}
   934  				wg.Done()
   935  			}
   936  		}()
   937  	}
   938  
   939  	// Make sure all the request come back, one way or another.
   940  	wg.Wait()
   941  }
   942  
   943  // TestTransportHeadResponses verifies that we deal with Content-Lengths
   944  // with no bodies properly
   945  func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
   946  func testTransportHeadResponses(t *testing.T, mode testMode) {
   947  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   948  		if r.Method != "HEAD" {
   949  			panic("expected HEAD; got " + r.Method)
   950  		}
   951  		w.Header().Set("Content-Length", "123")
   952  		w.WriteHeader(200)
   953  	})).ts
   954  	c := ts.Client()
   955  
   956  	for i := 0; i < 2; i++ {
   957  		res, err := c.Head(ts.URL)
   958  		if err != nil {
   959  			t.Errorf("error on loop %d: %v", i, err)
   960  			continue
   961  		}
   962  		if e, g := "123", res.Header.Get("Content-Length"); e != g {
   963  			t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
   964  		}
   965  		if e, g := int64(123), res.ContentLength; e != g {
   966  			t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
   967  		}
   968  		if all, err := io.ReadAll(res.Body); err != nil {
   969  			t.Errorf("loop %d: Body ReadAll: %v", i, err)
   970  		} else if len(all) != 0 {
   971  			t.Errorf("Bogus body %q", all)
   972  		}
   973  	}
   974  }
   975  
   976  // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
   977  // on responses to HEAD requests.
   978  func TestTransportHeadChunkedResponse(t *testing.T) {
   979  	run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
   980  }
   981  func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
   982  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   983  		if r.Method != "HEAD" {
   984  			panic("expected HEAD; got " + r.Method)
   985  		}
   986  		w.Header().Set("Transfer-Encoding", "chunked") // client should ignore
   987  		w.Header().Set("x-client-ipport", r.RemoteAddr)
   988  		w.WriteHeader(200)
   989  	})).ts
   990  	c := ts.Client()
   991  
   992  	// Ensure that we wait for the readLoop to complete before
   993  	// calling Head again
   994  	didRead := make(chan bool)
   995  	SetReadLoopBeforeNextReadHook(func() { didRead <- true })
   996  	defer SetReadLoopBeforeNextReadHook(nil)
   997  
   998  	res1, err := c.Head(ts.URL)
   999  	<-didRead
  1000  
  1001  	if err != nil {
  1002  		t.Fatalf("request 1 error: %v", err)
  1003  	}
  1004  
  1005  	res2, err := c.Head(ts.URL)
  1006  	<-didRead
  1007  
  1008  	if err != nil {
  1009  		t.Fatalf("request 2 error: %v", err)
  1010  	}
  1011  	if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
  1012  		t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
  1013  	}
  1014  }
  1015  
  1016  var roundTripTests = []struct {
  1017  	accept       string
  1018  	expectAccept string
  1019  	compressed   bool
  1020  }{
  1021  	// Requests with no accept-encoding header use transparent compression
  1022  	{"", "gzip", false},
  1023  	// Requests with other accept-encoding should pass through unmodified
  1024  	{"foo", "foo", false},
  1025  	// Requests with accept-encoding == gzip should be passed through
  1026  	{"gzip", "gzip", true},
  1027  }
  1028  
  1029  // Test that the modification made to the Request by the RoundTripper is cleaned up
  1030  func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
  1031  func testRoundTripGzip(t *testing.T, mode testMode) {
  1032  	const responseBody = "test response body"
  1033  	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  1034  		accept := req.Header.Get("Accept-Encoding")
  1035  		if expect := req.FormValue("expect_accept"); accept != expect {
  1036  			t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
  1037  				req.FormValue("testnum"), accept, expect)
  1038  		}
  1039  		if accept == "gzip" {
  1040  			rw.Header().Set("Content-Encoding", "gzip")
  1041  			gz := gzip.NewWriter(rw)
  1042  			gz.Write([]byte(responseBody))
  1043  			gz.Close()
  1044  		} else {
  1045  			rw.Header().Set("Content-Encoding", accept)
  1046  			rw.Write([]byte(responseBody))
  1047  		}
  1048  	})).ts
  1049  	tr := ts.Client().Transport.(*Transport)
  1050  
  1051  	for i, test := range roundTripTests {
  1052  		// Test basic request (no accept-encoding)
  1053  		req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
  1054  		if test.accept != "" {
  1055  			req.Header.Set("Accept-Encoding", test.accept)
  1056  		}
  1057  		res, err := tr.RoundTrip(req)
  1058  		if err != nil {
  1059  			t.Errorf("%d. RoundTrip: %v", i, err)
  1060  			continue
  1061  		}
  1062  		var body []byte
  1063  		if test.compressed {
  1064  			var r *gzip.Reader
  1065  			r, err = gzip.NewReader(res.Body)
  1066  			if err != nil {
  1067  				t.Errorf("%d. gzip NewReader: %v", i, err)
  1068  				continue
  1069  			}
  1070  			body, err = io.ReadAll(r)
  1071  			res.Body.Close()
  1072  		} else {
  1073  			body, err = io.ReadAll(res.Body)
  1074  		}
  1075  		if err != nil {
  1076  			t.Errorf("%d. Error: %q", i, err)
  1077  			continue
  1078  		}
  1079  		if g, e := string(body), responseBody; g != e {
  1080  			t.Errorf("%d. body = %q; want %q", i, g, e)
  1081  		}
  1082  		if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
  1083  			t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
  1084  		}
  1085  		if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
  1086  			t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
  1087  		}
  1088  	}
  1089  
  1090  }
  1091  
  1092  func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
  1093  func testTransportGzip(t *testing.T, mode testMode) {
  1094  	if mode == http2Mode {
  1095  		t.Skip("https://go.dev/issue/56020")
  1096  	}
  1097  	const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
  1098  	const nRandBytes = 1024 * 1024
  1099  	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  1100  		if req.Method == "HEAD" {
  1101  			if g := req.Header.Get("Accept-Encoding"); g != "" {
  1102  				t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
  1103  			}
  1104  			return
  1105  		}
  1106  		if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
  1107  			t.Errorf("Accept-Encoding = %q, want %q", g, e)
  1108  		}
  1109  		rw.Header().Set("Content-Encoding", "gzip")
  1110  
  1111  		var w io.Writer = rw
  1112  		var buf bytes.Buffer
  1113  		if req.FormValue("chunked") == "0" {
  1114  			w = &buf
  1115  			defer io.Copy(rw, &buf)
  1116  			defer func() {
  1117  				rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
  1118  			}()
  1119  		}
  1120  		gz := gzip.NewWriter(w)
  1121  		gz.Write([]byte(testString))
  1122  		if req.FormValue("body") == "large" {
  1123  			io.CopyN(gz, rand.Reader, nRandBytes)
  1124  		}
  1125  		gz.Close()
  1126  	})).ts
  1127  	c := ts.Client()
  1128  
  1129  	for _, chunked := range []string{"1", "0"} {
  1130  		// First fetch something large, but only read some of it.
  1131  		res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
  1132  		if err != nil {
  1133  			t.Fatalf("large get: %v", err)
  1134  		}
  1135  		buf := make([]byte, len(testString))
  1136  		n, err := io.ReadFull(res.Body, buf)
  1137  		if err != nil {
  1138  			t.Fatalf("partial read of large response: size=%d, %v", n, err)
  1139  		}
  1140  		if e, g := testString, string(buf); e != g {
  1141  			t.Errorf("partial read got %q, expected %q", g, e)
  1142  		}
  1143  		res.Body.Close()
  1144  		// Read on the body, even though it's closed
  1145  		n, err = res.Body.Read(buf)
  1146  		if n != 0 || err == nil {
  1147  			t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
  1148  		}
  1149  
  1150  		// Then something small.
  1151  		res, err = c.Get(ts.URL + "/?chunked=" + chunked)
  1152  		if err != nil {
  1153  			t.Fatal(err)
  1154  		}
  1155  		body, err := io.ReadAll(res.Body)
  1156  		if err != nil {
  1157  			t.Fatal(err)
  1158  		}
  1159  		if g, e := string(body), testString; g != e {
  1160  			t.Fatalf("body = %q; want %q", g, e)
  1161  		}
  1162  		if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
  1163  			t.Fatalf("Content-Encoding = %q; want %q", g, e)
  1164  		}
  1165  
  1166  		// Read on the body after it's been fully read:
  1167  		n, err = res.Body.Read(buf)
  1168  		if n != 0 || err == nil {
  1169  			t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
  1170  		}
  1171  		res.Body.Close()
  1172  		n, err = res.Body.Read(buf)
  1173  		if n != 0 || err == nil {
  1174  			t.Errorf("expected Read error after Close; got %d, %v", n, err)
  1175  		}
  1176  	}
  1177  
  1178  	// And a HEAD request too, because they're always weird.
  1179  	res, err := c.Head(ts.URL)
  1180  	if err != nil {
  1181  		t.Fatalf("Head: %v", err)
  1182  	}
  1183  	if res.StatusCode != 200 {
  1184  		t.Errorf("Head status=%d; want=200", res.StatusCode)
  1185  	}
  1186  }
  1187  
  1188  // A transport100Continue test exercises Transport behaviors when sending a
  1189  // request with an Expect: 100-continue header.
  1190  type transport100ContinueTest struct {
  1191  	t *testing.T
  1192  
  1193  	reqdone chan struct{}
  1194  	resp    *Response
  1195  	respErr error
  1196  
  1197  	conn   net.Conn
  1198  	reader *bufio.Reader
  1199  }
  1200  
  1201  const transport100ContinueTestBody = "request body"
  1202  
  1203  // newTransport100ContinueTest creates a Transport and sends an Expect: 100-continue
  1204  // request on it.
  1205  func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
  1206  	ln := newLocalListener(t)
  1207  	defer ln.Close()
  1208  
  1209  	test := &transport100ContinueTest{
  1210  		t:       t,
  1211  		reqdone: make(chan struct{}),
  1212  	}
  1213  
  1214  	tr := &Transport{
  1215  		ExpectContinueTimeout: timeout,
  1216  	}
  1217  	go func() {
  1218  		defer close(test.reqdone)
  1219  		body := strings.NewReader(transport100ContinueTestBody)
  1220  		req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
  1221  		req.Header.Set("Expect", "100-continue")
  1222  		req.ContentLength = int64(len(transport100ContinueTestBody))
  1223  		test.resp, test.respErr = tr.RoundTrip(req)
  1224  		test.resp.Body.Close()
  1225  	}()
  1226  
  1227  	c, err := ln.Accept()
  1228  	if err != nil {
  1229  		t.Fatalf("Accept: %v", err)
  1230  	}
  1231  	t.Cleanup(func() {
  1232  		c.Close()
  1233  	})
  1234  	br := bufio.NewReader(c)
  1235  	_, err = ReadRequest(br)
  1236  	if err != nil {
  1237  		t.Fatalf("ReadRequest: %v", err)
  1238  	}
  1239  	test.conn = c
  1240  	test.reader = br
  1241  	t.Cleanup(func() {
  1242  		<-test.reqdone
  1243  		tr.CloseIdleConnections()
  1244  		got, _ := io.ReadAll(test.reader)
  1245  		if len(got) > 0 {
  1246  			t.Fatalf("Transport sent unexpected bytes: %q", got)
  1247  		}
  1248  	})
  1249  
  1250  	return test
  1251  }
  1252  
  1253  // respond sends response lines from the server to the transport.
  1254  func (test *transport100ContinueTest) respond(lines ...string) {
  1255  	for _, line := range lines {
  1256  		if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
  1257  			test.t.Fatalf("Write: %v", err)
  1258  		}
  1259  	}
  1260  	if _, err := test.conn.Write([]byte("\r\n")); err != nil {
  1261  		test.t.Fatalf("Write: %v", err)
  1262  	}
  1263  }
  1264  
  1265  // wantBodySent ensures the transport has sent the request body to the server.
  1266  func (test *transport100ContinueTest) wantBodySent() {
  1267  	got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
  1268  	if err != nil {
  1269  		test.t.Fatalf("unexpected error reading body: %v", err)
  1270  	}
  1271  	if got, want := string(got), transport100ContinueTestBody; got != want {
  1272  		test.t.Fatalf("unexpected body: got %q, want %q", got, want)
  1273  	}
  1274  }
  1275  
  1276  // wantRequestDone ensures the Transport.RoundTrip has completed with the expected status.
  1277  func (test *transport100ContinueTest) wantRequestDone(want int) {
  1278  	<-test.reqdone
  1279  	if test.respErr != nil {
  1280  		test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
  1281  	}
  1282  	if got := test.resp.StatusCode; got != want {
  1283  		test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
  1284  	}
  1285  }
  1286  
  1287  func TestTransportExpect100ContinueSent(t *testing.T) {
  1288  	test := newTransport100ContinueTest(t, 1*time.Hour)
  1289  	// Server sends a 100 Continue response, and the client sends the request body.
  1290  	test.respond("HTTP/1.1 100 Continue")
  1291  	test.wantBodySent()
  1292  	test.respond("HTTP/1.1 200", "Content-Length: 0")
  1293  	test.wantRequestDone(200)
  1294  }
  1295  
  1296  func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
  1297  	test := newTransport100ContinueTest(t, 1*time.Hour)
  1298  	// No 100 Continue response, no Connection: close header.
  1299  	test.respond("HTTP/1.1 200", "Content-Length: 0")
  1300  	test.wantBodySent()
  1301  	test.wantRequestDone(200)
  1302  }
  1303  
  1304  func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
  1305  	test := newTransport100ContinueTest(t, 1*time.Hour)
  1306  	// No 100 Continue response, Connection: close header set.
  1307  	test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
  1308  	test.wantRequestDone(200)
  1309  }
  1310  
  1311  func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
  1312  	test := newTransport100ContinueTest(t, 1*time.Hour)
  1313  	// No 100 Continue response, no Connection: close header.
  1314  	test.respond("HTTP/1.1 500", "Content-Length: 0")
  1315  	test.wantBodySent()
  1316  	test.wantRequestDone(500)
  1317  }
  1318  
  1319  func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
  1320  	test := newTransport100ContinueTest(t, 5*time.Millisecond) // short timeout
  1321  	test.wantBodySent()                                        // after timeout
  1322  	test.respond("HTTP/1.1 200", "Content-Length: 0")
  1323  	test.wantRequestDone(200)
  1324  }
  1325  
  1326  func TestSOCKS5Proxy(t *testing.T) {
  1327  	run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
  1328  }
  1329  func testSOCKS5Proxy(t *testing.T, mode testMode) {
  1330  	ch := make(chan string, 1)
  1331  	l := newLocalListener(t)
  1332  	defer l.Close()
  1333  	defer close(ch)
  1334  	proxy := func(t *testing.T) {
  1335  		s, err := l.Accept()
  1336  		if err != nil {
  1337  			t.Errorf("socks5 proxy Accept(): %v", err)
  1338  			return
  1339  		}
  1340  		defer s.Close()
  1341  		var buf [22]byte
  1342  		if _, err := io.ReadFull(s, buf[:3]); err != nil {
  1343  			t.Errorf("socks5 proxy initial read: %v", err)
  1344  			return
  1345  		}
  1346  		if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
  1347  			t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
  1348  			return
  1349  		}
  1350  		if _, err := s.Write([]byte{5, 0}); err != nil {
  1351  			t.Errorf("socks5 proxy initial write: %v", err)
  1352  			return
  1353  		}
  1354  		if _, err := io.ReadFull(s, buf[:4]); err != nil {
  1355  			t.Errorf("socks5 proxy second read: %v", err)
  1356  			return
  1357  		}
  1358  		if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
  1359  			t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
  1360  			return
  1361  		}
  1362  		var ipLen int
  1363  		switch buf[3] {
  1364  		case 1:
  1365  			ipLen = net.IPv4len
  1366  		case 4:
  1367  			ipLen = net.IPv6len
  1368  		default:
  1369  			t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
  1370  			return
  1371  		}
  1372  		if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
  1373  			t.Errorf("socks5 proxy address read: %v", err)
  1374  			return
  1375  		}
  1376  		ip := net.IP(buf[4 : ipLen+4])
  1377  		port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
  1378  		copy(buf[:3], []byte{5, 0, 0})
  1379  		if _, err := s.Write(buf[:ipLen+6]); err != nil {
  1380  			t.Errorf("socks5 proxy connect write: %v", err)
  1381  			return
  1382  		}
  1383  		ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
  1384  
  1385  		// Implement proxying.
  1386  		targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
  1387  		targetConn, err := net.Dial("tcp", targetHost)
  1388  		if err != nil {
  1389  			t.Errorf("net.Dial failed")
  1390  			return
  1391  		}
  1392  		go io.Copy(targetConn, s)
  1393  		io.Copy(s, targetConn) // Wait for the client to close the socket.
  1394  		targetConn.Close()
  1395  	}
  1396  
  1397  	pu, err := url.Parse("socks5://" + l.Addr().String())
  1398  	if err != nil {
  1399  		t.Fatal(err)
  1400  	}
  1401  
  1402  	sentinelHeader := "X-Sentinel"
  1403  	sentinelValue := "12345"
  1404  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
  1405  		w.Header().Set(sentinelHeader, sentinelValue)
  1406  	})
  1407  	for _, useTLS := range []bool{false, true} {
  1408  		t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
  1409  			ts := newClientServerTest(t, mode, h).ts
  1410  			go proxy(t)
  1411  			c := ts.Client()
  1412  			c.Transport.(*Transport).Proxy = ProxyURL(pu)
  1413  			r, err := c.Head(ts.URL)
  1414  			if err != nil {
  1415  				t.Fatal(err)
  1416  			}
  1417  			if r.Header.Get(sentinelHeader) != sentinelValue {
  1418  				t.Errorf("Failed to retrieve sentinel value")
  1419  			}
  1420  			got := <-ch
  1421  			ts.Close()
  1422  			tsu, err := url.Parse(ts.URL)
  1423  			if err != nil {
  1424  				t.Fatal(err)
  1425  			}
  1426  			want := "proxy for " + tsu.Host
  1427  			if got != want {
  1428  				t.Errorf("got %q, want %q", got, want)
  1429  			}
  1430  		})
  1431  	}
  1432  }
  1433  
  1434  func TestTransportProxy(t *testing.T) {
  1435  	defer afterTest(t)
  1436  	testCases := []struct{ siteMode, proxyMode testMode }{
  1437  		{http1Mode, http1Mode},
  1438  		{http1Mode, https1Mode},
  1439  		{https1Mode, http1Mode},
  1440  		{https1Mode, https1Mode},
  1441  	}
  1442  	for _, testCase := range testCases {
  1443  		siteMode := testCase.siteMode
  1444  		proxyMode := testCase.proxyMode
  1445  		t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
  1446  			siteCh := make(chan *Request, 1)
  1447  			h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
  1448  				siteCh <- r
  1449  			})
  1450  			proxyCh := make(chan *Request, 1)
  1451  			h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
  1452  				proxyCh <- r
  1453  				// Implement an entire CONNECT proxy
  1454  				if r.Method == "CONNECT" {
  1455  					hijacker, ok := w.(Hijacker)
  1456  					if !ok {
  1457  						t.Errorf("hijack not allowed")
  1458  						return
  1459  					}
  1460  					clientConn, _, err := hijacker.Hijack()
  1461  					if err != nil {
  1462  						t.Errorf("hijacking failed")
  1463  						return
  1464  					}
  1465  					res := &Response{
  1466  						StatusCode: StatusOK,
  1467  						Proto:      "HTTP/1.1",
  1468  						ProtoMajor: 1,
  1469  						ProtoMinor: 1,
  1470  						Header:     make(Header),
  1471  					}
  1472  
  1473  					targetConn, err := net.Dial("tcp", r.URL.Host)
  1474  					if err != nil {
  1475  						t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
  1476  						return
  1477  					}
  1478  
  1479  					if err := res.Write(clientConn); err != nil {
  1480  						t.Errorf("Writing 200 OK failed: %v", err)
  1481  						return
  1482  					}
  1483  
  1484  					go io.Copy(targetConn, clientConn)
  1485  					go func() {
  1486  						io.Copy(clientConn, targetConn)
  1487  						targetConn.Close()
  1488  					}()
  1489  				}
  1490  			})
  1491  			ts := newClientServerTest(t, siteMode, h1).ts
  1492  			proxy := newClientServerTest(t, proxyMode, h2).ts
  1493  
  1494  			pu, err := url.Parse(proxy.URL)
  1495  			if err != nil {
  1496  				t.Fatal(err)
  1497  			}
  1498  
  1499  			// If neither server is HTTPS or both are, then c may be derived from either.
  1500  			// If only one server is HTTPS, c must be derived from that server in order
  1501  			// to ensure that it is configured to use the fake root CA from testcert.go.
  1502  			c := proxy.Client()
  1503  			if siteMode == https1Mode {
  1504  				c = ts.Client()
  1505  			}
  1506  
  1507  			c.Transport.(*Transport).Proxy = ProxyURL(pu)
  1508  			if _, err := c.Head(ts.URL); err != nil {
  1509  				t.Error(err)
  1510  			}
  1511  			got := <-proxyCh
  1512  			c.Transport.(*Transport).CloseIdleConnections()
  1513  			ts.Close()
  1514  			proxy.Close()
  1515  			if siteMode == https1Mode {
  1516  				// First message should be a CONNECT, asking for a socket to the real server,
  1517  				if got.Method != "CONNECT" {
  1518  					t.Errorf("Wrong method for secure proxying: %q", got.Method)
  1519  				}
  1520  				gotHost := got.URL.Host
  1521  				pu, err := url.Parse(ts.URL)
  1522  				if err != nil {
  1523  					t.Fatal("Invalid site URL")
  1524  				}
  1525  				if wantHost := pu.Host; gotHost != wantHost {
  1526  					t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
  1527  				}
  1528  
  1529  				// The next message on the channel should be from the site's server.
  1530  				next := <-siteCh
  1531  				if next.Method != "HEAD" {
  1532  					t.Errorf("Wrong method at destination: %s", next.Method)
  1533  				}
  1534  				if nextURL := next.URL.String(); nextURL != "/" {
  1535  					t.Errorf("Wrong URL at destination: %s", nextURL)
  1536  				}
  1537  			} else {
  1538  				if got.Method != "HEAD" {
  1539  					t.Errorf("Wrong method for destination: %q", got.Method)
  1540  				}
  1541  				gotURL := got.URL.String()
  1542  				wantURL := ts.URL + "/"
  1543  				if gotURL != wantURL {
  1544  					t.Errorf("Got URL %q, want %q", gotURL, wantURL)
  1545  				}
  1546  			}
  1547  		})
  1548  	}
  1549  }
  1550  
  1551  func TestOnProxyConnectResponse(t *testing.T) {
  1552  
  1553  	var tcases = []struct {
  1554  		proxyStatusCode int
  1555  		err             error
  1556  	}{
  1557  		{
  1558  			StatusOK,
  1559  			nil,
  1560  		},
  1561  		{
  1562  			StatusForbidden,
  1563  			errors.New("403"),
  1564  		},
  1565  	}
  1566  	for _, tcase := range tcases {
  1567  		h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
  1568  
  1569  		})
  1570  
  1571  		h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
  1572  			// Implement an entire CONNECT proxy
  1573  			if r.Method == "CONNECT" {
  1574  				if tcase.proxyStatusCode != StatusOK {
  1575  					w.WriteHeader(tcase.proxyStatusCode)
  1576  					return
  1577  				}
  1578  				hijacker, ok := w.(Hijacker)
  1579  				if !ok {
  1580  					t.Errorf("hijack not allowed")
  1581  					return
  1582  				}
  1583  				clientConn, _, err := hijacker.Hijack()
  1584  				if err != nil {
  1585  					t.Errorf("hijacking failed")
  1586  					return
  1587  				}
  1588  				res := &Response{
  1589  					StatusCode: StatusOK,
  1590  					Proto:      "HTTP/1.1",
  1591  					ProtoMajor: 1,
  1592  					ProtoMinor: 1,
  1593  					Header:     make(Header),
  1594  				}
  1595  
  1596  				targetConn, err := net.Dial("tcp", r.URL.Host)
  1597  				if err != nil {
  1598  					t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
  1599  					return
  1600  				}
  1601  
  1602  				if err := res.Write(clientConn); err != nil {
  1603  					t.Errorf("Writing 200 OK failed: %v", err)
  1604  					return
  1605  				}
  1606  
  1607  				go io.Copy(targetConn, clientConn)
  1608  				go func() {
  1609  					io.Copy(clientConn, targetConn)
  1610  					targetConn.Close()
  1611  				}()
  1612  			}
  1613  		})
  1614  		ts := newClientServerTest(t, https1Mode, h1).ts
  1615  		proxy := newClientServerTest(t, https1Mode, h2).ts
  1616  
  1617  		pu, err := url.Parse(proxy.URL)
  1618  		if err != nil {
  1619  			t.Fatal(err)
  1620  		}
  1621  
  1622  		c := proxy.Client()
  1623  
  1624  		var (
  1625  			dials  atomic.Int32
  1626  			closes atomic.Int32
  1627  		)
  1628  		c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  1629  			conn, err := net.Dial(network, addr)
  1630  			if err != nil {
  1631  				return nil, err
  1632  			}
  1633  			dials.Add(1)
  1634  			return noteCloseConn{
  1635  				Conn: conn,
  1636  				closeFunc: func() {
  1637  					closes.Add(1)
  1638  				},
  1639  			}, nil
  1640  		}
  1641  
  1642  		c.Transport.(*Transport).Proxy = ProxyURL(pu)
  1643  		c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
  1644  			if proxyURL.String() != pu.String() {
  1645  				t.Errorf("proxy url got %s, want %s", proxyURL, pu)
  1646  			}
  1647  
  1648  			if "https://"+connectReq.URL.String() != ts.URL {
  1649  				t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
  1650  			}
  1651  			return tcase.err
  1652  		}
  1653  		wantCloses := int32(0)
  1654  		if _, err := c.Head(ts.URL); err != nil {
  1655  			wantCloses = 1
  1656  			if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
  1657  				t.Errorf("got %v, want %v", err, tcase.err)
  1658  			}
  1659  		} else {
  1660  			if tcase.err != nil {
  1661  				t.Errorf("got %v, want nil", err)
  1662  			}
  1663  		}
  1664  		if got, want := dials.Load(), int32(1); got != want {
  1665  			t.Errorf("got %v dials, want %v", got, want)
  1666  		}
  1667  		// #64804: If OnProxyConnectResponse returns an error, we should close the conn.
  1668  		if got, want := closes.Load(), wantCloses; got != want {
  1669  			t.Errorf("got %v closes, want %v", got, want)
  1670  		}
  1671  	}
  1672  }
  1673  
  1674  // Issue 28012: verify that the Transport closes its TCP connection to http proxies
  1675  // when they're slow to reply to HTTPS CONNECT responses.
  1676  func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
  1677  	cancelc := make(chan struct{})
  1678  	SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
  1679  		ctx, cancel := context.WithCancel(ctx)
  1680  		go func() {
  1681  			select {
  1682  			case <-cancelc:
  1683  			case <-ctx.Done():
  1684  			}
  1685  			cancel()
  1686  		}()
  1687  		return ctx, cancel
  1688  	})
  1689  
  1690  	defer afterTest(t)
  1691  
  1692  	ln := newLocalListener(t)
  1693  	defer ln.Close()
  1694  	listenerDone := make(chan struct{})
  1695  	go func() {
  1696  		defer close(listenerDone)
  1697  		c, err := ln.Accept()
  1698  		if err != nil {
  1699  			t.Errorf("Accept: %v", err)
  1700  			return
  1701  		}
  1702  		defer c.Close()
  1703  		// Read the CONNECT request
  1704  		br := bufio.NewReader(c)
  1705  		cr, err := ReadRequest(br)
  1706  		if err != nil {
  1707  			t.Errorf("proxy server failed to read CONNECT request")
  1708  			return
  1709  		}
  1710  		if cr.Method != "CONNECT" {
  1711  			t.Errorf("unexpected method %q", cr.Method)
  1712  			return
  1713  		}
  1714  
  1715  		// Now hang and never write a response; instead, cancel the request and wait
  1716  		// for the client to close.
  1717  		// (Prior to Issue 28012 being fixed, we never closed.)
  1718  		close(cancelc)
  1719  		var buf [1]byte
  1720  		_, err = br.Read(buf[:])
  1721  		if err != io.EOF {
  1722  			t.Errorf("proxy server Read err = %v; want EOF", err)
  1723  		}
  1724  		return
  1725  	}()
  1726  
  1727  	c := &Client{
  1728  		Transport: &Transport{
  1729  			Proxy: func(*Request) (*url.URL, error) {
  1730  				return url.Parse("http://" + ln.Addr().String())
  1731  			},
  1732  		},
  1733  	}
  1734  	req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
  1735  	if err != nil {
  1736  		t.Fatal(err)
  1737  	}
  1738  	_, err = c.Do(req)
  1739  	if err == nil {
  1740  		t.Errorf("unexpected Get success")
  1741  	}
  1742  
  1743  	// Wait unconditionally for the listener goroutine to exit: this should never
  1744  	// hang, so if it does we want a full goroutine dump — and that's exactly what
  1745  	// the testing package will give us when the test run times out.
  1746  	<-listenerDone
  1747  }
  1748  
  1749  // Issue 16997: test transport dial preserves typed errors
  1750  func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
  1751  	defer afterTest(t)
  1752  
  1753  	var errDial = errors.New("some dial error")
  1754  
  1755  	tr := &Transport{
  1756  		Proxy: func(*Request) (*url.URL, error) {
  1757  			return url.Parse("http://proxy.fake.tld/")
  1758  		},
  1759  		Dial: func(string, string) (net.Conn, error) {
  1760  			return nil, errDial
  1761  		},
  1762  	}
  1763  	defer tr.CloseIdleConnections()
  1764  
  1765  	c := &Client{Transport: tr}
  1766  	req, _ := NewRequest("GET", "http://fake.tld", nil)
  1767  	res, err := c.Do(req)
  1768  	if err == nil {
  1769  		res.Body.Close()
  1770  		t.Fatal("wanted a non-nil error")
  1771  	}
  1772  
  1773  	uerr, ok := err.(*url.Error)
  1774  	if !ok {
  1775  		t.Fatalf("got %T, want *url.Error", err)
  1776  	}
  1777  	oe, ok := uerr.Err.(*net.OpError)
  1778  	if !ok {
  1779  		t.Fatalf("url.Error.Err =  %T; want *net.OpError", uerr.Err)
  1780  	}
  1781  	want := &net.OpError{
  1782  		Op:  "proxyconnect",
  1783  		Net: "tcp",
  1784  		Err: errDial, // original error, unwrapped.
  1785  	}
  1786  	if !reflect.DeepEqual(oe, want) {
  1787  		t.Errorf("Got error %#v; want %#v", oe, want)
  1788  	}
  1789  }
  1790  
  1791  // Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader.
  1792  //
  1793  // (A bug caused dialConn to instead write the per-request Proxy-Authorization
  1794  // header through to the shared Header instance, introducing a data race.)
  1795  func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
  1796  	run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
  1797  }
  1798  func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
  1799  	proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
  1800  	defer proxy.Close()
  1801  	c := proxy.Client()
  1802  
  1803  	tr := c.Transport.(*Transport)
  1804  	tr.Proxy = func(*Request) (*url.URL, error) {
  1805  		u, _ := url.Parse(proxy.URL)
  1806  		u.User = url.UserPassword("aladdin", "opensesame")
  1807  		return u, nil
  1808  	}
  1809  	h := tr.ProxyConnectHeader
  1810  	if h == nil {
  1811  		h = make(Header)
  1812  	}
  1813  	tr.ProxyConnectHeader = h.Clone()
  1814  
  1815  	req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
  1816  	if err != nil {
  1817  		t.Fatal(err)
  1818  	}
  1819  	_, err = c.Do(req)
  1820  	if err == nil {
  1821  		t.Errorf("unexpected Get success")
  1822  	}
  1823  
  1824  	if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
  1825  		t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
  1826  	}
  1827  }
  1828  
  1829  // TestTransportGzipRecursive sends a gzip quine and checks that the
  1830  // client gets the same value back. This is more cute than anything,
  1831  // but checks that we don't recurse forever, and checks that
  1832  // Content-Encoding is removed.
  1833  func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
  1834  func testTransportGzipRecursive(t *testing.T, mode testMode) {
  1835  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1836  		w.Header().Set("Content-Encoding", "gzip")
  1837  		w.Write(rgz)
  1838  	})).ts
  1839  
  1840  	c := ts.Client()
  1841  	res, err := c.Get(ts.URL)
  1842  	if err != nil {
  1843  		t.Fatal(err)
  1844  	}
  1845  	body, err := io.ReadAll(res.Body)
  1846  	if err != nil {
  1847  		t.Fatal(err)
  1848  	}
  1849  	if !bytes.Equal(body, rgz) {
  1850  		t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
  1851  			body, rgz)
  1852  	}
  1853  	if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
  1854  		t.Fatalf("Content-Encoding = %q; want %q", g, e)
  1855  	}
  1856  }
  1857  
  1858  // golang.org/issue/7750: request fails when server replies with
  1859  // a short gzip body
  1860  func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
  1861  func testTransportGzipShort(t *testing.T, mode testMode) {
  1862  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1863  		w.Header().Set("Content-Encoding", "gzip")
  1864  		w.Write([]byte{0x1f, 0x8b})
  1865  	})).ts
  1866  
  1867  	c := ts.Client()
  1868  	res, err := c.Get(ts.URL)
  1869  	if err != nil {
  1870  		t.Fatal(err)
  1871  	}
  1872  	defer res.Body.Close()
  1873  	_, err = io.ReadAll(res.Body)
  1874  	if err == nil {
  1875  		t.Fatal("Expect an error from reading a body.")
  1876  	}
  1877  	if err != io.ErrUnexpectedEOF {
  1878  		t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
  1879  	}
  1880  }
  1881  
  1882  // Wait until number of goroutines is no greater than nmax, or time out.
  1883  func waitNumGoroutine(nmax int) int {
  1884  	nfinal := runtime.NumGoroutine()
  1885  	for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
  1886  		time.Sleep(50 * time.Millisecond)
  1887  		runtime.GC()
  1888  		nfinal = runtime.NumGoroutine()
  1889  	}
  1890  	return nfinal
  1891  }
  1892  
  1893  // tests that persistent goroutine connections shut down when no longer desired.
  1894  func TestTransportPersistConnLeak(t *testing.T) {
  1895  	run(t, testTransportPersistConnLeak, testNotParallel)
  1896  }
  1897  func testTransportPersistConnLeak(t *testing.T, mode testMode) {
  1898  	if mode == http2Mode {
  1899  		t.Skip("flaky in HTTP/2")
  1900  	}
  1901  	// Not parallel: counts goroutines
  1902  
  1903  	const numReq = 25
  1904  	gotReqCh := make(chan bool, numReq)
  1905  	unblockCh := make(chan bool, numReq)
  1906  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1907  		gotReqCh <- true
  1908  		<-unblockCh
  1909  		w.Header().Set("Content-Length", "0")
  1910  		w.WriteHeader(204)
  1911  	})).ts
  1912  	c := ts.Client()
  1913  	tr := c.Transport.(*Transport)
  1914  
  1915  	n0 := runtime.NumGoroutine()
  1916  
  1917  	didReqCh := make(chan bool, numReq)
  1918  	failed := make(chan bool, numReq)
  1919  	for i := 0; i < numReq; i++ {
  1920  		go func() {
  1921  			res, err := c.Get(ts.URL)
  1922  			didReqCh <- true
  1923  			if err != nil {
  1924  				t.Logf("client fetch error: %v", err)
  1925  				failed <- true
  1926  				return
  1927  			}
  1928  			res.Body.Close()
  1929  		}()
  1930  	}
  1931  
  1932  	// Wait for all goroutines to be stuck in the Handler.
  1933  	for i := 0; i < numReq; i++ {
  1934  		select {
  1935  		case <-gotReqCh:
  1936  			// ok
  1937  		case <-failed:
  1938  			// Not great but not what we are testing:
  1939  			// sometimes an overloaded system will fail to make all the connections.
  1940  		}
  1941  	}
  1942  
  1943  	nhigh := runtime.NumGoroutine()
  1944  
  1945  	// Tell all handlers to unblock and reply.
  1946  	close(unblockCh)
  1947  
  1948  	// Wait for all HTTP clients to be done.
  1949  	for i := 0; i < numReq; i++ {
  1950  		<-didReqCh
  1951  	}
  1952  
  1953  	tr.CloseIdleConnections()
  1954  	nfinal := waitNumGoroutine(n0 + 5)
  1955  
  1956  	growth := nfinal - n0
  1957  
  1958  	// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
  1959  	// Previously we were leaking one per numReq.
  1960  	if int(growth) > 5 {
  1961  		t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
  1962  		t.Error("too many new goroutines")
  1963  	}
  1964  }
  1965  
  1966  // golang.org/issue/4531: Transport leaks goroutines when
  1967  // request.ContentLength is explicitly short
  1968  func TestTransportPersistConnLeakShortBody(t *testing.T) {
  1969  	run(t, testTransportPersistConnLeakShortBody, testNotParallel)
  1970  }
  1971  func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
  1972  	if mode == http2Mode {
  1973  		t.Skip("flaky in HTTP/2")
  1974  	}
  1975  
  1976  	// Not parallel: measures goroutines.
  1977  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1978  	})).ts
  1979  	c := ts.Client()
  1980  	tr := c.Transport.(*Transport)
  1981  
  1982  	n0 := runtime.NumGoroutine()
  1983  	body := []byte("Hello")
  1984  	for i := 0; i < 20; i++ {
  1985  		req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
  1986  		if err != nil {
  1987  			t.Fatal(err)
  1988  		}
  1989  		req.ContentLength = int64(len(body) - 2) // explicitly short
  1990  		_, err = c.Do(req)
  1991  		if err == nil {
  1992  			t.Fatal("Expect an error from writing too long of a body.")
  1993  		}
  1994  	}
  1995  	nhigh := runtime.NumGoroutine()
  1996  	tr.CloseIdleConnections()
  1997  	nfinal := waitNumGoroutine(n0 + 5)
  1998  
  1999  	growth := nfinal - n0
  2000  
  2001  	// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
  2002  	// Previously we were leaking one per numReq.
  2003  	t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
  2004  	if int(growth) > 5 {
  2005  		t.Error("too many new goroutines")
  2006  	}
  2007  }
  2008  
  2009  // A countedConn is a net.Conn that decrements an atomic counter when finalized.
  2010  type countedConn struct {
  2011  	net.Conn
  2012  }
  2013  
  2014  // A countingDialer dials connections and counts the number that remain reachable.
  2015  type countingDialer struct {
  2016  	dialer      net.Dialer
  2017  	mu          sync.Mutex
  2018  	total, live int64
  2019  }
  2020  
  2021  func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  2022  	conn, err := d.dialer.DialContext(ctx, network, address)
  2023  	if err != nil {
  2024  		return nil, err
  2025  	}
  2026  
  2027  	counted := new(countedConn)
  2028  	counted.Conn = conn
  2029  
  2030  	d.mu.Lock()
  2031  	defer d.mu.Unlock()
  2032  	d.total++
  2033  	d.live++
  2034  
  2035  	runtime.SetFinalizer(counted, d.decrement)
  2036  	return counted, nil
  2037  }
  2038  
  2039  func (d *countingDialer) decrement(*countedConn) {
  2040  	d.mu.Lock()
  2041  	defer d.mu.Unlock()
  2042  	d.live--
  2043  }
  2044  
  2045  func (d *countingDialer) Read() (total, live int64) {
  2046  	d.mu.Lock()
  2047  	defer d.mu.Unlock()
  2048  	return d.total, d.live
  2049  }
  2050  
  2051  func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
  2052  	run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
  2053  }
  2054  func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
  2055  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2056  		// Close every connection so that it cannot be kept alive.
  2057  		conn, _, err := w.(Hijacker).Hijack()
  2058  		if err != nil {
  2059  			t.Errorf("Hijack failed unexpectedly: %v", err)
  2060  			return
  2061  		}
  2062  		conn.Close()
  2063  	})).ts
  2064  
  2065  	var d countingDialer
  2066  	c := ts.Client()
  2067  	c.Transport.(*Transport).DialContext = d.DialContext
  2068  
  2069  	body := []byte("Hello")
  2070  	for i := 0; ; i++ {
  2071  		total, live := d.Read()
  2072  		if live < total {
  2073  			break
  2074  		}
  2075  		if i >= 1<<12 {
  2076  			t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
  2077  		}
  2078  
  2079  		req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
  2080  		if err != nil {
  2081  			t.Fatal(err)
  2082  		}
  2083  		_, err = c.Do(req)
  2084  		if err == nil {
  2085  			t.Fatal("expected broken connection")
  2086  		}
  2087  
  2088  		runtime.GC()
  2089  	}
  2090  }
  2091  
  2092  type countedContext struct {
  2093  	context.Context
  2094  }
  2095  
  2096  type contextCounter struct {
  2097  	mu   sync.Mutex
  2098  	live int64
  2099  }
  2100  
  2101  func (cc *contextCounter) Track(ctx context.Context) context.Context {
  2102  	counted := new(countedContext)
  2103  	counted.Context = ctx
  2104  	cc.mu.Lock()
  2105  	defer cc.mu.Unlock()
  2106  	cc.live++
  2107  	runtime.SetFinalizer(counted, cc.decrement)
  2108  	return counted
  2109  }
  2110  
  2111  func (cc *contextCounter) decrement(*countedContext) {
  2112  	cc.mu.Lock()
  2113  	defer cc.mu.Unlock()
  2114  	cc.live--
  2115  }
  2116  
  2117  func (cc *contextCounter) Read() (live int64) {
  2118  	cc.mu.Lock()
  2119  	defer cc.mu.Unlock()
  2120  	return cc.live
  2121  }
  2122  
  2123  func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
  2124  	run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
  2125  }
  2126  func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
  2127  	if mode == http2Mode {
  2128  		t.Skip("https://go.dev/issue/56021")
  2129  	}
  2130  
  2131  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2132  		runtime.Gosched()
  2133  		w.WriteHeader(StatusOK)
  2134  	})).ts
  2135  
  2136  	c := ts.Client()
  2137  	c.Transport.(*Transport).MaxConnsPerHost = 1
  2138  
  2139  	ctx := context.Background()
  2140  	body := []byte("Hello")
  2141  	doPosts := func(cc *contextCounter) {
  2142  		var wg sync.WaitGroup
  2143  		for n := 64; n > 0; n-- {
  2144  			wg.Add(1)
  2145  			go func() {
  2146  				defer wg.Done()
  2147  
  2148  				ctx := cc.Track(ctx)
  2149  				req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
  2150  				if err != nil {
  2151  					t.Error(err)
  2152  				}
  2153  
  2154  				_, err = c.Do(req.WithContext(ctx))
  2155  				if err != nil {
  2156  					t.Errorf("Do failed with error: %v", err)
  2157  				}
  2158  			}()
  2159  		}
  2160  		wg.Wait()
  2161  	}
  2162  
  2163  	var initialCC contextCounter
  2164  	doPosts(&initialCC)
  2165  
  2166  	// flushCC exists only to put pressure on the GC to finalize the initialCC
  2167  	// contexts: the flushCC allocations should eventually displace the initialCC
  2168  	// allocations.
  2169  	var flushCC contextCounter
  2170  	for i := 0; ; i++ {
  2171  		live := initialCC.Read()
  2172  		if live == 0 {
  2173  			break
  2174  		}
  2175  		if i >= 100 {
  2176  			t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
  2177  		}
  2178  		doPosts(&flushCC)
  2179  		runtime.GC()
  2180  	}
  2181  }
  2182  
  2183  // This used to crash; https://golang.org/issue/3266
  2184  func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
  2185  func testTransportIdleConnCrash(t *testing.T, mode testMode) {
  2186  	var tr *Transport
  2187  
  2188  	unblockCh := make(chan bool, 1)
  2189  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2190  		<-unblockCh
  2191  		tr.CloseIdleConnections()
  2192  	})).ts
  2193  	c := ts.Client()
  2194  	tr = c.Transport.(*Transport)
  2195  
  2196  	didreq := make(chan bool)
  2197  	go func() {
  2198  		res, err := c.Get(ts.URL)
  2199  		if err != nil {
  2200  			t.Error(err)
  2201  		} else {
  2202  			res.Body.Close() // returns idle conn
  2203  		}
  2204  		didreq <- true
  2205  	}()
  2206  	unblockCh <- true
  2207  	<-didreq
  2208  }
  2209  
  2210  // Test that the transport doesn't close the TCP connection early,
  2211  // before the response body has been read. This was a regression
  2212  // which sadly lacked a triggering test. The large response body made
  2213  // the old race easier to trigger.
  2214  func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
  2215  func testIssue3644(t *testing.T, mode testMode) {
  2216  	const numFoos = 5000
  2217  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2218  		w.Header().Set("Connection", "close")
  2219  		for i := 0; i < numFoos; i++ {
  2220  			w.Write([]byte("foo "))
  2221  		}
  2222  	})).ts
  2223  	c := ts.Client()
  2224  	res, err := c.Get(ts.URL)
  2225  	if err != nil {
  2226  		t.Fatal(err)
  2227  	}
  2228  	defer res.Body.Close()
  2229  	bs, err := io.ReadAll(res.Body)
  2230  	if err != nil {
  2231  		t.Fatal(err)
  2232  	}
  2233  	if len(bs) != numFoos*len("foo ") {
  2234  		t.Errorf("unexpected response length")
  2235  	}
  2236  }
  2237  
  2238  // Test that a client receives a server's reply, even if the server doesn't read
  2239  // the entire request body.
  2240  func TestIssue3595(t *testing.T) {
  2241  	// Not parallel: modifies the global rstAvoidanceDelay.
  2242  	run(t, testIssue3595, testNotParallel)
  2243  }
  2244  func testIssue3595(t *testing.T, mode testMode) {
  2245  	runTimeSensitiveTest(t, []time.Duration{
  2246  		1 * time.Millisecond,
  2247  		5 * time.Millisecond,
  2248  		10 * time.Millisecond,
  2249  		50 * time.Millisecond,
  2250  		100 * time.Millisecond,
  2251  		500 * time.Millisecond,
  2252  		time.Second,
  2253  		5 * time.Second,
  2254  	}, func(t *testing.T, timeout time.Duration) error {
  2255  		SetRSTAvoidanceDelay(t, timeout)
  2256  		t.Logf("set RST avoidance delay to %v", timeout)
  2257  
  2258  		const deniedMsg = "sorry, denied."
  2259  		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2260  			Error(w, deniedMsg, StatusUnauthorized)
  2261  		}))
  2262  		// We need to close cst explicitly here so that in-flight server
  2263  		// requests don't race with the call to SetRSTAvoidanceDelay for a retry.
  2264  		defer cst.close()
  2265  		ts := cst.ts
  2266  		c := ts.Client()
  2267  
  2268  		res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
  2269  		if err != nil {
  2270  			return fmt.Errorf("Post: %v", err)
  2271  		}
  2272  		got, err := io.ReadAll(res.Body)
  2273  		if err != nil {
  2274  			return fmt.Errorf("Body ReadAll: %v", err)
  2275  		}
  2276  		t.Logf("server response:\n%s", got)
  2277  		if !strings.Contains(string(got), deniedMsg) {
  2278  			// If we got an RST packet too early, we should have seen an error
  2279  			// from io.ReadAll, not a silently-truncated body.
  2280  			t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
  2281  		}
  2282  		return nil
  2283  	})
  2284  }
  2285  
  2286  // From https://golang.org/issue/4454 ,
  2287  // "client fails to handle requests with no body and chunked encoding"
  2288  func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
  2289  func testChunkedNoContent(t *testing.T, mode testMode) {
  2290  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2291  		w.WriteHeader(StatusNoContent)
  2292  	})).ts
  2293  
  2294  	c := ts.Client()
  2295  	for _, closeBody := range []bool{true, false} {
  2296  		const n = 4
  2297  		for i := 1; i <= n; i++ {
  2298  			res, err := c.Get(ts.URL)
  2299  			if err != nil {
  2300  				t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
  2301  			} else {
  2302  				if closeBody {
  2303  					res.Body.Close()
  2304  				}
  2305  			}
  2306  		}
  2307  	}
  2308  }
  2309  
  2310  func TestTransportConcurrency(t *testing.T) {
  2311  	run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
  2312  }
  2313  func testTransportConcurrency(t *testing.T, mode testMode) {
  2314  	// Not parallel: uses global test hooks.
  2315  	maxProcs, numReqs := 16, 500
  2316  	if testing.Short() {
  2317  		maxProcs, numReqs = 4, 50
  2318  	}
  2319  	defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
  2320  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2321  		fmt.Fprintf(w, "%v", r.FormValue("echo"))
  2322  	})).ts
  2323  
  2324  	var wg sync.WaitGroup
  2325  	wg.Add(numReqs)
  2326  
  2327  	// Due to the Transport's "socket late binding" (see
  2328  	// idleConnCh in transport.go), the numReqs HTTP requests
  2329  	// below can finish with a dial still outstanding. To keep
  2330  	// the leak checker happy, keep track of pending dials and
  2331  	// wait for them to finish (and be closed or returned to the
  2332  	// idle pool) before we close idle connections.
  2333  	SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
  2334  	defer SetPendingDialHooks(nil, nil)
  2335  
  2336  	c := ts.Client()
  2337  	reqs := make(chan string)
  2338  	defer close(reqs)
  2339  
  2340  	for i := 0; i < maxProcs*2; i++ {
  2341  		go func() {
  2342  			for req := range reqs {
  2343  				res, err := c.Get(ts.URL + "/?echo=" + req)
  2344  				if err != nil {
  2345  					if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
  2346  						// https://go.dev/issue/52168: this test was observed to fail with
  2347  						// ECONNRESET errors in Dial on various netbsd builders.
  2348  						t.Logf("error on req %s: %v", req, err)
  2349  						t.Logf("(see https://go.dev/issue/52168)")
  2350  					} else {
  2351  						t.Errorf("error on req %s: %v", req, err)
  2352  					}
  2353  					wg.Done()
  2354  					continue
  2355  				}
  2356  				all, err := io.ReadAll(res.Body)
  2357  				if err != nil {
  2358  					t.Errorf("read error on req %s: %v", req, err)
  2359  				} else if string(all) != req {
  2360  					t.Errorf("body of req %s = %q; want %q", req, all, req)
  2361  				}
  2362  				res.Body.Close()
  2363  				wg.Done()
  2364  			}
  2365  		}()
  2366  	}
  2367  	for i := 0; i < numReqs; i++ {
  2368  		reqs <- fmt.Sprintf("request-%d", i)
  2369  	}
  2370  	wg.Wait()
  2371  }
  2372  
  2373  func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
  2374  func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
  2375  	mux := NewServeMux()
  2376  	mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
  2377  		io.Copy(w, neverEnding('a'))
  2378  	})
  2379  	ts := newClientServerTest(t, mode, mux).ts
  2380  
  2381  	connc := make(chan net.Conn, 1)
  2382  	c := ts.Client()
  2383  	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
  2384  		conn, err := net.Dial(n, addr)
  2385  		if err != nil {
  2386  			return nil, err
  2387  		}
  2388  		select {
  2389  		case connc <- conn:
  2390  		default:
  2391  		}
  2392  		return conn, nil
  2393  	}
  2394  
  2395  	res, err := c.Get(ts.URL + "/get")
  2396  	if err != nil {
  2397  		t.Fatalf("Error issuing GET: %v", err)
  2398  	}
  2399  	defer res.Body.Close()
  2400  
  2401  	conn := <-connc
  2402  	conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
  2403  	_, err = io.Copy(io.Discard, res.Body)
  2404  	if err == nil {
  2405  		t.Errorf("Unexpected successful copy")
  2406  	}
  2407  }
  2408  
  2409  func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
  2410  	run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
  2411  }
  2412  func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
  2413  	const debug = false
  2414  	mux := NewServeMux()
  2415  	mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
  2416  		io.Copy(w, neverEnding('a'))
  2417  	})
  2418  	mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
  2419  		defer r.Body.Close()
  2420  		io.Copy(io.Discard, r.Body)
  2421  	})
  2422  	ts := newClientServerTest(t, mode, mux).ts
  2423  	timeout := 100 * time.Millisecond
  2424  
  2425  	c := ts.Client()
  2426  	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
  2427  		conn, err := net.Dial(n, addr)
  2428  		if err != nil {
  2429  			return nil, err
  2430  		}
  2431  		conn.SetDeadline(time.Now().Add(timeout))
  2432  		if debug {
  2433  			conn = NewLoggingConn("client", conn)
  2434  		}
  2435  		return conn, nil
  2436  	}
  2437  
  2438  	getFailed := false
  2439  	nRuns := 5
  2440  	if testing.Short() {
  2441  		nRuns = 1
  2442  	}
  2443  	for i := 0; i < nRuns; i++ {
  2444  		if debug {
  2445  			println("run", i+1, "of", nRuns)
  2446  		}
  2447  		sres, err := c.Get(ts.URL + "/get")
  2448  		if err != nil {
  2449  			if !getFailed {
  2450  				// Make the timeout longer, once.
  2451  				getFailed = true
  2452  				t.Logf("increasing timeout")
  2453  				i--
  2454  				timeout *= 10
  2455  				continue
  2456  			}
  2457  			t.Errorf("Error issuing GET: %v", err)
  2458  			break
  2459  		}
  2460  		req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
  2461  		_, err = c.Do(req)
  2462  		if err == nil {
  2463  			sres.Body.Close()
  2464  			t.Errorf("Unexpected successful PUT")
  2465  			break
  2466  		}
  2467  		sres.Body.Close()
  2468  	}
  2469  	if debug {
  2470  		println("tests complete; waiting for handlers to finish")
  2471  	}
  2472  	ts.Close()
  2473  }
  2474  
  2475  func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
  2476  func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
  2477  	if testing.Short() {
  2478  		t.Skip("skipping timeout test in -short mode")
  2479  	}
  2480  
  2481  	timeout := 2 * time.Millisecond
  2482  	retry := true
  2483  	for retry && !t.Failed() {
  2484  		var srvWG sync.WaitGroup
  2485  		inHandler := make(chan bool, 1)
  2486  		mux := NewServeMux()
  2487  		mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
  2488  			inHandler <- true
  2489  			srvWG.Done()
  2490  		})
  2491  		mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
  2492  			inHandler <- true
  2493  			<-r.Context().Done()
  2494  			srvWG.Done()
  2495  		})
  2496  		ts := newClientServerTest(t, mode, mux).ts
  2497  
  2498  		c := ts.Client()
  2499  		c.Transport.(*Transport).ResponseHeaderTimeout = timeout
  2500  
  2501  		retry = false
  2502  		srvWG.Add(3)
  2503  		tests := []struct {
  2504  			path        string
  2505  			wantTimeout bool
  2506  		}{
  2507  			{path: "/fast"},
  2508  			{path: "/slow", wantTimeout: true},
  2509  			{path: "/fast"},
  2510  		}
  2511  		for i, tt := range tests {
  2512  			req, _ := NewRequest("GET", ts.URL+tt.path, nil)
  2513  			req = req.WithT(t)
  2514  			res, err := c.Do(req)
  2515  			<-inHandler
  2516  			if err != nil {
  2517  				uerr, ok := err.(*url.Error)
  2518  				if !ok {
  2519  					t.Errorf("error is not a url.Error; got: %#v", err)
  2520  					continue
  2521  				}
  2522  				nerr, ok := uerr.Err.(net.Error)
  2523  				if !ok {
  2524  					t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
  2525  					continue
  2526  				}
  2527  				if !nerr.Timeout() {
  2528  					t.Errorf("want timeout error; got: %q", nerr)
  2529  					continue
  2530  				}
  2531  				if !tt.wantTimeout {
  2532  					if !retry {
  2533  						// The timeout may be set too short. Retry with a longer one.
  2534  						t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
  2535  						timeout *= 2
  2536  						retry = true
  2537  					}
  2538  				}
  2539  				if !strings.Contains(err.Error(), "timeout awaiting response headers") {
  2540  					t.Errorf("%d. unexpected error: %v", i, err)
  2541  				}
  2542  				continue
  2543  			}
  2544  			if tt.wantTimeout {
  2545  				t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
  2546  				continue
  2547  			}
  2548  			if res.StatusCode != 200 {
  2549  				t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
  2550  			}
  2551  		}
  2552  
  2553  		srvWG.Wait()
  2554  		ts.Close()
  2555  	}
  2556  }
  2557  
  2558  // A cancelTest is a test of request cancellation.
  2559  type cancelTest struct {
  2560  	mode     testMode
  2561  	newReq   func(req *Request) *Request       // prepare the request to cancel
  2562  	cancel   func(tr *Transport, req *Request) // cancel the request
  2563  	checkErr func(when string, err error)      // verify the expected error
  2564  }
  2565  
  2566  // runCancelTestTransport uses Transport.CancelRequest.
  2567  func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
  2568  	t.Run("TransportCancel", func(t *testing.T) {
  2569  		f(t, cancelTest{
  2570  			mode: mode,
  2571  			newReq: func(req *Request) *Request {
  2572  				return req
  2573  			},
  2574  			cancel: func(tr *Transport, req *Request) {
  2575  				tr.CancelRequest(req)
  2576  			},
  2577  			checkErr: func(when string, err error) {
  2578  				if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
  2579  					t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
  2580  				}
  2581  			},
  2582  		})
  2583  	})
  2584  }
  2585  
  2586  // runCancelTestChannel uses Request.Cancel.
  2587  func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
  2588  	var cancelOnce sync.Once
  2589  	cancelc := make(chan struct{})
  2590  	f(t, cancelTest{
  2591  		mode: mode,
  2592  		newReq: func(req *Request) *Request {
  2593  			req.Cancel = cancelc
  2594  			return req
  2595  		},
  2596  		cancel: func(tr *Transport, req *Request) {
  2597  			cancelOnce.Do(func() {
  2598  				close(cancelc)
  2599  			})
  2600  		},
  2601  		checkErr: func(when string, err error) {
  2602  			if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
  2603  				t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
  2604  			}
  2605  		},
  2606  	})
  2607  }
  2608  
  2609  // runCancelTestContext uses a request context.
  2610  func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
  2611  	ctx, cancel := context.WithCancel(context.Background())
  2612  	f(t, cancelTest{
  2613  		mode: mode,
  2614  		newReq: func(req *Request) *Request {
  2615  			return req.WithContext(ctx)
  2616  		},
  2617  		cancel: func(tr *Transport, req *Request) {
  2618  			cancel()
  2619  		},
  2620  		checkErr: func(when string, err error) {
  2621  			if !errors.Is(err, context.Canceled) {
  2622  				t.Errorf("%v error = %v, want context.Canceled", when, err)
  2623  			}
  2624  		},
  2625  	})
  2626  }
  2627  
  2628  func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
  2629  	run(t, func(t *testing.T, mode testMode) {
  2630  		if mode == http1Mode {
  2631  			t.Run("TransportCancel", func(t *testing.T) {
  2632  				runCancelTestTransport(t, mode, f)
  2633  			})
  2634  		}
  2635  		t.Run("RequestCancel", func(t *testing.T) {
  2636  			runCancelTestChannel(t, mode, f)
  2637  		})
  2638  		t.Run("ContextCancel", func(t *testing.T) {
  2639  			runCancelTestContext(t, mode, f)
  2640  		})
  2641  	}, opts...)
  2642  }
  2643  
  2644  func TestTransportCancelRequest(t *testing.T) {
  2645  	runCancelTest(t, testTransportCancelRequest)
  2646  }
  2647  func testTransportCancelRequest(t *testing.T, test cancelTest) {
  2648  	if testing.Short() {
  2649  		t.Skip("skipping test in -short mode")
  2650  	}
  2651  
  2652  	const msg = "Hello"
  2653  	unblockc := make(chan bool)
  2654  	ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2655  		io.WriteString(w, msg)
  2656  		w.(Flusher).Flush() // send headers and some body
  2657  		<-unblockc
  2658  	})).ts
  2659  	defer close(unblockc)
  2660  
  2661  	c := ts.Client()
  2662  	tr := c.Transport.(*Transport)
  2663  
  2664  	req, _ := NewRequest("GET", ts.URL, nil)
  2665  	req = test.newReq(req)
  2666  	res, err := c.Do(req)
  2667  	if err != nil {
  2668  		t.Fatal(err)
  2669  	}
  2670  	body := make([]byte, len(msg))
  2671  	n, _ := io.ReadFull(res.Body, body)
  2672  	if n != len(body) || !bytes.Equal(body, []byte(msg)) {
  2673  		t.Errorf("Body = %q; want %q", body[:n], msg)
  2674  	}
  2675  	test.cancel(tr, req)
  2676  
  2677  	tail, err := io.ReadAll(res.Body)
  2678  	res.Body.Close()
  2679  	test.checkErr("Body.Read", err)
  2680  	if len(tail) > 0 {
  2681  		t.Errorf("Spurious bytes from Body.Read: %q", tail)
  2682  	}
  2683  
  2684  	// Verify no outstanding requests after readLoop/writeLoop
  2685  	// goroutines shut down.
  2686  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
  2687  		n := tr.NumPendingRequestsForTesting()
  2688  		if n > 0 {
  2689  			if d > 0 {
  2690  				t.Logf("pending requests = %d after %v (want 0)", n, d)
  2691  			}
  2692  			return false
  2693  		}
  2694  		return true
  2695  	})
  2696  }
  2697  
  2698  func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
  2699  	if testing.Short() {
  2700  		t.Skip("skipping test in -short mode")
  2701  	}
  2702  	unblockc := make(chan bool)
  2703  	ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2704  		<-unblockc
  2705  	})).ts
  2706  	defer close(unblockc)
  2707  
  2708  	c := ts.Client()
  2709  	tr := c.Transport.(*Transport)
  2710  
  2711  	donec := make(chan bool)
  2712  	req, _ := NewRequest("GET", ts.URL, body)
  2713  	req = test.newReq(req)
  2714  	go func() {
  2715  		defer close(donec)
  2716  		c.Do(req)
  2717  	}()
  2718  
  2719  	unblockc <- true
  2720  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
  2721  		test.cancel(tr, req)
  2722  		select {
  2723  		case <-donec:
  2724  			return true
  2725  		default:
  2726  			if d > 0 {
  2727  				t.Logf("Do of canceled request has not returned after %v", d)
  2728  			}
  2729  			return false
  2730  		}
  2731  	})
  2732  }
  2733  
  2734  func TestTransportCancelRequestInDo(t *testing.T) {
  2735  	runCancelTest(t, func(t *testing.T, test cancelTest) {
  2736  		testTransportCancelRequestInDo(t, test, nil)
  2737  	})
  2738  }
  2739  
  2740  func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
  2741  	runCancelTest(t, func(t *testing.T, test cancelTest) {
  2742  		testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
  2743  	})
  2744  }
  2745  
  2746  func TestTransportCancelRequestInDial(t *testing.T) {
  2747  	runCancelTest(t, testTransportCancelRequestInDial)
  2748  }
  2749  func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
  2750  	defer afterTest(t)
  2751  	if testing.Short() {
  2752  		t.Skip("skipping test in -short mode")
  2753  	}
  2754  	var logbuf strings.Builder
  2755  	eventLog := log.New(&logbuf, "", 0)
  2756  
  2757  	unblockDial := make(chan bool)
  2758  	defer close(unblockDial)
  2759  
  2760  	inDial := make(chan bool)
  2761  	tr := &Transport{
  2762  		Dial: func(network, addr string) (net.Conn, error) {
  2763  			eventLog.Println("dial: blocking")
  2764  			if !<-inDial {
  2765  				return nil, errors.New("main Test goroutine exited")
  2766  			}
  2767  			<-unblockDial
  2768  			return nil, errors.New("nope")
  2769  		},
  2770  	}
  2771  	cl := &Client{Transport: tr}
  2772  	gotres := make(chan bool)
  2773  	req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
  2774  	req = test.newReq(req)
  2775  	go func() {
  2776  		_, err := cl.Do(req)
  2777  		eventLog.Printf("Get error = %v", err != nil)
  2778  		test.checkErr("Get", err)
  2779  		gotres <- true
  2780  	}()
  2781  
  2782  	inDial <- true
  2783  
  2784  	eventLog.Printf("canceling")
  2785  	test.cancel(tr, req)
  2786  	test.cancel(tr, req) // used to panic on second call to Transport.Cancel
  2787  
  2788  	if d, ok := t.Deadline(); ok {
  2789  		// When the test's deadline is about to expire, log the pending events for
  2790  		// better debugging.
  2791  		timeout := time.Until(d) * 19 / 20 // Allow 5% for cleanup.
  2792  		timer := time.AfterFunc(timeout, func() {
  2793  			panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
  2794  		})
  2795  		defer timer.Stop()
  2796  	}
  2797  	<-gotres
  2798  
  2799  	got := logbuf.String()
  2800  	want := `dial: blocking
  2801  canceling
  2802  Get error = true
  2803  `
  2804  	if got != want {
  2805  		t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
  2806  	}
  2807  }
  2808  
  2809  // Issue 51354
  2810  func TestTransportCancelRequestWithBody(t *testing.T) {
  2811  	runCancelTest(t, testTransportCancelRequestWithBody)
  2812  }
  2813  func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
  2814  	if testing.Short() {
  2815  		t.Skip("skipping test in -short mode")
  2816  	}
  2817  
  2818  	const msg = "Hello"
  2819  	unblockc := make(chan struct{})
  2820  	ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2821  		io.WriteString(w, msg)
  2822  		w.(Flusher).Flush() // send headers and some body
  2823  		<-unblockc
  2824  	})).ts
  2825  	defer close(unblockc)
  2826  
  2827  	c := ts.Client()
  2828  	tr := c.Transport.(*Transport)
  2829  
  2830  	req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
  2831  	req = test.newReq(req)
  2832  
  2833  	res, err := c.Do(req)
  2834  	if err != nil {
  2835  		t.Fatal(err)
  2836  	}
  2837  	body := make([]byte, len(msg))
  2838  	n, _ := io.ReadFull(res.Body, body)
  2839  	if n != len(body) || !bytes.Equal(body, []byte(msg)) {
  2840  		t.Errorf("Body = %q; want %q", body[:n], msg)
  2841  	}
  2842  	test.cancel(tr, req)
  2843  
  2844  	tail, err := io.ReadAll(res.Body)
  2845  	res.Body.Close()
  2846  	test.checkErr("Body.Read", err)
  2847  	if len(tail) > 0 {
  2848  		t.Errorf("Spurious bytes from Body.Read: %q", tail)
  2849  	}
  2850  
  2851  	// Verify no outstanding requests after readLoop/writeLoop
  2852  	// goroutines shut down.
  2853  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
  2854  		n := tr.NumPendingRequestsForTesting()
  2855  		if n > 0 {
  2856  			if d > 0 {
  2857  				t.Logf("pending requests = %d after %v (want 0)", n, d)
  2858  			}
  2859  			return false
  2860  		}
  2861  		return true
  2862  	})
  2863  }
  2864  
  2865  func TestTransportCancelRequestBeforeDo(t *testing.T) {
  2866  	// We can't cancel a request that hasn't started using Transport.CancelRequest.
  2867  	run(t, func(t *testing.T, mode testMode) {
  2868  		t.Run("RequestCancel", func(t *testing.T) {
  2869  			runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
  2870  		})
  2871  		t.Run("ContextCancel", func(t *testing.T) {
  2872  			runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
  2873  		})
  2874  	})
  2875  }
  2876  func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
  2877  	unblockc := make(chan bool)
  2878  	cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2879  		<-unblockc
  2880  	}))
  2881  	defer close(unblockc)
  2882  
  2883  	c := cst.ts.Client()
  2884  
  2885  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  2886  	req = test.newReq(req)
  2887  	test.cancel(cst.tr, req)
  2888  
  2889  	_, err := c.Do(req)
  2890  	test.checkErr("Do", err)
  2891  }
  2892  
  2893  // Issue 11020. The returned error message should be errRequestCanceled
  2894  func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
  2895  	runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
  2896  }
  2897  func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
  2898  	defer afterTest(t)
  2899  
  2900  	serverConnCh := make(chan net.Conn, 1)
  2901  	tr := &Transport{
  2902  		Dial: func(network, addr string) (net.Conn, error) {
  2903  			cc, sc := net.Pipe()
  2904  			serverConnCh <- sc
  2905  			return cc, nil
  2906  		},
  2907  	}
  2908  	defer tr.CloseIdleConnections()
  2909  	errc := make(chan error, 1)
  2910  	req, _ := NewRequest("GET", "http://example.com/", nil)
  2911  	req = test.newReq(req)
  2912  	go func() {
  2913  		_, err := tr.RoundTrip(req)
  2914  		errc <- err
  2915  	}()
  2916  
  2917  	sc := <-serverConnCh
  2918  	verb := make([]byte, 3)
  2919  	if _, err := io.ReadFull(sc, verb); err != nil {
  2920  		t.Errorf("Error reading HTTP verb from server: %v", err)
  2921  	}
  2922  	if string(verb) != "GET" {
  2923  		t.Errorf("server received %q; want GET", verb)
  2924  	}
  2925  	defer sc.Close()
  2926  
  2927  	test.cancel(tr, req)
  2928  
  2929  	err := <-errc
  2930  	if err == nil {
  2931  		t.Fatalf("unexpected success from RoundTrip")
  2932  	}
  2933  	test.checkErr("RoundTrip", err)
  2934  }
  2935  
  2936  // golang.org/issue/3672 -- Client can't close HTTP stream
  2937  // Calling Close on a Response.Body used to just read until EOF.
  2938  // Now it actually closes the TCP connection.
  2939  func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
  2940  func testTransportCloseResponseBody(t *testing.T, mode testMode) {
  2941  	writeErr := make(chan error, 1)
  2942  	msg := []byte("young\n")
  2943  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2944  		for {
  2945  			_, err := w.Write(msg)
  2946  			if err != nil {
  2947  				writeErr <- err
  2948  				return
  2949  			}
  2950  			w.(Flusher).Flush()
  2951  		}
  2952  	})).ts
  2953  
  2954  	c := ts.Client()
  2955  	tr := c.Transport.(*Transport)
  2956  
  2957  	req, _ := NewRequest("GET", ts.URL, nil)
  2958  	defer tr.CancelRequest(req)
  2959  
  2960  	res, err := c.Do(req)
  2961  	if err != nil {
  2962  		t.Fatal(err)
  2963  	}
  2964  
  2965  	const repeats = 3
  2966  	buf := make([]byte, len(msg)*repeats)
  2967  	want := bytes.Repeat(msg, repeats)
  2968  
  2969  	_, err = io.ReadFull(res.Body, buf)
  2970  	if err != nil {
  2971  		t.Fatal(err)
  2972  	}
  2973  	if !bytes.Equal(buf, want) {
  2974  		t.Fatalf("read %q; want %q", buf, want)
  2975  	}
  2976  
  2977  	if err := res.Body.Close(); err != nil {
  2978  		t.Errorf("Close = %v", err)
  2979  	}
  2980  
  2981  	if err := <-writeErr; err == nil {
  2982  		t.Errorf("expected non-nil write error")
  2983  	}
  2984  }
  2985  
  2986  type fooProto struct{}
  2987  
  2988  func (fooProto) RoundTrip(req *Request) (*Response, error) {
  2989  	res := &Response{
  2990  		Status:     "200 OK",
  2991  		StatusCode: 200,
  2992  		Header:     make(Header),
  2993  		Body:       io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
  2994  	}
  2995  	return res, nil
  2996  }
  2997  
  2998  func TestTransportAltProto(t *testing.T) {
  2999  	defer afterTest(t)
  3000  	tr := &Transport{}
  3001  	c := &Client{Transport: tr}
  3002  	tr.RegisterProtocol("foo", fooProto{})
  3003  	res, err := c.Get("foo://bar.com/path")
  3004  	if err != nil {
  3005  		t.Fatal(err)
  3006  	}
  3007  	bodyb, err := io.ReadAll(res.Body)
  3008  	if err != nil {
  3009  		t.Fatal(err)
  3010  	}
  3011  	body := string(bodyb)
  3012  	if e := "You wanted foo://bar.com/path"; body != e {
  3013  		t.Errorf("got response %q, want %q", body, e)
  3014  	}
  3015  }
  3016  
  3017  func TestTransportNoHost(t *testing.T) {
  3018  	defer afterTest(t)
  3019  	tr := &Transport{}
  3020  	_, err := tr.RoundTrip(&Request{
  3021  		Header: make(Header),
  3022  		URL: &url.URL{
  3023  			Scheme: "http",
  3024  		},
  3025  	})
  3026  	want := "http: no Host in request URL"
  3027  	if got := fmt.Sprint(err); got != want {
  3028  		t.Errorf("error = %v; want %q", err, want)
  3029  	}
  3030  }
  3031  
  3032  // Issue 13311
  3033  func TestTransportEmptyMethod(t *testing.T) {
  3034  	req, _ := NewRequest("GET", "http://foo.com/", nil)
  3035  	req.Method = ""                                 // docs say "For client requests an empty string means GET"
  3036  	got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport
  3037  	if err != nil {
  3038  		t.Fatal(err)
  3039  	}
  3040  	if !strings.Contains(string(got), "GET ") {
  3041  		t.Fatalf("expected substring 'GET '; got: %s", got)
  3042  	}
  3043  }
  3044  
  3045  func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
  3046  func testTransportSocketLateBinding(t *testing.T, mode testMode) {
  3047  	mux := NewServeMux()
  3048  	fooGate := make(chan bool, 1)
  3049  	mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
  3050  		w.Header().Set("foo-ipport", r.RemoteAddr)
  3051  		w.(Flusher).Flush()
  3052  		<-fooGate
  3053  	})
  3054  	mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
  3055  		w.Header().Set("bar-ipport", r.RemoteAddr)
  3056  	})
  3057  	ts := newClientServerTest(t, mode, mux).ts
  3058  
  3059  	dialGate := make(chan bool, 1)
  3060  	dialing := make(chan bool)
  3061  	c := ts.Client()
  3062  	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
  3063  		for {
  3064  			select {
  3065  			case ok := <-dialGate:
  3066  				if !ok {
  3067  					return nil, errors.New("manually closed")
  3068  				}
  3069  				return net.Dial(n, addr)
  3070  			case dialing <- true:
  3071  			}
  3072  		}
  3073  	}
  3074  	defer close(dialGate)
  3075  
  3076  	dialGate <- true // only allow one dial
  3077  	fooRes, err := c.Get(ts.URL + "/foo")
  3078  	if err != nil {
  3079  		t.Fatal(err)
  3080  	}
  3081  	fooAddr := fooRes.Header.Get("foo-ipport")
  3082  	if fooAddr == "" {
  3083  		t.Fatal("No addr on /foo request")
  3084  	}
  3085  
  3086  	fooDone := make(chan struct{})
  3087  	go func() {
  3088  		// We know that the foo Dial completed and reached the handler because we
  3089  		// read its header. Wait for the bar request to block in Dial, then
  3090  		// let the foo response finish so we can use its connection for /bar.
  3091  
  3092  		if mode == http2Mode {
  3093  			// In HTTP/2 mode, the second Dial won't happen because the protocol
  3094  			// multiplexes the streams by default. Just sleep for an arbitrary time;
  3095  			// the test should pass regardless of how far the bar request gets by this
  3096  			// point.
  3097  			select {
  3098  			case <-dialing:
  3099  				t.Errorf("unexpected second Dial in HTTP/2 mode")
  3100  			case <-time.After(10 * time.Millisecond):
  3101  			}
  3102  		} else {
  3103  			<-dialing
  3104  		}
  3105  		fooGate <- true
  3106  		io.Copy(io.Discard, fooRes.Body)
  3107  		fooRes.Body.Close()
  3108  		close(fooDone)
  3109  	}()
  3110  	defer func() {
  3111  		<-fooDone
  3112  	}()
  3113  
  3114  	barRes, err := c.Get(ts.URL + "/bar")
  3115  	if err != nil {
  3116  		t.Fatal(err)
  3117  	}
  3118  	barAddr := barRes.Header.Get("bar-ipport")
  3119  	if barAddr != fooAddr {
  3120  		t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
  3121  	}
  3122  	barRes.Body.Close()
  3123  }
  3124  
  3125  // Issue 2184
  3126  func TestTransportReading100Continue(t *testing.T) {
  3127  	defer afterTest(t)
  3128  
  3129  	const numReqs = 5
  3130  	reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
  3131  	reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
  3132  
  3133  	send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
  3134  		defer w.Close()
  3135  		defer r.Close()
  3136  		br := bufio.NewReader(r)
  3137  		n := 0
  3138  		for {
  3139  			n++
  3140  			req, err := ReadRequest(br)
  3141  			if err == io.EOF {
  3142  				return
  3143  			}
  3144  			if err != nil {
  3145  				t.Error(err)
  3146  				return
  3147  			}
  3148  			slurp, err := io.ReadAll(req.Body)
  3149  			if err != nil {
  3150  				t.Errorf("Server request body slurp: %v", err)
  3151  				return
  3152  			}
  3153  			id := req.Header.Get("Request-Id")
  3154  			resCode := req.Header.Get("X-Want-Response-Code")
  3155  			if resCode == "" {
  3156  				resCode = "100 Continue"
  3157  				if string(slurp) != reqBody(n) {
  3158  					t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
  3159  				}
  3160  			}
  3161  			body := fmt.Sprintf("Response number %d", n)
  3162  			v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
  3163  Date: Thu, 28 Feb 2013 17:55:41 GMT
  3164  
  3165  HTTP/1.1 200 OK
  3166  Content-Type: text/html
  3167  Echo-Request-Id: %s
  3168  Content-Length: %d
  3169  
  3170  %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
  3171  			w.Write(v)
  3172  			if id == reqID(numReqs) {
  3173  				return
  3174  			}
  3175  		}
  3176  
  3177  	}
  3178  
  3179  	tr := &Transport{
  3180  		Dial: func(n, addr string) (net.Conn, error) {
  3181  			sr, sw := io.Pipe() // server read/write
  3182  			cr, cw := io.Pipe() // client read/write
  3183  			conn := &rwTestConn{
  3184  				Reader: cr,
  3185  				Writer: sw,
  3186  				closeFunc: func() error {
  3187  					sw.Close()
  3188  					cw.Close()
  3189  					return nil
  3190  				},
  3191  			}
  3192  			go send100Response(cw, sr)
  3193  			return conn, nil
  3194  		},
  3195  		DisableKeepAlives: false,
  3196  	}
  3197  	defer tr.CloseIdleConnections()
  3198  	c := &Client{Transport: tr}
  3199  
  3200  	testResponse := func(req *Request, name string, wantCode int) {
  3201  		t.Helper()
  3202  		res, err := c.Do(req)
  3203  		if err != nil {
  3204  			t.Fatalf("%s: Do: %v", name, err)
  3205  		}
  3206  		if res.StatusCode != wantCode {
  3207  			t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
  3208  		}
  3209  		if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
  3210  			t.Errorf("%s: response id %q != request id %q", name, idBack, id)
  3211  		}
  3212  		_, err = io.ReadAll(res.Body)
  3213  		if err != nil {
  3214  			t.Fatalf("%s: Slurp error: %v", name, err)
  3215  		}
  3216  	}
  3217  
  3218  	// Few 100 responses, making sure we're not off-by-one.
  3219  	for i := 1; i <= numReqs; i++ {
  3220  		req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
  3221  		req.Header.Set("Request-Id", reqID(i))
  3222  		testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
  3223  	}
  3224  }
  3225  
  3226  // Issue 17739: the HTTP client must ignore any unknown 1xx
  3227  // informational responses before the actual response.
  3228  func TestTransportIgnore1xxResponses(t *testing.T) {
  3229  	run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
  3230  }
  3231  func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
  3232  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3233  		conn, buf, _ := w.(Hijacker).Hijack()
  3234  		buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
  3235  		buf.Flush()
  3236  		conn.Close()
  3237  	}))
  3238  	cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
  3239  
  3240  	var got strings.Builder
  3241  
  3242  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  3243  	req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  3244  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  3245  			fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
  3246  			return nil
  3247  		},
  3248  	}))
  3249  	res, err := cst.c.Do(req)
  3250  	if err != nil {
  3251  		t.Fatal(err)
  3252  	}
  3253  	defer res.Body.Close()
  3254  
  3255  	res.Write(&got)
  3256  	want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
  3257  	if got.String() != want {
  3258  		t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
  3259  	}
  3260  }
  3261  
  3262  func TestTransportLimits1xxResponses(t *testing.T) {
  3263  	run(t, testTransportLimits1xxResponses, []testMode{http1Mode})
  3264  }
  3265  func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
  3266  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3267  		conn, buf, _ := w.(Hijacker).Hijack()
  3268  		for i := 0; i < 10; i++ {
  3269  			buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
  3270  		}
  3271  		buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
  3272  		buf.Flush()
  3273  		conn.Close()
  3274  	}))
  3275  	cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
  3276  
  3277  	res, err := cst.c.Get(cst.ts.URL)
  3278  	if res != nil {
  3279  		defer res.Body.Close()
  3280  	}
  3281  	got := fmt.Sprint(err)
  3282  	wantSub := "too many 1xx informational responses"
  3283  	if !strings.Contains(got, wantSub) {
  3284  		t.Errorf("Get error = %v; want substring %q", err, wantSub)
  3285  	}
  3286  }
  3287  
  3288  // Issue 26161: the HTTP client must treat 101 responses
  3289  // as the final response.
  3290  func TestTransportTreat101Terminal(t *testing.T) {
  3291  	run(t, testTransportTreat101Terminal, []testMode{http1Mode})
  3292  }
  3293  func testTransportTreat101Terminal(t *testing.T, mode testMode) {
  3294  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3295  		conn, buf, _ := w.(Hijacker).Hijack()
  3296  		buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
  3297  		buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
  3298  		buf.Flush()
  3299  		conn.Close()
  3300  	}))
  3301  	res, err := cst.c.Get(cst.ts.URL)
  3302  	if err != nil {
  3303  		t.Fatal(err)
  3304  	}
  3305  	defer res.Body.Close()
  3306  	if res.StatusCode != StatusSwitchingProtocols {
  3307  		t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
  3308  	}
  3309  }
  3310  
  3311  type proxyFromEnvTest struct {
  3312  	req string // URL to fetch; blank means "http://example.com"
  3313  
  3314  	env      string // HTTP_PROXY
  3315  	httpsenv string // HTTPS_PROXY
  3316  	noenv    string // NO_PROXY
  3317  	reqmeth  string // REQUEST_METHOD
  3318  
  3319  	want    string
  3320  	wanterr error
  3321  }
  3322  
  3323  func (t proxyFromEnvTest) String() string {
  3324  	var buf strings.Builder
  3325  	space := func() {
  3326  		if buf.Len() > 0 {
  3327  			buf.WriteByte(' ')
  3328  		}
  3329  	}
  3330  	if t.env != "" {
  3331  		fmt.Fprintf(&buf, "http_proxy=%q", t.env)
  3332  	}
  3333  	if t.httpsenv != "" {
  3334  		space()
  3335  		fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
  3336  	}
  3337  	if t.noenv != "" {
  3338  		space()
  3339  		fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
  3340  	}
  3341  	if t.reqmeth != "" {
  3342  		space()
  3343  		fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
  3344  	}
  3345  	req := "http://example.com"
  3346  	if t.req != "" {
  3347  		req = t.req
  3348  	}
  3349  	space()
  3350  	fmt.Fprintf(&buf, "req=%q", req)
  3351  	return strings.TrimSpace(buf.String())
  3352  }
  3353  
  3354  var proxyFromEnvTests = []proxyFromEnvTest{
  3355  	{env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
  3356  	{env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
  3357  	{env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
  3358  	{env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
  3359  	{env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
  3360  	{env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
  3361  	{env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
  3362  	{env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
  3363  
  3364  	// Don't use secure for http
  3365  	{req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
  3366  	// Use secure for https.
  3367  	{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
  3368  	{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
  3369  
  3370  	// Issue 16405: don't use HTTP_PROXY in a CGI environment,
  3371  	// where HTTP_PROXY can be attacker-controlled.
  3372  	{env: "http://10.1.2.3:8080", reqmeth: "POST",
  3373  		want:    "<nil>",
  3374  		wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
  3375  
  3376  	{want: "<nil>"},
  3377  
  3378  	{noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
  3379  	{noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
  3380  	{noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
  3381  	{noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
  3382  	{noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
  3383  }
  3384  
  3385  func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
  3386  	t.Helper()
  3387  	reqURL := tt.req
  3388  	if reqURL == "" {
  3389  		reqURL = "http://example.com"
  3390  	}
  3391  	req, _ := NewRequest("GET", reqURL, nil)
  3392  	url, err := proxyForRequest(req)
  3393  	if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
  3394  		t.Errorf("%v: got error = %q, want %q", tt, g, e)
  3395  		return
  3396  	}
  3397  	if got := fmt.Sprintf("%s", url); got != tt.want {
  3398  		t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
  3399  	}
  3400  }
  3401  
  3402  func TestProxyFromEnvironment(t *testing.T) {
  3403  	ResetProxyEnv()
  3404  	defer ResetProxyEnv()
  3405  	for _, tt := range proxyFromEnvTests {
  3406  		testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
  3407  			os.Setenv("HTTP_PROXY", tt.env)
  3408  			os.Setenv("HTTPS_PROXY", tt.httpsenv)
  3409  			os.Setenv("NO_PROXY", tt.noenv)
  3410  			os.Setenv("REQUEST_METHOD", tt.reqmeth)
  3411  			ResetCachedEnvironment()
  3412  			return ProxyFromEnvironment(req)
  3413  		})
  3414  	}
  3415  }
  3416  
  3417  func TestProxyFromEnvironmentLowerCase(t *testing.T) {
  3418  	ResetProxyEnv()
  3419  	defer ResetProxyEnv()
  3420  	for _, tt := range proxyFromEnvTests {
  3421  		testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
  3422  			os.Setenv("http_proxy", tt.env)
  3423  			os.Setenv("https_proxy", tt.httpsenv)
  3424  			os.Setenv("no_proxy", tt.noenv)
  3425  			os.Setenv("REQUEST_METHOD", tt.reqmeth)
  3426  			ResetCachedEnvironment()
  3427  			return ProxyFromEnvironment(req)
  3428  		})
  3429  	}
  3430  }
  3431  
  3432  func TestIdleConnChannelLeak(t *testing.T) {
  3433  	run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
  3434  }
  3435  func testIdleConnChannelLeak(t *testing.T, mode testMode) {
  3436  	// Not parallel: uses global test hooks.
  3437  	var mu sync.Mutex
  3438  	var n int
  3439  
  3440  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3441  		mu.Lock()
  3442  		n++
  3443  		mu.Unlock()
  3444  	})).ts
  3445  
  3446  	const nReqs = 5
  3447  	didRead := make(chan bool, nReqs)
  3448  	SetReadLoopBeforeNextReadHook(func() { didRead <- true })
  3449  	defer SetReadLoopBeforeNextReadHook(nil)
  3450  
  3451  	c := ts.Client()
  3452  	tr := c.Transport.(*Transport)
  3453  	tr.Dial = func(netw, addr string) (net.Conn, error) {
  3454  		return net.Dial(netw, ts.Listener.Addr().String())
  3455  	}
  3456  
  3457  	// First, without keep-alives.
  3458  	for _, disableKeep := range []bool{true, false} {
  3459  		tr.DisableKeepAlives = disableKeep
  3460  		for i := 0; i < nReqs; i++ {
  3461  			_, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
  3462  			if err != nil {
  3463  				t.Fatal(err)
  3464  			}
  3465  			// Note: no res.Body.Close is needed here, since the
  3466  			// response Content-Length is zero. Perhaps the test
  3467  			// should be more explicit and use a HEAD, but tests
  3468  			// elsewhere guarantee that zero byte responses generate
  3469  			// a "Content-Length: 0" instead of chunking.
  3470  		}
  3471  
  3472  		// At this point, each of the 5 Transport.readLoop goroutines
  3473  		// are scheduling noting that there are no response bodies (see
  3474  		// earlier comment), and are then calling putIdleConn, which
  3475  		// decrements this count. Usually that happens quickly, which is
  3476  		// why this test has seemed to work for ages. But it's still
  3477  		// racey: we have wait for them to finish first. See Issue 10427
  3478  		for i := 0; i < nReqs; i++ {
  3479  			<-didRead
  3480  		}
  3481  
  3482  		if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
  3483  			t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
  3484  		}
  3485  	}
  3486  }
  3487  
  3488  // Verify the status quo: that the Client.Post function coerces its
  3489  // body into a ReadCloser if it's a Closer, and that the Transport
  3490  // then closes it.
  3491  func TestTransportClosesRequestBody(t *testing.T) {
  3492  	run(t, testTransportClosesRequestBody, []testMode{http1Mode})
  3493  }
  3494  func testTransportClosesRequestBody(t *testing.T, mode testMode) {
  3495  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3496  		io.Copy(io.Discard, r.Body)
  3497  	})).ts
  3498  
  3499  	c := ts.Client()
  3500  
  3501  	closes := 0
  3502  
  3503  	res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
  3504  	if err != nil {
  3505  		t.Fatal(err)
  3506  	}
  3507  	res.Body.Close()
  3508  	if closes != 1 {
  3509  		t.Errorf("closes = %d; want 1", closes)
  3510  	}
  3511  }
  3512  
  3513  func TestTransportTLSHandshakeTimeout(t *testing.T) {
  3514  	defer afterTest(t)
  3515  	if testing.Short() {
  3516  		t.Skip("skipping in short mode")
  3517  	}
  3518  	ln := newLocalListener(t)
  3519  	defer ln.Close()
  3520  	testdonec := make(chan struct{})
  3521  	defer close(testdonec)
  3522  
  3523  	go func() {
  3524  		c, err := ln.Accept()
  3525  		if err != nil {
  3526  			t.Error(err)
  3527  			return
  3528  		}
  3529  		<-testdonec
  3530  		c.Close()
  3531  	}()
  3532  
  3533  	tr := &Transport{
  3534  		Dial: func(_, _ string) (net.Conn, error) {
  3535  			return net.Dial("tcp", ln.Addr().String())
  3536  		},
  3537  		TLSHandshakeTimeout: 250 * time.Millisecond,
  3538  	}
  3539  	cl := &Client{Transport: tr}
  3540  	_, err := cl.Get("https://dummy.tld/")
  3541  	if err == nil {
  3542  		t.Error("expected error")
  3543  		return
  3544  	}
  3545  	ue, ok := err.(*url.Error)
  3546  	if !ok {
  3547  		t.Errorf("expected url.Error; got %#v", err)
  3548  		return
  3549  	}
  3550  	ne, ok := ue.Err.(net.Error)
  3551  	if !ok {
  3552  		t.Errorf("expected net.Error; got %#v", err)
  3553  		return
  3554  	}
  3555  	if !ne.Timeout() {
  3556  		t.Errorf("expected timeout error; got %v", err)
  3557  	}
  3558  	if !strings.Contains(err.Error(), "handshake timeout") {
  3559  		t.Errorf("expected 'handshake timeout' in error; got %v", err)
  3560  	}
  3561  }
  3562  
  3563  // Trying to repro golang.org/issue/3514
  3564  func TestTLSServerClosesConnection(t *testing.T) {
  3565  	run(t, testTLSServerClosesConnection, []testMode{https1Mode})
  3566  }
  3567  func testTLSServerClosesConnection(t *testing.T, mode testMode) {
  3568  	closedc := make(chan bool, 1)
  3569  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3570  		if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
  3571  			conn, _, _ := w.(Hijacker).Hijack()
  3572  			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
  3573  			conn.Close()
  3574  			closedc <- true
  3575  			return
  3576  		}
  3577  		fmt.Fprintf(w, "hello")
  3578  	})).ts
  3579  
  3580  	c := ts.Client()
  3581  	tr := c.Transport.(*Transport)
  3582  
  3583  	var nSuccess = 0
  3584  	var errs []error
  3585  	const trials = 20
  3586  	for i := 0; i < trials; i++ {
  3587  		tr.CloseIdleConnections()
  3588  		res, err := c.Get(ts.URL + "/keep-alive-then-die")
  3589  		if err != nil {
  3590  			t.Fatal(err)
  3591  		}
  3592  		<-closedc
  3593  		slurp, err := io.ReadAll(res.Body)
  3594  		if err != nil {
  3595  			t.Fatal(err)
  3596  		}
  3597  		if string(slurp) != "foo" {
  3598  			t.Errorf("Got %q, want foo", slurp)
  3599  		}
  3600  
  3601  		// Now try again and see if we successfully
  3602  		// pick a new connection.
  3603  		res, err = c.Get(ts.URL + "/")
  3604  		if err != nil {
  3605  			errs = append(errs, err)
  3606  			continue
  3607  		}
  3608  		slurp, err = io.ReadAll(res.Body)
  3609  		if err != nil {
  3610  			errs = append(errs, err)
  3611  			continue
  3612  		}
  3613  		nSuccess++
  3614  	}
  3615  	if nSuccess > 0 {
  3616  		t.Logf("successes = %d of %d", nSuccess, trials)
  3617  	} else {
  3618  		t.Errorf("All runs failed:")
  3619  	}
  3620  	for _, err := range errs {
  3621  		t.Logf("  err: %v", err)
  3622  	}
  3623  }
  3624  
  3625  // byteFromChanReader is an io.Reader that reads a single byte at a
  3626  // time from the channel. When the channel is closed, the reader
  3627  // returns io.EOF.
  3628  type byteFromChanReader chan byte
  3629  
  3630  func (c byteFromChanReader) Read(p []byte) (n int, err error) {
  3631  	if len(p) == 0 {
  3632  		return
  3633  	}
  3634  	b, ok := <-c
  3635  	if !ok {
  3636  		return 0, io.EOF
  3637  	}
  3638  	p[0] = b
  3639  	return 1, nil
  3640  }
  3641  
  3642  // Verifies that the Transport doesn't reuse a connection in the case
  3643  // where the server replies before the request has been fully
  3644  // written. We still honor that reply (see TestIssue3595), but don't
  3645  // send future requests on the connection because it's then in a
  3646  // questionable state.
  3647  // golang.org/issue/7569
  3648  func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
  3649  	run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
  3650  }
  3651  func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
  3652  	defer func(d time.Duration) {
  3653  		*MaxWriteWaitBeforeConnReuse = d
  3654  	}(*MaxWriteWaitBeforeConnReuse)
  3655  	*MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
  3656  	var sconn struct {
  3657  		sync.Mutex
  3658  		c net.Conn
  3659  	}
  3660  	var getOkay bool
  3661  	var copying sync.WaitGroup
  3662  	closeConn := func() {
  3663  		sconn.Lock()
  3664  		defer sconn.Unlock()
  3665  		if sconn.c != nil {
  3666  			sconn.c.Close()
  3667  			sconn.c = nil
  3668  			if !getOkay {
  3669  				t.Logf("Closed server connection")
  3670  			}
  3671  		}
  3672  	}
  3673  	defer func() {
  3674  		closeConn()
  3675  		copying.Wait()
  3676  	}()
  3677  
  3678  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3679  		if r.Method == "GET" {
  3680  			io.WriteString(w, "bar")
  3681  			return
  3682  		}
  3683  		conn, _, _ := w.(Hijacker).Hijack()
  3684  		sconn.Lock()
  3685  		sconn.c = conn
  3686  		sconn.Unlock()
  3687  		conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
  3688  
  3689  		copying.Add(1)
  3690  		go func() {
  3691  			io.Copy(io.Discard, conn)
  3692  			copying.Done()
  3693  		}()
  3694  	})).ts
  3695  	c := ts.Client()
  3696  
  3697  	const bodySize = 256 << 10
  3698  	finalBit := make(byteFromChanReader, 1)
  3699  	req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
  3700  	req.ContentLength = bodySize
  3701  	res, err := c.Do(req)
  3702  	if err := wantBody(res, err, "foo"); err != nil {
  3703  		t.Errorf("POST response: %v", err)
  3704  	}
  3705  
  3706  	res, err = c.Get(ts.URL)
  3707  	if err := wantBody(res, err, "bar"); err != nil {
  3708  		t.Errorf("GET response: %v", err)
  3709  		return
  3710  	}
  3711  	getOkay = true  // suppress test noise
  3712  	finalBit <- 'x' // unblock the writeloop of the first Post
  3713  	close(finalBit)
  3714  }
  3715  
  3716  // Tests that we don't leak Transport persistConn.readLoop goroutines
  3717  // when a server hangs up immediately after saying it would keep-alive.
  3718  func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
  3719  func testTransportIssue10457(t *testing.T, mode testMode) {
  3720  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3721  		// Send a response with no body, keep-alive
  3722  		// (implicit), and then lie and immediately close the
  3723  		// connection. This forces the Transport's readLoop to
  3724  		// immediately Peek an io.EOF and get to the point
  3725  		// that used to hang.
  3726  		conn, _, _ := w.(Hijacker).Hijack()
  3727  		conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive
  3728  		conn.Close()
  3729  	})).ts
  3730  	c := ts.Client()
  3731  
  3732  	res, err := c.Get(ts.URL)
  3733  	if err != nil {
  3734  		t.Fatalf("Get: %v", err)
  3735  	}
  3736  	defer res.Body.Close()
  3737  
  3738  	// Just a sanity check that we at least get the response. The real
  3739  	// test here is that the "defer afterTest" above doesn't find any
  3740  	// leaked goroutines.
  3741  	if got, want := res.Header.Get("Foo"), "Bar"; got != want {
  3742  		t.Errorf("Foo header = %q; want %q", got, want)
  3743  	}
  3744  }
  3745  
  3746  type closerFunc func() error
  3747  
  3748  func (f closerFunc) Close() error { return f() }
  3749  
  3750  type writerFuncConn struct {
  3751  	net.Conn
  3752  	write func(p []byte) (n int, err error)
  3753  }
  3754  
  3755  func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
  3756  
  3757  // Issues 4677, 18241, and 17844. If we try to reuse a connection that the
  3758  // server is in the process of closing, we may end up successfully writing out
  3759  // our request (or a portion of our request) only to find a connection error
  3760  // when we try to read from (or finish writing to) the socket.
  3761  //
  3762  // NOTE: we resend a request only if:
  3763  //   - we reused a keep-alive connection
  3764  //   - we haven't yet received any header data
  3765  //   - either we wrote no bytes to the server, or the request is idempotent
  3766  //
  3767  // This automatically prevents an infinite resend loop because we'll run out of
  3768  // the cached keep-alive connections eventually.
  3769  func TestRetryRequestsOnError(t *testing.T) {
  3770  	run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
  3771  }
  3772  func testRetryRequestsOnError(t *testing.T, mode testMode) {
  3773  	newRequest := func(method, urlStr string, body io.Reader) *Request {
  3774  		req, err := NewRequest(method, urlStr, body)
  3775  		if err != nil {
  3776  			t.Fatal(err)
  3777  		}
  3778  		return req
  3779  	}
  3780  
  3781  	testCases := []struct {
  3782  		name       string
  3783  		failureN   int
  3784  		failureErr error
  3785  		// Note that we can't just re-use the Request object across calls to c.Do
  3786  		// because we need to rewind Body between calls.  (GetBody is only used to
  3787  		// rewind Body on failure and redirects, not just because it's done.)
  3788  		req       func() *Request
  3789  		reqString string
  3790  	}{
  3791  		{
  3792  			name: "IdempotentNoBodySomeWritten",
  3793  			// Believe that we've written some bytes to the server, so we know we're
  3794  			// not just in the "retry when no bytes sent" case".
  3795  			failureN: 1,
  3796  			// Use the specific error that shouldRetryRequest looks for with idempotent requests.
  3797  			failureErr: ExportErrServerClosedIdle,
  3798  			req: func() *Request {
  3799  				return newRequest("GET", "http://fake.golang", nil)
  3800  			},
  3801  			reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
  3802  		},
  3803  		{
  3804  			name: "IdempotentGetBodySomeWritten",
  3805  			// Believe that we've written some bytes to the server, so we know we're
  3806  			// not just in the "retry when no bytes sent" case".
  3807  			failureN: 1,
  3808  			// Use the specific error that shouldRetryRequest looks for with idempotent requests.
  3809  			failureErr: ExportErrServerClosedIdle,
  3810  			req: func() *Request {
  3811  				return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
  3812  			},
  3813  			reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
  3814  		},
  3815  		{
  3816  			name: "NothingWrittenNoBody",
  3817  			// It's key that we return 0 here -- that's what enables Transport to know
  3818  			// that nothing was written, even though this is a non-idempotent request.
  3819  			failureN:   0,
  3820  			failureErr: errors.New("second write fails"),
  3821  			req: func() *Request {
  3822  				return newRequest("DELETE", "http://fake.golang", nil)
  3823  			},
  3824  			reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
  3825  		},
  3826  		{
  3827  			name: "NothingWrittenGetBody",
  3828  			// It's key that we return 0 here -- that's what enables Transport to know
  3829  			// that nothing was written, even though this is a non-idempotent request.
  3830  			failureN:   0,
  3831  			failureErr: errors.New("second write fails"),
  3832  			// Note that NewRequest will set up GetBody for strings.Reader, which is
  3833  			// required for the retry to occur
  3834  			req: func() *Request {
  3835  				return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
  3836  			},
  3837  			reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
  3838  		},
  3839  	}
  3840  
  3841  	for _, tc := range testCases {
  3842  		t.Run(tc.name, func(t *testing.T) {
  3843  			var (
  3844  				mu     sync.Mutex
  3845  				logbuf strings.Builder
  3846  			)
  3847  			logf := func(format string, args ...any) {
  3848  				mu.Lock()
  3849  				defer mu.Unlock()
  3850  				fmt.Fprintf(&logbuf, format, args...)
  3851  				logbuf.WriteByte('\n')
  3852  			}
  3853  
  3854  			ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3855  				logf("Handler")
  3856  				w.Header().Set("X-Status", "ok")
  3857  			})).ts
  3858  
  3859  			var writeNumAtomic int32
  3860  			c := ts.Client()
  3861  			c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
  3862  				logf("Dial")
  3863  				c, err := net.Dial(network, ts.Listener.Addr().String())
  3864  				if err != nil {
  3865  					logf("Dial error: %v", err)
  3866  					return nil, err
  3867  				}
  3868  				return &writerFuncConn{
  3869  					Conn: c,
  3870  					write: func(p []byte) (n int, err error) {
  3871  						if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
  3872  							logf("intentional write failure")
  3873  							return tc.failureN, tc.failureErr
  3874  						}
  3875  						logf("Write(%q)", p)
  3876  						return c.Write(p)
  3877  					},
  3878  				}, nil
  3879  			}
  3880  
  3881  			SetRoundTripRetried(func() {
  3882  				logf("Retried.")
  3883  			})
  3884  			defer SetRoundTripRetried(nil)
  3885  
  3886  			for i := 0; i < 3; i++ {
  3887  				t0 := time.Now()
  3888  				req := tc.req()
  3889  				res, err := c.Do(req)
  3890  				if err != nil {
  3891  					if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
  3892  						mu.Lock()
  3893  						got := logbuf.String()
  3894  						mu.Unlock()
  3895  						t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
  3896  					}
  3897  					t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
  3898  				}
  3899  				res.Body.Close()
  3900  				if res.Request != req {
  3901  					t.Errorf("Response.Request != original request; want identical Request")
  3902  				}
  3903  			}
  3904  
  3905  			mu.Lock()
  3906  			got := logbuf.String()
  3907  			mu.Unlock()
  3908  			want := fmt.Sprintf(`Dial
  3909  Write("%s")
  3910  Handler
  3911  intentional write failure
  3912  Retried.
  3913  Dial
  3914  Write("%s")
  3915  Handler
  3916  Write("%s")
  3917  Handler
  3918  `, tc.reqString, tc.reqString, tc.reqString)
  3919  			if got != want {
  3920  				t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
  3921  			}
  3922  		})
  3923  	}
  3924  }
  3925  
  3926  // Issue 6981
  3927  func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
  3928  func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
  3929  	readBody := make(chan error, 1)
  3930  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3931  		_, err := io.ReadAll(r.Body)
  3932  		readBody <- err
  3933  	})).ts
  3934  	c := ts.Client()
  3935  	fakeErr := errors.New("fake error")
  3936  	didClose := make(chan bool, 1)
  3937  	req, _ := NewRequest("POST", ts.URL, struct {
  3938  		io.Reader
  3939  		io.Closer
  3940  	}{
  3941  		io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
  3942  		closerFunc(func() error {
  3943  			select {
  3944  			case didClose <- true:
  3945  			default:
  3946  			}
  3947  			return nil
  3948  		}),
  3949  	})
  3950  	res, err := c.Do(req)
  3951  	if res != nil {
  3952  		defer res.Body.Close()
  3953  	}
  3954  	if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
  3955  		t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
  3956  	}
  3957  	if err := <-readBody; err == nil {
  3958  		t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
  3959  	}
  3960  	select {
  3961  	case <-didClose:
  3962  	default:
  3963  		t.Errorf("didn't see Body.Close")
  3964  	}
  3965  }
  3966  
  3967  func TestTransportDialTLS(t *testing.T) {
  3968  	run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
  3969  }
  3970  func testTransportDialTLS(t *testing.T, mode testMode) {
  3971  	var mu sync.Mutex // guards following
  3972  	var gotReq, didDial bool
  3973  
  3974  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3975  		mu.Lock()
  3976  		gotReq = true
  3977  		mu.Unlock()
  3978  	})).ts
  3979  	c := ts.Client()
  3980  	c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
  3981  		mu.Lock()
  3982  		didDial = true
  3983  		mu.Unlock()
  3984  		c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
  3985  		if err != nil {
  3986  			return nil, err
  3987  		}
  3988  		return c, c.Handshake()
  3989  	}
  3990  
  3991  	res, err := c.Get(ts.URL)
  3992  	if err != nil {
  3993  		t.Fatal(err)
  3994  	}
  3995  	res.Body.Close()
  3996  	mu.Lock()
  3997  	if !gotReq {
  3998  		t.Error("didn't get request")
  3999  	}
  4000  	if !didDial {
  4001  		t.Error("didn't use dial hook")
  4002  	}
  4003  }
  4004  
  4005  func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
  4006  func testTransportDialContext(t *testing.T, mode testMode) {
  4007  	ctxKey := "some-key"
  4008  	ctxValue := "some-value"
  4009  	var (
  4010  		mu          sync.Mutex // guards following
  4011  		gotReq      bool
  4012  		gotCtxValue any
  4013  	)
  4014  
  4015  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4016  		mu.Lock()
  4017  		gotReq = true
  4018  		mu.Unlock()
  4019  	})).ts
  4020  	c := ts.Client()
  4021  	c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
  4022  		mu.Lock()
  4023  		gotCtxValue = ctx.Value(ctxKey)
  4024  		mu.Unlock()
  4025  		return net.Dial(netw, addr)
  4026  	}
  4027  
  4028  	req, err := NewRequest("GET", ts.URL, nil)
  4029  	if err != nil {
  4030  		t.Fatal(err)
  4031  	}
  4032  	ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
  4033  	res, err := c.Do(req.WithContext(ctx))
  4034  	if err != nil {
  4035  		t.Fatal(err)
  4036  	}
  4037  	res.Body.Close()
  4038  	mu.Lock()
  4039  	if !gotReq {
  4040  		t.Error("didn't get request")
  4041  	}
  4042  	if got, want := gotCtxValue, ctxValue; got != want {
  4043  		t.Errorf("got context with value %v, want %v", got, want)
  4044  	}
  4045  }
  4046  
  4047  func TestTransportDialTLSContext(t *testing.T) {
  4048  	run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
  4049  }
  4050  func testTransportDialTLSContext(t *testing.T, mode testMode) {
  4051  	ctxKey := "some-key"
  4052  	ctxValue := "some-value"
  4053  	var (
  4054  		mu          sync.Mutex // guards following
  4055  		gotReq      bool
  4056  		gotCtxValue any
  4057  	)
  4058  
  4059  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4060  		mu.Lock()
  4061  		gotReq = true
  4062  		mu.Unlock()
  4063  	})).ts
  4064  	c := ts.Client()
  4065  	c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
  4066  		mu.Lock()
  4067  		gotCtxValue = ctx.Value(ctxKey)
  4068  		mu.Unlock()
  4069  		c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
  4070  		if err != nil {
  4071  			return nil, err
  4072  		}
  4073  		return c, c.HandshakeContext(ctx)
  4074  	}
  4075  
  4076  	req, err := NewRequest("GET", ts.URL, nil)
  4077  	if err != nil {
  4078  		t.Fatal(err)
  4079  	}
  4080  	ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
  4081  	res, err := c.Do(req.WithContext(ctx))
  4082  	if err != nil {
  4083  		t.Fatal(err)
  4084  	}
  4085  	res.Body.Close()
  4086  	mu.Lock()
  4087  	if !gotReq {
  4088  		t.Error("didn't get request")
  4089  	}
  4090  	if got, want := gotCtxValue, ctxValue; got != want {
  4091  		t.Errorf("got context with value %v, want %v", got, want)
  4092  	}
  4093  }
  4094  
  4095  // Test for issue 8755
  4096  // Ensure that if a proxy returns an error, it is exposed by RoundTrip
  4097  func TestRoundTripReturnsProxyError(t *testing.T) {
  4098  	badProxy := func(*Request) (*url.URL, error) {
  4099  		return nil, errors.New("errorMessage")
  4100  	}
  4101  
  4102  	tr := &Transport{Proxy: badProxy}
  4103  
  4104  	req, _ := NewRequest("GET", "http://example.com", nil)
  4105  
  4106  	_, err := tr.RoundTrip(req)
  4107  
  4108  	if err == nil {
  4109  		t.Error("Expected proxy error to be returned by RoundTrip")
  4110  	}
  4111  }
  4112  
  4113  // tests that putting an idle conn after a call to CloseIdleConns does return it
  4114  func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
  4115  	tr := &Transport{}
  4116  	wantIdle := func(when string, n int) bool {
  4117  		got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
  4118  		if got == n {
  4119  			return true
  4120  		}
  4121  		t.Errorf("%s: idle conns = %d; want %d", when, got, n)
  4122  		return false
  4123  	}
  4124  	wantIdle("start", 0)
  4125  	if !tr.PutIdleTestConn("http", "example.com") {
  4126  		t.Fatal("put failed")
  4127  	}
  4128  	if !tr.PutIdleTestConn("http", "example.com") {
  4129  		t.Fatal("second put failed")
  4130  	}
  4131  	wantIdle("after put", 2)
  4132  	tr.CloseIdleConnections()
  4133  	if !tr.IsIdleForTesting() {
  4134  		t.Error("should be idle after CloseIdleConnections")
  4135  	}
  4136  	wantIdle("after close idle", 0)
  4137  	if tr.PutIdleTestConn("http", "example.com") {
  4138  		t.Fatal("put didn't fail")
  4139  	}
  4140  	wantIdle("after second put", 0)
  4141  
  4142  	tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode
  4143  	if tr.IsIdleForTesting() {
  4144  		t.Error("shouldn't be idle after QueueForIdleConnForTesting")
  4145  	}
  4146  	if !tr.PutIdleTestConn("http", "example.com") {
  4147  		t.Fatal("after re-activation")
  4148  	}
  4149  	wantIdle("after final put", 1)
  4150  }
  4151  
  4152  // Test for issue 34282
  4153  // Ensure that getConn doesn't call the GotConn trace hook on an HTTP/2 idle conn
  4154  func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
  4155  	tr := &Transport{}
  4156  	wantIdle := func(when string, n int) bool {
  4157  		got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2
  4158  		if got == n {
  4159  			return true
  4160  		}
  4161  		t.Errorf("%s: idle conns = %d; want %d", when, got, n)
  4162  		return false
  4163  	}
  4164  	wantIdle("start", 0)
  4165  	alt := funcRoundTripper(func() {})
  4166  	if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
  4167  		t.Fatal("put failed")
  4168  	}
  4169  	wantIdle("after put", 1)
  4170  	ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  4171  		GotConn: func(httptrace.GotConnInfo) {
  4172  			// tr.getConn should leave it for the HTTP/2 alt to call GotConn.
  4173  			t.Error("GotConn called")
  4174  		},
  4175  	})
  4176  	req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
  4177  	_, err := tr.RoundTrip(req)
  4178  	if err != errFakeRoundTrip {
  4179  		t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
  4180  	}
  4181  	wantIdle("after round trip", 1)
  4182  }
  4183  
  4184  func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
  4185  	run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
  4186  }
  4187  func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
  4188  	if testing.Short() {
  4189  		t.Skip("skipping in short mode")
  4190  	}
  4191  
  4192  	timeout := 1 * time.Millisecond
  4193  	retry := true
  4194  	for retry {
  4195  		trFunc := func(tr *Transport) {
  4196  			tr.MaxConnsPerHost = 1
  4197  			tr.MaxIdleConnsPerHost = 1
  4198  			tr.IdleConnTimeout = timeout
  4199  		}
  4200  		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
  4201  
  4202  		retry = false
  4203  		tooShort := func(err error) bool {
  4204  			if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
  4205  				return false
  4206  			}
  4207  			if !retry {
  4208  				t.Helper()
  4209  				t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout)
  4210  				timeout *= 2
  4211  				retry = true
  4212  				cst.close()
  4213  			}
  4214  			return true
  4215  		}
  4216  
  4217  		if _, err := cst.c.Get(cst.ts.URL); err != nil {
  4218  			if tooShort(err) {
  4219  				continue
  4220  			}
  4221  			t.Fatalf("got error: %s", err)
  4222  		}
  4223  
  4224  		time.Sleep(10 * timeout)
  4225  		if _, err := cst.c.Get(cst.ts.URL); err != nil {
  4226  			if tooShort(err) {
  4227  				continue
  4228  			}
  4229  			t.Fatalf("got error: %s", err)
  4230  		}
  4231  	}
  4232  }
  4233  
  4234  // This tests that a client requesting a content range won't also
  4235  // implicitly ask for gzip support. If they want that, they need to do it
  4236  // on their own.
  4237  // golang.org/issue/8923
  4238  func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
  4239  func testTransportRangeAndGzip(t *testing.T, mode testMode) {
  4240  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4241  		if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
  4242  			t.Error("Transport advertised gzip support in the Accept header")
  4243  		}
  4244  		if r.Header.Get("Range") == "" {
  4245  			t.Error("no Range in request")
  4246  		}
  4247  	})).ts
  4248  	c := ts.Client()
  4249  
  4250  	req, _ := NewRequest("GET", ts.URL, nil)
  4251  	req.Header.Set("Range", "bytes=7-11")
  4252  	res, err := c.Do(req)
  4253  	if err != nil {
  4254  		t.Fatal(err)
  4255  	}
  4256  	res.Body.Close()
  4257  }
  4258  
  4259  // Test for issue 10474
  4260  func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
  4261  func testTransportResponseCancelRace(t *testing.T, mode testMode) {
  4262  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4263  		// important that this response has a body.
  4264  		var b [1024]byte
  4265  		w.Write(b[:])
  4266  	})).ts
  4267  	tr := ts.Client().Transport.(*Transport)
  4268  
  4269  	req, err := NewRequest("GET", ts.URL, nil)
  4270  	if err != nil {
  4271  		t.Fatal(err)
  4272  	}
  4273  	res, err := tr.RoundTrip(req)
  4274  	if err != nil {
  4275  		t.Fatal(err)
  4276  	}
  4277  	// If we do an early close, Transport just throws the connection away and
  4278  	// doesn't reuse it. In order to trigger the bug, it has to reuse the connection
  4279  	// so read the body
  4280  	if _, err := io.Copy(io.Discard, res.Body); err != nil {
  4281  		t.Fatal(err)
  4282  	}
  4283  
  4284  	req2, err := NewRequest("GET", ts.URL, nil)
  4285  	if err != nil {
  4286  		t.Fatal(err)
  4287  	}
  4288  	tr.CancelRequest(req)
  4289  	res, err = tr.RoundTrip(req2)
  4290  	if err != nil {
  4291  		t.Fatal(err)
  4292  	}
  4293  	res.Body.Close()
  4294  }
  4295  
  4296  // Test for issue 19248: Content-Encoding's value is case insensitive.
  4297  func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
  4298  	run(t, testTransportContentEncodingCaseInsensitive)
  4299  }
  4300  func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
  4301  	for _, ce := range []string{"gzip", "GZIP"} {
  4302  		ce := ce
  4303  		t.Run(ce, func(t *testing.T) {
  4304  			const encodedString = "Hello Gopher"
  4305  			ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4306  				w.Header().Set("Content-Encoding", ce)
  4307  				gz := gzip.NewWriter(w)
  4308  				gz.Write([]byte(encodedString))
  4309  				gz.Close()
  4310  			})).ts
  4311  
  4312  			res, err := ts.Client().Get(ts.URL)
  4313  			if err != nil {
  4314  				t.Fatal(err)
  4315  			}
  4316  
  4317  			body, err := io.ReadAll(res.Body)
  4318  			res.Body.Close()
  4319  			if err != nil {
  4320  				t.Fatal(err)
  4321  			}
  4322  
  4323  			if string(body) != encodedString {
  4324  				t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
  4325  			}
  4326  		})
  4327  	}
  4328  }
  4329  
  4330  // https://go.dev/issue/49621
  4331  func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
  4332  	run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
  4333  }
  4334  func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
  4335  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
  4336  		func(tr *Transport) {
  4337  			tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
  4338  				// Connection immediately returns errors.
  4339  				return &funcConn{
  4340  					read: func([]byte) (int, error) {
  4341  						return 0, errors.New("error")
  4342  					},
  4343  					write: func([]byte) (int, error) {
  4344  						return 0, errors.New("error")
  4345  					},
  4346  				}, nil
  4347  			}
  4348  		},
  4349  	).ts
  4350  	// Set a short delay in RoundTrip to give the persistConn time to notice
  4351  	// the connection is broken. We want to exercise the path where writeLoop exits
  4352  	// before it reads the request to send. If this delay is too short, we may instead
  4353  	// exercise the path where writeLoop accepts the request and then fails to write it.
  4354  	// That's fine, so long as we get the desired path often enough.
  4355  	SetEnterRoundTripHook(func() {
  4356  		time.Sleep(1 * time.Millisecond)
  4357  	})
  4358  	defer SetEnterRoundTripHook(nil)
  4359  	var closes int
  4360  	_, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
  4361  	if err == nil {
  4362  		t.Fatalf("expected request to fail, but it did not")
  4363  	}
  4364  	if closes != 1 {
  4365  		t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
  4366  	}
  4367  }
  4368  
  4369  // logWritesConn is a net.Conn that logs each Write call to writes
  4370  // and then proxies to w.
  4371  // It proxies Read calls to a reader it receives from rch.
  4372  type logWritesConn struct {
  4373  	net.Conn // nil. crash on use.
  4374  
  4375  	w io.Writer
  4376  
  4377  	rch <-chan io.Reader
  4378  	r   io.Reader // nil until received by rch
  4379  
  4380  	mu     sync.Mutex
  4381  	writes []string
  4382  }
  4383  
  4384  func (c *logWritesConn) Write(p []byte) (n int, err error) {
  4385  	c.mu.Lock()
  4386  	defer c.mu.Unlock()
  4387  	c.writes = append(c.writes, string(p))
  4388  	return c.w.Write(p)
  4389  }
  4390  
  4391  func (c *logWritesConn) Read(p []byte) (n int, err error) {
  4392  	if c.r == nil {
  4393  		c.r = <-c.rch
  4394  	}
  4395  	return c.r.Read(p)
  4396  }
  4397  
  4398  func (c *logWritesConn) Close() error { return nil }
  4399  
  4400  // Issue 6574
  4401  func TestTransportFlushesBodyChunks(t *testing.T) {
  4402  	defer afterTest(t)
  4403  	resBody := make(chan io.Reader, 1)
  4404  	connr, connw := io.Pipe() // connection pipe pair
  4405  	lw := &logWritesConn{
  4406  		rch: resBody,
  4407  		w:   connw,
  4408  	}
  4409  	tr := &Transport{
  4410  		Dial: func(network, addr string) (net.Conn, error) {
  4411  			return lw, nil
  4412  		},
  4413  	}
  4414  	bodyr, bodyw := io.Pipe() // body pipe pair
  4415  	go func() {
  4416  		defer bodyw.Close()
  4417  		for i := 0; i < 3; i++ {
  4418  			fmt.Fprintf(bodyw, "num%d\n", i)
  4419  		}
  4420  	}()
  4421  	resc := make(chan *Response)
  4422  	go func() {
  4423  		req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
  4424  		req.Header.Set("User-Agent", "x") // known value for test
  4425  		res, err := tr.RoundTrip(req)
  4426  		if err != nil {
  4427  			t.Errorf("RoundTrip: %v", err)
  4428  			close(resc)
  4429  			return
  4430  		}
  4431  		resc <- res
  4432  
  4433  	}()
  4434  	// Fully consume the request before checking the Write log vs. want.
  4435  	req, err := ReadRequest(bufio.NewReader(connr))
  4436  	if err != nil {
  4437  		t.Fatal(err)
  4438  	}
  4439  	io.Copy(io.Discard, req.Body)
  4440  
  4441  	// Unblock the transport's roundTrip goroutine.
  4442  	resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
  4443  	res, ok := <-resc
  4444  	if !ok {
  4445  		return
  4446  	}
  4447  	defer res.Body.Close()
  4448  
  4449  	want := []string{
  4450  		"POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
  4451  		"5\r\nnum0\n\r\n",
  4452  		"5\r\nnum1\n\r\n",
  4453  		"5\r\nnum2\n\r\n",
  4454  		"0\r\n\r\n",
  4455  	}
  4456  	if !reflect.DeepEqual(lw.writes, want) {
  4457  		t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
  4458  	}
  4459  }
  4460  
  4461  // Issue 22088: flush Transport request headers if we're not sure the body won't block on read.
  4462  func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
  4463  func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
  4464  	gotReq := make(chan struct{})
  4465  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4466  		close(gotReq)
  4467  	}))
  4468  
  4469  	pr, pw := io.Pipe()
  4470  	req, err := NewRequest("POST", cst.ts.URL, pr)
  4471  	if err != nil {
  4472  		t.Fatal(err)
  4473  	}
  4474  	gotRes := make(chan struct{})
  4475  	go func() {
  4476  		defer close(gotRes)
  4477  		res, err := cst.tr.RoundTrip(req)
  4478  		if err != nil {
  4479  			t.Error(err)
  4480  			return
  4481  		}
  4482  		res.Body.Close()
  4483  	}()
  4484  
  4485  	<-gotReq
  4486  	pw.Close()
  4487  	<-gotRes
  4488  }
  4489  
  4490  type wgReadCloser struct {
  4491  	io.Reader
  4492  	wg     *sync.WaitGroup
  4493  	closed bool
  4494  }
  4495  
  4496  func (c *wgReadCloser) Close() error {
  4497  	if c.closed {
  4498  		return net.ErrClosed
  4499  	}
  4500  	c.closed = true
  4501  	c.wg.Done()
  4502  	return nil
  4503  }
  4504  
  4505  // Issue 11745.
  4506  func TestTransportPrefersResponseOverWriteError(t *testing.T) {
  4507  	// Not parallel: modifies the global rstAvoidanceDelay.
  4508  	run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
  4509  }
  4510  func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
  4511  	if testing.Short() {
  4512  		t.Skip("skipping in short mode")
  4513  	}
  4514  
  4515  	runTimeSensitiveTest(t, []time.Duration{
  4516  		1 * time.Millisecond,
  4517  		5 * time.Millisecond,
  4518  		10 * time.Millisecond,
  4519  		50 * time.Millisecond,
  4520  		100 * time.Millisecond,
  4521  		500 * time.Millisecond,
  4522  		time.Second,
  4523  		5 * time.Second,
  4524  	}, func(t *testing.T, timeout time.Duration) error {
  4525  		SetRSTAvoidanceDelay(t, timeout)
  4526  		t.Logf("set RST avoidance delay to %v", timeout)
  4527  
  4528  		const contentLengthLimit = 1024 * 1024 // 1MB
  4529  		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4530  			if r.ContentLength >= contentLengthLimit {
  4531  				w.WriteHeader(StatusBadRequest)
  4532  				r.Body.Close()
  4533  				return
  4534  			}
  4535  			w.WriteHeader(StatusOK)
  4536  		}))
  4537  		// We need to close cst explicitly here so that in-flight server
  4538  		// requests don't race with the call to SetRSTAvoidanceDelay for a retry.
  4539  		defer cst.close()
  4540  		ts := cst.ts
  4541  		c := ts.Client()
  4542  
  4543  		count := 100
  4544  
  4545  		bigBody := strings.Repeat("a", contentLengthLimit*2)
  4546  		var wg sync.WaitGroup
  4547  		defer wg.Wait()
  4548  		getBody := func() (io.ReadCloser, error) {
  4549  			wg.Add(1)
  4550  			body := &wgReadCloser{
  4551  				Reader: strings.NewReader(bigBody),
  4552  				wg:     &wg,
  4553  			}
  4554  			return body, nil
  4555  		}
  4556  
  4557  		for i := 0; i < count; i++ {
  4558  			reqBody, _ := getBody()
  4559  			req, err := NewRequest("PUT", ts.URL, reqBody)
  4560  			if err != nil {
  4561  				reqBody.Close()
  4562  				t.Fatal(err)
  4563  			}
  4564  			req.ContentLength = int64(len(bigBody))
  4565  			req.GetBody = getBody
  4566  
  4567  			resp, err := c.Do(req)
  4568  			if err != nil {
  4569  				return fmt.Errorf("Do %d: %v", i, err)
  4570  			} else {
  4571  				resp.Body.Close()
  4572  				if resp.StatusCode != 400 {
  4573  					t.Errorf("Expected status code 400, got %v", resp.Status)
  4574  				}
  4575  			}
  4576  		}
  4577  		return nil
  4578  	})
  4579  }
  4580  
  4581  func TestTransportAutomaticHTTP2(t *testing.T) {
  4582  	testTransportAutoHTTP(t, &Transport{}, true)
  4583  }
  4584  
  4585  func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
  4586  	testTransportAutoHTTP(t, &Transport{
  4587  		ForceAttemptHTTP2: true,
  4588  		TLSClientConfig:   new(tls.Config),
  4589  	}, true)
  4590  }
  4591  
  4592  // golang.org/issue/14391: also check DefaultTransport
  4593  func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
  4594  	testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
  4595  }
  4596  
  4597  func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
  4598  	testTransportAutoHTTP(t, &Transport{
  4599  		TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
  4600  	}, false)
  4601  }
  4602  
  4603  func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
  4604  	testTransportAutoHTTP(t, &Transport{
  4605  		TLSClientConfig: new(tls.Config),
  4606  	}, false)
  4607  }
  4608  
  4609  func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
  4610  	testTransportAutoHTTP(t, &Transport{
  4611  		ExpectContinueTimeout: 1 * time.Second,
  4612  	}, true)
  4613  }
  4614  
  4615  func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
  4616  	var d net.Dialer
  4617  	testTransportAutoHTTP(t, &Transport{
  4618  		Dial: d.Dial,
  4619  	}, false)
  4620  }
  4621  
  4622  func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
  4623  	var d net.Dialer
  4624  	testTransportAutoHTTP(t, &Transport{
  4625  		DialContext: d.DialContext,
  4626  	}, false)
  4627  }
  4628  
  4629  func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
  4630  	testTransportAutoHTTP(t, &Transport{
  4631  		DialTLS: func(network, addr string) (net.Conn, error) {
  4632  			panic("unused")
  4633  		},
  4634  	}, false)
  4635  }
  4636  
  4637  func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
  4638  	CondSkipHTTP2(t)
  4639  	_, err := tr.RoundTrip(new(Request))
  4640  	if err == nil {
  4641  		t.Error("expected error from RoundTrip")
  4642  	}
  4643  	if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
  4644  		t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
  4645  	}
  4646  }
  4647  
  4648  // Issue 13633: there was a race where we returned bodyless responses
  4649  // to callers before recycling the persistent connection, which meant
  4650  // a client doing two subsequent requests could end up on different
  4651  // connections. It's somewhat harmless but enough tests assume it's
  4652  // not true in order to test other things that it's worth fixing.
  4653  // Plus it's nice to be consistent and not have timing-dependent
  4654  // behavior.
  4655  func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
  4656  	run(t, testTransportReuseConnEmptyResponseBody)
  4657  }
  4658  func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
  4659  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4660  		w.Header().Set("X-Addr", r.RemoteAddr)
  4661  		// Empty response body.
  4662  	}))
  4663  	n := 100
  4664  	if testing.Short() {
  4665  		n = 10
  4666  	}
  4667  	var firstAddr string
  4668  	for i := 0; i < n; i++ {
  4669  		res, err := cst.c.Get(cst.ts.URL)
  4670  		if err != nil {
  4671  			log.Fatal(err)
  4672  		}
  4673  		addr := res.Header.Get("X-Addr")
  4674  		if i == 0 {
  4675  			firstAddr = addr
  4676  		} else if addr != firstAddr {
  4677  			t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
  4678  		}
  4679  		res.Body.Close()
  4680  	}
  4681  }
  4682  
  4683  // Issue 13839
  4684  func TestNoCrashReturningTransportAltConn(t *testing.T) {
  4685  	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
  4686  	if err != nil {
  4687  		t.Fatal(err)
  4688  	}
  4689  	ln := newLocalListener(t)
  4690  	defer ln.Close()
  4691  
  4692  	var wg sync.WaitGroup
  4693  	SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
  4694  	defer SetPendingDialHooks(nil, nil)
  4695  
  4696  	testDone := make(chan struct{})
  4697  	defer close(testDone)
  4698  	go func() {
  4699  		tln := tls.NewListener(ln, &tls.Config{
  4700  			NextProtos:   []string{"foo"},
  4701  			Certificates: []tls.Certificate{cert},
  4702  		})
  4703  		sc, err := tln.Accept()
  4704  		if err != nil {
  4705  			t.Error(err)
  4706  			return
  4707  		}
  4708  		if err := sc.(*tls.Conn).Handshake(); err != nil {
  4709  			t.Error(err)
  4710  			return
  4711  		}
  4712  		<-testDone
  4713  		sc.Close()
  4714  	}()
  4715  
  4716  	addr := ln.Addr().String()
  4717  
  4718  	req, _ := NewRequest("GET", "https://fake.tld/", nil)
  4719  	cancel := make(chan struct{})
  4720  	req.Cancel = cancel
  4721  
  4722  	doReturned := make(chan bool, 1)
  4723  	madeRoundTripper := make(chan bool, 1)
  4724  
  4725  	tr := &Transport{
  4726  		DisableKeepAlives: true,
  4727  		TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
  4728  			"foo": func(authority string, c *tls.Conn) RoundTripper {
  4729  				madeRoundTripper <- true
  4730  				return funcRoundTripper(func() {
  4731  					t.Error("foo RoundTripper should not be called")
  4732  				})
  4733  			},
  4734  		},
  4735  		Dial: func(_, _ string) (net.Conn, error) {
  4736  			panic("shouldn't be called")
  4737  		},
  4738  		DialTLS: func(_, _ string) (net.Conn, error) {
  4739  			tc, err := tls.Dial("tcp", addr, &tls.Config{
  4740  				InsecureSkipVerify: true,
  4741  				NextProtos:         []string{"foo"},
  4742  			})
  4743  			if err != nil {
  4744  				return nil, err
  4745  			}
  4746  			if err := tc.Handshake(); err != nil {
  4747  				return nil, err
  4748  			}
  4749  			close(cancel)
  4750  			<-doReturned
  4751  			return tc, nil
  4752  		},
  4753  	}
  4754  	c := &Client{Transport: tr}
  4755  
  4756  	_, err = c.Do(req)
  4757  	if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
  4758  		t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
  4759  	}
  4760  
  4761  	doReturned <- true
  4762  	<-madeRoundTripper
  4763  	wg.Wait()
  4764  }
  4765  
  4766  func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
  4767  	run(t, func(t *testing.T, mode testMode) {
  4768  		testTransportReuseConnection_Gzip(t, mode, true)
  4769  	})
  4770  }
  4771  
  4772  func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
  4773  	run(t, func(t *testing.T, mode testMode) {
  4774  		testTransportReuseConnection_Gzip(t, mode, false)
  4775  	})
  4776  }
  4777  
  4778  // Make sure we re-use underlying TCP connection for gzipped responses too.
  4779  func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
  4780  	addr := make(chan string, 2)
  4781  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4782  		addr <- r.RemoteAddr
  4783  		w.Header().Set("Content-Encoding", "gzip")
  4784  		if chunked {
  4785  			w.(Flusher).Flush()
  4786  		}
  4787  		w.Write(rgz) // arbitrary gzip response
  4788  	})).ts
  4789  	c := ts.Client()
  4790  
  4791  	trace := &httptrace.ClientTrace{
  4792  		GetConn:      func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
  4793  		GotConn:      func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
  4794  		PutIdleConn:  func(err error) { t.Logf("PutIdleConn(%v)", err) },
  4795  		ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
  4796  		ConnectDone:  func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
  4797  	}
  4798  	ctx := httptrace.WithClientTrace(context.Background(), trace)
  4799  
  4800  	for i := 0; i < 2; i++ {
  4801  		req, _ := NewRequest("GET", ts.URL, nil)
  4802  		req = req.WithContext(ctx)
  4803  		res, err := c.Do(req)
  4804  		if err != nil {
  4805  			t.Fatal(err)
  4806  		}
  4807  		buf := make([]byte, len(rgz))
  4808  		if n, err := io.ReadFull(res.Body, buf); err != nil {
  4809  			t.Errorf("%d. ReadFull = %v, %v", i, n, err)
  4810  		}
  4811  		// Note: no res.Body.Close call. It should work without it,
  4812  		// since the flate.Reader's internal buffering will hit EOF
  4813  		// and that should be sufficient.
  4814  	}
  4815  	a1, a2 := <-addr, <-addr
  4816  	if a1 != a2 {
  4817  		t.Fatalf("didn't reuse connection")
  4818  	}
  4819  }
  4820  
  4821  func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
  4822  func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
  4823  	if mode == http2Mode {
  4824  		t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
  4825  	}
  4826  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4827  		if r.URL.Path == "/long" {
  4828  			w.Header().Set("Long", strings.Repeat("a", 1<<20))
  4829  		}
  4830  	})).ts
  4831  	c := ts.Client()
  4832  	c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
  4833  
  4834  	if res, err := c.Get(ts.URL); err != nil {
  4835  		t.Fatal(err)
  4836  	} else {
  4837  		res.Body.Close()
  4838  	}
  4839  
  4840  	res, err := c.Get(ts.URL + "/long")
  4841  	if err == nil {
  4842  		defer res.Body.Close()
  4843  		var n int64
  4844  		for k, vv := range res.Header {
  4845  			for _, v := range vv {
  4846  				n += int64(len(k)) + int64(len(v))
  4847  			}
  4848  		}
  4849  		t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
  4850  	}
  4851  	if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
  4852  		t.Errorf("got error: %v; want %q", err, want)
  4853  	}
  4854  }
  4855  
  4856  func TestTransportEventTrace(t *testing.T) {
  4857  	run(t, func(t *testing.T, mode testMode) {
  4858  		testTransportEventTrace(t, mode, false)
  4859  	}, testNotParallel)
  4860  }
  4861  
  4862  // test a non-nil httptrace.ClientTrace but with all hooks set to zero.
  4863  func TestTransportEventTrace_NoHooks(t *testing.T) {
  4864  	run(t, func(t *testing.T, mode testMode) {
  4865  		testTransportEventTrace(t, mode, true)
  4866  	}, testNotParallel)
  4867  }
  4868  
  4869  func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
  4870  	const resBody = "some body"
  4871  	gotWroteReqEvent := make(chan struct{}, 500)
  4872  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4873  		if r.Method == "GET" {
  4874  			// Do nothing for the second request.
  4875  			return
  4876  		}
  4877  		if _, err := io.ReadAll(r.Body); err != nil {
  4878  			t.Error(err)
  4879  		}
  4880  		if !noHooks {
  4881  			<-gotWroteReqEvent
  4882  		}
  4883  		io.WriteString(w, resBody)
  4884  	}), func(tr *Transport) {
  4885  		if tr.TLSClientConfig != nil {
  4886  			tr.TLSClientConfig.InsecureSkipVerify = true
  4887  		}
  4888  	})
  4889  	defer cst.close()
  4890  
  4891  	cst.tr.ExpectContinueTimeout = 1 * time.Second
  4892  
  4893  	var mu sync.Mutex // guards buf
  4894  	var buf strings.Builder
  4895  	logf := func(format string, args ...any) {
  4896  		mu.Lock()
  4897  		defer mu.Unlock()
  4898  		fmt.Fprintf(&buf, format, args...)
  4899  		buf.WriteByte('\n')
  4900  	}
  4901  
  4902  	addrStr := cst.ts.Listener.Addr().String()
  4903  	ip, port, err := net.SplitHostPort(addrStr)
  4904  	if err != nil {
  4905  		t.Fatal(err)
  4906  	}
  4907  
  4908  	// Install a fake DNS server.
  4909  	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
  4910  		if host != "dns-is-faked.golang" {
  4911  			t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
  4912  			return nil, nil
  4913  		}
  4914  		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
  4915  	})
  4916  
  4917  	body := "some body"
  4918  	req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
  4919  	req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
  4920  	trace := &httptrace.ClientTrace{
  4921  		GetConn:              func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
  4922  		GotConn:              func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
  4923  		GotFirstResponseByte: func() { logf("first response byte") },
  4924  		PutIdleConn:          func(err error) { logf("PutIdleConn = %v", err) },
  4925  		DNSStart:             func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
  4926  		DNSDone:              func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
  4927  		ConnectStart:         func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
  4928  		ConnectDone: func(network, addr string, err error) {
  4929  			if err != nil {
  4930  				t.Errorf("ConnectDone: %v", err)
  4931  			}
  4932  			logf("ConnectDone: connected to %s %s = %v", network, addr, err)
  4933  		},
  4934  		WroteHeaderField: func(key string, value []string) {
  4935  			logf("WroteHeaderField: %s: %v", key, value)
  4936  		},
  4937  		WroteHeaders: func() {
  4938  			logf("WroteHeaders")
  4939  		},
  4940  		Wait100Continue: func() { logf("Wait100Continue") },
  4941  		Got100Continue:  func() { logf("Got100Continue") },
  4942  		WroteRequest: func(e httptrace.WroteRequestInfo) {
  4943  			logf("WroteRequest: %+v", e)
  4944  			gotWroteReqEvent <- struct{}{}
  4945  		},
  4946  	}
  4947  	if mode == http2Mode {
  4948  		trace.TLSHandshakeStart = func() { logf("tls handshake start") }
  4949  		trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
  4950  			logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
  4951  		}
  4952  	}
  4953  	if noHooks {
  4954  		// zero out all func pointers, trying to get some path to crash
  4955  		*trace = httptrace.ClientTrace{}
  4956  	}
  4957  	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
  4958  
  4959  	req.Header.Set("Expect", "100-continue")
  4960  	res, err := cst.c.Do(req)
  4961  	if err != nil {
  4962  		t.Fatal(err)
  4963  	}
  4964  	logf("got roundtrip.response")
  4965  	slurp, err := io.ReadAll(res.Body)
  4966  	if err != nil {
  4967  		t.Fatal(err)
  4968  	}
  4969  	logf("consumed body")
  4970  	if string(slurp) != resBody || res.StatusCode != 200 {
  4971  		t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
  4972  	}
  4973  	res.Body.Close()
  4974  
  4975  	if noHooks {
  4976  		// Done at this point. Just testing a full HTTP
  4977  		// requests can happen with a trace pointing to a zero
  4978  		// ClientTrace, full of nil func pointers.
  4979  		return
  4980  	}
  4981  
  4982  	mu.Lock()
  4983  	got := buf.String()
  4984  	mu.Unlock()
  4985  
  4986  	wantOnce := func(sub string) {
  4987  		if strings.Count(got, sub) != 1 {
  4988  			t.Errorf("expected substring %q exactly once in output.", sub)
  4989  		}
  4990  	}
  4991  	wantOnceOrMore := func(sub string) {
  4992  		if strings.Count(got, sub) == 0 {
  4993  			t.Errorf("expected substring %q at least once in output.", sub)
  4994  		}
  4995  	}
  4996  	wantOnce("Getting conn for dns-is-faked.golang:" + port)
  4997  	wantOnce("DNS start: {Host:dns-is-faked.golang}")
  4998  	wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
  4999  	wantOnce("got conn: {")
  5000  	wantOnceOrMore("Connecting to tcp " + addrStr)
  5001  	wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
  5002  	wantOnce("Reused:false WasIdle:false IdleTime:0s")
  5003  	wantOnce("first response byte")
  5004  	if mode == http2Mode {
  5005  		wantOnce("tls handshake start")
  5006  		wantOnce("tls handshake done")
  5007  	} else {
  5008  		wantOnce("PutIdleConn = <nil>")
  5009  		wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
  5010  		// TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the
  5011  		// WroteHeaderField hook is not yet implemented in h2.)
  5012  		wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
  5013  		wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
  5014  		wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
  5015  		wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
  5016  	}
  5017  	wantOnce("WroteHeaders")
  5018  	wantOnce("Wait100Continue")
  5019  	wantOnce("Got100Continue")
  5020  	wantOnce("WroteRequest: {Err:<nil>}")
  5021  	if strings.Contains(got, " to udp ") {
  5022  		t.Errorf("should not see UDP (DNS) connections")
  5023  	}
  5024  	if t.Failed() {
  5025  		t.Errorf("Output:\n%s", got)
  5026  	}
  5027  
  5028  	// And do a second request:
  5029  	req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
  5030  	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
  5031  	res, err = cst.c.Do(req)
  5032  	if err != nil {
  5033  		t.Fatal(err)
  5034  	}
  5035  	if res.StatusCode != 200 {
  5036  		t.Fatal(res.Status)
  5037  	}
  5038  	res.Body.Close()
  5039  
  5040  	mu.Lock()
  5041  	got = buf.String()
  5042  	mu.Unlock()
  5043  
  5044  	sub := "Getting conn for dns-is-faked.golang:"
  5045  	if gotn, want := strings.Count(got, sub), 2; gotn != want {
  5046  		t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
  5047  	}
  5048  
  5049  }
  5050  
  5051  func TestTransportEventTraceTLSVerify(t *testing.T) {
  5052  	run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
  5053  }
  5054  func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
  5055  	var mu sync.Mutex
  5056  	var buf strings.Builder
  5057  	logf := func(format string, args ...any) {
  5058  		mu.Lock()
  5059  		defer mu.Unlock()
  5060  		fmt.Fprintf(&buf, format, args...)
  5061  		buf.WriteByte('\n')
  5062  	}
  5063  
  5064  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5065  		t.Error("Unexpected request")
  5066  	}), func(ts *httptest.Server) {
  5067  		ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
  5068  			logf("%s", p)
  5069  			return len(p), nil
  5070  		}), "", 0)
  5071  	}).ts
  5072  
  5073  	certpool := x509.NewCertPool()
  5074  	certpool.AddCert(ts.Certificate())
  5075  
  5076  	c := &Client{Transport: &Transport{
  5077  		TLSClientConfig: &tls.Config{
  5078  			ServerName: "dns-is-faked.golang",
  5079  			RootCAs:    certpool,
  5080  		},
  5081  	}}
  5082  
  5083  	trace := &httptrace.ClientTrace{
  5084  		TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
  5085  		TLSHandshakeDone: func(s tls.ConnectionState, err error) {
  5086  			logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
  5087  		},
  5088  	}
  5089  
  5090  	req, _ := NewRequest("GET", ts.URL, nil)
  5091  	req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
  5092  	_, err := c.Do(req)
  5093  	if err == nil {
  5094  		t.Error("Expected request to fail TLS verification")
  5095  	}
  5096  
  5097  	mu.Lock()
  5098  	got := buf.String()
  5099  	mu.Unlock()
  5100  
  5101  	wantOnce := func(sub string) {
  5102  		if strings.Count(got, sub) != 1 {
  5103  			t.Errorf("expected substring %q exactly once in output.", sub)
  5104  		}
  5105  	}
  5106  
  5107  	wantOnce("TLSHandshakeStart")
  5108  	wantOnce("TLSHandshakeDone")
  5109  	wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
  5110  
  5111  	if t.Failed() {
  5112  		t.Errorf("Output:\n%s", got)
  5113  	}
  5114  }
  5115  
  5116  var (
  5117  	isDNSHijackedOnce sync.Once
  5118  	isDNSHijacked     bool
  5119  )
  5120  
  5121  func skipIfDNSHijacked(t *testing.T) {
  5122  	// Skip this test if the user is using a shady/ISP
  5123  	// DNS server hijacking queries.
  5124  	// See issues 16732, 16716.
  5125  	isDNSHijackedOnce.Do(func() {
  5126  		addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
  5127  		isDNSHijacked = len(addrs) != 0
  5128  	})
  5129  	if isDNSHijacked {
  5130  		t.Skip("skipping; test requires non-hijacking DNS server")
  5131  	}
  5132  }
  5133  
  5134  func TestTransportEventTraceRealDNS(t *testing.T) {
  5135  	skipIfDNSHijacked(t)
  5136  	defer afterTest(t)
  5137  	tr := &Transport{}
  5138  	defer tr.CloseIdleConnections()
  5139  	c := &Client{Transport: tr}
  5140  
  5141  	var mu sync.Mutex // guards buf
  5142  	var buf strings.Builder
  5143  	logf := func(format string, args ...any) {
  5144  		mu.Lock()
  5145  		defer mu.Unlock()
  5146  		fmt.Fprintf(&buf, format, args...)
  5147  		buf.WriteByte('\n')
  5148  	}
  5149  
  5150  	req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
  5151  	trace := &httptrace.ClientTrace{
  5152  		DNSStart:     func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
  5153  		DNSDone:      func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
  5154  		ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
  5155  		ConnectDone:  func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
  5156  	}
  5157  	req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
  5158  
  5159  	resp, err := c.Do(req)
  5160  	if err == nil {
  5161  		resp.Body.Close()
  5162  		t.Fatal("expected error during DNS lookup")
  5163  	}
  5164  
  5165  	mu.Lock()
  5166  	got := buf.String()
  5167  	mu.Unlock()
  5168  
  5169  	wantSub := func(sub string) {
  5170  		if !strings.Contains(got, sub) {
  5171  			t.Errorf("expected substring %q in output.", sub)
  5172  		}
  5173  	}
  5174  	wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
  5175  	wantSub("DNSDone: {Addrs:[] Err:")
  5176  	if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
  5177  		t.Errorf("should not see Connect events")
  5178  	}
  5179  	if t.Failed() {
  5180  		t.Errorf("Output:\n%s", got)
  5181  	}
  5182  }
  5183  
  5184  // Issue 14353: port can only contain digits.
  5185  func TestTransportRejectsAlphaPort(t *testing.T) {
  5186  	res, err := Get("http://dummy.tld:123foo/bar")
  5187  	if err == nil {
  5188  		res.Body.Close()
  5189  		t.Fatal("unexpected success")
  5190  	}
  5191  	ue, ok := err.(*url.Error)
  5192  	if !ok {
  5193  		t.Fatalf("got %#v; want *url.Error", err)
  5194  	}
  5195  	got := ue.Err.Error()
  5196  	want := `invalid port ":123foo" after host`
  5197  	if got != want {
  5198  		t.Errorf("got error %q; want %q", got, want)
  5199  	}
  5200  }
  5201  
  5202  // Test the httptrace.TLSHandshake{Start,Done} hooks with an https http1
  5203  // connections. The http2 test is done in TestTransportEventTrace_h2
  5204  func TestTLSHandshakeTrace(t *testing.T) {
  5205  	run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
  5206  }
  5207  func testTLSHandshakeTrace(t *testing.T, mode testMode) {
  5208  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
  5209  
  5210  	var mu sync.Mutex
  5211  	var start, done bool
  5212  	trace := &httptrace.ClientTrace{
  5213  		TLSHandshakeStart: func() {
  5214  			mu.Lock()
  5215  			defer mu.Unlock()
  5216  			start = true
  5217  		},
  5218  		TLSHandshakeDone: func(s tls.ConnectionState, err error) {
  5219  			mu.Lock()
  5220  			defer mu.Unlock()
  5221  			done = true
  5222  			if err != nil {
  5223  				t.Fatal("Expected error to be nil but was:", err)
  5224  			}
  5225  		},
  5226  	}
  5227  
  5228  	c := ts.Client()
  5229  	req, err := NewRequest("GET", ts.URL, nil)
  5230  	if err != nil {
  5231  		t.Fatal("Unable to construct test request:", err)
  5232  	}
  5233  	req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
  5234  
  5235  	r, err := c.Do(req)
  5236  	if err != nil {
  5237  		t.Fatal("Unexpected error making request:", err)
  5238  	}
  5239  	r.Body.Close()
  5240  	mu.Lock()
  5241  	defer mu.Unlock()
  5242  	if !start {
  5243  		t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
  5244  	}
  5245  	if !done {
  5246  		t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
  5247  	}
  5248  }
  5249  
  5250  func TestTransportMaxIdleConns(t *testing.T) {
  5251  	run(t, testTransportMaxIdleConns, []testMode{http1Mode})
  5252  }
  5253  func testTransportMaxIdleConns(t *testing.T, mode testMode) {
  5254  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5255  		// No body for convenience.
  5256  	})).ts
  5257  	c := ts.Client()
  5258  	tr := c.Transport.(*Transport)
  5259  	tr.MaxIdleConns = 4
  5260  
  5261  	ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
  5262  	if err != nil {
  5263  		t.Fatal(err)
  5264  	}
  5265  	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
  5266  		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
  5267  	})
  5268  
  5269  	hitHost := func(n int) {
  5270  		req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
  5271  		req = req.WithContext(ctx)
  5272  		res, err := c.Do(req)
  5273  		if err != nil {
  5274  			t.Fatal(err)
  5275  		}
  5276  		res.Body.Close()
  5277  	}
  5278  	for i := 0; i < 4; i++ {
  5279  		hitHost(i)
  5280  	}
  5281  	want := []string{
  5282  		"|http|host-0.dns-is-faked.golang:" + port,
  5283  		"|http|host-1.dns-is-faked.golang:" + port,
  5284  		"|http|host-2.dns-is-faked.golang:" + port,
  5285  		"|http|host-3.dns-is-faked.golang:" + port,
  5286  	}
  5287  	if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
  5288  		t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
  5289  	}
  5290  
  5291  	// Now hitting the 5th host should kick out the first host:
  5292  	hitHost(4)
  5293  	want = []string{
  5294  		"|http|host-1.dns-is-faked.golang:" + port,
  5295  		"|http|host-2.dns-is-faked.golang:" + port,
  5296  		"|http|host-3.dns-is-faked.golang:" + port,
  5297  		"|http|host-4.dns-is-faked.golang:" + port,
  5298  	}
  5299  	if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
  5300  		t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
  5301  	}
  5302  }
  5303  
  5304  func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
  5305  func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
  5306  	if testing.Short() {
  5307  		t.Skip("skipping in short mode")
  5308  	}
  5309  
  5310  	timeout := 1 * time.Millisecond
  5311  timeoutLoop:
  5312  	for {
  5313  		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5314  			// No body for convenience.
  5315  		}))
  5316  		tr := cst.tr
  5317  		tr.IdleConnTimeout = timeout
  5318  		defer tr.CloseIdleConnections()
  5319  		c := &Client{Transport: tr}
  5320  
  5321  		idleConns := func() []string {
  5322  			if mode == http2Mode {
  5323  				return tr.IdleConnStrsForTesting_h2()
  5324  			} else {
  5325  				return tr.IdleConnStrsForTesting()
  5326  			}
  5327  		}
  5328  
  5329  		var conn string
  5330  		doReq := func(n int) (timeoutOk bool) {
  5331  			req, _ := NewRequest("GET", cst.ts.URL, nil)
  5332  			req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  5333  				PutIdleConn: func(err error) {
  5334  					if err != nil {
  5335  						t.Errorf("failed to keep idle conn: %v", err)
  5336  					}
  5337  				},
  5338  			}))
  5339  			res, err := c.Do(req)
  5340  			if err != nil {
  5341  				if strings.Contains(err.Error(), "use of closed network connection") {
  5342  					t.Logf("req %v: connection closed prematurely", n)
  5343  					return false
  5344  				}
  5345  			}
  5346  			res.Body.Close()
  5347  			conns := idleConns()
  5348  			if len(conns) != 1 {
  5349  				if len(conns) == 0 {
  5350  					t.Logf("req %v: no idle conns", n)
  5351  					return false
  5352  				}
  5353  				t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
  5354  			}
  5355  			if conn == "" {
  5356  				conn = conns[0]
  5357  			}
  5358  			if conn != conns[0] {
  5359  				t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
  5360  				return false
  5361  			}
  5362  			return true
  5363  		}
  5364  		for i := 0; i < 3; i++ {
  5365  			if !doReq(i) {
  5366  				t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
  5367  				timeout *= 2
  5368  				cst.close()
  5369  				continue timeoutLoop
  5370  			}
  5371  			time.Sleep(timeout / 2)
  5372  		}
  5373  
  5374  		waitCondition(t, timeout/2, func(d time.Duration) bool {
  5375  			if got := idleConns(); len(got) != 0 {
  5376  				if d >= timeout*3/2 {
  5377  					t.Logf("after %v, idle conns = %q", d, got)
  5378  				}
  5379  				return false
  5380  			}
  5381  			return true
  5382  		})
  5383  		break
  5384  	}
  5385  }
  5386  
  5387  // Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an
  5388  // HTTP/2 connection was established but its caller no longer
  5389  // wanted it. (Assuming the connection cache was enabled, which it is
  5390  // by default)
  5391  //
  5392  // This test reproduced the crash by setting the IdleConnTimeout low
  5393  // (to make the test reasonable) and then making a request which is
  5394  // canceled by the DialTLS hook, which then also waits to return the
  5395  // real connection until after the RoundTrip saw the error.  Then we
  5396  // know the successful tls.Dial from DialTLS will need to go into the
  5397  // idle pool. Then we give it a of time to explode.
  5398  func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
  5399  func testIdleConnH2Crash(t *testing.T, mode testMode) {
  5400  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5401  		// nothing
  5402  	}))
  5403  
  5404  	ctx, cancel := context.WithCancel(context.Background())
  5405  	defer cancel()
  5406  
  5407  	sawDoErr := make(chan bool, 1)
  5408  	testDone := make(chan struct{})
  5409  	defer close(testDone)
  5410  
  5411  	cst.tr.IdleConnTimeout = 5 * time.Millisecond
  5412  	cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
  5413  		c, err := tls.Dial(network, addr, &tls.Config{
  5414  			InsecureSkipVerify: true,
  5415  			NextProtos:         []string{"h2"},
  5416  		})
  5417  		if err != nil {
  5418  			t.Error(err)
  5419  			return nil, err
  5420  		}
  5421  		if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
  5422  			t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
  5423  			c.Close()
  5424  			return nil, errors.New("bogus")
  5425  		}
  5426  
  5427  		cancel()
  5428  
  5429  		select {
  5430  		case <-sawDoErr:
  5431  		case <-testDone:
  5432  		}
  5433  		return c, nil
  5434  	}
  5435  
  5436  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  5437  	req = req.WithContext(ctx)
  5438  	res, err := cst.c.Do(req)
  5439  	if err == nil {
  5440  		res.Body.Close()
  5441  		t.Fatal("unexpected success")
  5442  	}
  5443  	sawDoErr <- true
  5444  
  5445  	// Wait for the explosion.
  5446  	time.Sleep(cst.tr.IdleConnTimeout * 10)
  5447  }
  5448  
  5449  type funcConn struct {
  5450  	net.Conn
  5451  	read  func([]byte) (int, error)
  5452  	write func([]byte) (int, error)
  5453  }
  5454  
  5455  func (c funcConn) Read(p []byte) (int, error)  { return c.read(p) }
  5456  func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
  5457  func (c funcConn) Close() error                { return nil }
  5458  
  5459  // Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek
  5460  // back to the caller.
  5461  func TestTransportReturnsPeekError(t *testing.T) {
  5462  	errValue := errors.New("specific error value")
  5463  
  5464  	wrote := make(chan struct{})
  5465  	var wroteOnce sync.Once
  5466  
  5467  	tr := &Transport{
  5468  		Dial: func(network, addr string) (net.Conn, error) {
  5469  			c := funcConn{
  5470  				read: func([]byte) (int, error) {
  5471  					<-wrote
  5472  					return 0, errValue
  5473  				},
  5474  				write: func(p []byte) (int, error) {
  5475  					wroteOnce.Do(func() { close(wrote) })
  5476  					return len(p), nil
  5477  				},
  5478  			}
  5479  			return c, nil
  5480  		},
  5481  	}
  5482  	_, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
  5483  	if err != errValue {
  5484  		t.Errorf("error = %#v; want %v", err, errValue)
  5485  	}
  5486  }
  5487  
  5488  // Issue 13835: international domain names should work
  5489  func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
  5490  func testTransportIDNA(t *testing.T, mode testMode) {
  5491  	const uniDomain = "гофер.го"
  5492  	const punyDomain = "xn--c1ae0ajs.xn--c1aw"
  5493  
  5494  	var port string
  5495  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5496  		want := punyDomain + ":" + port
  5497  		if r.Host != want {
  5498  			t.Errorf("Host header = %q; want %q", r.Host, want)
  5499  		}
  5500  		if mode == http2Mode {
  5501  			if r.TLS == nil {
  5502  				t.Errorf("r.TLS == nil")
  5503  			} else if r.TLS.ServerName != punyDomain {
  5504  				t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
  5505  			}
  5506  		}
  5507  		w.Header().Set("Hit-Handler", "1")
  5508  	}), func(tr *Transport) {
  5509  		if tr.TLSClientConfig != nil {
  5510  			tr.TLSClientConfig.InsecureSkipVerify = true
  5511  		}
  5512  	})
  5513  
  5514  	ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
  5515  	if err != nil {
  5516  		t.Fatal(err)
  5517  	}
  5518  
  5519  	// Install a fake DNS server.
  5520  	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
  5521  		if host != punyDomain {
  5522  			t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
  5523  			return nil, nil
  5524  		}
  5525  		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
  5526  	})
  5527  
  5528  	req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
  5529  	trace := &httptrace.ClientTrace{
  5530  		GetConn: func(hostPort string) {
  5531  			want := net.JoinHostPort(punyDomain, port)
  5532  			if hostPort != want {
  5533  				t.Errorf("getting conn for %q; want %q", hostPort, want)
  5534  			}
  5535  		},
  5536  		DNSStart: func(e httptrace.DNSStartInfo) {
  5537  			if e.Host != punyDomain {
  5538  				t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
  5539  			}
  5540  		},
  5541  	}
  5542  	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
  5543  
  5544  	res, err := cst.tr.RoundTrip(req)
  5545  	if err != nil {
  5546  		t.Fatal(err)
  5547  	}
  5548  	defer res.Body.Close()
  5549  	if res.Header.Get("Hit-Handler") != "1" {
  5550  		out, err := httputil.DumpResponse(res, true)
  5551  		if err != nil {
  5552  			t.Fatal(err)
  5553  		}
  5554  		t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
  5555  	}
  5556  }
  5557  
  5558  // Issue 13290: send User-Agent in proxy CONNECT
  5559  func TestTransportProxyConnectHeader(t *testing.T) {
  5560  	run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
  5561  }
  5562  func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
  5563  	reqc := make(chan *Request, 1)
  5564  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5565  		if r.Method != "CONNECT" {
  5566  			t.Errorf("method = %q; want CONNECT", r.Method)
  5567  		}
  5568  		reqc <- r
  5569  		c, _, err := w.(Hijacker).Hijack()
  5570  		if err != nil {
  5571  			t.Errorf("Hijack: %v", err)
  5572  			return
  5573  		}
  5574  		c.Close()
  5575  	})).ts
  5576  
  5577  	c := ts.Client()
  5578  	c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
  5579  		return url.Parse(ts.URL)
  5580  	}
  5581  	c.Transport.(*Transport).ProxyConnectHeader = Header{
  5582  		"User-Agent": {"foo"},
  5583  		"Other":      {"bar"},
  5584  	}
  5585  
  5586  	res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
  5587  	if err == nil {
  5588  		res.Body.Close()
  5589  		t.Errorf("unexpected success")
  5590  	}
  5591  
  5592  	r := <-reqc
  5593  	if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
  5594  		t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
  5595  	}
  5596  	if got, want := r.Header.Get("Other"), "bar"; got != want {
  5597  		t.Errorf("CONNECT request Other = %q; want %q", got, want)
  5598  	}
  5599  }
  5600  
  5601  func TestTransportProxyGetConnectHeader(t *testing.T) {
  5602  	run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
  5603  }
  5604  func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
  5605  	reqc := make(chan *Request, 1)
  5606  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5607  		if r.Method != "CONNECT" {
  5608  			t.Errorf("method = %q; want CONNECT", r.Method)
  5609  		}
  5610  		reqc <- r
  5611  		c, _, err := w.(Hijacker).Hijack()
  5612  		if err != nil {
  5613  			t.Errorf("Hijack: %v", err)
  5614  			return
  5615  		}
  5616  		c.Close()
  5617  	})).ts
  5618  
  5619  	c := ts.Client()
  5620  	c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
  5621  		return url.Parse(ts.URL)
  5622  	}
  5623  	// These should be ignored:
  5624  	c.Transport.(*Transport).ProxyConnectHeader = Header{
  5625  		"User-Agent": {"foo"},
  5626  		"Other":      {"bar"},
  5627  	}
  5628  	c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
  5629  		return Header{
  5630  			"User-Agent": {"foo2"},
  5631  			"Other":      {"bar2"},
  5632  		}, nil
  5633  	}
  5634  
  5635  	res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
  5636  	if err == nil {
  5637  		res.Body.Close()
  5638  		t.Errorf("unexpected success")
  5639  	}
  5640  
  5641  	r := <-reqc
  5642  	if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
  5643  		t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
  5644  	}
  5645  	if got, want := r.Header.Get("Other"), "bar2"; got != want {
  5646  		t.Errorf("CONNECT request Other = %q; want %q", got, want)
  5647  	}
  5648  }
  5649  
  5650  var errFakeRoundTrip = errors.New("fake roundtrip")
  5651  
  5652  type funcRoundTripper func()
  5653  
  5654  func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
  5655  	fn()
  5656  	return nil, errFakeRoundTrip
  5657  }
  5658  
  5659  func wantBody(res *Response, err error, want string) error {
  5660  	if err != nil {
  5661  		return err
  5662  	}
  5663  	slurp, err := io.ReadAll(res.Body)
  5664  	if err != nil {
  5665  		return fmt.Errorf("error reading body: %v", err)
  5666  	}
  5667  	if string(slurp) != want {
  5668  		return fmt.Errorf("body = %q; want %q", slurp, want)
  5669  	}
  5670  	if err := res.Body.Close(); err != nil {
  5671  		return fmt.Errorf("body Close = %v", err)
  5672  	}
  5673  	return nil
  5674  }
  5675  
  5676  func newLocalListener(t *testing.T) net.Listener {
  5677  	ln, err := net.Listen("tcp", "127.0.0.1:0")
  5678  	if err != nil {
  5679  		ln, err = net.Listen("tcp6", "[::1]:0")
  5680  	}
  5681  	if err != nil {
  5682  		t.Fatal(err)
  5683  	}
  5684  	return ln
  5685  }
  5686  
  5687  type countCloseReader struct {
  5688  	n *int
  5689  	io.Reader
  5690  }
  5691  
  5692  func (cr countCloseReader) Close() error {
  5693  	(*cr.n)++
  5694  	return nil
  5695  }
  5696  
  5697  // rgz is a gzip quine that uncompresses to itself.
  5698  var rgz = []byte{
  5699  	0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
  5700  	0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
  5701  	0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
  5702  	0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
  5703  	0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
  5704  	0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
  5705  	0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
  5706  	0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
  5707  	0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
  5708  	0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
  5709  	0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
  5710  	0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
  5711  	0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
  5712  	0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
  5713  	0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
  5714  	0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
  5715  	0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
  5716  	0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
  5717  	0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
  5718  	0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
  5719  	0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
  5720  	0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
  5721  	0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
  5722  	0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
  5723  	0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
  5724  	0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
  5725  	0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
  5726  	0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
  5727  	0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
  5728  	0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
  5729  	0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
  5730  	0x00, 0x00,
  5731  }
  5732  
  5733  // Ensure that a missing status doesn't make the server panic
  5734  // See Issue https://golang.org/issues/21701
  5735  func TestMissingStatusNoPanic(t *testing.T) {
  5736  	t.Parallel()
  5737  
  5738  	const want = "unknown status code"
  5739  
  5740  	ln := newLocalListener(t)
  5741  	addr := ln.Addr().String()
  5742  	done := make(chan bool)
  5743  	fullAddrURL := fmt.Sprintf("http://%s", addr)
  5744  	raw := "HTTP/1.1 400\r\n" +
  5745  		"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
  5746  		"Content-Type: text/html; charset=utf-8\r\n" +
  5747  		"Content-Length: 10\r\n" +
  5748  		"Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
  5749  		"Vary: Accept-Encoding\r\n\r\n" +
  5750  		"Aloha Olaa"
  5751  
  5752  	go func() {
  5753  		defer close(done)
  5754  
  5755  		conn, _ := ln.Accept()
  5756  		if conn != nil {
  5757  			io.WriteString(conn, raw)
  5758  			io.ReadAll(conn)
  5759  			conn.Close()
  5760  		}
  5761  	}()
  5762  
  5763  	proxyURL, err := url.Parse(fullAddrURL)
  5764  	if err != nil {
  5765  		t.Fatalf("proxyURL: %v", err)
  5766  	}
  5767  
  5768  	tr := &Transport{Proxy: ProxyURL(proxyURL)}
  5769  
  5770  	req, _ := NewRequest("GET", "https://golang.org/", nil)
  5771  	res, err, panicked := doFetchCheckPanic(tr, req)
  5772  	if panicked {
  5773  		t.Error("panicked, expecting an error")
  5774  	}
  5775  	if res != nil && res.Body != nil {
  5776  		io.Copy(io.Discard, res.Body)
  5777  		res.Body.Close()
  5778  	}
  5779  
  5780  	if err == nil || !strings.Contains(err.Error(), want) {
  5781  		t.Errorf("got=%v want=%q", err, want)
  5782  	}
  5783  
  5784  	ln.Close()
  5785  	<-done
  5786  }
  5787  
  5788  func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
  5789  	defer func() {
  5790  		if r := recover(); r != nil {
  5791  			panicked = true
  5792  		}
  5793  	}()
  5794  	res, err = tr.RoundTrip(req)
  5795  	return
  5796  }
  5797  
  5798  // Issue 22330: do not allow the response body to be read when the status code
  5799  // forbids a response body.
  5800  func TestNoBodyOnChunked304Response(t *testing.T) {
  5801  	run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
  5802  }
  5803  func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
  5804  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5805  		conn, buf, _ := w.(Hijacker).Hijack()
  5806  		buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
  5807  		buf.Flush()
  5808  		conn.Close()
  5809  	}))
  5810  
  5811  	// Our test server above is sending back bogus data after the
  5812  	// response (the "0\r\n\r\n" part), which causes the Transport
  5813  	// code to log spam. Disable keep-alives so we never even try
  5814  	// to reuse the connection.
  5815  	cst.tr.DisableKeepAlives = true
  5816  
  5817  	res, err := cst.c.Get(cst.ts.URL)
  5818  	if err != nil {
  5819  		t.Fatal(err)
  5820  	}
  5821  
  5822  	if res.Body != NoBody {
  5823  		t.Errorf("Unexpected body on 304 response")
  5824  	}
  5825  }
  5826  
  5827  type funcWriter func([]byte) (int, error)
  5828  
  5829  func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
  5830  
  5831  type doneContext struct {
  5832  	context.Context
  5833  	err error
  5834  }
  5835  
  5836  func (doneContext) Done() <-chan struct{} {
  5837  	c := make(chan struct{})
  5838  	close(c)
  5839  	return c
  5840  }
  5841  
  5842  func (d doneContext) Err() error { return d.err }
  5843  
  5844  // Issue 25852: Transport should check whether Context is done early.
  5845  func TestTransportCheckContextDoneEarly(t *testing.T) {
  5846  	tr := &Transport{}
  5847  	req, _ := NewRequest("GET", "http://fake.example/", nil)
  5848  	wantErr := errors.New("some error")
  5849  	req = req.WithContext(doneContext{context.Background(), wantErr})
  5850  	_, err := tr.RoundTrip(req)
  5851  	if err != wantErr {
  5852  		t.Errorf("error = %v; want %v", err, wantErr)
  5853  	}
  5854  }
  5855  
  5856  // Issue 23399: verify that if a client request times out, the Transport's
  5857  // conn is closed so that it's not reused.
  5858  //
  5859  // This is the test variant that times out before the server replies with
  5860  // any response headers.
  5861  func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
  5862  	run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
  5863  }
  5864  func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
  5865  	timeout := 1 * time.Millisecond
  5866  	for {
  5867  		inHandler := make(chan bool)
  5868  		cancelHandler := make(chan struct{})
  5869  		handlerDone := make(chan bool)
  5870  		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5871  			<-r.Context().Done()
  5872  
  5873  			select {
  5874  			case <-cancelHandler:
  5875  				return
  5876  			case inHandler <- true:
  5877  			}
  5878  			defer func() { handlerDone <- true }()
  5879  
  5880  			// Read from the conn until EOF to verify that it was correctly closed.
  5881  			conn, _, err := w.(Hijacker).Hijack()
  5882  			if err != nil {
  5883  				t.Error(err)
  5884  				return
  5885  			}
  5886  			n, err := conn.Read([]byte{0})
  5887  			if n != 0 || err != io.EOF {
  5888  				t.Errorf("unexpected Read result: %v, %v", n, err)
  5889  			}
  5890  			conn.Close()
  5891  		}))
  5892  
  5893  		cst.c.Timeout = timeout
  5894  
  5895  		_, err := cst.c.Get(cst.ts.URL)
  5896  		if err == nil {
  5897  			close(cancelHandler)
  5898  			t.Fatal("unexpected Get success")
  5899  		}
  5900  
  5901  		tooSlow := time.NewTimer(timeout * 10)
  5902  		select {
  5903  		case <-tooSlow.C:
  5904  			// If we didn't get into the Handler, that probably means the builder was
  5905  			// just slow and the Get failed in that time but never made it to the
  5906  			// server. That's fine; we'll try again with a longer timeout.
  5907  			t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
  5908  			close(cancelHandler)
  5909  			cst.close()
  5910  			timeout *= 2
  5911  			continue
  5912  		case <-inHandler:
  5913  			tooSlow.Stop()
  5914  			<-handlerDone
  5915  		}
  5916  		break
  5917  	}
  5918  }
  5919  
  5920  // Issue 23399: verify that if a client request times out, the Transport's
  5921  // conn is closed so that it's not reused.
  5922  //
  5923  // This is the test variant that has the server send response headers
  5924  // first, and time out during the write of the response body.
  5925  func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
  5926  	run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
  5927  }
  5928  func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
  5929  	inHandler := make(chan bool)
  5930  	cancelHandler := make(chan struct{})
  5931  	handlerDone := make(chan bool)
  5932  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5933  		w.Header().Set("Content-Length", "100")
  5934  		w.(Flusher).Flush()
  5935  
  5936  		select {
  5937  		case <-cancelHandler:
  5938  			return
  5939  		case inHandler <- true:
  5940  		}
  5941  		defer func() { handlerDone <- true }()
  5942  
  5943  		conn, _, err := w.(Hijacker).Hijack()
  5944  		if err != nil {
  5945  			t.Error(err)
  5946  			return
  5947  		}
  5948  		conn.Write([]byte("foo"))
  5949  
  5950  		n, err := conn.Read([]byte{0})
  5951  		// The error should be io.EOF or "read tcp
  5952  		// 127.0.0.1:35827->127.0.0.1:40290: read: connection
  5953  		// reset by peer" depending on timing. Really we just
  5954  		// care that it returns at all. But if it returns with
  5955  		// data, that's weird.
  5956  		if n != 0 || err == nil {
  5957  			t.Errorf("unexpected Read result: %v, %v", n, err)
  5958  		}
  5959  		conn.Close()
  5960  	}))
  5961  
  5962  	// Set Timeout to something very long but non-zero to exercise
  5963  	// the codepaths that check for it. But rather than wait for it to fire
  5964  	// (which would make the test slow), we send on the req.Cancel channel instead,
  5965  	// which happens to exercise the same code paths.
  5966  	cst.c.Timeout = 24 * time.Hour // just to be non-zero, not to hit it.
  5967  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  5968  	cancelReq := make(chan struct{})
  5969  	req.Cancel = cancelReq
  5970  
  5971  	res, err := cst.c.Do(req)
  5972  	if err != nil {
  5973  		close(cancelHandler)
  5974  		t.Fatalf("Get error: %v", err)
  5975  	}
  5976  
  5977  	// Cancel the request while the handler is still blocked on sending to the
  5978  	// inHandler channel. Then read it until it fails, to verify that the
  5979  	// connection is broken before the handler itself closes it.
  5980  	close(cancelReq)
  5981  	got, err := io.ReadAll(res.Body)
  5982  	if err == nil {
  5983  		t.Errorf("unexpected success; read %q, nil", got)
  5984  	}
  5985  
  5986  	// Now unblock the handler and wait for it to complete.
  5987  	<-inHandler
  5988  	<-handlerDone
  5989  }
  5990  
  5991  func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
  5992  	run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
  5993  }
  5994  func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
  5995  	done := make(chan struct{})
  5996  	defer close(done)
  5997  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5998  		conn, _, err := w.(Hijacker).Hijack()
  5999  		if err != nil {
  6000  			t.Error(err)
  6001  			return
  6002  		}
  6003  		defer conn.Close()
  6004  		io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
  6005  		bs := bufio.NewScanner(conn)
  6006  		bs.Scan()
  6007  		fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
  6008  		<-done
  6009  	}))
  6010  
  6011  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  6012  	req.Header.Set("Upgrade", "foo")
  6013  	req.Header.Set("Connection", "upgrade")
  6014  	res, err := cst.c.Do(req)
  6015  	if err != nil {
  6016  		t.Fatal(err)
  6017  	}
  6018  	if res.StatusCode != 101 {
  6019  		t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
  6020  	}
  6021  	rwc, ok := res.Body.(io.ReadWriteCloser)
  6022  	if !ok {
  6023  		t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
  6024  	}
  6025  	defer rwc.Close()
  6026  	bs := bufio.NewScanner(rwc)
  6027  	if !bs.Scan() {
  6028  		t.Fatalf("expected readable input")
  6029  	}
  6030  	if got, want := bs.Text(), "Some buffered data"; got != want {
  6031  		t.Errorf("read %q; want %q", got, want)
  6032  	}
  6033  	io.WriteString(rwc, "echo\n")
  6034  	if !bs.Scan() {
  6035  		t.Fatalf("expected another line")
  6036  	}
  6037  	if got, want := bs.Text(), "ECHO"; got != want {
  6038  		t.Errorf("read %q; want %q", got, want)
  6039  	}
  6040  }
  6041  
  6042  func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
  6043  func testTransportCONNECTBidi(t *testing.T, mode testMode) {
  6044  	const target = "backend:443"
  6045  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6046  		if r.Method != "CONNECT" {
  6047  			t.Errorf("unexpected method %q", r.Method)
  6048  			w.WriteHeader(500)
  6049  			return
  6050  		}
  6051  		if r.RequestURI != target {
  6052  			t.Errorf("unexpected CONNECT target %q", r.RequestURI)
  6053  			w.WriteHeader(500)
  6054  			return
  6055  		}
  6056  		nc, brw, err := w.(Hijacker).Hijack()
  6057  		if err != nil {
  6058  			t.Error(err)
  6059  			return
  6060  		}
  6061  		defer nc.Close()
  6062  		nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
  6063  		// Switch to a little protocol that capitalize its input lines:
  6064  		for {
  6065  			line, err := brw.ReadString('\n')
  6066  			if err != nil {
  6067  				if err != io.EOF {
  6068  					t.Error(err)
  6069  				}
  6070  				return
  6071  			}
  6072  			io.WriteString(brw, strings.ToUpper(line))
  6073  			brw.Flush()
  6074  		}
  6075  	}))
  6076  	pr, pw := io.Pipe()
  6077  	defer pw.Close()
  6078  	req, err := NewRequest("CONNECT", cst.ts.URL, pr)
  6079  	if err != nil {
  6080  		t.Fatal(err)
  6081  	}
  6082  	req.URL.Opaque = target
  6083  	res, err := cst.c.Do(req)
  6084  	if err != nil {
  6085  		t.Fatal(err)
  6086  	}
  6087  	defer res.Body.Close()
  6088  	if res.StatusCode != 200 {
  6089  		t.Fatalf("status code = %d; want 200", res.StatusCode)
  6090  	}
  6091  	br := bufio.NewReader(res.Body)
  6092  	for _, str := range []string{"foo", "bar", "baz"} {
  6093  		fmt.Fprintf(pw, "%s\n", str)
  6094  		got, err := br.ReadString('\n')
  6095  		if err != nil {
  6096  			t.Fatal(err)
  6097  		}
  6098  		got = strings.TrimSpace(got)
  6099  		want := strings.ToUpper(str)
  6100  		if got != want {
  6101  			t.Fatalf("got %q; want %q", got, want)
  6102  		}
  6103  	}
  6104  }
  6105  
  6106  func TestTransportRequestReplayable(t *testing.T) {
  6107  	someBody := io.NopCloser(strings.NewReader(""))
  6108  	tests := []struct {
  6109  		name string
  6110  		req  *Request
  6111  		want bool
  6112  	}{
  6113  		{
  6114  			name: "GET",
  6115  			req:  &Request{Method: "GET"},
  6116  			want: true,
  6117  		},
  6118  		{
  6119  			name: "GET_http.NoBody",
  6120  			req:  &Request{Method: "GET", Body: NoBody},
  6121  			want: true,
  6122  		},
  6123  		{
  6124  			name: "GET_body",
  6125  			req:  &Request{Method: "GET", Body: someBody},
  6126  			want: false,
  6127  		},
  6128  		{
  6129  			name: "POST",
  6130  			req:  &Request{Method: "POST"},
  6131  			want: false,
  6132  		},
  6133  		{
  6134  			name: "POST_idempotency-key",
  6135  			req:  &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
  6136  			want: true,
  6137  		},
  6138  		{
  6139  			name: "POST_x-idempotency-key",
  6140  			req:  &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
  6141  			want: true,
  6142  		},
  6143  		{
  6144  			name: "POST_body",
  6145  			req:  &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
  6146  			want: false,
  6147  		},
  6148  	}
  6149  	for _, tt := range tests {
  6150  		t.Run(tt.name, func(t *testing.T) {
  6151  			got := tt.req.ExportIsReplayable()
  6152  			if got != tt.want {
  6153  				t.Errorf("replyable = %v; want %v", got, tt.want)
  6154  			}
  6155  		})
  6156  	}
  6157  }
  6158  
  6159  // testMockTCPConn is a mock TCP connection used to test that
  6160  // ReadFrom is called when sending the request body.
  6161  type testMockTCPConn struct {
  6162  	*net.TCPConn
  6163  
  6164  	ReadFromCalled bool
  6165  }
  6166  
  6167  func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
  6168  	c.ReadFromCalled = true
  6169  	return c.TCPConn.ReadFrom(r)
  6170  }
  6171  
  6172  func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
  6173  func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
  6174  	nBytes := int64(1 << 10)
  6175  	newFileFunc := func() (r io.Reader, done func(), err error) {
  6176  		f, err := os.CreateTemp("", "net-http-newfilefunc")
  6177  		if err != nil {
  6178  			return nil, nil, err
  6179  		}
  6180  
  6181  		// Write some bytes to the file to enable reading.
  6182  		if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
  6183  			return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
  6184  		}
  6185  		if _, err := f.Seek(0, 0); err != nil {
  6186  			return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
  6187  		}
  6188  
  6189  		done = func() {
  6190  			f.Close()
  6191  			os.Remove(f.Name())
  6192  		}
  6193  
  6194  		return f, done, nil
  6195  	}
  6196  
  6197  	newBufferFunc := func() (io.Reader, func(), error) {
  6198  		return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
  6199  	}
  6200  
  6201  	cases := []struct {
  6202  		name             string
  6203  		readerFunc       func() (io.Reader, func(), error)
  6204  		contentLength    int64
  6205  		expectedReadFrom bool
  6206  	}{
  6207  		{
  6208  			name:             "file, length",
  6209  			readerFunc:       newFileFunc,
  6210  			contentLength:    nBytes,
  6211  			expectedReadFrom: true,
  6212  		},
  6213  		{
  6214  			name:       "file, no length",
  6215  			readerFunc: newFileFunc,
  6216  		},
  6217  		{
  6218  			name:          "file, negative length",
  6219  			readerFunc:    newFileFunc,
  6220  			contentLength: -1,
  6221  		},
  6222  		{
  6223  			name:          "buffer",
  6224  			contentLength: nBytes,
  6225  			readerFunc:    newBufferFunc,
  6226  		},
  6227  		{
  6228  			name:       "buffer, no length",
  6229  			readerFunc: newBufferFunc,
  6230  		},
  6231  		{
  6232  			name:          "buffer, length -1",
  6233  			contentLength: -1,
  6234  			readerFunc:    newBufferFunc,
  6235  		},
  6236  	}
  6237  
  6238  	for _, tc := range cases {
  6239  		t.Run(tc.name, func(t *testing.T) {
  6240  			r, cleanup, err := tc.readerFunc()
  6241  			if err != nil {
  6242  				t.Fatal(err)
  6243  			}
  6244  			defer cleanup()
  6245  
  6246  			tConn := &testMockTCPConn{}
  6247  			trFunc := func(tr *Transport) {
  6248  				tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  6249  					var d net.Dialer
  6250  					conn, err := d.DialContext(ctx, network, addr)
  6251  					if err != nil {
  6252  						return nil, err
  6253  					}
  6254  
  6255  					tcpConn, ok := conn.(*net.TCPConn)
  6256  					if !ok {
  6257  						return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
  6258  					}
  6259  
  6260  					tConn.TCPConn = tcpConn
  6261  					return tConn, nil
  6262  				}
  6263  			}
  6264  
  6265  			cst := newClientServerTest(
  6266  				t,
  6267  				mode,
  6268  				HandlerFunc(func(w ResponseWriter, r *Request) {
  6269  					io.Copy(io.Discard, r.Body)
  6270  					r.Body.Close()
  6271  					w.WriteHeader(200)
  6272  				}),
  6273  				trFunc,
  6274  			)
  6275  
  6276  			req, err := NewRequest("PUT", cst.ts.URL, r)
  6277  			if err != nil {
  6278  				t.Fatal(err)
  6279  			}
  6280  			req.ContentLength = tc.contentLength
  6281  			req.Header.Set("Content-Type", "application/octet-stream")
  6282  			resp, err := cst.c.Do(req)
  6283  			if err != nil {
  6284  				t.Fatal(err)
  6285  			}
  6286  			defer resp.Body.Close()
  6287  			if resp.StatusCode != 200 {
  6288  				t.Fatalf("status code = %d; want 200", resp.StatusCode)
  6289  			}
  6290  
  6291  			expectedReadFrom := tc.expectedReadFrom
  6292  			if mode != http1Mode {
  6293  				expectedReadFrom = false
  6294  			}
  6295  			if !tConn.ReadFromCalled && expectedReadFrom {
  6296  				t.Fatalf("did not call ReadFrom")
  6297  			}
  6298  
  6299  			if tConn.ReadFromCalled && !expectedReadFrom {
  6300  				t.Fatalf("ReadFrom was unexpectedly invoked")
  6301  			}
  6302  		})
  6303  	}
  6304  }
  6305  
  6306  func TestTransportClone(t *testing.T) {
  6307  	tr := &Transport{
  6308  		Proxy: func(*Request) (*url.URL, error) { panic("") },
  6309  		OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
  6310  			return nil
  6311  		},
  6312  		DialContext:            func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
  6313  		Dial:                   func(network, addr string) (net.Conn, error) { panic("") },
  6314  		DialTLS:                func(network, addr string) (net.Conn, error) { panic("") },
  6315  		DialTLSContext:         func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
  6316  		TLSClientConfig:        new(tls.Config),
  6317  		TLSHandshakeTimeout:    time.Second,
  6318  		DisableKeepAlives:      true,
  6319  		DisableCompression:     true,
  6320  		MaxIdleConns:           1,
  6321  		MaxIdleConnsPerHost:    1,
  6322  		MaxConnsPerHost:        1,
  6323  		IdleConnTimeout:        time.Second,
  6324  		ResponseHeaderTimeout:  time.Second,
  6325  		ExpectContinueTimeout:  time.Second,
  6326  		ProxyConnectHeader:     Header{},
  6327  		GetProxyConnectHeader:  func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
  6328  		MaxResponseHeaderBytes: 1,
  6329  		ForceAttemptHTTP2:      true,
  6330  		TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
  6331  			"foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
  6332  		},
  6333  		ReadBufferSize:  1,
  6334  		WriteBufferSize: 1,
  6335  	}
  6336  	tr2 := tr.Clone()
  6337  	rv := reflect.ValueOf(tr2).Elem()
  6338  	rt := rv.Type()
  6339  	for i := 0; i < rt.NumField(); i++ {
  6340  		sf := rt.Field(i)
  6341  		if !token.IsExported(sf.Name) {
  6342  			continue
  6343  		}
  6344  		if rv.Field(i).IsZero() {
  6345  			t.Errorf("cloned field t2.%s is zero", sf.Name)
  6346  		}
  6347  	}
  6348  
  6349  	if _, ok := tr2.TLSNextProto["foo"]; !ok {
  6350  		t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
  6351  	}
  6352  
  6353  	// But test that a nil TLSNextProto is kept nil:
  6354  	tr = new(Transport)
  6355  	tr2 = tr.Clone()
  6356  	if tr2.TLSNextProto != nil {
  6357  		t.Errorf("Transport.TLSNextProto unexpected non-nil")
  6358  	}
  6359  }
  6360  
  6361  func TestIs408(t *testing.T) {
  6362  	tests := []struct {
  6363  		in   string
  6364  		want bool
  6365  	}{
  6366  		{"HTTP/1.0 408", true},
  6367  		{"HTTP/1.1 408", true},
  6368  		{"HTTP/1.8 408", true},
  6369  		{"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now.
  6370  		{"HTTP/1.1 408 ", true},
  6371  		{"HTTP/1.1 40", false},
  6372  		{"http/1.0 408", false},
  6373  		{"HTTP/1-1 408", false},
  6374  	}
  6375  	for _, tt := range tests {
  6376  		if got := Export_is408Message([]byte(tt.in)); got != tt.want {
  6377  			t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
  6378  		}
  6379  	}
  6380  }
  6381  
  6382  func TestTransportIgnores408(t *testing.T) {
  6383  	run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
  6384  }
  6385  func testTransportIgnores408(t *testing.T, mode testMode) {
  6386  	// Not parallel. Relies on mutating the log package's global Output.
  6387  	defer log.SetOutput(log.Writer())
  6388  
  6389  	var logout strings.Builder
  6390  	log.SetOutput(&logout)
  6391  
  6392  	const target = "backend:443"
  6393  
  6394  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6395  		nc, _, err := w.(Hijacker).Hijack()
  6396  		if err != nil {
  6397  			t.Error(err)
  6398  			return
  6399  		}
  6400  		defer nc.Close()
  6401  		nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
  6402  		nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail
  6403  	}))
  6404  	req, err := NewRequest("GET", cst.ts.URL, nil)
  6405  	if err != nil {
  6406  		t.Fatal(err)
  6407  	}
  6408  	res, err := cst.c.Do(req)
  6409  	if err != nil {
  6410  		t.Fatal(err)
  6411  	}
  6412  	slurp, err := io.ReadAll(res.Body)
  6413  	if err != nil {
  6414  		t.Fatal(err)
  6415  	}
  6416  	if err != nil {
  6417  		t.Fatal(err)
  6418  	}
  6419  	if string(slurp) != "ok" {
  6420  		t.Fatalf("got %q; want ok", slurp)
  6421  	}
  6422  
  6423  	waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
  6424  		if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
  6425  			if d > 0 {
  6426  				t.Logf("%v idle conns still present after %v", n, d)
  6427  			}
  6428  			return false
  6429  		}
  6430  		return true
  6431  	})
  6432  	if got := logout.String(); got != "" {
  6433  		t.Fatalf("expected no log output; got: %s", got)
  6434  	}
  6435  }
  6436  
  6437  func TestInvalidHeaderResponse(t *testing.T) {
  6438  	run(t, testInvalidHeaderResponse, []testMode{http1Mode})
  6439  }
  6440  func testInvalidHeaderResponse(t *testing.T, mode testMode) {
  6441  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6442  		conn, buf, _ := w.(Hijacker).Hijack()
  6443  		buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
  6444  			"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
  6445  			"Content-Type: text/html; charset=utf-8\r\n" +
  6446  			"Content-Length: 0\r\n" +
  6447  			"Foo : bar\r\n\r\n"))
  6448  		buf.Flush()
  6449  		conn.Close()
  6450  	}))
  6451  	res, err := cst.c.Get(cst.ts.URL)
  6452  	if err != nil {
  6453  		t.Fatal(err)
  6454  	}
  6455  	defer res.Body.Close()
  6456  	if v := res.Header.Get("Foo"); v != "" {
  6457  		t.Errorf(`unexpected "Foo" header: %q`, v)
  6458  	}
  6459  	if v := res.Header.Get("Foo "); v != "bar" {
  6460  		t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
  6461  	}
  6462  }
  6463  
  6464  type bodyCloser bool
  6465  
  6466  func (bc *bodyCloser) Close() error {
  6467  	*bc = true
  6468  	return nil
  6469  }
  6470  func (bc *bodyCloser) Read(b []byte) (n int, err error) {
  6471  	return 0, io.EOF
  6472  }
  6473  
  6474  // Issue 35015: ensure that Transport closes the body on any error
  6475  // with an invalid request, as promised by Client.Do docs.
  6476  func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
  6477  	run(t, testTransportClosesBodyOnInvalidRequests)
  6478  }
  6479  func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
  6480  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6481  		t.Errorf("Should not have been invoked")
  6482  	})).ts
  6483  
  6484  	u, _ := url.Parse(cst.URL)
  6485  
  6486  	tests := []struct {
  6487  		name    string
  6488  		req     *Request
  6489  		wantErr string
  6490  	}{
  6491  		{
  6492  			name: "invalid method",
  6493  			req: &Request{
  6494  				Method: " ",
  6495  				URL:    u,
  6496  			},
  6497  			wantErr: `invalid method " "`,
  6498  		},
  6499  		{
  6500  			name: "nil URL",
  6501  			req: &Request{
  6502  				Method: "GET",
  6503  			},
  6504  			wantErr: `nil Request.URL`,
  6505  		},
  6506  		{
  6507  			name: "invalid header key",
  6508  			req: &Request{
  6509  				Method: "GET",
  6510  				Header: Header{"💡": {"emoji"}},
  6511  				URL:    u,
  6512  			},
  6513  			wantErr: `invalid header field name "💡"`,
  6514  		},
  6515  		{
  6516  			name: "invalid header value",
  6517  			req: &Request{
  6518  				Method: "POST",
  6519  				Header: Header{"key": {"\x19"}},
  6520  				URL:    u,
  6521  			},
  6522  			wantErr: `invalid header field value for "key"`,
  6523  		},
  6524  		{
  6525  			name: "non HTTP(s) scheme",
  6526  			req: &Request{
  6527  				Method: "POST",
  6528  				URL:    &url.URL{Scheme: "faux"},
  6529  			},
  6530  			wantErr: `unsupported protocol scheme "faux"`,
  6531  		},
  6532  		{
  6533  			name: "no Host in URL",
  6534  			req: &Request{
  6535  				Method: "POST",
  6536  				URL:    &url.URL{Scheme: "http"},
  6537  			},
  6538  			wantErr: `no Host in request URL`,
  6539  		},
  6540  	}
  6541  
  6542  	for _, tt := range tests {
  6543  		t.Run(tt.name, func(t *testing.T) {
  6544  			var bc bodyCloser
  6545  			req := tt.req
  6546  			req.Body = &bc
  6547  			_, err := cst.Client().Do(tt.req)
  6548  			if err == nil {
  6549  				t.Fatal("Expected an error")
  6550  			}
  6551  			if !bc {
  6552  				t.Fatal("Expected body to have been closed")
  6553  			}
  6554  			if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
  6555  				t.Fatalf("Error mismatch: %q does not end with %q", g, w)
  6556  			}
  6557  		})
  6558  	}
  6559  }
  6560  
  6561  // breakableConn is a net.Conn wrapper with a Write method
  6562  // that will fail when its brokenState is true.
  6563  type breakableConn struct {
  6564  	net.Conn
  6565  	*brokenState
  6566  }
  6567  
  6568  type brokenState struct {
  6569  	sync.Mutex
  6570  	broken bool
  6571  }
  6572  
  6573  func (w *breakableConn) Write(b []byte) (n int, err error) {
  6574  	w.Lock()
  6575  	defer w.Unlock()
  6576  	if w.broken {
  6577  		return 0, errors.New("some write error")
  6578  	}
  6579  	return w.Conn.Write(b)
  6580  }
  6581  
  6582  // Issue 34978: don't cache a broken HTTP/2 connection
  6583  func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
  6584  	run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
  6585  }
  6586  func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
  6587  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
  6588  
  6589  	var brokenState brokenState
  6590  
  6591  	const numReqs = 5
  6592  	var numDials, gotConns uint32 // atomic
  6593  
  6594  	cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
  6595  		atomic.AddUint32(&numDials, 1)
  6596  		c, err := net.Dial(netw, addr)
  6597  		if err != nil {
  6598  			t.Errorf("unexpected Dial error: %v", err)
  6599  			return nil, err
  6600  		}
  6601  		return &breakableConn{c, &brokenState}, err
  6602  	}
  6603  
  6604  	for i := 1; i <= numReqs; i++ {
  6605  		brokenState.Lock()
  6606  		brokenState.broken = false
  6607  		brokenState.Unlock()
  6608  
  6609  		// doBreak controls whether we break the TCP connection after the TLS
  6610  		// handshake (before the HTTP/2 handshake). We test a few failures
  6611  		// in a row followed by a final success.
  6612  		doBreak := i != numReqs
  6613  
  6614  		ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  6615  			GotConn: func(info httptrace.GotConnInfo) {
  6616  				t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
  6617  				atomic.AddUint32(&gotConns, 1)
  6618  			},
  6619  			TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
  6620  				brokenState.Lock()
  6621  				defer brokenState.Unlock()
  6622  				if doBreak {
  6623  					brokenState.broken = true
  6624  				}
  6625  			},
  6626  		})
  6627  		req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
  6628  		if err != nil {
  6629  			t.Fatal(err)
  6630  		}
  6631  		_, err = cst.c.Do(req)
  6632  		if doBreak != (err != nil) {
  6633  			t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
  6634  		}
  6635  	}
  6636  	if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
  6637  		t.Errorf("GotConn calls = %v; want %v", got, want)
  6638  	}
  6639  	if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
  6640  		t.Errorf("Dials = %v; want %v", got, want)
  6641  	}
  6642  }
  6643  
  6644  // Issue 34941
  6645  // When the client has too many concurrent requests on a single connection,
  6646  // http.http2noCachedConnError is reported on multiple requests. There should
  6647  // only be one decrement regardless of the number of failures.
  6648  func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
  6649  	run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
  6650  }
  6651  func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
  6652  	CondSkipHTTP2(t)
  6653  
  6654  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
  6655  		_, err := w.Write([]byte("foo"))
  6656  		if err != nil {
  6657  			t.Fatalf("Write: %v", err)
  6658  		}
  6659  	})
  6660  
  6661  	ts := newClientServerTest(t, mode, h).ts
  6662  
  6663  	c := ts.Client()
  6664  	tr := c.Transport.(*Transport)
  6665  	tr.MaxConnsPerHost = 1
  6666  
  6667  	errCh := make(chan error, 300)
  6668  	doReq := func() {
  6669  		resp, err := c.Get(ts.URL)
  6670  		if err != nil {
  6671  			errCh <- fmt.Errorf("request failed: %v", err)
  6672  			return
  6673  		}
  6674  		defer resp.Body.Close()
  6675  		_, err = io.ReadAll(resp.Body)
  6676  		if err != nil {
  6677  			errCh <- fmt.Errorf("read body failed: %v", err)
  6678  		}
  6679  	}
  6680  
  6681  	var wg sync.WaitGroup
  6682  	for i := 0; i < 300; i++ {
  6683  		wg.Add(1)
  6684  		go func() {
  6685  			defer wg.Done()
  6686  			doReq()
  6687  		}()
  6688  	}
  6689  	wg.Wait()
  6690  	close(errCh)
  6691  
  6692  	for err := range errCh {
  6693  		t.Errorf("error occurred: %v", err)
  6694  	}
  6695  }
  6696  
  6697  // Issue 36820
  6698  // Test that we use the older backward compatible cancellation protocol
  6699  // when a RoundTripper is registered via RegisterProtocol.
  6700  func TestAltProtoCancellation(t *testing.T) {
  6701  	defer afterTest(t)
  6702  	tr := &Transport{}
  6703  	c := &Client{
  6704  		Transport: tr,
  6705  		Timeout:   time.Millisecond,
  6706  	}
  6707  	tr.RegisterProtocol("cancel", cancelProto{})
  6708  	_, err := c.Get("cancel://bar.com/path")
  6709  	if err == nil {
  6710  		t.Error("request unexpectedly succeeded")
  6711  	} else if !strings.Contains(err.Error(), errCancelProto.Error()) {
  6712  		t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
  6713  	}
  6714  }
  6715  
  6716  var errCancelProto = errors.New("canceled as expected")
  6717  
  6718  type cancelProto struct{}
  6719  
  6720  func (cancelProto) RoundTrip(req *Request) (*Response, error) {
  6721  	<-req.Cancel
  6722  	return nil, errCancelProto
  6723  }
  6724  
  6725  type roundTripFunc func(r *Request) (*Response, error)
  6726  
  6727  func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
  6728  
  6729  // Issue 32441: body is not reset after ErrSkipAltProtocol
  6730  func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
  6731  func testIssue32441(t *testing.T, mode testMode) {
  6732  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6733  		if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
  6734  			t.Error("body length is zero")
  6735  		}
  6736  	})).ts
  6737  	c := ts.Client()
  6738  	c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
  6739  		// Draining body to trigger failure condition on actual request to server.
  6740  		if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
  6741  			t.Error("body length is zero during round trip")
  6742  		}
  6743  		return nil, ErrSkipAltProtocol
  6744  	}))
  6745  	if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
  6746  		t.Error(err)
  6747  	}
  6748  }
  6749  
  6750  // Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers
  6751  // that contain a sign (eg. "+3"), per RFC 2616, Section 14.13.
  6752  func TestTransportRejectsSignInContentLength(t *testing.T) {
  6753  	run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
  6754  }
  6755  func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
  6756  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6757  		w.Header().Set("Content-Length", "+3")
  6758  		w.Write([]byte("abc"))
  6759  	})).ts
  6760  
  6761  	c := cst.Client()
  6762  	res, err := c.Get(cst.URL)
  6763  	if err == nil || res != nil {
  6764  		t.Fatal("Expected a non-nil error and a nil http.Response")
  6765  	}
  6766  	if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
  6767  		t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
  6768  	}
  6769  }
  6770  
  6771  // dumpConn is a net.Conn which writes to Writer and reads from Reader
  6772  type dumpConn struct {
  6773  	io.Writer
  6774  	io.Reader
  6775  }
  6776  
  6777  func (c *dumpConn) Close() error                       { return nil }
  6778  func (c *dumpConn) LocalAddr() net.Addr                { return nil }
  6779  func (c *dumpConn) RemoteAddr() net.Addr               { return nil }
  6780  func (c *dumpConn) SetDeadline(t time.Time) error      { return nil }
  6781  func (c *dumpConn) SetReadDeadline(t time.Time) error  { return nil }
  6782  func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
  6783  
  6784  // delegateReader is a reader that delegates to another reader,
  6785  // once it arrives on a channel.
  6786  type delegateReader struct {
  6787  	c chan io.Reader
  6788  	r io.Reader // nil until received from c
  6789  }
  6790  
  6791  func (r *delegateReader) Read(p []byte) (int, error) {
  6792  	if r.r == nil {
  6793  		var ok bool
  6794  		if r.r, ok = <-r.c; !ok {
  6795  			return 0, errors.New("delegate closed")
  6796  		}
  6797  	}
  6798  	return r.r.Read(p)
  6799  }
  6800  
  6801  func testTransportRace(req *Request) {
  6802  	save := req.Body
  6803  	pr, pw := io.Pipe()
  6804  	defer pr.Close()
  6805  	defer pw.Close()
  6806  	dr := &delegateReader{c: make(chan io.Reader)}
  6807  
  6808  	t := &Transport{
  6809  		Dial: func(net, addr string) (net.Conn, error) {
  6810  			return &dumpConn{pw, dr}, nil
  6811  		},
  6812  	}
  6813  	defer t.CloseIdleConnections()
  6814  
  6815  	quitReadCh := make(chan struct{})
  6816  	// Wait for the request before replying with a dummy response:
  6817  	go func() {
  6818  		defer close(quitReadCh)
  6819  
  6820  		req, err := ReadRequest(bufio.NewReader(pr))
  6821  		if err == nil {
  6822  			// Ensure all the body is read; otherwise
  6823  			// we'll get a partial dump.
  6824  			io.Copy(io.Discard, req.Body)
  6825  			req.Body.Close()
  6826  		}
  6827  		select {
  6828  		case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
  6829  		case quitReadCh <- struct{}{}:
  6830  			// Ensure delegate is closed so Read doesn't block forever.
  6831  			close(dr.c)
  6832  		}
  6833  	}()
  6834  
  6835  	t.RoundTrip(req)
  6836  
  6837  	// Ensure the reader returns before we reset req.Body to prevent
  6838  	// a data race on req.Body.
  6839  	pw.Close()
  6840  	<-quitReadCh
  6841  
  6842  	req.Body = save
  6843  }
  6844  
  6845  // Issue 37669
  6846  // Test that a cancellation doesn't result in a data race due to the writeLoop
  6847  // goroutine being left running, if the caller mutates the processed Request
  6848  // upon completion.
  6849  func TestErrorWriteLoopRace(t *testing.T) {
  6850  	if testing.Short() {
  6851  		return
  6852  	}
  6853  	t.Parallel()
  6854  	for i := 0; i < 1000; i++ {
  6855  		delay := time.Duration(mrand.Intn(5)) * time.Millisecond
  6856  		ctx, cancel := context.WithTimeout(context.Background(), delay)
  6857  		defer cancel()
  6858  
  6859  		r := bytes.NewBuffer(make([]byte, 10000))
  6860  		req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
  6861  		if err != nil {
  6862  			t.Fatal(err)
  6863  		}
  6864  
  6865  		testTransportRace(req)
  6866  	}
  6867  }
  6868  
  6869  // Issue 41600
  6870  // Test that a new request which uses the connection of an active request
  6871  // cannot cause it to be canceled as well.
  6872  func TestCancelRequestWhenSharingConnection(t *testing.T) {
  6873  	run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
  6874  }
  6875  func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
  6876  	reqc := make(chan chan struct{}, 2)
  6877  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
  6878  		ch := make(chan struct{}, 1)
  6879  		reqc <- ch
  6880  		<-ch
  6881  		w.Header().Add("Content-Length", "0")
  6882  	})).ts
  6883  
  6884  	client := ts.Client()
  6885  	transport := client.Transport.(*Transport)
  6886  	transport.MaxIdleConns = 1
  6887  	transport.MaxConnsPerHost = 1
  6888  
  6889  	var wg sync.WaitGroup
  6890  
  6891  	wg.Add(1)
  6892  	putidlec := make(chan chan struct{}, 1)
  6893  	reqerrc := make(chan error, 1)
  6894  	go func() {
  6895  		defer wg.Done()
  6896  		ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  6897  			PutIdleConn: func(error) {
  6898  				// Signal that the idle conn has been returned to the pool,
  6899  				// and wait for the order to proceed.
  6900  				ch := make(chan struct{})
  6901  				putidlec <- ch
  6902  				close(putidlec) // panic if PutIdleConn runs twice for some reason
  6903  				<-ch
  6904  			},
  6905  		})
  6906  		req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
  6907  		res, err := client.Do(req)
  6908  		if err != nil {
  6909  			reqerrc <- err
  6910  		} else {
  6911  			res.Body.Close()
  6912  		}
  6913  	}()
  6914  
  6915  	// Wait for the first request to receive a response and return the
  6916  	// connection to the idle pool.
  6917  	select {
  6918  	case err := <-reqerrc:
  6919  		t.Fatalf("request 1: got err %v, want nil", err)
  6920  	case r1c := <-reqc:
  6921  		close(r1c)
  6922  	}
  6923  	var idlec chan struct{}
  6924  	select {
  6925  	case err := <-reqerrc:
  6926  		t.Fatalf("request 1: got err %v, want nil", err)
  6927  	case idlec = <-putidlec:
  6928  	}
  6929  
  6930  	wg.Add(1)
  6931  	cancelctx, cancel := context.WithCancel(context.Background())
  6932  	go func() {
  6933  		defer wg.Done()
  6934  		req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
  6935  		res, err := client.Do(req)
  6936  		if err == nil {
  6937  			res.Body.Close()
  6938  		}
  6939  		if !errors.Is(err, context.Canceled) {
  6940  			t.Errorf("request 2: got err %v, want Canceled", err)
  6941  		}
  6942  
  6943  		// Unblock the first request.
  6944  		close(idlec)
  6945  	}()
  6946  
  6947  	// Wait for the second request to arrive at the server, and then cancel
  6948  	// the request context.
  6949  	r2c := <-reqc
  6950  	cancel()
  6951  
  6952  	<-idlec
  6953  
  6954  	close(r2c)
  6955  	wg.Wait()
  6956  }
  6957  
  6958  func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
  6959  func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
  6960  	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  6961  		go io.Copy(io.Discard, req.Body)
  6962  		panic(ErrAbortHandler)
  6963  	})).ts
  6964  
  6965  	var wg sync.WaitGroup
  6966  	for i := 0; i < 2; i++ {
  6967  		wg.Add(1)
  6968  		go func() {
  6969  			defer wg.Done()
  6970  			for j := 0; j < 10; j++ {
  6971  				const reqLen = 6 * 1024 * 1024
  6972  				req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
  6973  				req.ContentLength = reqLen
  6974  				resp, _ := ts.Client().Transport.RoundTrip(req)
  6975  				if resp != nil {
  6976  					resp.Body.Close()
  6977  				}
  6978  			}
  6979  		}()
  6980  	}
  6981  	wg.Wait()
  6982  }
  6983  
  6984  func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
  6985  func testRequestSanitization(t *testing.T, mode testMode) {
  6986  	if mode == http2Mode {
  6987  		// Remove this after updating x/net.
  6988  		t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
  6989  	}
  6990  	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  6991  		if h, ok := req.Header["X-Evil"]; ok {
  6992  			t.Errorf("request has X-Evil header: %q", h)
  6993  		}
  6994  	})).ts
  6995  	req, _ := NewRequest("GET", ts.URL, nil)
  6996  	req.Host = "go.dev\r\nX-Evil:evil"
  6997  	resp, _ := ts.Client().Do(req)
  6998  	if resp != nil {
  6999  		resp.Body.Close()
  7000  	}
  7001  }
  7002  
  7003  func TestProxyAuthHeader(t *testing.T) {
  7004  	// Not parallel: Sets an environment variable.
  7005  	run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
  7006  }
  7007  func testProxyAuthHeader(t *testing.T, mode testMode) {
  7008  	const username = "u"
  7009  	const password = "@/?!"
  7010  	cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  7011  		// Copy the Proxy-Authorization header to a new Request,
  7012  		// since Request.BasicAuth only parses the Authorization header.
  7013  		var r2 Request
  7014  		r2.Header = Header{
  7015  			"Authorization": req.Header["Proxy-Authorization"],
  7016  		}
  7017  		gotuser, gotpass, ok := r2.BasicAuth()
  7018  		if !ok || gotuser != username || gotpass != password {
  7019  			t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
  7020  		}
  7021  	}))
  7022  	u, err := url.Parse(cst.ts.URL)
  7023  	if err != nil {
  7024  		t.Fatal(err)
  7025  	}
  7026  	u.User = url.UserPassword(username, password)
  7027  	t.Setenv("HTTP_PROXY", u.String())
  7028  	cst.tr.Proxy = ProxyURL(u)
  7029  	resp, err := cst.c.Get("http://_/")
  7030  	if err != nil {
  7031  		t.Fatal(err)
  7032  	}
  7033  	resp.Body.Close()
  7034  }
  7035  
  7036  // Issue 61708
  7037  func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
  7038  	ln := newLocalListener(t)
  7039  	addr := ln.Addr().String()
  7040  
  7041  	done := make(chan struct{})
  7042  	go func() {
  7043  		conn, err := ln.Accept()
  7044  		if err != nil {
  7045  			t.Errorf("ln.Accept: %v", err)
  7046  			return
  7047  		}
  7048  		// Start reading request before sending response to avoid
  7049  		// "Unsolicited response received on idle HTTP channel" RoundTrip error.
  7050  		if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
  7051  			t.Errorf("conn.Read: %v", err)
  7052  			return
  7053  		}
  7054  		io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
  7055  		<-done
  7056  		conn.Close()
  7057  	}()
  7058  
  7059  	didRead := make(chan bool)
  7060  	SetReadLoopBeforeNextReadHook(func() { didRead <- true })
  7061  	defer SetReadLoopBeforeNextReadHook(nil)
  7062  
  7063  	tr := &Transport{}
  7064  
  7065  	// Send a request with a body guaranteed to fail on write.
  7066  	req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
  7067  	if err != nil {
  7068  		t.Fatalf("NewRequest: %v", err)
  7069  	}
  7070  
  7071  	resp, err := tr.RoundTrip(req)
  7072  	if err != nil {
  7073  		t.Fatalf("tr.RoundTrip: %v", err)
  7074  	}
  7075  
  7076  	close(done)
  7077  
  7078  	// Before closing response body wait for readLoopDone goroutine
  7079  	// to complete due to closed connection by writeLoop.
  7080  	<-didRead
  7081  
  7082  	resp.Body.Close()
  7083  
  7084  	// Verify no outstanding requests after readLoop/writeLoop
  7085  	// goroutines shut down.
  7086  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
  7087  		n := tr.NumPendingRequestsForTesting()
  7088  		if n > 0 {
  7089  			if d > 0 {
  7090  				t.Logf("pending requests = %d after %v (want 0)", n, d)
  7091  			}
  7092  			return false
  7093  		}
  7094  		return true
  7095  	})
  7096  }
  7097  
  7098  func TestValidateClientRequestTrailers(t *testing.T) {
  7099  	run(t, testValidateClientRequestTrailers)
  7100  }
  7101  
  7102  func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
  7103  	cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  7104  		rw.Write([]byte("Hello"))
  7105  	})).ts
  7106  
  7107  	cases := []struct {
  7108  		trailer Header
  7109  		wantErr string
  7110  	}{
  7111  		{Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
  7112  		{Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
  7113  	}
  7114  
  7115  	for i, tt := range cases {
  7116  		testName := fmt.Sprintf("%s%d", mode, i)
  7117  		t.Run(testName, func(t *testing.T) {
  7118  			req, err := NewRequest("GET", cst.URL, nil)
  7119  			if err != nil {
  7120  				t.Fatal(err)
  7121  			}
  7122  			req.Trailer = tt.trailer
  7123  			res, err := cst.Client().Do(req)
  7124  			if err == nil {
  7125  				t.Fatal("Expected an error")
  7126  			}
  7127  			if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
  7128  				t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
  7129  			}
  7130  			if res != nil {
  7131  				t.Fatal("Unexpected non-nil response")
  7132  			}
  7133  		})
  7134  	}
  7135  }
  7136  

View as plain text