1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "io"
14 "log"
15 "mime"
16 "net"
17 "net/http"
18 "net/http/httptrace"
19 "net/http/internal/ascii"
20 "net/textproto"
21 "net/url"
22 "strings"
23 "sync"
24 "sync/atomic"
25 "time"
26
27 "golang.org/x/net/http/httpguts"
28 )
29
30
31 type ProxyRequest struct {
32
33
34 In *http.Request
35
36
37
38
39
40 Out *http.Request
41 }
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57 func (r *ProxyRequest) SetURL(target *url.URL) {
58 rewriteRequestURL(r.Out, target)
59 r.Out.Host = ""
60 }
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81 func (r *ProxyRequest) SetXForwarded() {
82 clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
83 if err == nil {
84 prior := r.Out.Header["X-Forwarded-For"]
85 if len(prior) > 0 {
86 clientIP = strings.Join(prior, ", ") + ", " + clientIP
87 }
88 r.Out.Header.Set("X-Forwarded-For", clientIP)
89 } else {
90 r.Out.Header.Del("X-Forwarded-For")
91 }
92 r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
93 if r.In.TLS == nil {
94 r.Out.Header.Set("X-Forwarded-Proto", "http")
95 } else {
96 r.Out.Header.Set("X-Forwarded-Proto", "https")
97 }
98 }
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113 type ReverseProxy struct {
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135 Rewrite func(*ProxyRequest)
136
137
138
139 Transport http.RoundTripper
140
141
142
143
144
145
146
147
148
149
150
151 FlushInterval time.Duration
152
153
154
155
156 ErrorLog *log.Logger
157
158
159
160
161 BufferPool BufferPool
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176 ModifyResponse func(*http.Response) error
177
178
179
180
181
182
183 ErrorHandler func(http.ResponseWriter, *http.Request, error)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265 Director func(*http.Request)
266 }
267
268
269
270 type BufferPool interface {
271 Get() []byte
272 Put([]byte)
273 }
274
275 func singleJoiningSlash(a, b string) string {
276 aslash := strings.HasSuffix(a, "/")
277 bslash := strings.HasPrefix(b, "/")
278 switch {
279 case aslash && bslash:
280 return a + b[1:]
281 case !aslash && !bslash:
282 return a + "/" + b
283 }
284 return a + b
285 }
286
287 func joinURLPath(a, b *url.URL) (path, rawpath string) {
288 if a.RawPath == "" && b.RawPath == "" {
289 return singleJoiningSlash(a.Path, b.Path), ""
290 }
291
292
293 apath := a.EscapedPath()
294 bpath := b.EscapedPath()
295
296 aslash := strings.HasSuffix(apath, "/")
297 bslash := strings.HasPrefix(bpath, "/")
298
299 switch {
300 case aslash && bslash:
301 return a.Path + b.Path[1:], apath + bpath[1:]
302 case !aslash && !bslash:
303 return a.Path + "/" + b.Path, apath + "/" + bpath
304 }
305 return a.Path + b.Path, apath + bpath
306 }
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
333 director := func(req *http.Request) {
334 rewriteRequestURL(req, target)
335 }
336 return &ReverseProxy{Director: director}
337 }
338
339 func rewriteRequestURL(req *http.Request, target *url.URL) {
340 targetQuery := target.RawQuery
341 req.URL.Scheme = target.Scheme
342 req.URL.Host = target.Host
343 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
344 if targetQuery == "" || req.URL.RawQuery == "" {
345 req.URL.RawQuery = targetQuery + req.URL.RawQuery
346 } else {
347 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
348 }
349 }
350
351 func copyHeader(dst, src http.Header) {
352 for k, vv := range src {
353 for _, v := range vv {
354 dst.Add(k, v)
355 }
356 }
357 }
358
359
360
361
362
363
364 var hopHeaders = []string{
365 "Connection",
366 "Proxy-Connection",
367 "Keep-Alive",
368 "Proxy-Authenticate",
369 "Proxy-Authorization",
370 "Te",
371 "Trailer",
372 "Transfer-Encoding",
373 "Upgrade",
374 }
375
376 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
377 p.logf("http: proxy error: %v", err)
378 rw.WriteHeader(http.StatusBadGateway)
379 }
380
381 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
382 if p.ErrorHandler != nil {
383 return p.ErrorHandler
384 }
385 return p.defaultErrorHandler
386 }
387
388
389
390 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
391 if p.ModifyResponse == nil {
392 return true
393 }
394 if err := p.ModifyResponse(res); err != nil {
395 res.Body.Close()
396 p.getErrorHandler()(rw, req, err)
397 return false
398 }
399 return true
400 }
401
402 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
403 transport := p.Transport
404 if transport == nil {
405 transport = http.DefaultTransport
406 }
407
408 ctx := req.Context()
409 if ctx.Done() != nil {
410
411
412
413
414
415
416
417
418
419
420 } else if cn, ok := rw.(http.CloseNotifier); ok {
421 var cancel context.CancelFunc
422 ctx, cancel = context.WithCancel(ctx)
423 defer cancel()
424 notifyChan := cn.CloseNotify()
425 go func() {
426 select {
427 case <-notifyChan:
428 cancel()
429 case <-ctx.Done():
430 }
431 }()
432 }
433
434 outreq := req.Clone(ctx)
435 if req.ContentLength == 0 {
436 outreq.Body = nil
437 }
438 if outreq.Body != nil {
439
440
441
442
443
444
445
446
447
448
449
450 outreq.Body = &noopCloseReader{readCloser: outreq.Body}
451
452
453
454
455
456
457 defer outreq.Body.Close()
458 }
459 if outreq.Header == nil {
460 outreq.Header = make(http.Header)
461 }
462
463 if (p.Director != nil) == (p.Rewrite != nil) {
464 p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set"))
465 return
466 }
467
468 if p.Director != nil {
469 p.Director(outreq)
470 if outreq.Form != nil {
471 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
472 }
473 }
474 outreq.Close = false
475
476 reqUpType := upgradeType(outreq.Header)
477 if !ascii.IsPrint(reqUpType) {
478 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
479 return
480 }
481 removeHopByHopHeaders(outreq.Header)
482
483
484
485
486
487
488 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
489 outreq.Header.Set("Te", "trailers")
490 }
491
492
493
494 if reqUpType != "" {
495 outreq.Header.Set("Connection", "Upgrade")
496 outreq.Header.Set("Upgrade", reqUpType)
497 }
498
499 if p.Rewrite != nil {
500
501
502
503 outreq.Header.Del("Forwarded")
504 outreq.Header.Del("X-Forwarded-For")
505 outreq.Header.Del("X-Forwarded-Host")
506 outreq.Header.Del("X-Forwarded-Proto")
507
508
509 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
510
511 pr := &ProxyRequest{
512 In: req,
513 Out: outreq,
514 }
515 p.Rewrite(pr)
516 outreq = pr.Out
517 } else {
518 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
519
520
521
522 prior, ok := outreq.Header["X-Forwarded-For"]
523 omit := ok && prior == nil
524 if len(prior) > 0 {
525 clientIP = strings.Join(prior, ", ") + ", " + clientIP
526 }
527 if !omit {
528 outreq.Header.Set("X-Forwarded-For", clientIP)
529 }
530 }
531 }
532
533 if _, ok := outreq.Header["User-Agent"]; !ok {
534
535
536 outreq.Header.Set("User-Agent", "")
537 }
538
539 var (
540 roundTripMutex sync.Mutex
541 roundTripDone bool
542 )
543 trace := &httptrace.ClientTrace{
544 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
545 roundTripMutex.Lock()
546 defer roundTripMutex.Unlock()
547 if roundTripDone {
548
549
550 return nil
551 }
552 h := rw.Header()
553 copyHeader(h, http.Header(header))
554 rw.WriteHeader(code)
555
556
557 clear(h)
558 return nil
559 },
560 }
561 outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
562
563 res, err := transport.RoundTrip(outreq)
564 roundTripMutex.Lock()
565 roundTripDone = true
566 roundTripMutex.Unlock()
567 if err != nil {
568 p.getErrorHandler()(rw, outreq, err)
569 return
570 }
571
572
573 if res.StatusCode == http.StatusSwitchingProtocols {
574 if !p.modifyResponse(rw, res, outreq) {
575 return
576 }
577 p.handleUpgradeResponse(rw, outreq, res)
578 return
579 }
580
581 removeHopByHopHeaders(res.Header)
582
583 if !p.modifyResponse(rw, res, outreq) {
584 return
585 }
586
587 copyHeader(rw.Header(), res.Header)
588
589
590
591 announcedTrailers := len(res.Trailer)
592 if announcedTrailers > 0 {
593 trailerKeys := make([]string, 0, len(res.Trailer))
594 for k := range res.Trailer {
595 trailerKeys = append(trailerKeys, k)
596 }
597 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
598 }
599
600 rw.WriteHeader(res.StatusCode)
601
602 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
603 if err != nil {
604 defer res.Body.Close()
605
606
607
608 if !shouldPanicOnCopyError(req) {
609 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
610 return
611 }
612 panic(http.ErrAbortHandler)
613 }
614 res.Body.Close()
615
616 if len(res.Trailer) > 0 {
617
618
619
620 http.NewResponseController(rw).Flush()
621 }
622
623 if len(res.Trailer) == announcedTrailers {
624 copyHeader(rw.Header(), res.Trailer)
625 return
626 }
627
628 for k, vv := range res.Trailer {
629 k = http.TrailerPrefix + k
630 for _, v := range vv {
631 rw.Header().Add(k, v)
632 }
633 }
634 }
635
636 var inOurTests bool
637
638
639
640
641
642
643 func shouldPanicOnCopyError(req *http.Request) bool {
644 if inOurTests {
645
646 return true
647 }
648 if req.Context().Value(http.ServerContextKey) != nil {
649
650
651 return true
652 }
653
654
655 return false
656 }
657
658
659 func removeHopByHopHeaders(h http.Header) {
660
661 for _, f := range h["Connection"] {
662 for sf := range strings.SplitSeq(f, ",") {
663 if sf = textproto.TrimString(sf); sf != "" {
664 h.Del(sf)
665 }
666 }
667 }
668
669
670
671 for _, f := range hopHeaders {
672 h.Del(f)
673 }
674 }
675
676
677
678 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
679 resCT := res.Header.Get("Content-Type")
680
681
682
683 if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
684 return -1
685 }
686
687
688 if res.ContentLength == -1 {
689 return -1
690 }
691
692 return p.FlushInterval
693 }
694
695 func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
696 var w io.Writer = dst
697
698 if flushInterval != 0 {
699 mlw := &maxLatencyWriter{
700 dst: dst,
701 flush: http.NewResponseController(dst).Flush,
702 latency: flushInterval,
703 }
704 defer mlw.stop()
705
706
707 mlw.flushPending = true
708 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
709
710 w = mlw
711 }
712
713 var buf []byte
714 if p.BufferPool != nil {
715 buf = p.BufferPool.Get()
716 defer p.BufferPool.Put(buf)
717 }
718 _, err := p.copyBuffer(w, src, buf)
719 return err
720 }
721
722
723
724 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
725 if len(buf) == 0 {
726 buf = make([]byte, 32*1024)
727 }
728 var written int64
729 for {
730 nr, rerr := src.Read(buf)
731 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
732 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
733 }
734 if nr > 0 {
735 nw, werr := dst.Write(buf[:nr])
736 if nw > 0 {
737 written += int64(nw)
738 }
739 if werr != nil {
740 return written, werr
741 }
742 if nr != nw {
743 return written, io.ErrShortWrite
744 }
745 }
746 if rerr != nil {
747 if rerr == io.EOF {
748 rerr = nil
749 }
750 return written, rerr
751 }
752 }
753 }
754
755 func (p *ReverseProxy) logf(format string, args ...any) {
756 if p.ErrorLog != nil {
757 p.ErrorLog.Printf(format, args...)
758 } else {
759 log.Printf(format, args...)
760 }
761 }
762
763 type maxLatencyWriter struct {
764 dst io.Writer
765 flush func() error
766 latency time.Duration
767
768 mu sync.Mutex
769 t *time.Timer
770 flushPending bool
771 }
772
773 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
774 m.mu.Lock()
775 defer m.mu.Unlock()
776 n, err = m.dst.Write(p)
777 if m.latency < 0 {
778 m.flush()
779 return
780 }
781 if m.flushPending {
782 return
783 }
784 if m.t == nil {
785 m.t = time.AfterFunc(m.latency, m.delayedFlush)
786 } else {
787 m.t.Reset(m.latency)
788 }
789 m.flushPending = true
790 return
791 }
792
793 func (m *maxLatencyWriter) delayedFlush() {
794 m.mu.Lock()
795 defer m.mu.Unlock()
796 if !m.flushPending {
797 return
798 }
799 m.flush()
800 m.flushPending = false
801 }
802
803 func (m *maxLatencyWriter) stop() {
804 m.mu.Lock()
805 defer m.mu.Unlock()
806 m.flushPending = false
807 if m.t != nil {
808 m.t.Stop()
809 }
810 }
811
812 func upgradeType(h http.Header) string {
813 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
814 return ""
815 }
816 return h.Get("Upgrade")
817 }
818
819 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
820 reqUpType := upgradeType(req.Header)
821 resUpType := upgradeType(res.Header)
822 if !ascii.IsPrint(resUpType) {
823 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
824 return
825 }
826 if !ascii.EqualFold(reqUpType, resUpType) {
827 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
828 return
829 }
830
831 backConn, ok := res.Body.(io.ReadWriteCloser)
832 if !ok {
833 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
834 return
835 }
836
837 rc := http.NewResponseController(rw)
838 conn, brw, hijackErr := rc.Hijack()
839 if errors.Is(hijackErr, http.ErrNotSupported) {
840 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
841 return
842 }
843
844 backConnCloseCh := make(chan bool)
845 go func() {
846
847
848 select {
849 case <-req.Context().Done():
850 case <-backConnCloseCh:
851 }
852 backConn.Close()
853 }()
854 defer close(backConnCloseCh)
855
856 if hijackErr != nil {
857 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
858 return
859 }
860 defer conn.Close()
861
862 copyHeader(rw.Header(), res.Header)
863
864 res.Header = rw.Header()
865 res.Body = nil
866 if err := res.Write(brw); err != nil {
867 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
868 return
869 }
870 if err := brw.Flush(); err != nil {
871 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
872 return
873 }
874 errc := make(chan error, 1)
875 spc := switchProtocolCopier{user: conn, backend: backConn}
876 go spc.copyToBackend(errc)
877 go spc.copyFromBackend(errc)
878
879
880
881 err := <-errc
882 if err == nil {
883 err = <-errc
884 }
885 }
886
887 var errCopyDone = errors.New("hijacked connection copy complete")
888
889
890
891 type switchProtocolCopier struct {
892 user, backend io.ReadWriter
893 }
894
895 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
896 if _, err := io.Copy(c.user, c.backend); err != nil {
897 errc <- err
898 return
899 }
900
901
902 if wc, ok := c.user.(interface{ CloseWrite() error }); ok {
903 errc <- wc.CloseWrite()
904 return
905 }
906
907 errc <- errCopyDone
908 }
909
910 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
911 if _, err := io.Copy(c.backend, c.user); err != nil {
912 errc <- err
913 return
914 }
915
916
917 if wc, ok := c.backend.(interface{ CloseWrite() error }); ok {
918 errc <- wc.CloseWrite()
919 return
920 }
921
922 errc <- errCopyDone
923 }
924
925 func cleanQueryParams(s string) string {
926 reencode := func(s string) string {
927 v, _ := url.ParseQuery(s)
928 return v.Encode()
929 }
930 for i := 0; i < len(s); {
931 switch s[i] {
932 case ';':
933 return reencode(s)
934 case '%':
935 if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
936 return reencode(s)
937 }
938 i += 3
939 default:
940 i++
941 }
942 }
943 return s
944 }
945
946 func ishex(c byte) bool {
947 switch {
948 case '0' <= c && c <= '9':
949 return true
950 case 'a' <= c && c <= 'f':
951 return true
952 case 'A' <= c && c <= 'F':
953 return true
954 }
955 return false
956 }
957
958 type noopCloseReader struct {
959 readCloser io.ReadCloser
960 closed atomic.Bool
961 }
962
963 func (ncr *noopCloseReader) Close() error {
964 ncr.closed.Store(true)
965 return nil
966 }
967
968 func (ncr *noopCloseReader) Read(p []byte) (int, error) {
969 if ncr.closed.Load() {
970 return 0, errors.New("ReverseProxy does an invalid Read on closed Body")
971 }
972 return ncr.readCloser.Read(p)
973 }
974
View as plain text