1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "io"
14 "log"
15 "mime"
16 "net"
17 "net/http"
18 "net/http/httptrace"
19 "net/http/internal/ascii"
20 "net/textproto"
21 "net/url"
22 "strings"
23 "sync"
24 "time"
25
26 "golang.org/x/net/http/httpguts"
27 )
28
29
30 type ProxyRequest struct {
31
32
33 In *http.Request
34
35
36
37
38
39 Out *http.Request
40 }
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56 func (r *ProxyRequest) SetURL(target *url.URL) {
57 rewriteRequestURL(r.Out, target)
58 r.Out.Host = ""
59 }
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80 func (r *ProxyRequest) SetXForwarded() {
81 clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
82 if err == nil {
83 prior := r.Out.Header["X-Forwarded-For"]
84 if len(prior) > 0 {
85 clientIP = strings.Join(prior, ", ") + ", " + clientIP
86 }
87 r.Out.Header.Set("X-Forwarded-For", clientIP)
88 } else {
89 r.Out.Header.Del("X-Forwarded-For")
90 }
91 r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
92 if r.In.TLS == nil {
93 r.Out.Header.Set("X-Forwarded-Proto", "http")
94 } else {
95 r.Out.Header.Set("X-Forwarded-Proto", "https")
96 }
97 }
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112 type ReverseProxy struct {
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134 Rewrite func(*ProxyRequest)
135
136
137
138 Transport http.RoundTripper
139
140
141
142
143
144
145
146
147
148
149
150 FlushInterval time.Duration
151
152
153
154
155 ErrorLog *log.Logger
156
157
158
159
160 BufferPool BufferPool
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175 ModifyResponse func(*http.Response) error
176
177
178
179
180
181
182 ErrorHandler func(http.ResponseWriter, *http.Request, error)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264 Director func(*http.Request)
265 }
266
267
268
269 type BufferPool interface {
270 Get() []byte
271 Put([]byte)
272 }
273
274 func singleJoiningSlash(a, b string) string {
275 aslash := strings.HasSuffix(a, "/")
276 bslash := strings.HasPrefix(b, "/")
277 switch {
278 case aslash && bslash:
279 return a + b[1:]
280 case !aslash && !bslash:
281 return a + "/" + b
282 }
283 return a + b
284 }
285
286 func joinURLPath(a, b *url.URL) (path, rawpath string) {
287 if a.RawPath == "" && b.RawPath == "" {
288 return singleJoiningSlash(a.Path, b.Path), ""
289 }
290
291
292 apath := a.EscapedPath()
293 bpath := b.EscapedPath()
294
295 aslash := strings.HasSuffix(apath, "/")
296 bslash := strings.HasPrefix(bpath, "/")
297
298 switch {
299 case aslash && bslash:
300 return a.Path + b.Path[1:], apath + bpath[1:]
301 case !aslash && !bslash:
302 return a.Path + "/" + b.Path, apath + "/" + bpath
303 }
304 return a.Path + b.Path, apath + bpath
305 }
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
332 director := func(req *http.Request) {
333 rewriteRequestURL(req, target)
334 }
335 return &ReverseProxy{Director: director}
336 }
337
338 func rewriteRequestURL(req *http.Request, target *url.URL) {
339 targetQuery := target.RawQuery
340 req.URL.Scheme = target.Scheme
341 req.URL.Host = target.Host
342 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
343 if targetQuery == "" || req.URL.RawQuery == "" {
344 req.URL.RawQuery = targetQuery + req.URL.RawQuery
345 } else {
346 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
347 }
348 }
349
350 func copyHeader(dst, src http.Header) {
351 for k, vv := range src {
352 for _, v := range vv {
353 dst.Add(k, v)
354 }
355 }
356 }
357
358
359
360
361
362
363 var hopHeaders = []string{
364 "Connection",
365 "Proxy-Connection",
366 "Keep-Alive",
367 "Proxy-Authenticate",
368 "Proxy-Authorization",
369 "Te",
370 "Trailer",
371 "Transfer-Encoding",
372 "Upgrade",
373 }
374
375 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
376 p.logf("http: proxy error: %v", err)
377 rw.WriteHeader(http.StatusBadGateway)
378 }
379
380 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
381 if p.ErrorHandler != nil {
382 return p.ErrorHandler
383 }
384 return p.defaultErrorHandler
385 }
386
387
388
389 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
390 if p.ModifyResponse == nil {
391 return true
392 }
393 if err := p.ModifyResponse(res); err != nil {
394 res.Body.Close()
395 p.getErrorHandler()(rw, req, err)
396 return false
397 }
398 return true
399 }
400
401 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
402 transport := p.Transport
403 if transport == nil {
404 transport = http.DefaultTransport
405 }
406
407 ctx := req.Context()
408 if ctx.Done() != nil {
409
410
411
412
413
414
415
416
417
418
419 } else if cn, ok := rw.(http.CloseNotifier); ok {
420 var cancel context.CancelFunc
421 ctx, cancel = context.WithCancel(ctx)
422 defer cancel()
423 notifyChan := cn.CloseNotify()
424 go func() {
425 select {
426 case <-notifyChan:
427 cancel()
428 case <-ctx.Done():
429 }
430 }()
431 }
432
433 outreq := req.Clone(ctx)
434 if req.ContentLength == 0 {
435 outreq.Body = nil
436 }
437 if outreq.Body != nil {
438
439
440
441
442
443
444 defer outreq.Body.Close()
445 }
446 if outreq.Header == nil {
447 outreq.Header = make(http.Header)
448 }
449
450 if (p.Director != nil) == (p.Rewrite != nil) {
451 p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set"))
452 return
453 }
454
455 if p.Director != nil {
456 p.Director(outreq)
457 if outreq.Form != nil {
458 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
459 }
460 }
461 outreq.Close = false
462
463 reqUpType := upgradeType(outreq.Header)
464 if !ascii.IsPrint(reqUpType) {
465 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
466 return
467 }
468 removeHopByHopHeaders(outreq.Header)
469
470
471
472
473
474
475 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
476 outreq.Header.Set("Te", "trailers")
477 }
478
479
480
481 if reqUpType != "" {
482 outreq.Header.Set("Connection", "Upgrade")
483 outreq.Header.Set("Upgrade", reqUpType)
484 }
485
486 if p.Rewrite != nil {
487
488
489
490 outreq.Header.Del("Forwarded")
491 outreq.Header.Del("X-Forwarded-For")
492 outreq.Header.Del("X-Forwarded-Host")
493 outreq.Header.Del("X-Forwarded-Proto")
494
495
496 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
497
498 pr := &ProxyRequest{
499 In: req,
500 Out: outreq,
501 }
502 p.Rewrite(pr)
503 outreq = pr.Out
504 } else {
505 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
506
507
508
509 prior, ok := outreq.Header["X-Forwarded-For"]
510 omit := ok && prior == nil
511 if len(prior) > 0 {
512 clientIP = strings.Join(prior, ", ") + ", " + clientIP
513 }
514 if !omit {
515 outreq.Header.Set("X-Forwarded-For", clientIP)
516 }
517 }
518 }
519
520 if _, ok := outreq.Header["User-Agent"]; !ok {
521
522
523 outreq.Header.Set("User-Agent", "")
524 }
525
526 var (
527 roundTripMutex sync.Mutex
528 roundTripDone bool
529 )
530 trace := &httptrace.ClientTrace{
531 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
532 roundTripMutex.Lock()
533 defer roundTripMutex.Unlock()
534 if roundTripDone {
535
536
537 return nil
538 }
539 h := rw.Header()
540 copyHeader(h, http.Header(header))
541 rw.WriteHeader(code)
542
543
544 clear(h)
545 return nil
546 },
547 }
548 outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
549
550 res, err := transport.RoundTrip(outreq)
551 roundTripMutex.Lock()
552 roundTripDone = true
553 roundTripMutex.Unlock()
554 if err != nil {
555 p.getErrorHandler()(rw, outreq, err)
556 return
557 }
558
559
560 if res.StatusCode == http.StatusSwitchingProtocols {
561 if !p.modifyResponse(rw, res, outreq) {
562 return
563 }
564 p.handleUpgradeResponse(rw, outreq, res)
565 return
566 }
567
568 removeHopByHopHeaders(res.Header)
569
570 if !p.modifyResponse(rw, res, outreq) {
571 return
572 }
573
574 copyHeader(rw.Header(), res.Header)
575
576
577
578 announcedTrailers := len(res.Trailer)
579 if announcedTrailers > 0 {
580 trailerKeys := make([]string, 0, len(res.Trailer))
581 for k := range res.Trailer {
582 trailerKeys = append(trailerKeys, k)
583 }
584 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
585 }
586
587 rw.WriteHeader(res.StatusCode)
588
589 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
590 if err != nil {
591 defer res.Body.Close()
592
593
594
595 if !shouldPanicOnCopyError(req) {
596 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
597 return
598 }
599 panic(http.ErrAbortHandler)
600 }
601 res.Body.Close()
602
603 if len(res.Trailer) > 0 {
604
605
606
607 http.NewResponseController(rw).Flush()
608 }
609
610 if len(res.Trailer) == announcedTrailers {
611 copyHeader(rw.Header(), res.Trailer)
612 return
613 }
614
615 for k, vv := range res.Trailer {
616 k = http.TrailerPrefix + k
617 for _, v := range vv {
618 rw.Header().Add(k, v)
619 }
620 }
621 }
622
623 var inOurTests bool
624
625
626
627
628
629
630 func shouldPanicOnCopyError(req *http.Request) bool {
631 if inOurTests {
632
633 return true
634 }
635 if req.Context().Value(http.ServerContextKey) != nil {
636
637
638 return true
639 }
640
641
642 return false
643 }
644
645
646 func removeHopByHopHeaders(h http.Header) {
647
648 for _, f := range h["Connection"] {
649 for sf := range strings.SplitSeq(f, ",") {
650 if sf = textproto.TrimString(sf); sf != "" {
651 h.Del(sf)
652 }
653 }
654 }
655
656
657
658 for _, f := range hopHeaders {
659 h.Del(f)
660 }
661 }
662
663
664
665 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
666 resCT := res.Header.Get("Content-Type")
667
668
669
670 if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
671 return -1
672 }
673
674
675 if res.ContentLength == -1 {
676 return -1
677 }
678
679 return p.FlushInterval
680 }
681
682 func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
683 var w io.Writer = dst
684
685 if flushInterval != 0 {
686 mlw := &maxLatencyWriter{
687 dst: dst,
688 flush: http.NewResponseController(dst).Flush,
689 latency: flushInterval,
690 }
691 defer mlw.stop()
692
693
694 mlw.flushPending = true
695 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
696
697 w = mlw
698 }
699
700 var buf []byte
701 if p.BufferPool != nil {
702 buf = p.BufferPool.Get()
703 defer p.BufferPool.Put(buf)
704 }
705 _, err := p.copyBuffer(w, src, buf)
706 return err
707 }
708
709
710
711 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
712 if len(buf) == 0 {
713 buf = make([]byte, 32*1024)
714 }
715 var written int64
716 for {
717 nr, rerr := src.Read(buf)
718 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
719 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
720 }
721 if nr > 0 {
722 nw, werr := dst.Write(buf[:nr])
723 if nw > 0 {
724 written += int64(nw)
725 }
726 if werr != nil {
727 return written, werr
728 }
729 if nr != nw {
730 return written, io.ErrShortWrite
731 }
732 }
733 if rerr != nil {
734 if rerr == io.EOF {
735 rerr = nil
736 }
737 return written, rerr
738 }
739 }
740 }
741
742 func (p *ReverseProxy) logf(format string, args ...any) {
743 if p.ErrorLog != nil {
744 p.ErrorLog.Printf(format, args...)
745 } else {
746 log.Printf(format, args...)
747 }
748 }
749
750 type maxLatencyWriter struct {
751 dst io.Writer
752 flush func() error
753 latency time.Duration
754
755 mu sync.Mutex
756 t *time.Timer
757 flushPending bool
758 }
759
760 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
761 m.mu.Lock()
762 defer m.mu.Unlock()
763 n, err = m.dst.Write(p)
764 if m.latency < 0 {
765 m.flush()
766 return
767 }
768 if m.flushPending {
769 return
770 }
771 if m.t == nil {
772 m.t = time.AfterFunc(m.latency, m.delayedFlush)
773 } else {
774 m.t.Reset(m.latency)
775 }
776 m.flushPending = true
777 return
778 }
779
780 func (m *maxLatencyWriter) delayedFlush() {
781 m.mu.Lock()
782 defer m.mu.Unlock()
783 if !m.flushPending {
784 return
785 }
786 m.flush()
787 m.flushPending = false
788 }
789
790 func (m *maxLatencyWriter) stop() {
791 m.mu.Lock()
792 defer m.mu.Unlock()
793 m.flushPending = false
794 if m.t != nil {
795 m.t.Stop()
796 }
797 }
798
799 func upgradeType(h http.Header) string {
800 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
801 return ""
802 }
803 return h.Get("Upgrade")
804 }
805
806 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
807 reqUpType := upgradeType(req.Header)
808 resUpType := upgradeType(res.Header)
809 if !ascii.IsPrint(resUpType) {
810 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
811 return
812 }
813 if !ascii.EqualFold(reqUpType, resUpType) {
814 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
815 return
816 }
817
818 backConn, ok := res.Body.(io.ReadWriteCloser)
819 if !ok {
820 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
821 return
822 }
823
824 rc := http.NewResponseController(rw)
825 conn, brw, hijackErr := rc.Hijack()
826 if errors.Is(hijackErr, http.ErrNotSupported) {
827 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
828 return
829 }
830
831 backConnCloseCh := make(chan bool)
832 go func() {
833
834
835 select {
836 case <-req.Context().Done():
837 case <-backConnCloseCh:
838 }
839 backConn.Close()
840 }()
841 defer close(backConnCloseCh)
842
843 if hijackErr != nil {
844 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
845 return
846 }
847 defer conn.Close()
848
849 copyHeader(rw.Header(), res.Header)
850
851 res.Header = rw.Header()
852 res.Body = nil
853 if err := res.Write(brw); err != nil {
854 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
855 return
856 }
857 if err := brw.Flush(); err != nil {
858 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
859 return
860 }
861 errc := make(chan error, 1)
862 spc := switchProtocolCopier{user: conn, backend: backConn}
863 go spc.copyToBackend(errc)
864 go spc.copyFromBackend(errc)
865
866
867
868 err := <-errc
869 if err == nil {
870 err = <-errc
871 }
872 }
873
874 var errCopyDone = errors.New("hijacked connection copy complete")
875
876
877
878 type switchProtocolCopier struct {
879 user, backend io.ReadWriter
880 }
881
882 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
883 if _, err := io.Copy(c.user, c.backend); err != nil {
884 errc <- err
885 return
886 }
887
888
889 if wc, ok := c.user.(interface{ CloseWrite() error }); ok {
890 errc <- wc.CloseWrite()
891 return
892 }
893
894 errc <- errCopyDone
895 }
896
897 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
898 if _, err := io.Copy(c.backend, c.user); err != nil {
899 errc <- err
900 return
901 }
902
903
904 if wc, ok := c.backend.(interface{ CloseWrite() error }); ok {
905 errc <- wc.CloseWrite()
906 return
907 }
908
909 errc <- errCopyDone
910 }
911
912 func cleanQueryParams(s string) string {
913 reencode := func(s string) string {
914 v, _ := url.ParseQuery(s)
915 return v.Encode()
916 }
917 for i := 0; i < len(s); {
918 switch s[i] {
919 case ';':
920 return reencode(s)
921 case '%':
922 if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
923 return reencode(s)
924 }
925 i += 3
926 default:
927 i++
928 }
929 }
930 return s
931 }
932
933 func ishex(c byte) bool {
934 switch {
935 case '0' <= c && c <= '9':
936 return true
937 case 'a' <= c && c <= 'f':
938 return true
939 case 'A' <= c && c <= 'F':
940 return true
941 }
942 return false
943 }
944
View as plain text