1
2
3
4
5
6
7 package httputil
8
9 import (
10 "bufio"
11 "bytes"
12 "context"
13 "errors"
14 "fmt"
15 "io"
16 "log"
17 "net/http"
18 "net/http/httptest"
19 "net/http/httptrace"
20 "net/http/internal/ascii"
21 "net/textproto"
22 "net/url"
23 "os"
24 "reflect"
25 "slices"
26 "strconv"
27 "strings"
28 "sync"
29 "testing"
30 "time"
31 )
32
33 const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
34
35 func init() {
36 inOurTests = true
37 hopHeaders = append(hopHeaders, fakeHopHeader)
38 }
39
40 func TestReverseProxy(t *testing.T) {
41 const backendResponse = "I am the backend"
42 const backendStatus = 404
43 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
44 if r.Method == "GET" && r.FormValue("mode") == "hangup" {
45 c, _, _ := w.(http.Hijacker).Hijack()
46 c.Close()
47 return
48 }
49 if len(r.TransferEncoding) > 0 {
50 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
51 }
52 if r.Header.Get("X-Forwarded-For") == "" {
53 t.Errorf("didn't get X-Forwarded-For header")
54 }
55 if c := r.Header.Get("Connection"); c != "" {
56 t.Errorf("handler got Connection header value %q", c)
57 }
58 if c := r.Header.Get("Te"); c != "trailers" {
59 t.Errorf("handler got Te header value %q; want 'trailers'", c)
60 }
61 if c := r.Header.Get("Upgrade"); c != "" {
62 t.Errorf("handler got Upgrade header value %q", c)
63 }
64 if c := r.Header.Get("Proxy-Connection"); c != "" {
65 t.Errorf("handler got Proxy-Connection header value %q", c)
66 }
67 if g, e := r.Host, "some-name"; g != e {
68 t.Errorf("backend got Host header %q, want %q", g, e)
69 }
70 w.Header().Set("Trailers", "not a special header field name")
71 w.Header().Set("Trailer", "X-Trailer")
72 w.Header().Set("X-Foo", "bar")
73 w.Header().Set("Upgrade", "foo")
74 w.Header().Set(fakeHopHeader, "foo")
75 w.Header().Add("X-Multi-Value", "foo")
76 w.Header().Add("X-Multi-Value", "bar")
77 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
78 w.WriteHeader(backendStatus)
79 w.Write([]byte(backendResponse))
80 w.Header().Set("X-Trailer", "trailer_value")
81 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
82 }))
83 defer backend.Close()
84 backendURL, err := url.Parse(backend.URL)
85 if err != nil {
86 t.Fatal(err)
87 }
88 proxyHandler := NewSingleHostReverseProxy(backendURL)
89 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
90 frontend := httptest.NewServer(proxyHandler)
91 defer frontend.Close()
92 frontendClient := frontend.Client()
93
94 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
95 getReq.Host = "some-name"
96 getReq.Header.Set("Connection", "close, TE")
97 getReq.Header.Add("Te", "foo")
98 getReq.Header.Add("Te", "bar, trailers")
99 getReq.Header.Set("Proxy-Connection", "should be deleted")
100 getReq.Header.Set("Upgrade", "foo")
101 getReq.Close = true
102 res, err := frontendClient.Do(getReq)
103 if err != nil {
104 t.Fatalf("Get: %v", err)
105 }
106 if g, e := res.StatusCode, backendStatus; g != e {
107 t.Errorf("got res.StatusCode %d; expected %d", g, e)
108 }
109 if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
110 t.Errorf("got X-Foo %q; expected %q", g, e)
111 }
112 if c := res.Header.Get(fakeHopHeader); c != "" {
113 t.Errorf("got %s header value %q", fakeHopHeader, c)
114 }
115 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
116 t.Errorf("header Trailers = %q; want %q", g, e)
117 }
118 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
119 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
120 }
121 if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
122 t.Fatalf("got %d SetCookies, want %d", g, e)
123 }
124 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
125 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
126 }
127 if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
128 t.Errorf("unexpected cookie %q", cookie.Name)
129 }
130 bodyBytes, _ := io.ReadAll(res.Body)
131 if g, e := string(bodyBytes), backendResponse; g != e {
132 t.Errorf("got body %q; expected %q", g, e)
133 }
134 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
135 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
136 }
137 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
138 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
139 }
140 res.Body.Close()
141
142
143
144 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
145 getReq.Close = true
146 res, err = frontendClient.Do(getReq)
147 if err != nil {
148 t.Fatal(err)
149 }
150 res.Body.Close()
151 if res.StatusCode != http.StatusBadGateway {
152 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
153 }
154
155 }
156
157
158
159 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
160 const fakeConnectionToken = "X-Fake-Connection-Token"
161 const backendResponse = "I am the backend"
162
163
164
165 const someConnHeader = "X-Some-Conn-Header"
166
167 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
168 if c := r.Header.Get("Connection"); c != "" {
169 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
170 }
171 if c := r.Header.Get(fakeConnectionToken); c != "" {
172 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
173 }
174 if c := r.Header.Get(someConnHeader); c != "" {
175 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
176 }
177 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
178 w.Header().Add("Connection", someConnHeader)
179 w.Header().Set(someConnHeader, "should be deleted")
180 w.Header().Set(fakeConnectionToken, "should be deleted")
181 io.WriteString(w, backendResponse)
182 }))
183 defer backend.Close()
184 backendURL, err := url.Parse(backend.URL)
185 if err != nil {
186 t.Fatal(err)
187 }
188 proxyHandler := NewSingleHostReverseProxy(backendURL)
189 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
190 proxyHandler.ServeHTTP(w, r)
191 if c := r.Header.Get(someConnHeader); c != "should be deleted" {
192 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
193 }
194 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
195 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
196 }
197 c := r.Header["Connection"]
198 var cf []string
199 for _, f := range c {
200 for _, sf := range strings.Split(f, ",") {
201 if sf = strings.TrimSpace(sf); sf != "" {
202 cf = append(cf, sf)
203 }
204 }
205 }
206 slices.Sort(cf)
207 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
208 slices.Sort(expectedValues)
209 if !slices.Equal(cf, expectedValues) {
210 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
211 }
212 }))
213 defer frontend.Close()
214
215 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
216 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
217 getReq.Header.Add("Connection", someConnHeader)
218 getReq.Header.Set(someConnHeader, "should be deleted")
219 getReq.Header.Set(fakeConnectionToken, "should be deleted")
220 res, err := frontend.Client().Do(getReq)
221 if err != nil {
222 t.Fatalf("Get: %v", err)
223 }
224 defer res.Body.Close()
225 bodyBytes, err := io.ReadAll(res.Body)
226 if err != nil {
227 t.Fatalf("reading body: %v", err)
228 }
229 if got, want := string(bodyBytes), backendResponse; got != want {
230 t.Errorf("got body %q; want %q", got, want)
231 }
232 if c := res.Header.Get("Connection"); c != "" {
233 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
234 }
235 if c := res.Header.Get(someConnHeader); c != "" {
236 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
237 }
238 if c := res.Header.Get(fakeConnectionToken); c != "" {
239 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
240 }
241 }
242
243 func TestReverseProxyStripEmptyConnection(t *testing.T) {
244
245 const backendResponse = "I am the backend"
246
247
248
249 const someConnHeader = "X-Some-Conn-Header"
250
251 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
252 if c := r.Header.Values("Connection"); len(c) != 0 {
253 t.Errorf("handler got header %q = %v; want empty", "Connection", c)
254 }
255 if c := r.Header.Get(someConnHeader); c != "" {
256 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
257 }
258 w.Header().Add("Connection", "")
259 w.Header().Add("Connection", someConnHeader)
260 w.Header().Set(someConnHeader, "should be deleted")
261 io.WriteString(w, backendResponse)
262 }))
263 defer backend.Close()
264 backendURL, err := url.Parse(backend.URL)
265 if err != nil {
266 t.Fatal(err)
267 }
268 proxyHandler := NewSingleHostReverseProxy(backendURL)
269 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
270 proxyHandler.ServeHTTP(w, r)
271 if c := r.Header.Get(someConnHeader); c != "should be deleted" {
272 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
273 }
274 }))
275 defer frontend.Close()
276
277 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
278 getReq.Header.Add("Connection", "")
279 getReq.Header.Add("Connection", someConnHeader)
280 getReq.Header.Set(someConnHeader, "should be deleted")
281 res, err := frontend.Client().Do(getReq)
282 if err != nil {
283 t.Fatalf("Get: %v", err)
284 }
285 defer res.Body.Close()
286 bodyBytes, err := io.ReadAll(res.Body)
287 if err != nil {
288 t.Fatalf("reading body: %v", err)
289 }
290 if got, want := string(bodyBytes), backendResponse; got != want {
291 t.Errorf("got body %q; want %q", got, want)
292 }
293 if c := res.Header.Get("Connection"); c != "" {
294 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
295 }
296 if c := res.Header.Get(someConnHeader); c != "" {
297 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
298 }
299 }
300
301 func TestXForwardedFor(t *testing.T) {
302 const prevForwardedFor = "client ip"
303 const backendResponse = "I am the backend"
304 const backendStatus = 404
305 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
306 if r.Header.Get("X-Forwarded-For") == "" {
307 t.Errorf("didn't get X-Forwarded-For header")
308 }
309 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
310 t.Errorf("X-Forwarded-For didn't contain prior data")
311 }
312 w.WriteHeader(backendStatus)
313 w.Write([]byte(backendResponse))
314 }))
315 defer backend.Close()
316 backendURL, err := url.Parse(backend.URL)
317 if err != nil {
318 t.Fatal(err)
319 }
320 proxyHandler := NewSingleHostReverseProxy(backendURL)
321 frontend := httptest.NewServer(proxyHandler)
322 defer frontend.Close()
323
324 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
325 getReq.Header.Set("Connection", "close")
326 getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
327 getReq.Close = true
328 res, err := frontend.Client().Do(getReq)
329 if err != nil {
330 t.Fatalf("Get: %v", err)
331 }
332 defer res.Body.Close()
333 if g, e := res.StatusCode, backendStatus; g != e {
334 t.Errorf("got res.StatusCode %d; expected %d", g, e)
335 }
336 bodyBytes, _ := io.ReadAll(res.Body)
337 if g, e := string(bodyBytes), backendResponse; g != e {
338 t.Errorf("got body %q; expected %q", g, e)
339 }
340 }
341
342
343 func TestXForwardedFor_Omit(t *testing.T) {
344 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
345 if v := r.Header.Get("X-Forwarded-For"); v != "" {
346 t.Errorf("got X-Forwarded-For header: %q", v)
347 }
348 w.Write([]byte("hi"))
349 }))
350 defer backend.Close()
351 backendURL, err := url.Parse(backend.URL)
352 if err != nil {
353 t.Fatal(err)
354 }
355 proxyHandler := NewSingleHostReverseProxy(backendURL)
356 frontend := httptest.NewServer(proxyHandler)
357 defer frontend.Close()
358
359 oldDirector := proxyHandler.Director
360 proxyHandler.Director = func(r *http.Request) {
361 r.Header["X-Forwarded-For"] = nil
362 oldDirector(r)
363 }
364
365 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
366 getReq.Host = "some-name"
367 getReq.Close = true
368 res, err := frontend.Client().Do(getReq)
369 if err != nil {
370 t.Fatalf("Get: %v", err)
371 }
372 res.Body.Close()
373 }
374
375 func TestReverseProxyRewriteStripsForwarded(t *testing.T) {
376 headers := []string{
377 "Forwarded",
378 "X-Forwarded-For",
379 "X-Forwarded-Host",
380 "X-Forwarded-Proto",
381 }
382 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
383 for _, h := range headers {
384 if v := r.Header.Get(h); v != "" {
385 t.Errorf("got %v header: %q", h, v)
386 }
387 }
388 }))
389 defer backend.Close()
390 backendURL, err := url.Parse(backend.URL)
391 if err != nil {
392 t.Fatal(err)
393 }
394 proxyHandler := &ReverseProxy{
395 Rewrite: func(r *ProxyRequest) {
396 r.SetURL(backendURL)
397 },
398 }
399 frontend := httptest.NewServer(proxyHandler)
400 defer frontend.Close()
401
402 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
403 getReq.Host = "some-name"
404 getReq.Close = true
405 for _, h := range headers {
406 getReq.Header.Set(h, "x")
407 }
408 res, err := frontend.Client().Do(getReq)
409 if err != nil {
410 t.Fatalf("Get: %v", err)
411 }
412 res.Body.Close()
413 }
414
415 var proxyQueryTests = []struct {
416 baseSuffix string
417 reqSuffix string
418 want string
419 }{
420 {"", "", ""},
421 {"?sta=tic", "?us=er", "sta=tic&us=er"},
422 {"", "?us=er", "us=er"},
423 {"?sta=tic", "", "sta=tic"},
424 }
425
426 func TestReverseProxyQuery(t *testing.T) {
427 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
428 w.Header().Set("X-Got-Query", r.URL.RawQuery)
429 w.Write([]byte("hi"))
430 }))
431 defer backend.Close()
432
433 for i, tt := range proxyQueryTests {
434 backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
435 if err != nil {
436 t.Fatal(err)
437 }
438 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
439 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
440 req.Close = true
441 res, err := frontend.Client().Do(req)
442 if err != nil {
443 t.Fatalf("%d. Get: %v", i, err)
444 }
445 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
446 t.Errorf("%d. got query %q; expected %q", i, g, e)
447 }
448 res.Body.Close()
449 frontend.Close()
450 }
451 }
452
453 func TestReverseProxyFlushInterval(t *testing.T) {
454 const expected = "hi"
455 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
456 w.Write([]byte(expected))
457 }))
458 defer backend.Close()
459
460 backendURL, err := url.Parse(backend.URL)
461 if err != nil {
462 t.Fatal(err)
463 }
464
465 proxyHandler := NewSingleHostReverseProxy(backendURL)
466 proxyHandler.FlushInterval = time.Microsecond
467
468 frontend := httptest.NewServer(proxyHandler)
469 defer frontend.Close()
470
471 req, _ := http.NewRequest("GET", frontend.URL, nil)
472 req.Close = true
473 res, err := frontend.Client().Do(req)
474 if err != nil {
475 t.Fatalf("Get: %v", err)
476 }
477 defer res.Body.Close()
478 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
479 t.Errorf("got body %q; expected %q", bodyBytes, expected)
480 }
481 }
482
483 type mockFlusher struct {
484 http.ResponseWriter
485 flushed bool
486 }
487
488 func (m *mockFlusher) Flush() {
489 m.flushed = true
490 }
491
492 type wrappedRW struct {
493 http.ResponseWriter
494 }
495
496 func (w *wrappedRW) Unwrap() http.ResponseWriter {
497 return w.ResponseWriter
498 }
499
500 func TestReverseProxyResponseControllerFlushInterval(t *testing.T) {
501 const expected = "hi"
502 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
503 w.Write([]byte(expected))
504 }))
505 defer backend.Close()
506
507 backendURL, err := url.Parse(backend.URL)
508 if err != nil {
509 t.Fatal(err)
510 }
511
512 mf := &mockFlusher{}
513 proxyHandler := NewSingleHostReverseProxy(backendURL)
514 proxyHandler.FlushInterval = -1
515 proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
516 mf.ResponseWriter = w
517 w = &wrappedRW{mf}
518 proxyHandler.ServeHTTP(w, r)
519 })
520
521 frontend := httptest.NewServer(proxyWithMiddleware)
522 defer frontend.Close()
523
524 req, _ := http.NewRequest("GET", frontend.URL, nil)
525 req.Close = true
526 res, err := frontend.Client().Do(req)
527 if err != nil {
528 t.Fatalf("Get: %v", err)
529 }
530 defer res.Body.Close()
531 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
532 t.Errorf("got body %q; expected %q", bodyBytes, expected)
533 }
534 if !mf.flushed {
535 t.Errorf("response writer was not flushed")
536 }
537 }
538
539 func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
540 const expected = "hi"
541 stopCh := make(chan struct{})
542 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
543 w.Header().Add("MyHeader", expected)
544 w.WriteHeader(200)
545 w.(http.Flusher).Flush()
546 <-stopCh
547 }))
548 defer backend.Close()
549 defer close(stopCh)
550
551 backendURL, err := url.Parse(backend.URL)
552 if err != nil {
553 t.Fatal(err)
554 }
555
556 proxyHandler := NewSingleHostReverseProxy(backendURL)
557 proxyHandler.FlushInterval = time.Microsecond
558
559 frontend := httptest.NewServer(proxyHandler)
560 defer frontend.Close()
561
562 req, _ := http.NewRequest("GET", frontend.URL, nil)
563 req.Close = true
564
565 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
566 defer cancel()
567 req = req.WithContext(ctx)
568
569 res, err := frontend.Client().Do(req)
570 if err != nil {
571 t.Fatalf("Get: %v", err)
572 }
573 defer res.Body.Close()
574
575 if res.Header.Get("MyHeader") != expected {
576 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
577 }
578 }
579
580 func TestReverseProxyCancellation(t *testing.T) {
581 const backendResponse = "I am the backend"
582
583 reqInFlight := make(chan struct{})
584 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
585 close(reqInFlight)
586
587 select {
588 case <-time.After(10 * time.Second):
589
590
591 t.Error("Handler never saw CloseNotify")
592 return
593 case <-w.(http.CloseNotifier).CloseNotify():
594 }
595
596 w.WriteHeader(http.StatusOK)
597 w.Write([]byte(backendResponse))
598 }))
599
600 defer backend.Close()
601
602 backend.Config.ErrorLog = log.New(io.Discard, "", 0)
603
604 backendURL, err := url.Parse(backend.URL)
605 if err != nil {
606 t.Fatal(err)
607 }
608
609 proxyHandler := NewSingleHostReverseProxy(backendURL)
610
611
612
613 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
614
615 frontend := httptest.NewServer(proxyHandler)
616 defer frontend.Close()
617 frontendClient := frontend.Client()
618
619 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
620 go func() {
621 <-reqInFlight
622 frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
623 }()
624 res, err := frontendClient.Do(getReq)
625 if res != nil {
626 t.Errorf("got response %v; want nil", res.Status)
627 }
628 if err == nil {
629
630
631
632 t.Error("Server.Client().Do() returned nil error; want non-nil error")
633 }
634 }
635
636 func req(t *testing.T, v string) *http.Request {
637 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
638 if err != nil {
639 t.Fatal(err)
640 }
641 return req
642 }
643
644
645 func TestNilBody(t *testing.T) {
646 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
647 w.Write([]byte("hi"))
648 }))
649 defer backend.Close()
650
651 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
652 backURL, _ := url.Parse(backend.URL)
653 rp := NewSingleHostReverseProxy(backURL)
654 r := req(t, "GET / HTTP/1.0\r\n\r\n")
655 r.Body = nil
656 rp.ServeHTTP(w, r)
657 }))
658 defer frontend.Close()
659
660 res, err := http.Get(frontend.URL)
661 if err != nil {
662 t.Fatal(err)
663 }
664 defer res.Body.Close()
665 slurp, err := io.ReadAll(res.Body)
666 if err != nil {
667 t.Fatal(err)
668 }
669 if string(slurp) != "hi" {
670 t.Errorf("Got %q; want %q", slurp, "hi")
671 }
672 }
673
674
675 func TestUserAgentHeader(t *testing.T) {
676 var gotUA string
677 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
678 gotUA = r.Header.Get("User-Agent")
679 }))
680 defer backend.Close()
681 backendURL, err := url.Parse(backend.URL)
682 if err != nil {
683 t.Fatal(err)
684 }
685
686 proxyHandler := new(ReverseProxy)
687 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
688 proxyHandler.Director = func(req *http.Request) {
689 req.URL = backendURL
690 }
691 frontend := httptest.NewServer(proxyHandler)
692 defer frontend.Close()
693 frontendClient := frontend.Client()
694
695 for _, sentUA := range []string{"explicit UA", ""} {
696 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
697 getReq.Header.Set("User-Agent", sentUA)
698 getReq.Close = true
699 res, err := frontendClient.Do(getReq)
700 if err != nil {
701 t.Fatalf("Get: %v", err)
702 }
703 res.Body.Close()
704 if got, want := gotUA, sentUA; got != want {
705 t.Errorf("got forwarded User-Agent %q, want %q", got, want)
706 }
707 }
708 }
709
710 type bufferPool struct {
711 get func() []byte
712 put func([]byte)
713 }
714
715 func (bp bufferPool) Get() []byte { return bp.get() }
716 func (bp bufferPool) Put(v []byte) { bp.put(v) }
717
718 func TestReverseProxyGetPutBuffer(t *testing.T) {
719 const msg = "hi"
720 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
721 io.WriteString(w, msg)
722 }))
723 defer backend.Close()
724
725 backendURL, err := url.Parse(backend.URL)
726 if err != nil {
727 t.Fatal(err)
728 }
729
730 var (
731 mu sync.Mutex
732 log []string
733 )
734 addLog := func(event string) {
735 mu.Lock()
736 defer mu.Unlock()
737 log = append(log, event)
738 }
739 rp := NewSingleHostReverseProxy(backendURL)
740 const size = 1234
741 rp.BufferPool = bufferPool{
742 get: func() []byte {
743 addLog("getBuf")
744 return make([]byte, size)
745 },
746 put: func(p []byte) {
747 addLog("putBuf-" + strconv.Itoa(len(p)))
748 },
749 }
750 frontend := httptest.NewServer(rp)
751 defer frontend.Close()
752
753 req, _ := http.NewRequest("GET", frontend.URL, nil)
754 req.Close = true
755 res, err := frontend.Client().Do(req)
756 if err != nil {
757 t.Fatalf("Get: %v", err)
758 }
759 slurp, err := io.ReadAll(res.Body)
760 res.Body.Close()
761 if err != nil {
762 t.Fatalf("reading body: %v", err)
763 }
764 if string(slurp) != msg {
765 t.Errorf("msg = %q; want %q", slurp, msg)
766 }
767 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
768 mu.Lock()
769 defer mu.Unlock()
770 if !slices.Equal(log, wantLog) {
771 t.Errorf("Log events = %q; want %q", log, wantLog)
772 }
773 }
774
775 func TestReverseProxy_Post(t *testing.T) {
776 const backendResponse = "I am the backend"
777 const backendStatus = 200
778 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
779 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
780 slurp, err := io.ReadAll(r.Body)
781 if err != nil {
782 t.Errorf("Backend body read = %v", err)
783 }
784 if len(slurp) != len(requestBody) {
785 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
786 }
787 if !bytes.Equal(slurp, requestBody) {
788 t.Error("Backend read wrong request body.")
789 }
790 w.Write([]byte(backendResponse))
791 }))
792 defer backend.Close()
793 backendURL, err := url.Parse(backend.URL)
794 if err != nil {
795 t.Fatal(err)
796 }
797 proxyHandler := NewSingleHostReverseProxy(backendURL)
798 frontend := httptest.NewServer(proxyHandler)
799 defer frontend.Close()
800
801 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
802 res, err := frontend.Client().Do(postReq)
803 if err != nil {
804 t.Fatalf("Do: %v", err)
805 }
806 defer res.Body.Close()
807 if g, e := res.StatusCode, backendStatus; g != e {
808 t.Errorf("got res.StatusCode %d; expected %d", g, e)
809 }
810 bodyBytes, _ := io.ReadAll(res.Body)
811 if g, e := string(bodyBytes), backendResponse; g != e {
812 t.Errorf("got body %q; expected %q", g, e)
813 }
814 }
815
816 type RoundTripperFunc func(*http.Request) (*http.Response, error)
817
818 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
819 return fn(req)
820 }
821
822
823 func TestReverseProxy_NilBody(t *testing.T) {
824 backendURL, _ := url.Parse("http://fake.tld/")
825 proxyHandler := NewSingleHostReverseProxy(backendURL)
826 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
827 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
828 if req.Body != nil {
829 t.Error("Body != nil; want a nil Body")
830 }
831 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
832 })
833 frontend := httptest.NewServer(proxyHandler)
834 defer frontend.Close()
835
836 res, err := frontend.Client().Get(frontend.URL)
837 if err != nil {
838 t.Fatal(err)
839 }
840 defer res.Body.Close()
841 if res.StatusCode != 502 {
842 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
843 }
844 }
845
846
847 func TestReverseProxy_AllocatedHeader(t *testing.T) {
848 proxyHandler := new(ReverseProxy)
849 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
850 proxyHandler.Director = func(*http.Request) {}
851 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
852 if req.Header == nil {
853 t.Error("Header == nil; want a non-nil Header")
854 }
855 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
856 })
857
858 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
859 Method: "GET",
860 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
861 Proto: "HTTP/1.0",
862 ProtoMajor: 1,
863 })
864 }
865
866
867
868 func TestReverseProxyModifyResponse(t *testing.T) {
869 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
870 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
871 }))
872 defer backendServer.Close()
873
874 rpURL, _ := url.Parse(backendServer.URL)
875 rproxy := NewSingleHostReverseProxy(rpURL)
876 rproxy.ErrorLog = log.New(io.Discard, "", 0)
877 rproxy.ModifyResponse = func(resp *http.Response) error {
878 if resp.Header.Get("X-Hit-Mod") != "true" {
879 return fmt.Errorf("tried to by-pass proxy")
880 }
881 return nil
882 }
883
884 frontendProxy := httptest.NewServer(rproxy)
885 defer frontendProxy.Close()
886
887 tests := []struct {
888 url string
889 wantCode int
890 }{
891 {frontendProxy.URL + "/mod", http.StatusOK},
892 {frontendProxy.URL + "/schedule", http.StatusBadGateway},
893 }
894
895 for i, tt := range tests {
896 resp, err := http.Get(tt.url)
897 if err != nil {
898 t.Fatalf("failed to reach proxy: %v", err)
899 }
900 if g, e := resp.StatusCode, tt.wantCode; g != e {
901 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
902 }
903 resp.Body.Close()
904 }
905 }
906
907 type failingRoundTripper struct{}
908
909 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
910 return nil, errors.New("some error")
911 }
912
913 type staticResponseRoundTripper struct{ res *http.Response }
914
915 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
916 return rt.res, nil
917 }
918
919 func TestReverseProxyErrorHandler(t *testing.T) {
920 tests := []struct {
921 name string
922 wantCode int
923 errorHandler func(http.ResponseWriter, *http.Request, error)
924 transport http.RoundTripper
925 modifyResponse func(*http.Response) error
926 }{
927 {
928 name: "default",
929 wantCode: http.StatusBadGateway,
930 },
931 {
932 name: "errorhandler",
933 wantCode: http.StatusTeapot,
934 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
935 },
936 {
937 name: "modifyresponse_noerr",
938 transport: staticResponseRoundTripper{
939 &http.Response{StatusCode: 345, Body: http.NoBody},
940 },
941 modifyResponse: func(res *http.Response) error {
942 res.StatusCode++
943 return nil
944 },
945 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
946 wantCode: 346,
947 },
948 {
949 name: "modifyresponse_err",
950 transport: staticResponseRoundTripper{
951 &http.Response{StatusCode: 345, Body: http.NoBody},
952 },
953 modifyResponse: func(res *http.Response) error {
954 res.StatusCode++
955 return errors.New("some error to trigger errorHandler")
956 },
957 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
958 wantCode: http.StatusTeapot,
959 },
960 }
961
962 for _, tt := range tests {
963 t.Run(tt.name, func(t *testing.T) {
964 target := &url.URL{
965 Scheme: "http",
966 Host: "dummy.tld",
967 Path: "/",
968 }
969 rproxy := NewSingleHostReverseProxy(target)
970 rproxy.Transport = tt.transport
971 rproxy.ModifyResponse = tt.modifyResponse
972 if rproxy.Transport == nil {
973 rproxy.Transport = failingRoundTripper{}
974 }
975 rproxy.ErrorLog = log.New(io.Discard, "", 0)
976 if tt.errorHandler != nil {
977 rproxy.ErrorHandler = tt.errorHandler
978 }
979 frontendProxy := httptest.NewServer(rproxy)
980 defer frontendProxy.Close()
981
982 resp, err := http.Get(frontendProxy.URL + "/test")
983 if err != nil {
984 t.Fatalf("failed to reach proxy: %v", err)
985 }
986 if g, e := resp.StatusCode, tt.wantCode; g != e {
987 t.Errorf("got res.StatusCode %d; expected %d", g, e)
988 }
989 resp.Body.Close()
990 })
991 }
992 }
993
994
995 func TestReverseProxy_CopyBuffer(t *testing.T) {
996 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
997 out := "this call was relayed by the reverse proxy"
998
999 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
1000 fmt.Fprintln(w, out)
1001 }))
1002 defer backendServer.Close()
1003
1004 rpURL, err := url.Parse(backendServer.URL)
1005 if err != nil {
1006 t.Fatal(err)
1007 }
1008
1009 var proxyLog bytes.Buffer
1010 rproxy := NewSingleHostReverseProxy(rpURL)
1011 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
1012 donec := make(chan bool, 1)
1013 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1014 defer func() { donec <- true }()
1015 rproxy.ServeHTTP(w, r)
1016 }))
1017 defer frontendProxy.Close()
1018
1019 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
1020 t.Fatalf("want non-nil error")
1021 }
1022
1023
1024
1025
1026 <-donec
1027
1028 expected := []string{
1029 "EOF",
1030 "read",
1031 }
1032 for _, phrase := range expected {
1033 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
1034 t.Errorf("expected log to contain phrase %q", phrase)
1035 }
1036 }
1037 }
1038
1039 type staticTransport struct {
1040 res *http.Response
1041 }
1042
1043 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
1044 return t.res, nil
1045 }
1046
1047 func BenchmarkServeHTTP(b *testing.B) {
1048 res := &http.Response{
1049 StatusCode: 200,
1050 Body: io.NopCloser(strings.NewReader("")),
1051 }
1052 proxy := &ReverseProxy{
1053 Director: func(*http.Request) {},
1054 Transport: &staticTransport{res},
1055 }
1056
1057 w := httptest.NewRecorder()
1058 r := httptest.NewRequest("GET", "/", nil)
1059
1060 b.ReportAllocs()
1061 for i := 0; i < b.N; i++ {
1062 proxy.ServeHTTP(w, r)
1063 }
1064 }
1065
1066 func TestServeHTTPDeepCopy(t *testing.T) {
1067 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1068 w.Write([]byte("Hello Gopher!"))
1069 }))
1070 defer backend.Close()
1071 backendURL, err := url.Parse(backend.URL)
1072 if err != nil {
1073 t.Fatal(err)
1074 }
1075
1076 type result struct {
1077 before, after string
1078 }
1079
1080 resultChan := make(chan result, 1)
1081 proxyHandler := NewSingleHostReverseProxy(backendURL)
1082 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1083 before := r.URL.String()
1084 proxyHandler.ServeHTTP(w, r)
1085 after := r.URL.String()
1086 resultChan <- result{before: before, after: after}
1087 }))
1088 defer frontend.Close()
1089
1090 want := result{before: "/", after: "/"}
1091
1092 res, err := frontend.Client().Get(frontend.URL)
1093 if err != nil {
1094 t.Fatalf("Do: %v", err)
1095 }
1096 res.Body.Close()
1097
1098 got := <-resultChan
1099 if got != want {
1100 t.Errorf("got = %+v; want = %+v", got, want)
1101 }
1102 }
1103
1104
1105
1106 func TestClonesRequestHeaders(t *testing.T) {
1107 log.SetOutput(io.Discard)
1108 defer log.SetOutput(os.Stderr)
1109 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1110 req.RemoteAddr = "1.2.3.4:56789"
1111 rp := &ReverseProxy{
1112 Director: func(req *http.Request) {
1113 req.Header.Set("From-Director", "1")
1114 },
1115 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
1116 if v := req.Header.Get("From-Director"); v != "1" {
1117 t.Errorf("From-Directory value = %q; want 1", v)
1118 }
1119 return nil, io.EOF
1120 }),
1121 }
1122 rp.ServeHTTP(httptest.NewRecorder(), req)
1123
1124 for _, h := range []string{
1125 "From-Director",
1126 "X-Forwarded-For",
1127 } {
1128 if req.Header.Get(h) != "" {
1129 t.Errorf("%v header mutation modified caller's request", h)
1130 }
1131 }
1132 }
1133
1134 type roundTripperFunc func(req *http.Request) (*http.Response, error)
1135
1136 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
1137 return fn(req)
1138 }
1139
1140 func TestModifyResponseClosesBody(t *testing.T) {
1141 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1142 req.RemoteAddr = "1.2.3.4:56789"
1143 closeCheck := new(checkCloser)
1144 logBuf := new(strings.Builder)
1145 outErr := errors.New("ModifyResponse error")
1146 rp := &ReverseProxy{
1147 Director: func(req *http.Request) {},
1148 Transport: &staticTransport{&http.Response{
1149 StatusCode: 200,
1150 Body: closeCheck,
1151 }},
1152 ErrorLog: log.New(logBuf, "", 0),
1153 ModifyResponse: func(*http.Response) error {
1154 return outErr
1155 },
1156 }
1157 rec := httptest.NewRecorder()
1158 rp.ServeHTTP(rec, req)
1159 res := rec.Result()
1160 if g, e := res.StatusCode, http.StatusBadGateway; g != e {
1161 t.Errorf("got res.StatusCode %d; expected %d", g, e)
1162 }
1163 if !closeCheck.closed {
1164 t.Errorf("body should have been closed")
1165 }
1166 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
1167 t.Errorf("ErrorLog %q does not contain %q", g, e)
1168 }
1169 }
1170
1171 type checkCloser struct {
1172 closed bool
1173 }
1174
1175 func (cc *checkCloser) Close() error {
1176 cc.closed = true
1177 return nil
1178 }
1179
1180 func (cc *checkCloser) Read(b []byte) (int, error) {
1181 return len(b), nil
1182 }
1183
1184
1185 func TestReverseProxy_PanicBodyError(t *testing.T) {
1186 log.SetOutput(io.Discard)
1187 defer log.SetOutput(os.Stderr)
1188 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1189 out := "this call was relayed by the reverse proxy"
1190
1191 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
1192 fmt.Fprintln(w, out)
1193 }))
1194 defer backendServer.Close()
1195
1196 rpURL, err := url.Parse(backendServer.URL)
1197 if err != nil {
1198 t.Fatal(err)
1199 }
1200
1201 rproxy := NewSingleHostReverseProxy(rpURL)
1202
1203
1204
1205 defer func() {
1206 err := recover()
1207 if err == nil {
1208 t.Fatal("handler should have panicked")
1209 }
1210 if err != http.ErrAbortHandler {
1211 t.Fatal("expected ErrAbortHandler, got", err)
1212 }
1213 }()
1214 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1215 rproxy.ServeHTTP(httptest.NewRecorder(), req)
1216 }
1217
1218
1219 func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
1220 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1221 out := "this call was relayed by the reverse proxy"
1222
1223 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
1224 fmt.Fprintln(w, out)
1225 }))
1226 defer backend.Close()
1227 backendURL, err := url.Parse(backend.URL)
1228 if err != nil {
1229 t.Fatal(err)
1230 }
1231 proxyHandler := NewSingleHostReverseProxy(backendURL)
1232 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1233 frontend := httptest.NewServer(proxyHandler)
1234 defer frontend.Close()
1235 frontendClient := frontend.Client()
1236
1237 var wg sync.WaitGroup
1238 for i := 0; i < 2; i++ {
1239 wg.Add(1)
1240 go func() {
1241 defer wg.Done()
1242 for j := 0; j < 10; j++ {
1243 const reqLen = 6 * 1024 * 1024
1244 req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
1245 req.ContentLength = reqLen
1246 resp, _ := frontendClient.Transport.RoundTrip(req)
1247 if resp != nil {
1248 io.Copy(io.Discard, resp.Body)
1249 resp.Body.Close()
1250 }
1251 }
1252 }()
1253 }
1254 wg.Wait()
1255 }
1256
1257 func TestSelectFlushInterval(t *testing.T) {
1258 tests := []struct {
1259 name string
1260 p *ReverseProxy
1261 res *http.Response
1262 want time.Duration
1263 }{
1264 {
1265 name: "default",
1266 res: &http.Response{},
1267 p: &ReverseProxy{FlushInterval: 123},
1268 want: 123,
1269 },
1270 {
1271 name: "server-sent events overrides non-zero",
1272 res: &http.Response{
1273 Header: http.Header{
1274 "Content-Type": {"text/event-stream"},
1275 },
1276 },
1277 p: &ReverseProxy{FlushInterval: 123},
1278 want: -1,
1279 },
1280 {
1281 name: "server-sent events overrides zero",
1282 res: &http.Response{
1283 Header: http.Header{
1284 "Content-Type": {"text/event-stream"},
1285 },
1286 },
1287 p: &ReverseProxy{FlushInterval: 0},
1288 want: -1,
1289 },
1290 {
1291 name: "server-sent events with media-type parameters overrides non-zero",
1292 res: &http.Response{
1293 Header: http.Header{
1294 "Content-Type": {"text/event-stream;charset=utf-8"},
1295 },
1296 },
1297 p: &ReverseProxy{FlushInterval: 123},
1298 want: -1,
1299 },
1300 {
1301 name: "server-sent events with media-type parameters overrides zero",
1302 res: &http.Response{
1303 Header: http.Header{
1304 "Content-Type": {"text/event-stream;charset=utf-8"},
1305 },
1306 },
1307 p: &ReverseProxy{FlushInterval: 0},
1308 want: -1,
1309 },
1310 {
1311 name: "Content-Length: -1, overrides non-zero",
1312 res: &http.Response{
1313 ContentLength: -1,
1314 },
1315 p: &ReverseProxy{FlushInterval: 123},
1316 want: -1,
1317 },
1318 {
1319 name: "Content-Length: -1, overrides zero",
1320 res: &http.Response{
1321 ContentLength: -1,
1322 },
1323 p: &ReverseProxy{FlushInterval: 0},
1324 want: -1,
1325 },
1326 }
1327 for _, tt := range tests {
1328 t.Run(tt.name, func(t *testing.T) {
1329 got := tt.p.flushInterval(tt.res)
1330 if got != tt.want {
1331 t.Errorf("flushLatency = %v; want %v", got, tt.want)
1332 }
1333 })
1334 }
1335 }
1336
1337 func TestReverseProxyWebSocket(t *testing.T) {
1338 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1339 if upgradeType(r.Header) != "websocket" {
1340 t.Error("unexpected backend request")
1341 http.Error(w, "unexpected request", 400)
1342 return
1343 }
1344 c, _, err := w.(http.Hijacker).Hijack()
1345 if err != nil {
1346 t.Error(err)
1347 return
1348 }
1349 defer c.Close()
1350 io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
1351 bs := bufio.NewScanner(c)
1352 if !bs.Scan() {
1353 t.Errorf("backend failed to read line from client: %v", bs.Err())
1354 return
1355 }
1356 fmt.Fprintf(c, "backend got %q\n", bs.Text())
1357 }))
1358 defer backendServer.Close()
1359
1360 backURL, _ := url.Parse(backendServer.URL)
1361 rproxy := NewSingleHostReverseProxy(backURL)
1362 rproxy.ErrorLog = log.New(io.Discard, "", 0)
1363 rproxy.ModifyResponse = func(res *http.Response) error {
1364 res.Header.Add("X-Modified", "true")
1365 return nil
1366 }
1367
1368 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
1369 rw.Header().Set("X-Header", "X-Value")
1370 rproxy.ServeHTTP(rw, req)
1371 if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
1372 t.Errorf("response writer X-Modified header = %q; want %q", got, want)
1373 }
1374 })
1375
1376 frontendProxy := httptest.NewServer(handler)
1377 defer frontendProxy.Close()
1378
1379 req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1380 req.Header.Set("Connection", "Upgrade")
1381 req.Header.Set("Upgrade", "websocket")
1382
1383 c := frontendProxy.Client()
1384 res, err := c.Do(req)
1385 if err != nil {
1386 t.Fatal(err)
1387 }
1388 if res.StatusCode != 101 {
1389 t.Fatalf("status = %v; want 101", res.Status)
1390 }
1391
1392 got := res.Header.Get("X-Header")
1393 want := "X-Value"
1394 if got != want {
1395 t.Errorf("Header(XHeader) = %q; want %q", got, want)
1396 }
1397
1398 if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
1399 t.Fatalf("not websocket upgrade; got %#v", res.Header)
1400 }
1401 rwc, ok := res.Body.(io.ReadWriteCloser)
1402 if !ok {
1403 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
1404 }
1405 defer rwc.Close()
1406
1407 if got, want := res.Header.Get("X-Modified"), "true"; got != want {
1408 t.Errorf("response X-Modified header = %q; want %q", got, want)
1409 }
1410
1411 io.WriteString(rwc, "Hello\n")
1412 bs := bufio.NewScanner(rwc)
1413 if !bs.Scan() {
1414 t.Fatalf("Scan: %v", bs.Err())
1415 }
1416 got = bs.Text()
1417 want = `backend got "Hello"`
1418 if got != want {
1419 t.Errorf("got %#q, want %#q", got, want)
1420 }
1421 }
1422
1423 func TestReverseProxyWebSocketCancellation(t *testing.T) {
1424 n := 5
1425 triggerCancelCh := make(chan bool, n)
1426 nthResponse := func(i int) string {
1427 return fmt.Sprintf("backend response #%d\n", i)
1428 }
1429 terminalMsg := "final message"
1430
1431 cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1432 if g, ws := upgradeType(r.Header), "websocket"; g != ws {
1433 t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
1434 http.Error(w, "Unexpected request", 400)
1435 return
1436 }
1437 conn, bufrw, err := w.(http.Hijacker).Hijack()
1438 if err != nil {
1439 t.Error(err)
1440 return
1441 }
1442 defer conn.Close()
1443
1444 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
1445 if _, err := io.WriteString(conn, upgradeMsg); err != nil {
1446 t.Error(err)
1447 return
1448 }
1449 if _, _, err := bufrw.ReadLine(); err != nil {
1450 t.Errorf("Failed to read line from client: %v", err)
1451 return
1452 }
1453
1454 for i := 0; i < n; i++ {
1455 if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
1456 select {
1457 case <-triggerCancelCh:
1458 default:
1459 t.Errorf("Writing response #%d failed: %v", i, err)
1460 }
1461 return
1462 }
1463 bufrw.Flush()
1464 time.Sleep(time.Second)
1465 }
1466 if _, err := bufrw.WriteString(terminalMsg); err != nil {
1467 select {
1468 case <-triggerCancelCh:
1469 default:
1470 t.Errorf("Failed to write terminal message: %v", err)
1471 }
1472 }
1473 bufrw.Flush()
1474 }))
1475 defer cst.Close()
1476
1477 backendURL, _ := url.Parse(cst.URL)
1478 rproxy := NewSingleHostReverseProxy(backendURL)
1479 rproxy.ErrorLog = log.New(io.Discard, "", 0)
1480 rproxy.ModifyResponse = func(res *http.Response) error {
1481 res.Header.Add("X-Modified", "true")
1482 return nil
1483 }
1484
1485 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
1486 rw.Header().Set("X-Header", "X-Value")
1487 ctx, cancel := context.WithCancel(req.Context())
1488 go func() {
1489 <-triggerCancelCh
1490 cancel()
1491 }()
1492 rproxy.ServeHTTP(rw, req.WithContext(ctx))
1493 })
1494
1495 frontendProxy := httptest.NewServer(handler)
1496 defer frontendProxy.Close()
1497
1498 req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1499 req.Header.Set("Connection", "Upgrade")
1500 req.Header.Set("Upgrade", "websocket")
1501
1502 res, err := frontendProxy.Client().Do(req)
1503 if err != nil {
1504 t.Fatalf("Dialing to frontend proxy: %v", err)
1505 }
1506 defer res.Body.Close()
1507 if g, w := res.StatusCode, 101; g != w {
1508 t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
1509 }
1510
1511 if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
1512 t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
1513 }
1514
1515 if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
1516 t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
1517 }
1518
1519 rwc, ok := res.Body.(io.ReadWriteCloser)
1520 if !ok {
1521 t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
1522 }
1523
1524 if got, want := res.Header.Get("X-Modified"), "true"; got != want {
1525 t.Errorf("response X-Modified header = %q; want %q", got, want)
1526 }
1527
1528 if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
1529 t.Fatalf("Failed to write first message: %v", err)
1530 }
1531
1532
1533
1534 br := bufio.NewReader(rwc)
1535 for {
1536 line, err := br.ReadString('\n')
1537 switch {
1538 case line == terminalMsg:
1539 t.Fatalf("The websocket request was not canceled, unfortunately!")
1540
1541 case err == io.EOF:
1542 return
1543
1544 case err != nil:
1545 t.Fatalf("Unexpected error: %v", err)
1546
1547 case line == nthResponse(0):
1548
1549 close(triggerCancelCh)
1550 }
1551 }
1552 }
1553
1554 func TestUnannouncedTrailer(t *testing.T) {
1555 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1556 w.WriteHeader(http.StatusOK)
1557 w.(http.Flusher).Flush()
1558 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
1559 }))
1560 defer backend.Close()
1561 backendURL, err := url.Parse(backend.URL)
1562 if err != nil {
1563 t.Fatal(err)
1564 }
1565 proxyHandler := NewSingleHostReverseProxy(backendURL)
1566 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1567 frontend := httptest.NewServer(proxyHandler)
1568 defer frontend.Close()
1569 frontendClient := frontend.Client()
1570
1571 res, err := frontendClient.Get(frontend.URL)
1572 if err != nil {
1573 t.Fatalf("Get: %v", err)
1574 }
1575
1576 io.ReadAll(res.Body)
1577 res.Body.Close()
1578 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
1579 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
1580 }
1581
1582 }
1583
1584 func TestSetURL(t *testing.T) {
1585 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1586 w.Write([]byte(r.Host))
1587 }))
1588 defer backend.Close()
1589 backendURL, err := url.Parse(backend.URL)
1590 if err != nil {
1591 t.Fatal(err)
1592 }
1593 proxyHandler := &ReverseProxy{
1594 Rewrite: func(r *ProxyRequest) {
1595 r.SetURL(backendURL)
1596 },
1597 }
1598 frontend := httptest.NewServer(proxyHandler)
1599 defer frontend.Close()
1600 frontendClient := frontend.Client()
1601
1602 res, err := frontendClient.Get(frontend.URL)
1603 if err != nil {
1604 t.Fatalf("Get: %v", err)
1605 }
1606 defer res.Body.Close()
1607
1608 body, err := io.ReadAll(res.Body)
1609 if err != nil {
1610 t.Fatalf("Reading body: %v", err)
1611 }
1612
1613 if got, want := string(body), backendURL.Host; got != want {
1614 t.Errorf("backend got Host %q, want %q", got, want)
1615 }
1616 }
1617
1618 func TestSingleJoinSlash(t *testing.T) {
1619 tests := []struct {
1620 slasha string
1621 slashb string
1622 expected string
1623 }{
1624 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
1625 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
1626 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
1627 {"https://www.google.com", "", "https://www.google.com/"},
1628 {"", "favicon.ico", "/favicon.ico"},
1629 }
1630 for _, tt := range tests {
1631 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
1632 t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
1633 tt.slasha,
1634 tt.slashb,
1635 tt.expected,
1636 got)
1637 }
1638 }
1639 }
1640
1641 func TestJoinURLPath(t *testing.T) {
1642 tests := []struct {
1643 a *url.URL
1644 b *url.URL
1645 wantPath string
1646 wantRaw string
1647 }{
1648 {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
1649 {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
1650 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
1651 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
1652 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
1653 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
1654 }
1655
1656 for _, tt := range tests {
1657 p, rp := joinURLPath(tt.a, tt.b)
1658 if p != tt.wantPath || rp != tt.wantRaw {
1659 t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
1660 tt.a.Path, tt.a.RawPath,
1661 tt.b.Path, tt.b.RawPath,
1662 tt.wantPath, tt.wantRaw,
1663 p, rp)
1664 }
1665 }
1666 }
1667
1668 func TestReverseProxyRewriteReplacesOut(t *testing.T) {
1669 const content = "response_content"
1670 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1671 w.Write([]byte(content))
1672 }))
1673 defer backend.Close()
1674 proxyHandler := &ReverseProxy{
1675 Rewrite: func(r *ProxyRequest) {
1676 r.Out, _ = http.NewRequest("GET", backend.URL, nil)
1677 },
1678 }
1679 frontend := httptest.NewServer(proxyHandler)
1680 defer frontend.Close()
1681
1682 res, err := frontend.Client().Get(frontend.URL)
1683 if err != nil {
1684 t.Fatalf("Get: %v", err)
1685 }
1686 defer res.Body.Close()
1687 body, _ := io.ReadAll(res.Body)
1688 if got, want := string(body), content; got != want {
1689 t.Errorf("got response %q, want %q", got, want)
1690 }
1691 }
1692
1693 func Test1xxHeadersNotModifiedAfterRoundTrip(t *testing.T) {
1694
1695
1696
1697
1698 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1699 for i := 0; i < 5; i++ {
1700 w.WriteHeader(103)
1701 }
1702 }))
1703 defer backend.Close()
1704 backendURL, err := url.Parse(backend.URL)
1705 if err != nil {
1706 t.Fatal(err)
1707 }
1708 proxyHandler := NewSingleHostReverseProxy(backendURL)
1709 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1710
1711 rw := &testResponseWriter{}
1712 func() {
1713
1714
1715 ctx, cancel := context.WithCancel(context.Background())
1716 defer cancel()
1717 ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
1718 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1719 cancel()
1720 return nil
1721 },
1722 })
1723
1724 req, _ := http.NewRequestWithContext(ctx, "GET", "http://go.dev/", nil)
1725 proxyHandler.ServeHTTP(rw, req)
1726 }()
1727
1728
1729
1730 for _ = range rw.Header() {
1731 }
1732 }
1733
1734 func Test1xxResponses(t *testing.T) {
1735 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1736 h := w.Header()
1737 h.Add("Link", "</style.css>; rel=preload; as=style")
1738 h.Add("Link", "</script.js>; rel=preload; as=script")
1739 w.WriteHeader(http.StatusEarlyHints)
1740
1741 h.Add("Link", "</foo.js>; rel=preload; as=script")
1742 w.WriteHeader(http.StatusProcessing)
1743
1744 w.Write([]byte("Hello"))
1745 }))
1746 defer backend.Close()
1747 backendURL, err := url.Parse(backend.URL)
1748 if err != nil {
1749 t.Fatal(err)
1750 }
1751 proxyHandler := NewSingleHostReverseProxy(backendURL)
1752 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1753 frontend := httptest.NewServer(proxyHandler)
1754 defer frontend.Close()
1755 frontendClient := frontend.Client()
1756
1757 checkLinkHeaders := func(t *testing.T, expected, got []string) {
1758 t.Helper()
1759
1760 if len(expected) != len(got) {
1761 t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
1762 }
1763
1764 for i := range expected {
1765 if i >= len(got) {
1766 t.Errorf("Expected %q link header; got nothing", expected[i])
1767
1768 continue
1769 }
1770
1771 if expected[i] != got[i] {
1772 t.Errorf("Expected %q link header; got %q", expected[i], got[i])
1773 }
1774 }
1775 }
1776
1777 var respCounter uint8
1778 trace := &httptrace.ClientTrace{
1779 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1780 switch code {
1781 case http.StatusEarlyHints:
1782 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1783 case http.StatusProcessing:
1784 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1785 default:
1786 t.Error("Unexpected 1xx response")
1787 }
1788
1789 respCounter++
1790
1791 return nil
1792 },
1793 }
1794 req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)
1795
1796 res, err := frontendClient.Do(req)
1797 if err != nil {
1798 t.Fatalf("Get: %v", err)
1799 }
1800
1801 defer res.Body.Close()
1802
1803 if respCounter != 2 {
1804 t.Errorf("Expected 2 1xx responses; got %d", respCounter)
1805 }
1806 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1807
1808 body, _ := io.ReadAll(res.Body)
1809 if string(body) != "Hello" {
1810 t.Errorf("Read body %q; want Hello", body)
1811 }
1812 }
1813
1814 const (
1815 testWantsCleanQuery = true
1816 testWantsRawQuery = false
1817 )
1818
1819 func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) {
1820 testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
1821 proxyHandler := NewSingleHostReverseProxy(u)
1822 oldDirector := proxyHandler.Director
1823 proxyHandler.Director = func(r *http.Request) {
1824 oldDirector(r)
1825 }
1826 return proxyHandler
1827 })
1828 }
1829
1830 func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) {
1831 testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
1832 proxyHandler := NewSingleHostReverseProxy(u)
1833 oldDirector := proxyHandler.Director
1834 proxyHandler.Director = func(r *http.Request) {
1835
1836
1837 r.FormValue("a")
1838 oldDirector(r)
1839 }
1840 return proxyHandler
1841 })
1842 }
1843
1844 func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) {
1845 testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
1846 return &ReverseProxy{
1847 Rewrite: func(r *ProxyRequest) {
1848 r.SetURL(u)
1849 },
1850 }
1851 })
1852 }
1853
1854 func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) {
1855 testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
1856 return &ReverseProxy{
1857 Rewrite: func(r *ProxyRequest) {
1858 r.SetURL(u)
1859 r.Out.URL.RawQuery = r.In.URL.RawQuery
1860 },
1861 }
1862 })
1863 }
1864
1865 func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) {
1866 const content = "response_content"
1867 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1868 w.Write([]byte(r.URL.RawQuery))
1869 }))
1870 defer backend.Close()
1871 backendURL, err := url.Parse(backend.URL)
1872 if err != nil {
1873 t.Fatal(err)
1874 }
1875 proxyHandler := newProxy(backendURL)
1876 frontend := httptest.NewServer(proxyHandler)
1877 defer frontend.Close()
1878
1879
1880 backend.Config.ErrorLog = log.New(io.Discard, "", 0)
1881 frontend.Config.ErrorLog = log.New(io.Discard, "", 0)
1882
1883 for _, test := range []struct {
1884 rawQuery string
1885 cleanQuery string
1886 }{{
1887 rawQuery: "a=1&a=2;b=3",
1888 cleanQuery: "a=1",
1889 }, {
1890 rawQuery: "a=1&a=%zz&b=3",
1891 cleanQuery: "a=1&b=3",
1892 }} {
1893 res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery)
1894 if err != nil {
1895 t.Fatalf("Get: %v", err)
1896 }
1897 defer res.Body.Close()
1898 body, _ := io.ReadAll(res.Body)
1899 wantQuery := test.rawQuery
1900 if wantCleanQuery {
1901 wantQuery = test.cleanQuery
1902 }
1903 if got, want := string(body), wantQuery; got != want {
1904 t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want)
1905 }
1906 }
1907 }
1908
1909 type testResponseWriter struct {
1910 h http.Header
1911 writeHeader func(int)
1912 write func([]byte) (int, error)
1913 }
1914
1915 func (rw *testResponseWriter) Header() http.Header {
1916 if rw.h == nil {
1917 rw.h = make(http.Header)
1918 }
1919 return rw.h
1920 }
1921
1922 func (rw *testResponseWriter) WriteHeader(statusCode int) {
1923 if rw.writeHeader != nil {
1924 rw.writeHeader(statusCode)
1925 }
1926 }
1927
1928 func (rw *testResponseWriter) Write(p []byte) (int, error) {
1929 if rw.write != nil {
1930 return rw.write(p)
1931 }
1932 return len(p), nil
1933 }
1934
View as plain text