Source file src/net/http/httputil/reverseproxy_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"log"
    17  	"net/http"
    18  	"net/http/httptest"
    19  	"net/http/httptrace"
    20  	"net/http/internal/ascii"
    21  	"net/textproto"
    22  	"net/url"
    23  	"os"
    24  	"reflect"
    25  	"slices"
    26  	"strconv"
    27  	"strings"
    28  	"sync"
    29  	"testing"
    30  	"time"
    31  )
    32  
    33  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    34  
    35  func init() {
    36  	inOurTests = true
    37  	hopHeaders = append(hopHeaders, fakeHopHeader)
    38  }
    39  
    40  func TestReverseProxy(t *testing.T) {
    41  	const backendResponse = "I am the backend"
    42  	const backendStatus = 404
    43  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    44  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
    45  			c, _, _ := w.(http.Hijacker).Hijack()
    46  			c.Close()
    47  			return
    48  		}
    49  		if len(r.TransferEncoding) > 0 {
    50  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    51  		}
    52  		if r.Header.Get("X-Forwarded-For") == "" {
    53  			t.Errorf("didn't get X-Forwarded-For header")
    54  		}
    55  		if c := r.Header.Get("Connection"); c != "" {
    56  			t.Errorf("handler got Connection header value %q", c)
    57  		}
    58  		if c := r.Header.Get("Te"); c != "trailers" {
    59  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
    60  		}
    61  		if c := r.Header.Get("Upgrade"); c != "" {
    62  			t.Errorf("handler got Upgrade header value %q", c)
    63  		}
    64  		if c := r.Header.Get("Proxy-Connection"); c != "" {
    65  			t.Errorf("handler got Proxy-Connection header value %q", c)
    66  		}
    67  		if g, e := r.Host, "some-name"; g != e {
    68  			t.Errorf("backend got Host header %q, want %q", g, e)
    69  		}
    70  		w.Header().Set("Trailers", "not a special header field name")
    71  		w.Header().Set("Trailer", "X-Trailer")
    72  		w.Header().Set("X-Foo", "bar")
    73  		w.Header().Set("Upgrade", "foo")
    74  		w.Header().Set(fakeHopHeader, "foo")
    75  		w.Header().Add("X-Multi-Value", "foo")
    76  		w.Header().Add("X-Multi-Value", "bar")
    77  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    78  		w.WriteHeader(backendStatus)
    79  		w.Write([]byte(backendResponse))
    80  		w.Header().Set("X-Trailer", "trailer_value")
    81  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
    82  	}))
    83  	defer backend.Close()
    84  	backendURL, err := url.Parse(backend.URL)
    85  	if err != nil {
    86  		t.Fatal(err)
    87  	}
    88  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    89  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
    90  	frontend := httptest.NewServer(proxyHandler)
    91  	defer frontend.Close()
    92  	frontendClient := frontend.Client()
    93  
    94  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    95  	getReq.Host = "some-name"
    96  	getReq.Header.Set("Connection", "close, TE")
    97  	getReq.Header.Add("Te", "foo")
    98  	getReq.Header.Add("Te", "bar, trailers")
    99  	getReq.Header.Set("Proxy-Connection", "should be deleted")
   100  	getReq.Header.Set("Upgrade", "foo")
   101  	getReq.Close = true
   102  	res, err := frontendClient.Do(getReq)
   103  	if err != nil {
   104  		t.Fatalf("Get: %v", err)
   105  	}
   106  	if g, e := res.StatusCode, backendStatus; g != e {
   107  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   108  	}
   109  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
   110  		t.Errorf("got X-Foo %q; expected %q", g, e)
   111  	}
   112  	if c := res.Header.Get(fakeHopHeader); c != "" {
   113  		t.Errorf("got %s header value %q", fakeHopHeader, c)
   114  	}
   115  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
   116  		t.Errorf("header Trailers = %q; want %q", g, e)
   117  	}
   118  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
   119  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
   120  	}
   121  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
   122  		t.Fatalf("got %d SetCookies, want %d", g, e)
   123  	}
   124  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
   125  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
   126  	}
   127  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
   128  		t.Errorf("unexpected cookie %q", cookie.Name)
   129  	}
   130  	bodyBytes, _ := io.ReadAll(res.Body)
   131  	if g, e := string(bodyBytes), backendResponse; g != e {
   132  		t.Errorf("got body %q; expected %q", g, e)
   133  	}
   134  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   135  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   136  	}
   137  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
   138  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
   139  	}
   140  	res.Body.Close()
   141  
   142  	// Test that a backend failing to be reached or one which doesn't return
   143  	// a response results in a StatusBadGateway.
   144  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
   145  	getReq.Close = true
   146  	res, err = frontendClient.Do(getReq)
   147  	if err != nil {
   148  		t.Fatal(err)
   149  	}
   150  	res.Body.Close()
   151  	if res.StatusCode != http.StatusBadGateway {
   152  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
   153  	}
   154  
   155  }
   156  
   157  // Issue 16875: remove any proxied headers mentioned in the "Connection"
   158  // header value.
   159  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
   160  	const fakeConnectionToken = "X-Fake-Connection-Token"
   161  	const backendResponse = "I am the backend"
   162  
   163  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   164  	// in the Request's Connection header.
   165  	const someConnHeader = "X-Some-Conn-Header"
   166  
   167  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   168  		if c := r.Header.Get("Connection"); c != "" {
   169  			t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   170  		}
   171  		if c := r.Header.Get(fakeConnectionToken); c != "" {
   172  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   173  		}
   174  		if c := r.Header.Get(someConnHeader); c != "" {
   175  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   176  		}
   177  		w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
   178  		w.Header().Add("Connection", someConnHeader)
   179  		w.Header().Set(someConnHeader, "should be deleted")
   180  		w.Header().Set(fakeConnectionToken, "should be deleted")
   181  		io.WriteString(w, backendResponse)
   182  	}))
   183  	defer backend.Close()
   184  	backendURL, err := url.Parse(backend.URL)
   185  	if err != nil {
   186  		t.Fatal(err)
   187  	}
   188  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   189  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   190  		proxyHandler.ServeHTTP(w, r)
   191  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   192  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   193  		}
   194  		if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
   195  			t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
   196  		}
   197  		c := r.Header["Connection"]
   198  		var cf []string
   199  		for _, f := range c {
   200  			for _, sf := range strings.Split(f, ",") {
   201  				if sf = strings.TrimSpace(sf); sf != "" {
   202  					cf = append(cf, sf)
   203  				}
   204  			}
   205  		}
   206  		slices.Sort(cf)
   207  		expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
   208  		slices.Sort(expectedValues)
   209  		if !slices.Equal(cf, expectedValues) {
   210  			t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
   211  		}
   212  	}))
   213  	defer frontend.Close()
   214  
   215  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   216  	getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
   217  	getReq.Header.Add("Connection", someConnHeader)
   218  	getReq.Header.Set(someConnHeader, "should be deleted")
   219  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
   220  	res, err := frontend.Client().Do(getReq)
   221  	if err != nil {
   222  		t.Fatalf("Get: %v", err)
   223  	}
   224  	defer res.Body.Close()
   225  	bodyBytes, err := io.ReadAll(res.Body)
   226  	if err != nil {
   227  		t.Fatalf("reading body: %v", err)
   228  	}
   229  	if got, want := string(bodyBytes), backendResponse; got != want {
   230  		t.Errorf("got body %q; want %q", got, want)
   231  	}
   232  	if c := res.Header.Get("Connection"); c != "" {
   233  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   234  	}
   235  	if c := res.Header.Get(someConnHeader); c != "" {
   236  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   237  	}
   238  	if c := res.Header.Get(fakeConnectionToken); c != "" {
   239  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   240  	}
   241  }
   242  
   243  func TestReverseProxyStripEmptyConnection(t *testing.T) {
   244  	// See Issue 46313.
   245  	const backendResponse = "I am the backend"
   246  
   247  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   248  	// in the Request's Connection header.
   249  	const someConnHeader = "X-Some-Conn-Header"
   250  
   251  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   252  		if c := r.Header.Values("Connection"); len(c) != 0 {
   253  			t.Errorf("handler got header %q = %v; want empty", "Connection", c)
   254  		}
   255  		if c := r.Header.Get(someConnHeader); c != "" {
   256  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   257  		}
   258  		w.Header().Add("Connection", "")
   259  		w.Header().Add("Connection", someConnHeader)
   260  		w.Header().Set(someConnHeader, "should be deleted")
   261  		io.WriteString(w, backendResponse)
   262  	}))
   263  	defer backend.Close()
   264  	backendURL, err := url.Parse(backend.URL)
   265  	if err != nil {
   266  		t.Fatal(err)
   267  	}
   268  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   269  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   270  		proxyHandler.ServeHTTP(w, r)
   271  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   272  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   273  		}
   274  	}))
   275  	defer frontend.Close()
   276  
   277  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   278  	getReq.Header.Add("Connection", "")
   279  	getReq.Header.Add("Connection", someConnHeader)
   280  	getReq.Header.Set(someConnHeader, "should be deleted")
   281  	res, err := frontend.Client().Do(getReq)
   282  	if err != nil {
   283  		t.Fatalf("Get: %v", err)
   284  	}
   285  	defer res.Body.Close()
   286  	bodyBytes, err := io.ReadAll(res.Body)
   287  	if err != nil {
   288  		t.Fatalf("reading body: %v", err)
   289  	}
   290  	if got, want := string(bodyBytes), backendResponse; got != want {
   291  		t.Errorf("got body %q; want %q", got, want)
   292  	}
   293  	if c := res.Header.Get("Connection"); c != "" {
   294  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   295  	}
   296  	if c := res.Header.Get(someConnHeader); c != "" {
   297  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   298  	}
   299  }
   300  
   301  func TestXForwardedFor(t *testing.T) {
   302  	const prevForwardedFor = "client ip"
   303  	const backendResponse = "I am the backend"
   304  	const backendStatus = 404
   305  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   306  		if r.Header.Get("X-Forwarded-For") == "" {
   307  			t.Errorf("didn't get X-Forwarded-For header")
   308  		}
   309  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   310  			t.Errorf("X-Forwarded-For didn't contain prior data")
   311  		}
   312  		w.WriteHeader(backendStatus)
   313  		w.Write([]byte(backendResponse))
   314  	}))
   315  	defer backend.Close()
   316  	backendURL, err := url.Parse(backend.URL)
   317  	if err != nil {
   318  		t.Fatal(err)
   319  	}
   320  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   321  	frontend := httptest.NewServer(proxyHandler)
   322  	defer frontend.Close()
   323  
   324  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   325  	getReq.Header.Set("Connection", "close")
   326  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   327  	getReq.Close = true
   328  	res, err := frontend.Client().Do(getReq)
   329  	if err != nil {
   330  		t.Fatalf("Get: %v", err)
   331  	}
   332  	defer res.Body.Close()
   333  	if g, e := res.StatusCode, backendStatus; g != e {
   334  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   335  	}
   336  	bodyBytes, _ := io.ReadAll(res.Body)
   337  	if g, e := string(bodyBytes), backendResponse; g != e {
   338  		t.Errorf("got body %q; expected %q", g, e)
   339  	}
   340  }
   341  
   342  // Issue 38079: don't append to X-Forwarded-For if it's present but nil
   343  func TestXForwardedFor_Omit(t *testing.T) {
   344  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   345  		if v := r.Header.Get("X-Forwarded-For"); v != "" {
   346  			t.Errorf("got X-Forwarded-For header: %q", v)
   347  		}
   348  		w.Write([]byte("hi"))
   349  	}))
   350  	defer backend.Close()
   351  	backendURL, err := url.Parse(backend.URL)
   352  	if err != nil {
   353  		t.Fatal(err)
   354  	}
   355  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   356  	frontend := httptest.NewServer(proxyHandler)
   357  	defer frontend.Close()
   358  
   359  	oldDirector := proxyHandler.Director
   360  	proxyHandler.Director = func(r *http.Request) {
   361  		r.Header["X-Forwarded-For"] = nil
   362  		oldDirector(r)
   363  	}
   364  
   365  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   366  	getReq.Host = "some-name"
   367  	getReq.Close = true
   368  	res, err := frontend.Client().Do(getReq)
   369  	if err != nil {
   370  		t.Fatalf("Get: %v", err)
   371  	}
   372  	res.Body.Close()
   373  }
   374  
   375  func TestReverseProxyRewriteStripsForwarded(t *testing.T) {
   376  	headers := []string{
   377  		"Forwarded",
   378  		"X-Forwarded-For",
   379  		"X-Forwarded-Host",
   380  		"X-Forwarded-Proto",
   381  	}
   382  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   383  		for _, h := range headers {
   384  			if v := r.Header.Get(h); v != "" {
   385  				t.Errorf("got %v header: %q", h, v)
   386  			}
   387  		}
   388  	}))
   389  	defer backend.Close()
   390  	backendURL, err := url.Parse(backend.URL)
   391  	if err != nil {
   392  		t.Fatal(err)
   393  	}
   394  	proxyHandler := &ReverseProxy{
   395  		Rewrite: func(r *ProxyRequest) {
   396  			r.SetURL(backendURL)
   397  		},
   398  	}
   399  	frontend := httptest.NewServer(proxyHandler)
   400  	defer frontend.Close()
   401  
   402  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   403  	getReq.Host = "some-name"
   404  	getReq.Close = true
   405  	for _, h := range headers {
   406  		getReq.Header.Set(h, "x")
   407  	}
   408  	res, err := frontend.Client().Do(getReq)
   409  	if err != nil {
   410  		t.Fatalf("Get: %v", err)
   411  	}
   412  	res.Body.Close()
   413  }
   414  
   415  var proxyQueryTests = []struct {
   416  	baseSuffix string // suffix to add to backend URL
   417  	reqSuffix  string // suffix to add to frontend's request URL
   418  	want       string // what backend should see for final request URL (without ?)
   419  }{
   420  	{"", "", ""},
   421  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   422  	{"", "?us=er", "us=er"},
   423  	{"?sta=tic", "", "sta=tic"},
   424  }
   425  
   426  func TestReverseProxyQuery(t *testing.T) {
   427  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   428  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   429  		w.Write([]byte("hi"))
   430  	}))
   431  	defer backend.Close()
   432  
   433  	for i, tt := range proxyQueryTests {
   434  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   435  		if err != nil {
   436  			t.Fatal(err)
   437  		}
   438  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   439  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   440  		req.Close = true
   441  		res, err := frontend.Client().Do(req)
   442  		if err != nil {
   443  			t.Fatalf("%d. Get: %v", i, err)
   444  		}
   445  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   446  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   447  		}
   448  		res.Body.Close()
   449  		frontend.Close()
   450  	}
   451  }
   452  
   453  func TestReverseProxyFlushInterval(t *testing.T) {
   454  	const expected = "hi"
   455  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   456  		w.Write([]byte(expected))
   457  	}))
   458  	defer backend.Close()
   459  
   460  	backendURL, err := url.Parse(backend.URL)
   461  	if err != nil {
   462  		t.Fatal(err)
   463  	}
   464  
   465  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   466  	proxyHandler.FlushInterval = time.Microsecond
   467  
   468  	frontend := httptest.NewServer(proxyHandler)
   469  	defer frontend.Close()
   470  
   471  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   472  	req.Close = true
   473  	res, err := frontend.Client().Do(req)
   474  	if err != nil {
   475  		t.Fatalf("Get: %v", err)
   476  	}
   477  	defer res.Body.Close()
   478  	if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
   479  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   480  	}
   481  }
   482  
   483  type mockFlusher struct {
   484  	http.ResponseWriter
   485  	flushed bool
   486  }
   487  
   488  func (m *mockFlusher) Flush() {
   489  	m.flushed = true
   490  }
   491  
   492  type wrappedRW struct {
   493  	http.ResponseWriter
   494  }
   495  
   496  func (w *wrappedRW) Unwrap() http.ResponseWriter {
   497  	return w.ResponseWriter
   498  }
   499  
   500  func TestReverseProxyResponseControllerFlushInterval(t *testing.T) {
   501  	const expected = "hi"
   502  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   503  		w.Write([]byte(expected))
   504  	}))
   505  	defer backend.Close()
   506  
   507  	backendURL, err := url.Parse(backend.URL)
   508  	if err != nil {
   509  		t.Fatal(err)
   510  	}
   511  
   512  	mf := &mockFlusher{}
   513  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   514  	proxyHandler.FlushInterval = -1 // flush immediately
   515  	proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   516  		mf.ResponseWriter = w
   517  		w = &wrappedRW{mf}
   518  		proxyHandler.ServeHTTP(w, r)
   519  	})
   520  
   521  	frontend := httptest.NewServer(proxyWithMiddleware)
   522  	defer frontend.Close()
   523  
   524  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   525  	req.Close = true
   526  	res, err := frontend.Client().Do(req)
   527  	if err != nil {
   528  		t.Fatalf("Get: %v", err)
   529  	}
   530  	defer res.Body.Close()
   531  	if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
   532  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   533  	}
   534  	if !mf.flushed {
   535  		t.Errorf("response writer was not flushed")
   536  	}
   537  }
   538  
   539  func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
   540  	const expected = "hi"
   541  	stopCh := make(chan struct{})
   542  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   543  		w.Header().Add("MyHeader", expected)
   544  		w.WriteHeader(200)
   545  		w.(http.Flusher).Flush()
   546  		<-stopCh
   547  	}))
   548  	defer backend.Close()
   549  	defer close(stopCh)
   550  
   551  	backendURL, err := url.Parse(backend.URL)
   552  	if err != nil {
   553  		t.Fatal(err)
   554  	}
   555  
   556  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   557  	proxyHandler.FlushInterval = time.Microsecond
   558  
   559  	frontend := httptest.NewServer(proxyHandler)
   560  	defer frontend.Close()
   561  
   562  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   563  	req.Close = true
   564  
   565  	ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
   566  	defer cancel()
   567  	req = req.WithContext(ctx)
   568  
   569  	res, err := frontend.Client().Do(req)
   570  	if err != nil {
   571  		t.Fatalf("Get: %v", err)
   572  	}
   573  	defer res.Body.Close()
   574  
   575  	if res.Header.Get("MyHeader") != expected {
   576  		t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
   577  	}
   578  }
   579  
   580  func TestReverseProxyCancellation(t *testing.T) {
   581  	const backendResponse = "I am the backend"
   582  
   583  	reqInFlight := make(chan struct{})
   584  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   585  		close(reqInFlight) // cause the client to cancel its request
   586  
   587  		select {
   588  		case <-time.After(10 * time.Second):
   589  			// Note: this should only happen in broken implementations, and the
   590  			// closenotify case should be instantaneous.
   591  			t.Error("Handler never saw CloseNotify")
   592  			return
   593  		case <-w.(http.CloseNotifier).CloseNotify():
   594  		}
   595  
   596  		w.WriteHeader(http.StatusOK)
   597  		w.Write([]byte(backendResponse))
   598  	}))
   599  
   600  	defer backend.Close()
   601  
   602  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
   603  
   604  	backendURL, err := url.Parse(backend.URL)
   605  	if err != nil {
   606  		t.Fatal(err)
   607  	}
   608  
   609  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   610  
   611  	// Discards errors of the form:
   612  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   613  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
   614  
   615  	frontend := httptest.NewServer(proxyHandler)
   616  	defer frontend.Close()
   617  	frontendClient := frontend.Client()
   618  
   619  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   620  	go func() {
   621  		<-reqInFlight
   622  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
   623  	}()
   624  	res, err := frontendClient.Do(getReq)
   625  	if res != nil {
   626  		t.Errorf("got response %v; want nil", res.Status)
   627  	}
   628  	if err == nil {
   629  		// This should be an error like:
   630  		// Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
   631  		//    use of closed network connection
   632  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
   633  	}
   634  }
   635  
   636  func req(t *testing.T, v string) *http.Request {
   637  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   638  	if err != nil {
   639  		t.Fatal(err)
   640  	}
   641  	return req
   642  }
   643  
   644  // Issue 12344
   645  func TestNilBody(t *testing.T) {
   646  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   647  		w.Write([]byte("hi"))
   648  	}))
   649  	defer backend.Close()
   650  
   651  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   652  		backURL, _ := url.Parse(backend.URL)
   653  		rp := NewSingleHostReverseProxy(backURL)
   654  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   655  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   656  		rp.ServeHTTP(w, r)
   657  	}))
   658  	defer frontend.Close()
   659  
   660  	res, err := http.Get(frontend.URL)
   661  	if err != nil {
   662  		t.Fatal(err)
   663  	}
   664  	defer res.Body.Close()
   665  	slurp, err := io.ReadAll(res.Body)
   666  	if err != nil {
   667  		t.Fatal(err)
   668  	}
   669  	if string(slurp) != "hi" {
   670  		t.Errorf("Got %q; want %q", slurp, "hi")
   671  	}
   672  }
   673  
   674  // Issue 15524
   675  func TestUserAgentHeader(t *testing.T) {
   676  	var gotUA string
   677  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   678  		gotUA = r.Header.Get("User-Agent")
   679  	}))
   680  	defer backend.Close()
   681  	backendURL, err := url.Parse(backend.URL)
   682  	if err != nil {
   683  		t.Fatal(err)
   684  	}
   685  
   686  	proxyHandler := new(ReverseProxy)
   687  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   688  	proxyHandler.Director = func(req *http.Request) {
   689  		req.URL = backendURL
   690  	}
   691  	frontend := httptest.NewServer(proxyHandler)
   692  	defer frontend.Close()
   693  	frontendClient := frontend.Client()
   694  
   695  	for _, sentUA := range []string{"explicit UA", ""} {
   696  		getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   697  		getReq.Header.Set("User-Agent", sentUA)
   698  		getReq.Close = true
   699  		res, err := frontendClient.Do(getReq)
   700  		if err != nil {
   701  			t.Fatalf("Get: %v", err)
   702  		}
   703  		res.Body.Close()
   704  		if got, want := gotUA, sentUA; got != want {
   705  			t.Errorf("got forwarded User-Agent %q, want %q", got, want)
   706  		}
   707  	}
   708  }
   709  
   710  type bufferPool struct {
   711  	get func() []byte
   712  	put func([]byte)
   713  }
   714  
   715  func (bp bufferPool) Get() []byte  { return bp.get() }
   716  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   717  
   718  func TestReverseProxyGetPutBuffer(t *testing.T) {
   719  	const msg = "hi"
   720  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   721  		io.WriteString(w, msg)
   722  	}))
   723  	defer backend.Close()
   724  
   725  	backendURL, err := url.Parse(backend.URL)
   726  	if err != nil {
   727  		t.Fatal(err)
   728  	}
   729  
   730  	var (
   731  		mu  sync.Mutex
   732  		log []string
   733  	)
   734  	addLog := func(event string) {
   735  		mu.Lock()
   736  		defer mu.Unlock()
   737  		log = append(log, event)
   738  	}
   739  	rp := NewSingleHostReverseProxy(backendURL)
   740  	const size = 1234
   741  	rp.BufferPool = bufferPool{
   742  		get: func() []byte {
   743  			addLog("getBuf")
   744  			return make([]byte, size)
   745  		},
   746  		put: func(p []byte) {
   747  			addLog("putBuf-" + strconv.Itoa(len(p)))
   748  		},
   749  	}
   750  	frontend := httptest.NewServer(rp)
   751  	defer frontend.Close()
   752  
   753  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   754  	req.Close = true
   755  	res, err := frontend.Client().Do(req)
   756  	if err != nil {
   757  		t.Fatalf("Get: %v", err)
   758  	}
   759  	slurp, err := io.ReadAll(res.Body)
   760  	res.Body.Close()
   761  	if err != nil {
   762  		t.Fatalf("reading body: %v", err)
   763  	}
   764  	if string(slurp) != msg {
   765  		t.Errorf("msg = %q; want %q", slurp, msg)
   766  	}
   767  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   768  	mu.Lock()
   769  	defer mu.Unlock()
   770  	if !slices.Equal(log, wantLog) {
   771  		t.Errorf("Log events = %q; want %q", log, wantLog)
   772  	}
   773  }
   774  
   775  func TestReverseProxy_Post(t *testing.T) {
   776  	const backendResponse = "I am the backend"
   777  	const backendStatus = 200
   778  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
   779  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   780  		slurp, err := io.ReadAll(r.Body)
   781  		if err != nil {
   782  			t.Errorf("Backend body read = %v", err)
   783  		}
   784  		if len(slurp) != len(requestBody) {
   785  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
   786  		}
   787  		if !bytes.Equal(slurp, requestBody) {
   788  			t.Error("Backend read wrong request body.") // 1MB; omitting details
   789  		}
   790  		w.Write([]byte(backendResponse))
   791  	}))
   792  	defer backend.Close()
   793  	backendURL, err := url.Parse(backend.URL)
   794  	if err != nil {
   795  		t.Fatal(err)
   796  	}
   797  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   798  	frontend := httptest.NewServer(proxyHandler)
   799  	defer frontend.Close()
   800  
   801  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
   802  	res, err := frontend.Client().Do(postReq)
   803  	if err != nil {
   804  		t.Fatalf("Do: %v", err)
   805  	}
   806  	defer res.Body.Close()
   807  	if g, e := res.StatusCode, backendStatus; g != e {
   808  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   809  	}
   810  	bodyBytes, _ := io.ReadAll(res.Body)
   811  	if g, e := string(bodyBytes), backendResponse; g != e {
   812  		t.Errorf("got body %q; expected %q", g, e)
   813  	}
   814  }
   815  
   816  type RoundTripperFunc func(*http.Request) (*http.Response, error)
   817  
   818  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   819  	return fn(req)
   820  }
   821  
   822  // Issue 16036: send a Request with a nil Body when possible
   823  func TestReverseProxy_NilBody(t *testing.T) {
   824  	backendURL, _ := url.Parse("http://fake.tld/")
   825  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   826  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   827  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   828  		if req.Body != nil {
   829  			t.Error("Body != nil; want a nil Body")
   830  		}
   831  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   832  	})
   833  	frontend := httptest.NewServer(proxyHandler)
   834  	defer frontend.Close()
   835  
   836  	res, err := frontend.Client().Get(frontend.URL)
   837  	if err != nil {
   838  		t.Fatal(err)
   839  	}
   840  	defer res.Body.Close()
   841  	if res.StatusCode != 502 {
   842  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
   843  	}
   844  }
   845  
   846  // Issue 33142: always allocate the request headers
   847  func TestReverseProxy_AllocatedHeader(t *testing.T) {
   848  	proxyHandler := new(ReverseProxy)
   849  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   850  	proxyHandler.Director = func(*http.Request) {}     // noop
   851  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   852  		if req.Header == nil {
   853  			t.Error("Header == nil; want a non-nil Header")
   854  		}
   855  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   856  	})
   857  
   858  	proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
   859  		Method:     "GET",
   860  		URL:        &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
   861  		Proto:      "HTTP/1.0",
   862  		ProtoMajor: 1,
   863  	})
   864  }
   865  
   866  // Issue 14237. Test ModifyResponse and that an error from it
   867  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
   868  func TestReverseProxyModifyResponse(t *testing.T) {
   869  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   870  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
   871  	}))
   872  	defer backendServer.Close()
   873  
   874  	rpURL, _ := url.Parse(backendServer.URL)
   875  	rproxy := NewSingleHostReverseProxy(rpURL)
   876  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   877  	rproxy.ModifyResponse = func(resp *http.Response) error {
   878  		if resp.Header.Get("X-Hit-Mod") != "true" {
   879  			return fmt.Errorf("tried to by-pass proxy")
   880  		}
   881  		return nil
   882  	}
   883  
   884  	frontendProxy := httptest.NewServer(rproxy)
   885  	defer frontendProxy.Close()
   886  
   887  	tests := []struct {
   888  		url      string
   889  		wantCode int
   890  	}{
   891  		{frontendProxy.URL + "/mod", http.StatusOK},
   892  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
   893  	}
   894  
   895  	for i, tt := range tests {
   896  		resp, err := http.Get(tt.url)
   897  		if err != nil {
   898  			t.Fatalf("failed to reach proxy: %v", err)
   899  		}
   900  		if g, e := resp.StatusCode, tt.wantCode; g != e {
   901  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
   902  		}
   903  		resp.Body.Close()
   904  	}
   905  }
   906  
   907  type failingRoundTripper struct{}
   908  
   909  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   910  	return nil, errors.New("some error")
   911  }
   912  
   913  type staticResponseRoundTripper struct{ res *http.Response }
   914  
   915  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   916  	return rt.res, nil
   917  }
   918  
   919  func TestReverseProxyErrorHandler(t *testing.T) {
   920  	tests := []struct {
   921  		name           string
   922  		wantCode       int
   923  		errorHandler   func(http.ResponseWriter, *http.Request, error)
   924  		transport      http.RoundTripper // defaults to failingRoundTripper
   925  		modifyResponse func(*http.Response) error
   926  	}{
   927  		{
   928  			name:     "default",
   929  			wantCode: http.StatusBadGateway,
   930  		},
   931  		{
   932  			name:         "errorhandler",
   933  			wantCode:     http.StatusTeapot,
   934  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   935  		},
   936  		{
   937  			name: "modifyresponse_noerr",
   938  			transport: staticResponseRoundTripper{
   939  				&http.Response{StatusCode: 345, Body: http.NoBody},
   940  			},
   941  			modifyResponse: func(res *http.Response) error {
   942  				res.StatusCode++
   943  				return nil
   944  			},
   945  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   946  			wantCode:     346,
   947  		},
   948  		{
   949  			name: "modifyresponse_err",
   950  			transport: staticResponseRoundTripper{
   951  				&http.Response{StatusCode: 345, Body: http.NoBody},
   952  			},
   953  			modifyResponse: func(res *http.Response) error {
   954  				res.StatusCode++
   955  				return errors.New("some error to trigger errorHandler")
   956  			},
   957  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   958  			wantCode:     http.StatusTeapot,
   959  		},
   960  	}
   961  
   962  	for _, tt := range tests {
   963  		t.Run(tt.name, func(t *testing.T) {
   964  			target := &url.URL{
   965  				Scheme: "http",
   966  				Host:   "dummy.tld",
   967  				Path:   "/",
   968  			}
   969  			rproxy := NewSingleHostReverseProxy(target)
   970  			rproxy.Transport = tt.transport
   971  			rproxy.ModifyResponse = tt.modifyResponse
   972  			if rproxy.Transport == nil {
   973  				rproxy.Transport = failingRoundTripper{}
   974  			}
   975  			rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   976  			if tt.errorHandler != nil {
   977  				rproxy.ErrorHandler = tt.errorHandler
   978  			}
   979  			frontendProxy := httptest.NewServer(rproxy)
   980  			defer frontendProxy.Close()
   981  
   982  			resp, err := http.Get(frontendProxy.URL + "/test")
   983  			if err != nil {
   984  				t.Fatalf("failed to reach proxy: %v", err)
   985  			}
   986  			if g, e := resp.StatusCode, tt.wantCode; g != e {
   987  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
   988  			}
   989  			resp.Body.Close()
   990  		})
   991  	}
   992  }
   993  
   994  // Issue 16659: log errors from short read
   995  func TestReverseProxy_CopyBuffer(t *testing.T) {
   996  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   997  		out := "this call was relayed by the reverse proxy"
   998  		// Coerce a wrong content length to induce io.UnexpectedEOF
   999  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1000  		fmt.Fprintln(w, out)
  1001  	}))
  1002  	defer backendServer.Close()
  1003  
  1004  	rpURL, err := url.Parse(backendServer.URL)
  1005  	if err != nil {
  1006  		t.Fatal(err)
  1007  	}
  1008  
  1009  	var proxyLog bytes.Buffer
  1010  	rproxy := NewSingleHostReverseProxy(rpURL)
  1011  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
  1012  	donec := make(chan bool, 1)
  1013  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1014  		defer func() { donec <- true }()
  1015  		rproxy.ServeHTTP(w, r)
  1016  	}))
  1017  	defer frontendProxy.Close()
  1018  
  1019  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
  1020  		t.Fatalf("want non-nil error")
  1021  	}
  1022  	// The race detector complains about the proxyLog usage in logf in copyBuffer
  1023  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
  1024  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
  1025  	// continue after Get.
  1026  	<-donec
  1027  
  1028  	expected := []string{
  1029  		"EOF",
  1030  		"read",
  1031  	}
  1032  	for _, phrase := range expected {
  1033  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
  1034  			t.Errorf("expected log to contain phrase %q", phrase)
  1035  		}
  1036  	}
  1037  }
  1038  
  1039  type staticTransport struct {
  1040  	res *http.Response
  1041  }
  1042  
  1043  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  1044  	return t.res, nil
  1045  }
  1046  
  1047  func BenchmarkServeHTTP(b *testing.B) {
  1048  	res := &http.Response{
  1049  		StatusCode: 200,
  1050  		Body:       io.NopCloser(strings.NewReader("")),
  1051  	}
  1052  	proxy := &ReverseProxy{
  1053  		Director:  func(*http.Request) {},
  1054  		Transport: &staticTransport{res},
  1055  	}
  1056  
  1057  	w := httptest.NewRecorder()
  1058  	r := httptest.NewRequest("GET", "/", nil)
  1059  
  1060  	b.ReportAllocs()
  1061  	for i := 0; i < b.N; i++ {
  1062  		proxy.ServeHTTP(w, r)
  1063  	}
  1064  }
  1065  
  1066  func TestServeHTTPDeepCopy(t *testing.T) {
  1067  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1068  		w.Write([]byte("Hello Gopher!"))
  1069  	}))
  1070  	defer backend.Close()
  1071  	backendURL, err := url.Parse(backend.URL)
  1072  	if err != nil {
  1073  		t.Fatal(err)
  1074  	}
  1075  
  1076  	type result struct {
  1077  		before, after string
  1078  	}
  1079  
  1080  	resultChan := make(chan result, 1)
  1081  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1082  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1083  		before := r.URL.String()
  1084  		proxyHandler.ServeHTTP(w, r)
  1085  		after := r.URL.String()
  1086  		resultChan <- result{before: before, after: after}
  1087  	}))
  1088  	defer frontend.Close()
  1089  
  1090  	want := result{before: "/", after: "/"}
  1091  
  1092  	res, err := frontend.Client().Get(frontend.URL)
  1093  	if err != nil {
  1094  		t.Fatalf("Do: %v", err)
  1095  	}
  1096  	res.Body.Close()
  1097  
  1098  	got := <-resultChan
  1099  	if got != want {
  1100  		t.Errorf("got = %+v; want = %+v", got, want)
  1101  	}
  1102  }
  1103  
  1104  // Issue 18327: verify we always do a deep copy of the Request.Header map
  1105  // before any mutations.
  1106  func TestClonesRequestHeaders(t *testing.T) {
  1107  	log.SetOutput(io.Discard)
  1108  	defer log.SetOutput(os.Stderr)
  1109  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1110  	req.RemoteAddr = "1.2.3.4:56789"
  1111  	rp := &ReverseProxy{
  1112  		Director: func(req *http.Request) {
  1113  			req.Header.Set("From-Director", "1")
  1114  		},
  1115  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
  1116  			if v := req.Header.Get("From-Director"); v != "1" {
  1117  				t.Errorf("From-Directory value = %q; want 1", v)
  1118  			}
  1119  			return nil, io.EOF
  1120  		}),
  1121  	}
  1122  	rp.ServeHTTP(httptest.NewRecorder(), req)
  1123  
  1124  	for _, h := range []string{
  1125  		"From-Director",
  1126  		"X-Forwarded-For",
  1127  	} {
  1128  		if req.Header.Get(h) != "" {
  1129  			t.Errorf("%v header mutation modified caller's request", h)
  1130  		}
  1131  	}
  1132  }
  1133  
  1134  type roundTripperFunc func(req *http.Request) (*http.Response, error)
  1135  
  1136  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
  1137  	return fn(req)
  1138  }
  1139  
  1140  func TestModifyResponseClosesBody(t *testing.T) {
  1141  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1142  	req.RemoteAddr = "1.2.3.4:56789"
  1143  	closeCheck := new(checkCloser)
  1144  	logBuf := new(strings.Builder)
  1145  	outErr := errors.New("ModifyResponse error")
  1146  	rp := &ReverseProxy{
  1147  		Director: func(req *http.Request) {},
  1148  		Transport: &staticTransport{&http.Response{
  1149  			StatusCode: 200,
  1150  			Body:       closeCheck,
  1151  		}},
  1152  		ErrorLog: log.New(logBuf, "", 0),
  1153  		ModifyResponse: func(*http.Response) error {
  1154  			return outErr
  1155  		},
  1156  	}
  1157  	rec := httptest.NewRecorder()
  1158  	rp.ServeHTTP(rec, req)
  1159  	res := rec.Result()
  1160  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
  1161  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
  1162  	}
  1163  	if !closeCheck.closed {
  1164  		t.Errorf("body should have been closed")
  1165  	}
  1166  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
  1167  		t.Errorf("ErrorLog %q does not contain %q", g, e)
  1168  	}
  1169  }
  1170  
  1171  type checkCloser struct {
  1172  	closed bool
  1173  }
  1174  
  1175  func (cc *checkCloser) Close() error {
  1176  	cc.closed = true
  1177  	return nil
  1178  }
  1179  
  1180  func (cc *checkCloser) Read(b []byte) (int, error) {
  1181  	return len(b), nil
  1182  }
  1183  
  1184  // Issue 23643: panic on body copy error
  1185  func TestReverseProxy_PanicBodyError(t *testing.T) {
  1186  	log.SetOutput(io.Discard)
  1187  	defer log.SetOutput(os.Stderr)
  1188  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1189  		out := "this call was relayed by the reverse proxy"
  1190  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1191  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1192  		fmt.Fprintln(w, out)
  1193  	}))
  1194  	defer backendServer.Close()
  1195  
  1196  	rpURL, err := url.Parse(backendServer.URL)
  1197  	if err != nil {
  1198  		t.Fatal(err)
  1199  	}
  1200  
  1201  	rproxy := NewSingleHostReverseProxy(rpURL)
  1202  
  1203  	// Ensure that the handler panics when the body read encounters an
  1204  	// io.ErrUnexpectedEOF
  1205  	defer func() {
  1206  		err := recover()
  1207  		if err == nil {
  1208  			t.Fatal("handler should have panicked")
  1209  		}
  1210  		if err != http.ErrAbortHandler {
  1211  			t.Fatal("expected ErrAbortHandler, got", err)
  1212  		}
  1213  	}()
  1214  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1215  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
  1216  }
  1217  
  1218  // Issue #46866: panic without closing incoming request body causes a panic
  1219  func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
  1220  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1221  		out := "this call was relayed by the reverse proxy"
  1222  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1223  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1224  		fmt.Fprintln(w, out)
  1225  	}))
  1226  	defer backend.Close()
  1227  	backendURL, err := url.Parse(backend.URL)
  1228  	if err != nil {
  1229  		t.Fatal(err)
  1230  	}
  1231  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1232  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1233  	frontend := httptest.NewServer(proxyHandler)
  1234  	defer frontend.Close()
  1235  	frontendClient := frontend.Client()
  1236  
  1237  	var wg sync.WaitGroup
  1238  	for i := 0; i < 2; i++ {
  1239  		wg.Add(1)
  1240  		go func() {
  1241  			defer wg.Done()
  1242  			for j := 0; j < 10; j++ {
  1243  				const reqLen = 6 * 1024 * 1024
  1244  				req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
  1245  				req.ContentLength = reqLen
  1246  				resp, _ := frontendClient.Transport.RoundTrip(req)
  1247  				if resp != nil {
  1248  					io.Copy(io.Discard, resp.Body)
  1249  					resp.Body.Close()
  1250  				}
  1251  			}
  1252  		}()
  1253  	}
  1254  	wg.Wait()
  1255  }
  1256  
  1257  func TestSelectFlushInterval(t *testing.T) {
  1258  	tests := []struct {
  1259  		name string
  1260  		p    *ReverseProxy
  1261  		res  *http.Response
  1262  		want time.Duration
  1263  	}{
  1264  		{
  1265  			name: "default",
  1266  			res:  &http.Response{},
  1267  			p:    &ReverseProxy{FlushInterval: 123},
  1268  			want: 123,
  1269  		},
  1270  		{
  1271  			name: "server-sent events overrides non-zero",
  1272  			res: &http.Response{
  1273  				Header: http.Header{
  1274  					"Content-Type": {"text/event-stream"},
  1275  				},
  1276  			},
  1277  			p:    &ReverseProxy{FlushInterval: 123},
  1278  			want: -1,
  1279  		},
  1280  		{
  1281  			name: "server-sent events overrides zero",
  1282  			res: &http.Response{
  1283  				Header: http.Header{
  1284  					"Content-Type": {"text/event-stream"},
  1285  				},
  1286  			},
  1287  			p:    &ReverseProxy{FlushInterval: 0},
  1288  			want: -1,
  1289  		},
  1290  		{
  1291  			name: "server-sent events with media-type parameters overrides non-zero",
  1292  			res: &http.Response{
  1293  				Header: http.Header{
  1294  					"Content-Type": {"text/event-stream;charset=utf-8"},
  1295  				},
  1296  			},
  1297  			p:    &ReverseProxy{FlushInterval: 123},
  1298  			want: -1,
  1299  		},
  1300  		{
  1301  			name: "server-sent events with media-type parameters overrides zero",
  1302  			res: &http.Response{
  1303  				Header: http.Header{
  1304  					"Content-Type": {"text/event-stream;charset=utf-8"},
  1305  				},
  1306  			},
  1307  			p:    &ReverseProxy{FlushInterval: 0},
  1308  			want: -1,
  1309  		},
  1310  		{
  1311  			name: "Content-Length: -1, overrides non-zero",
  1312  			res: &http.Response{
  1313  				ContentLength: -1,
  1314  			},
  1315  			p:    &ReverseProxy{FlushInterval: 123},
  1316  			want: -1,
  1317  		},
  1318  		{
  1319  			name: "Content-Length: -1, overrides zero",
  1320  			res: &http.Response{
  1321  				ContentLength: -1,
  1322  			},
  1323  			p:    &ReverseProxy{FlushInterval: 0},
  1324  			want: -1,
  1325  		},
  1326  	}
  1327  	for _, tt := range tests {
  1328  		t.Run(tt.name, func(t *testing.T) {
  1329  			got := tt.p.flushInterval(tt.res)
  1330  			if got != tt.want {
  1331  				t.Errorf("flushLatency = %v; want %v", got, tt.want)
  1332  			}
  1333  		})
  1334  	}
  1335  }
  1336  
  1337  func TestReverseProxyWebSocket(t *testing.T) {
  1338  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1339  		if upgradeType(r.Header) != "websocket" {
  1340  			t.Error("unexpected backend request")
  1341  			http.Error(w, "unexpected request", 400)
  1342  			return
  1343  		}
  1344  		c, _, err := w.(http.Hijacker).Hijack()
  1345  		if err != nil {
  1346  			t.Error(err)
  1347  			return
  1348  		}
  1349  		defer c.Close()
  1350  		io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
  1351  		bs := bufio.NewScanner(c)
  1352  		if !bs.Scan() {
  1353  			t.Errorf("backend failed to read line from client: %v", bs.Err())
  1354  			return
  1355  		}
  1356  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
  1357  	}))
  1358  	defer backendServer.Close()
  1359  
  1360  	backURL, _ := url.Parse(backendServer.URL)
  1361  	rproxy := NewSingleHostReverseProxy(backURL)
  1362  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1363  	rproxy.ModifyResponse = func(res *http.Response) error {
  1364  		res.Header.Add("X-Modified", "true")
  1365  		return nil
  1366  	}
  1367  
  1368  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1369  		rw.Header().Set("X-Header", "X-Value")
  1370  		rproxy.ServeHTTP(rw, req)
  1371  		if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
  1372  			t.Errorf("response writer X-Modified header = %q; want %q", got, want)
  1373  		}
  1374  	})
  1375  
  1376  	frontendProxy := httptest.NewServer(handler)
  1377  	defer frontendProxy.Close()
  1378  
  1379  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1380  	req.Header.Set("Connection", "Upgrade")
  1381  	req.Header.Set("Upgrade", "websocket")
  1382  
  1383  	c := frontendProxy.Client()
  1384  	res, err := c.Do(req)
  1385  	if err != nil {
  1386  		t.Fatal(err)
  1387  	}
  1388  	if res.StatusCode != 101 {
  1389  		t.Fatalf("status = %v; want 101", res.Status)
  1390  	}
  1391  
  1392  	got := res.Header.Get("X-Header")
  1393  	want := "X-Value"
  1394  	if got != want {
  1395  		t.Errorf("Header(XHeader) = %q; want %q", got, want)
  1396  	}
  1397  
  1398  	if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
  1399  		t.Fatalf("not websocket upgrade; got %#v", res.Header)
  1400  	}
  1401  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1402  	if !ok {
  1403  		t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
  1404  	}
  1405  	defer rwc.Close()
  1406  
  1407  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1408  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1409  	}
  1410  
  1411  	io.WriteString(rwc, "Hello\n")
  1412  	bs := bufio.NewScanner(rwc)
  1413  	if !bs.Scan() {
  1414  		t.Fatalf("Scan: %v", bs.Err())
  1415  	}
  1416  	got = bs.Text()
  1417  	want = `backend got "Hello"`
  1418  	if got != want {
  1419  		t.Errorf("got %#q, want %#q", got, want)
  1420  	}
  1421  }
  1422  
  1423  func TestReverseProxyWebSocketCancellation(t *testing.T) {
  1424  	n := 5
  1425  	triggerCancelCh := make(chan bool, n)
  1426  	nthResponse := func(i int) string {
  1427  		return fmt.Sprintf("backend response #%d\n", i)
  1428  	}
  1429  	terminalMsg := "final message"
  1430  
  1431  	cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1432  		if g, ws := upgradeType(r.Header), "websocket"; g != ws {
  1433  			t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
  1434  			http.Error(w, "Unexpected request", 400)
  1435  			return
  1436  		}
  1437  		conn, bufrw, err := w.(http.Hijacker).Hijack()
  1438  		if err != nil {
  1439  			t.Error(err)
  1440  			return
  1441  		}
  1442  		defer conn.Close()
  1443  
  1444  		upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
  1445  		if _, err := io.WriteString(conn, upgradeMsg); err != nil {
  1446  			t.Error(err)
  1447  			return
  1448  		}
  1449  		if _, _, err := bufrw.ReadLine(); err != nil {
  1450  			t.Errorf("Failed to read line from client: %v", err)
  1451  			return
  1452  		}
  1453  
  1454  		for i := 0; i < n; i++ {
  1455  			if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
  1456  				select {
  1457  				case <-triggerCancelCh:
  1458  				default:
  1459  					t.Errorf("Writing response #%d failed: %v", i, err)
  1460  				}
  1461  				return
  1462  			}
  1463  			bufrw.Flush()
  1464  			time.Sleep(time.Second)
  1465  		}
  1466  		if _, err := bufrw.WriteString(terminalMsg); err != nil {
  1467  			select {
  1468  			case <-triggerCancelCh:
  1469  			default:
  1470  				t.Errorf("Failed to write terminal message: %v", err)
  1471  			}
  1472  		}
  1473  		bufrw.Flush()
  1474  	}))
  1475  	defer cst.Close()
  1476  
  1477  	backendURL, _ := url.Parse(cst.URL)
  1478  	rproxy := NewSingleHostReverseProxy(backendURL)
  1479  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1480  	rproxy.ModifyResponse = func(res *http.Response) error {
  1481  		res.Header.Add("X-Modified", "true")
  1482  		return nil
  1483  	}
  1484  
  1485  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1486  		rw.Header().Set("X-Header", "X-Value")
  1487  		ctx, cancel := context.WithCancel(req.Context())
  1488  		go func() {
  1489  			<-triggerCancelCh
  1490  			cancel()
  1491  		}()
  1492  		rproxy.ServeHTTP(rw, req.WithContext(ctx))
  1493  	})
  1494  
  1495  	frontendProxy := httptest.NewServer(handler)
  1496  	defer frontendProxy.Close()
  1497  
  1498  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1499  	req.Header.Set("Connection", "Upgrade")
  1500  	req.Header.Set("Upgrade", "websocket")
  1501  
  1502  	res, err := frontendProxy.Client().Do(req)
  1503  	if err != nil {
  1504  		t.Fatalf("Dialing to frontend proxy: %v", err)
  1505  	}
  1506  	defer res.Body.Close()
  1507  	if g, w := res.StatusCode, 101; g != w {
  1508  		t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
  1509  	}
  1510  
  1511  	if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
  1512  		t.Errorf("X-Header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1513  	}
  1514  
  1515  	if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
  1516  		t.Fatalf("Upgrade header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1517  	}
  1518  
  1519  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1520  	if !ok {
  1521  		t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
  1522  	}
  1523  
  1524  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1525  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1526  	}
  1527  
  1528  	if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
  1529  		t.Fatalf("Failed to write first message: %v", err)
  1530  	}
  1531  
  1532  	// Read loop.
  1533  
  1534  	br := bufio.NewReader(rwc)
  1535  	for {
  1536  		line, err := br.ReadString('\n')
  1537  		switch {
  1538  		case line == terminalMsg: // this case before "err == io.EOF"
  1539  			t.Fatalf("The websocket request was not canceled, unfortunately!")
  1540  
  1541  		case err == io.EOF:
  1542  			return
  1543  
  1544  		case err != nil:
  1545  			t.Fatalf("Unexpected error: %v", err)
  1546  
  1547  		case line == nthResponse(0): // We've gotten the first response back
  1548  			// Let's trigger a cancel.
  1549  			close(triggerCancelCh)
  1550  		}
  1551  	}
  1552  }
  1553  
  1554  func TestUnannouncedTrailer(t *testing.T) {
  1555  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1556  		w.WriteHeader(http.StatusOK)
  1557  		w.(http.Flusher).Flush()
  1558  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  1559  	}))
  1560  	defer backend.Close()
  1561  	backendURL, err := url.Parse(backend.URL)
  1562  	if err != nil {
  1563  		t.Fatal(err)
  1564  	}
  1565  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1566  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1567  	frontend := httptest.NewServer(proxyHandler)
  1568  	defer frontend.Close()
  1569  	frontendClient := frontend.Client()
  1570  
  1571  	res, err := frontendClient.Get(frontend.URL)
  1572  	if err != nil {
  1573  		t.Fatalf("Get: %v", err)
  1574  	}
  1575  
  1576  	io.ReadAll(res.Body)
  1577  	res.Body.Close()
  1578  	if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
  1579  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
  1580  	}
  1581  
  1582  }
  1583  
  1584  func TestSetURL(t *testing.T) {
  1585  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1586  		w.Write([]byte(r.Host))
  1587  	}))
  1588  	defer backend.Close()
  1589  	backendURL, err := url.Parse(backend.URL)
  1590  	if err != nil {
  1591  		t.Fatal(err)
  1592  	}
  1593  	proxyHandler := &ReverseProxy{
  1594  		Rewrite: func(r *ProxyRequest) {
  1595  			r.SetURL(backendURL)
  1596  		},
  1597  	}
  1598  	frontend := httptest.NewServer(proxyHandler)
  1599  	defer frontend.Close()
  1600  	frontendClient := frontend.Client()
  1601  
  1602  	res, err := frontendClient.Get(frontend.URL)
  1603  	if err != nil {
  1604  		t.Fatalf("Get: %v", err)
  1605  	}
  1606  	defer res.Body.Close()
  1607  
  1608  	body, err := io.ReadAll(res.Body)
  1609  	if err != nil {
  1610  		t.Fatalf("Reading body: %v", err)
  1611  	}
  1612  
  1613  	if got, want := string(body), backendURL.Host; got != want {
  1614  		t.Errorf("backend got Host %q, want %q", got, want)
  1615  	}
  1616  }
  1617  
  1618  func TestSingleJoinSlash(t *testing.T) {
  1619  	tests := []struct {
  1620  		slasha   string
  1621  		slashb   string
  1622  		expected string
  1623  	}{
  1624  		{"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1625  		{"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1626  		{"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
  1627  		{"https://www.google.com", "", "https://www.google.com/"},
  1628  		{"", "favicon.ico", "/favicon.ico"},
  1629  	}
  1630  	for _, tt := range tests {
  1631  		if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
  1632  			t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
  1633  				tt.slasha,
  1634  				tt.slashb,
  1635  				tt.expected,
  1636  				got)
  1637  		}
  1638  	}
  1639  }
  1640  
  1641  func TestJoinURLPath(t *testing.T) {
  1642  	tests := []struct {
  1643  		a        *url.URL
  1644  		b        *url.URL
  1645  		wantPath string
  1646  		wantRaw  string
  1647  	}{
  1648  		{&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
  1649  		{&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
  1650  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1651  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1652  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
  1653  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
  1654  	}
  1655  
  1656  	for _, tt := range tests {
  1657  		p, rp := joinURLPath(tt.a, tt.b)
  1658  		if p != tt.wantPath || rp != tt.wantRaw {
  1659  			t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
  1660  				tt.a.Path, tt.a.RawPath,
  1661  				tt.b.Path, tt.b.RawPath,
  1662  				tt.wantPath, tt.wantRaw,
  1663  				p, rp)
  1664  		}
  1665  	}
  1666  }
  1667  
  1668  func TestReverseProxyRewriteReplacesOut(t *testing.T) {
  1669  	const content = "response_content"
  1670  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1671  		w.Write([]byte(content))
  1672  	}))
  1673  	defer backend.Close()
  1674  	proxyHandler := &ReverseProxy{
  1675  		Rewrite: func(r *ProxyRequest) {
  1676  			r.Out, _ = http.NewRequest("GET", backend.URL, nil)
  1677  		},
  1678  	}
  1679  	frontend := httptest.NewServer(proxyHandler)
  1680  	defer frontend.Close()
  1681  
  1682  	res, err := frontend.Client().Get(frontend.URL)
  1683  	if err != nil {
  1684  		t.Fatalf("Get: %v", err)
  1685  	}
  1686  	defer res.Body.Close()
  1687  	body, _ := io.ReadAll(res.Body)
  1688  	if got, want := string(body), content; got != want {
  1689  		t.Errorf("got response %q, want %q", got, want)
  1690  	}
  1691  }
  1692  
  1693  func Test1xxHeadersNotModifiedAfterRoundTrip(t *testing.T) {
  1694  	// https://go.dev/issue/65123: We use httptrace.Got1xxResponse to capture 1xx responses
  1695  	// and proxy them. httptrace handlers can execute after RoundTrip returns, in particular
  1696  	// after experiencing connection errors. When this happens, we shouldn't modify the
  1697  	// ResponseWriter headers after ReverseProxy.ServeHTTP returns.
  1698  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1699  		for i := 0; i < 5; i++ {
  1700  			w.WriteHeader(103)
  1701  		}
  1702  	}))
  1703  	defer backend.Close()
  1704  	backendURL, err := url.Parse(backend.URL)
  1705  	if err != nil {
  1706  		t.Fatal(err)
  1707  	}
  1708  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1709  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1710  
  1711  	rw := &testResponseWriter{}
  1712  	func() {
  1713  		// Cancel the request (and cause RoundTrip to return) immediately upon
  1714  		// seeing a 1xx response.
  1715  		ctx, cancel := context.WithCancel(context.Background())
  1716  		defer cancel()
  1717  		ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
  1718  			Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  1719  				cancel()
  1720  				return nil
  1721  			},
  1722  		})
  1723  
  1724  		req, _ := http.NewRequestWithContext(ctx, "GET", "http://go.dev/", nil)
  1725  		proxyHandler.ServeHTTP(rw, req)
  1726  	}()
  1727  	// Trigger data race while iterating over response headers.
  1728  	// When run with -race, this causes the condition in https://go.dev/issue/65123 often
  1729  	// enough to detect reliably.
  1730  	for _ = range rw.Header() {
  1731  	}
  1732  }
  1733  
  1734  func Test1xxResponses(t *testing.T) {
  1735  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1736  		h := w.Header()
  1737  		h.Add("Link", "</style.css>; rel=preload; as=style")
  1738  		h.Add("Link", "</script.js>; rel=preload; as=script")
  1739  		w.WriteHeader(http.StatusEarlyHints)
  1740  
  1741  		h.Add("Link", "</foo.js>; rel=preload; as=script")
  1742  		w.WriteHeader(http.StatusProcessing)
  1743  
  1744  		w.Write([]byte("Hello"))
  1745  	}))
  1746  	defer backend.Close()
  1747  	backendURL, err := url.Parse(backend.URL)
  1748  	if err != nil {
  1749  		t.Fatal(err)
  1750  	}
  1751  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1752  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1753  	frontend := httptest.NewServer(proxyHandler)
  1754  	defer frontend.Close()
  1755  	frontendClient := frontend.Client()
  1756  
  1757  	checkLinkHeaders := func(t *testing.T, expected, got []string) {
  1758  		t.Helper()
  1759  
  1760  		if len(expected) != len(got) {
  1761  			t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
  1762  		}
  1763  
  1764  		for i := range expected {
  1765  			if i >= len(got) {
  1766  				t.Errorf("Expected %q link header; got nothing", expected[i])
  1767  
  1768  				continue
  1769  			}
  1770  
  1771  			if expected[i] != got[i] {
  1772  				t.Errorf("Expected %q link header; got %q", expected[i], got[i])
  1773  			}
  1774  		}
  1775  	}
  1776  
  1777  	var respCounter uint8
  1778  	trace := &httptrace.ClientTrace{
  1779  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  1780  			switch code {
  1781  			case http.StatusEarlyHints:
  1782  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
  1783  			case http.StatusProcessing:
  1784  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
  1785  			default:
  1786  				t.Error("Unexpected 1xx response")
  1787  			}
  1788  
  1789  			respCounter++
  1790  
  1791  			return nil
  1792  		},
  1793  	}
  1794  	req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)
  1795  
  1796  	res, err := frontendClient.Do(req)
  1797  	if err != nil {
  1798  		t.Fatalf("Get: %v", err)
  1799  	}
  1800  
  1801  	defer res.Body.Close()
  1802  
  1803  	if respCounter != 2 {
  1804  		t.Errorf("Expected 2 1xx responses; got %d", respCounter)
  1805  	}
  1806  	checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
  1807  
  1808  	body, _ := io.ReadAll(res.Body)
  1809  	if string(body) != "Hello" {
  1810  		t.Errorf("Read body %q; want Hello", body)
  1811  	}
  1812  }
  1813  
  1814  const (
  1815  	testWantsCleanQuery = true
  1816  	testWantsRawQuery   = false
  1817  )
  1818  
  1819  func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) {
  1820  	testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
  1821  		proxyHandler := NewSingleHostReverseProxy(u)
  1822  		oldDirector := proxyHandler.Director
  1823  		proxyHandler.Director = func(r *http.Request) {
  1824  			oldDirector(r)
  1825  		}
  1826  		return proxyHandler
  1827  	})
  1828  }
  1829  
  1830  func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) {
  1831  	testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
  1832  		proxyHandler := NewSingleHostReverseProxy(u)
  1833  		oldDirector := proxyHandler.Director
  1834  		proxyHandler.Director = func(r *http.Request) {
  1835  			// Parsing the form causes ReverseProxy to remove unparsable
  1836  			// query parameters before forwarding.
  1837  			r.FormValue("a")
  1838  			oldDirector(r)
  1839  		}
  1840  		return proxyHandler
  1841  	})
  1842  }
  1843  
  1844  func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) {
  1845  	testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
  1846  		return &ReverseProxy{
  1847  			Rewrite: func(r *ProxyRequest) {
  1848  				r.SetURL(u)
  1849  			},
  1850  		}
  1851  	})
  1852  }
  1853  
  1854  func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) {
  1855  	testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
  1856  		return &ReverseProxy{
  1857  			Rewrite: func(r *ProxyRequest) {
  1858  				r.SetURL(u)
  1859  				r.Out.URL.RawQuery = r.In.URL.RawQuery
  1860  			},
  1861  		}
  1862  	})
  1863  }
  1864  
  1865  func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) {
  1866  	const content = "response_content"
  1867  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1868  		w.Write([]byte(r.URL.RawQuery))
  1869  	}))
  1870  	defer backend.Close()
  1871  	backendURL, err := url.Parse(backend.URL)
  1872  	if err != nil {
  1873  		t.Fatal(err)
  1874  	}
  1875  	proxyHandler := newProxy(backendURL)
  1876  	frontend := httptest.NewServer(proxyHandler)
  1877  	defer frontend.Close()
  1878  
  1879  	// Don't spam output with logs of queries containing semicolons.
  1880  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
  1881  	frontend.Config.ErrorLog = log.New(io.Discard, "", 0)
  1882  
  1883  	for _, test := range []struct {
  1884  		rawQuery   string
  1885  		cleanQuery string
  1886  	}{{
  1887  		rawQuery:   "a=1&a=2;b=3",
  1888  		cleanQuery: "a=1",
  1889  	}, {
  1890  		rawQuery:   "a=1&a=%zz&b=3",
  1891  		cleanQuery: "a=1&b=3",
  1892  	}} {
  1893  		res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery)
  1894  		if err != nil {
  1895  			t.Fatalf("Get: %v", err)
  1896  		}
  1897  		defer res.Body.Close()
  1898  		body, _ := io.ReadAll(res.Body)
  1899  		wantQuery := test.rawQuery
  1900  		if wantCleanQuery {
  1901  			wantQuery = test.cleanQuery
  1902  		}
  1903  		if got, want := string(body), wantQuery; got != want {
  1904  			t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want)
  1905  		}
  1906  	}
  1907  }
  1908  
  1909  type testResponseWriter struct {
  1910  	h           http.Header
  1911  	writeHeader func(int)
  1912  	write       func([]byte) (int, error)
  1913  }
  1914  
  1915  func (rw *testResponseWriter) Header() http.Header {
  1916  	if rw.h == nil {
  1917  		rw.h = make(http.Header)
  1918  	}
  1919  	return rw.h
  1920  }
  1921  
  1922  func (rw *testResponseWriter) WriteHeader(statusCode int) {
  1923  	if rw.writeHeader != nil {
  1924  		rw.writeHeader(statusCode)
  1925  	}
  1926  }
  1927  
  1928  func (rw *testResponseWriter) Write(p []byte) (int, error) {
  1929  	if rw.write != nil {
  1930  		return rw.write(p)
  1931  	}
  1932  	return len(p), nil
  1933  }
  1934  

View as plain text