Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 crand "crypto/rand"
16 "crypto/tls"
17 "crypto/x509"
18 "encoding/json"
19 "errors"
20 "fmt"
21 "internal/synctest"
22 "internal/testenv"
23 "io"
24 "log"
25 "math/rand"
26 "mime/multipart"
27 "net"
28 . "net/http"
29 "net/http/httptest"
30 "net/http/httptrace"
31 "net/http/httputil"
32 "net/http/internal"
33 "net/http/internal/testcert"
34 "net/url"
35 "os"
36 "path/filepath"
37 "reflect"
38 "regexp"
39 "runtime"
40 "slices"
41 "strconv"
42 "strings"
43 "sync"
44 "sync/atomic"
45 "syscall"
46 "testing"
47 "time"
48 )
49
50 type dummyAddr string
51 type oneConnListener struct {
52 conn net.Conn
53 }
54
55 func (l *oneConnListener) Accept() (c net.Conn, err error) {
56 c = l.conn
57 if c == nil {
58 err = io.EOF
59 return
60 }
61 err = nil
62 l.conn = nil
63 return
64 }
65
66 func (l *oneConnListener) Close() error {
67 return nil
68 }
69
70 func (l *oneConnListener) Addr() net.Addr {
71 return dummyAddr("test-address")
72 }
73
74 func (a dummyAddr) Network() string {
75 return string(a)
76 }
77
78 func (a dummyAddr) String() string {
79 return string(a)
80 }
81
82 type noopConn struct{}
83
84 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
85 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
86 func (noopConn) SetDeadline(t time.Time) error { return nil }
87 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
88 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
89
90 type rwTestConn struct {
91 io.Reader
92 io.Writer
93 noopConn
94
95 closeFunc func() error
96 closec chan bool
97 }
98
99 func (c *rwTestConn) Close() error {
100 if c.closeFunc != nil {
101 return c.closeFunc()
102 }
103 select {
104 case c.closec <- true:
105 default:
106 }
107 return nil
108 }
109
110 type testConn struct {
111 readMu sync.Mutex
112 readBuf bytes.Buffer
113 writeBuf bytes.Buffer
114 closec chan bool
115 noopConn
116 }
117
118 func newTestConn() *testConn {
119 return &testConn{closec: make(chan bool, 1)}
120 }
121
122 func (c *testConn) Read(b []byte) (int, error) {
123 c.readMu.Lock()
124 defer c.readMu.Unlock()
125 return c.readBuf.Read(b)
126 }
127
128 func (c *testConn) Write(b []byte) (int, error) {
129 return c.writeBuf.Write(b)
130 }
131
132 func (c *testConn) Close() error {
133 select {
134 case c.closec <- true:
135 default:
136 }
137 return nil
138 }
139
140
141
142 func reqBytes(req string) []byte {
143 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
144 }
145
146 type handlerTest struct {
147 logbuf bytes.Buffer
148 handler Handler
149 }
150
151 func newHandlerTest(h Handler) handlerTest {
152 return handlerTest{handler: h}
153 }
154
155 func (ht *handlerTest) rawResponse(req string) string {
156 reqb := reqBytes(req)
157 var output strings.Builder
158 conn := &rwTestConn{
159 Reader: bytes.NewReader(reqb),
160 Writer: &output,
161 closec: make(chan bool, 1),
162 }
163 ln := &oneConnListener{conn: conn}
164 srv := &Server{
165 ErrorLog: log.New(&ht.logbuf, "", 0),
166 Handler: ht.handler,
167 }
168 go srv.Serve(ln)
169 <-conn.closec
170 return output.String()
171 }
172
173 func TestConsumingBodyOnNextConn(t *testing.T) {
174 t.Parallel()
175 defer afterTest(t)
176 conn := new(testConn)
177 for i := 0; i < 2; i++ {
178 conn.readBuf.Write([]byte(
179 "POST / HTTP/1.1\r\n" +
180 "Host: test\r\n" +
181 "Content-Length: 11\r\n" +
182 "\r\n" +
183 "foo=1&bar=1"))
184 }
185
186 reqNum := 0
187 ch := make(chan *Request)
188 servech := make(chan error)
189 listener := &oneConnListener{conn}
190 handler := func(res ResponseWriter, req *Request) {
191 reqNum++
192 ch <- req
193 }
194
195 go func() {
196 servech <- Serve(listener, HandlerFunc(handler))
197 }()
198
199 var req *Request
200 req = <-ch
201 if req == nil {
202 t.Fatal("Got nil first request.")
203 }
204 if req.Method != "POST" {
205 t.Errorf("For request #1's method, got %q; expected %q",
206 req.Method, "POST")
207 }
208
209 req = <-ch
210 if req == nil {
211 t.Fatal("Got nil first request.")
212 }
213 if req.Method != "POST" {
214 t.Errorf("For request #2's method, got %q; expected %q",
215 req.Method, "POST")
216 }
217
218 if serveerr := <-servech; serveerr != io.EOF {
219 t.Errorf("Serve returned %q; expected EOF", serveerr)
220 }
221 }
222
223 type stringHandler string
224
225 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
226 w.Header().Set("Result", string(s))
227 }
228
229 var handlers = []struct {
230 pattern string
231 msg string
232 }{
233 {"/", "Default"},
234 {"/someDir/", "someDir"},
235 {"/#/", "hash"},
236 {"someHost.com/someDir/", "someHost.com/someDir"},
237 }
238
239 var vtests = []struct {
240 url string
241 expected string
242 }{
243 {"http://localhost/someDir/apage", "someDir"},
244 {"http://localhost/%23/apage", "hash"},
245 {"http://localhost/otherDir/apage", "Default"},
246 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
247 {"http://otherHost.com/someDir/apage", "someDir"},
248 {"http://otherHost.com/aDir/apage", "Default"},
249
250 {"http://localhost/someDir", "/someDir/"},
251 {"http://localhost/%23", "/%23/"},
252 {"http://someHost.com/someDir", "/someDir/"},
253 }
254
255 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
256 func testHostHandlers(t *testing.T, mode testMode) {
257 mux := NewServeMux()
258 for _, h := range handlers {
259 mux.Handle(h.pattern, stringHandler(h.msg))
260 }
261 ts := newClientServerTest(t, mode, mux).ts
262
263 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
264 if err != nil {
265 t.Fatal(err)
266 }
267 defer conn.Close()
268 cc := httputil.NewClientConn(conn, nil)
269 for _, vt := range vtests {
270 var r *Response
271 var req Request
272 if req.URL, err = url.Parse(vt.url); err != nil {
273 t.Errorf("cannot parse url: %v", err)
274 continue
275 }
276 if err := cc.Write(&req); err != nil {
277 t.Errorf("writing request: %v", err)
278 continue
279 }
280 r, err := cc.Read(&req)
281 if err != nil {
282 t.Errorf("reading response: %v", err)
283 continue
284 }
285 switch r.StatusCode {
286 case StatusOK:
287 s := r.Header.Get("Result")
288 if s != vt.expected {
289 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
290 }
291 case StatusMovedPermanently:
292 s := r.Header.Get("Location")
293 if s != vt.expected {
294 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
295 }
296 default:
297 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
298 }
299 }
300 }
301
302 var serveMuxRegister = []struct {
303 pattern string
304 h Handler
305 }{
306 {"/dir/", serve(200)},
307 {"/search", serve(201)},
308 {"codesearch.google.com/search", serve(202)},
309 {"codesearch.google.com/", serve(203)},
310 {"example.com/", HandlerFunc(checkQueryStringHandler)},
311 }
312
313
314 func serve(code int) HandlerFunc {
315 return func(w ResponseWriter, r *Request) {
316 w.WriteHeader(code)
317 }
318 }
319
320
321
322
323 func checkQueryStringHandler(w ResponseWriter, r *Request) {
324 u := *r.URL
325 u.Scheme = "http"
326 u.Host = r.Host
327 u.RawQuery = ""
328 if "http://"+r.URL.RawQuery == u.String() {
329 w.WriteHeader(200)
330 } else {
331 w.WriteHeader(500)
332 }
333 }
334
335 var serveMuxTests = []struct {
336 method string
337 host string
338 path string
339 code int
340 pattern string
341 }{
342 {"GET", "google.com", "/", 404, ""},
343 {"GET", "google.com", "/dir", 301, "/dir/"},
344 {"GET", "google.com", "/dir/", 200, "/dir/"},
345 {"GET", "google.com", "/dir/file", 200, "/dir/"},
346 {"GET", "google.com", "/search", 201, "/search"},
347 {"GET", "google.com", "/search/", 404, ""},
348 {"GET", "google.com", "/search/foo", 404, ""},
349 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
350 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
351 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
352 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
353 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
354 {"GET", "images.google.com", "/search", 201, "/search"},
355 {"GET", "images.google.com", "/search/", 404, ""},
356 {"GET", "images.google.com", "/search/foo", 404, ""},
357 {"GET", "google.com", "/../search", 301, "/search"},
358 {"GET", "google.com", "/dir/..", 301, ""},
359 {"GET", "google.com", "/dir/..", 301, ""},
360 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
361
362
363
364 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
365 {"CONNECT", "google.com", "/../search", 404, ""},
366 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
367 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
368 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
369 }
370
371 func TestServeMuxHandler(t *testing.T) {
372 setParallel(t)
373 mux := NewServeMux()
374 for _, e := range serveMuxRegister {
375 mux.Handle(e.pattern, e.h)
376 }
377
378 for _, tt := range serveMuxTests {
379 r := &Request{
380 Method: tt.method,
381 Host: tt.host,
382 URL: &url.URL{
383 Path: tt.path,
384 },
385 }
386 h, pattern := mux.Handler(r)
387 rr := httptest.NewRecorder()
388 h.ServeHTTP(rr, r)
389 if pattern != tt.pattern || rr.Code != tt.code {
390 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
391 }
392 }
393 }
394
395
396 func TestServeMuxHandlerTrailingSlash(t *testing.T) {
397 setParallel(t)
398 mux := NewServeMux()
399 const original = "/{x}/"
400 mux.Handle(original, NotFoundHandler())
401 r, _ := NewRequest("POST", "/foo", nil)
402 _, p := mux.Handler(r)
403 if p != original {
404 t.Errorf("got %q, want %q", p, original)
405 }
406 }
407
408
409 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
410 setParallel(t)
411 defer func() {
412 if err := recover(); err == nil {
413 t.Error("expected call to mux.HandleFunc to panic")
414 }
415 }()
416 mux := NewServeMux()
417 mux.HandleFunc("/", nil)
418 }
419
420 var serveMuxTests2 = []struct {
421 method string
422 host string
423 url string
424 code int
425 redirOk bool
426 }{
427 {"GET", "google.com", "/", 404, false},
428 {"GET", "example.com", "/test/?example.com/test/", 200, false},
429 {"GET", "example.com", "test/?example.com/test/", 200, true},
430 }
431
432
433
434 func TestServeMuxHandlerRedirects(t *testing.T) {
435 setParallel(t)
436 mux := NewServeMux()
437 for _, e := range serveMuxRegister {
438 mux.Handle(e.pattern, e.h)
439 }
440
441 for _, tt := range serveMuxTests2 {
442 tries := 1
443 turl := tt.url
444 for {
445 u, e := url.Parse(turl)
446 if e != nil {
447 t.Fatal(e)
448 }
449 r := &Request{
450 Method: tt.method,
451 Host: tt.host,
452 URL: u,
453 }
454 h, _ := mux.Handler(r)
455 rr := httptest.NewRecorder()
456 h.ServeHTTP(rr, r)
457 if rr.Code != 301 {
458 if rr.Code != tt.code {
459 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
460 }
461 break
462 }
463 if !tt.redirOk {
464 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
465 break
466 }
467 turl = rr.HeaderMap.Get("Location")
468 tries--
469 }
470 if tries < 0 {
471 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
472 }
473 }
474 }
475
476
477 func TestMuxRedirectLeadingSlashes(t *testing.T) {
478 setParallel(t)
479 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
480 for _, path := range paths {
481 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
482 if err != nil {
483 t.Errorf("%s", err)
484 }
485 mux := NewServeMux()
486 resp := httptest.NewRecorder()
487
488 mux.ServeHTTP(resp, req)
489
490 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
491 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
492 return
493 }
494
495 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
496 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
497 return
498 }
499 }
500 }
501
502
503
504
505
506 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
507 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
508 }
509 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
510 writeBackQuery := func(w ResponseWriter, r *Request) {
511 fmt.Fprintf(w, "%s", r.URL.RawQuery)
512 }
513
514 mux := NewServeMux()
515 mux.HandleFunc("/testOne", writeBackQuery)
516 mux.HandleFunc("/testTwo/", writeBackQuery)
517 mux.HandleFunc("/testThree", writeBackQuery)
518 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
519 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
520 })
521
522 ts := newClientServerTest(t, mode, mux).ts
523
524 tests := [...]struct {
525 path string
526 method string
527 want string
528 statusOk bool
529 }{
530 0: {"/testOne?this=that", "GET", "this=that", true},
531 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
532 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
533 3: {"/testTwo?", "GET", "", true},
534 4: {"/testThree?foo", "GET", "foo", true},
535 5: {"/testThree/?foo", "GET", "foo:bar", true},
536 6: {"/testThree?foo", "CONNECT", "foo", true},
537 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
538
539
540 8: {"/testOne/foo/..?foo", "GET", "foo", true},
541 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
542 }
543
544 for i, tt := range tests {
545 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
546 res, err := ts.Client().Do(req)
547 if err != nil {
548 continue
549 }
550 slurp, _ := io.ReadAll(res.Body)
551 res.Body.Close()
552 if !tt.statusOk {
553 if got, want := res.StatusCode, 404; got != want {
554 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
555 }
556 }
557 if got, want := string(slurp), tt.want; got != want {
558 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
559 }
560 }
561 }
562
563 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
564 setParallel(t)
565
566 mux := NewServeMux()
567 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
568 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
569 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
570 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
571 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
572 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
573
574 tests := []struct {
575 method string
576 url string
577 code int
578 loc string
579 want string
580 }{
581 {"GET", "http://example.com/", 404, "", ""},
582 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
583 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
584 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
585 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
586 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
587 {"CONNECT", "http://example.com/", 404, "", ""},
588 {"CONNECT", "http://example.com:3000/", 404, "", ""},
589 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
590 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
591 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
592 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
593 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
594 }
595
596 for i, tt := range tests {
597 req, _ := NewRequest(tt.method, tt.url, nil)
598 w := httptest.NewRecorder()
599 mux.ServeHTTP(w, req)
600
601 if got, want := w.Code, tt.code; got != want {
602 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
603 }
604
605 if tt.code == 301 {
606 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
607 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
608 }
609 } else {
610 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
611 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
612 }
613 }
614 }
615 }
616
617
618
619
620 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
621 mux := NewServeMux()
622 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
623 fmt.Fprintln(w, "ok")
624 })
625 w := httptest.NewRecorder()
626 req, _ := NewRequest("GET", "/", nil)
627 mux.ServeHTTP(w, req)
628 if g, w := w.Code, 404; g != w {
629 t.Errorf("got %d, want %d", g, w)
630 }
631 }
632
633
634
635
636 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
637 mux := NewServeMux()
638 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
639 fmt.Fprintln(w, "ok")
640 })
641 w := httptest.NewRecorder()
642 req, _ := NewRequest("GET", "/", nil)
643 mux.ServeHTTP(w, req)
644 if g, w := w.Code, 404; g != w {
645 t.Errorf("got %d, want %d", g, w)
646 }
647 }
648
649 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
650 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
651 mux := NewServeMux()
652 newClientServerTest(t, mode, mux)
653 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
654 }
655
656 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
657 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
658 func benchmarkServeMux(b *testing.B, runHandler bool) {
659 type test struct {
660 path string
661 code int
662 req *Request
663 }
664
665
666 var tests []test
667 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
668 for _, e := range endpoints {
669 for i := 200; i < 230; i++ {
670 p := fmt.Sprintf("/%s/%d/", e, i)
671 tests = append(tests, test{
672 path: p,
673 code: i,
674 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
675 })
676 }
677 }
678 mux := NewServeMux()
679 for _, tt := range tests {
680 mux.Handle(tt.path, serve(tt.code))
681 }
682
683 rw := httptest.NewRecorder()
684 b.ReportAllocs()
685 b.ResetTimer()
686 for i := 0; i < b.N; i++ {
687 for _, tt := range tests {
688 *rw = httptest.ResponseRecorder{}
689 h, pattern := mux.Handler(tt.req)
690 if runHandler {
691 h.ServeHTTP(rw, tt.req)
692 if pattern != tt.path || rw.Code != tt.code {
693 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
694 }
695 }
696 }
697 }
698 }
699
700 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
701 func testServerTimeouts(t *testing.T, mode testMode) {
702 runTimeSensitiveTest(t, []time.Duration{
703 10 * time.Millisecond,
704 50 * time.Millisecond,
705 100 * time.Millisecond,
706 500 * time.Millisecond,
707 1 * time.Second,
708 }, func(t *testing.T, timeout time.Duration) error {
709 return testServerTimeoutsWithTimeout(t, timeout, mode)
710 })
711 }
712
713 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
714 var reqNum atomic.Int32
715 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
716 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
717 }), func(ts *httptest.Server) {
718 ts.Config.ReadTimeout = timeout
719 ts.Config.WriteTimeout = timeout
720 })
721 defer cst.close()
722 ts := cst.ts
723
724
725 c := ts.Client()
726 r, err := c.Get(ts.URL)
727 if err != nil {
728 return fmt.Errorf("http Get #1: %v", err)
729 }
730 got, err := io.ReadAll(r.Body)
731 expected := "req=1"
732 if string(got) != expected || err != nil {
733 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
734 string(got), err, expected)
735 }
736
737
738 t1 := time.Now()
739 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
740 if err != nil {
741 return fmt.Errorf("Dial: %v", err)
742 }
743 buf := make([]byte, 1)
744 n, err := conn.Read(buf)
745 conn.Close()
746 latency := time.Since(t1)
747 if n != 0 || err != io.EOF {
748 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
749 }
750 minLatency := timeout / 5 * 4
751 if latency < minLatency {
752 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
753 }
754
755
756
757
758 r, err = c.Get(ts.URL)
759 if err != nil {
760 return fmt.Errorf("http Get #2: %v", err)
761 }
762 got, err = io.ReadAll(r.Body)
763 r.Body.Close()
764 expected = "req=2"
765 if string(got) != expected || err != nil {
766 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
767 }
768
769 if !testing.Short() {
770 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
771 if err != nil {
772 return fmt.Errorf("long Dial: %v", err)
773 }
774 defer conn.Close()
775 go io.Copy(io.Discard, conn)
776 for i := 0; i < 5; i++ {
777 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
778 if err != nil {
779 return fmt.Errorf("on write %d: %v", i, err)
780 }
781 time.Sleep(timeout / 2)
782 }
783 }
784 return nil
785 }
786
787 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
788 func testServerReadTimeout(t *testing.T, mode testMode) {
789 respBody := "response body"
790 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
791 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
792 _, err := io.Copy(io.Discard, req.Body)
793 if !errors.Is(err, os.ErrDeadlineExceeded) {
794 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
795 }
796 res.Write([]byte(respBody))
797 }), func(ts *httptest.Server) {
798 ts.Config.ReadHeaderTimeout = -1
799 ts.Config.ReadTimeout = timeout
800 t.Logf("Server.Config.ReadTimeout = %v", timeout)
801 })
802
803 var retries atomic.Int32
804 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
805 if retries.Add(1) != 1 {
806 return nil, errors.New("too many retries")
807 }
808 return nil, nil
809 }
810
811 pr, pw := io.Pipe()
812 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
813 if err != nil {
814 t.Logf("Get error, retrying: %v", err)
815 cst.close()
816 continue
817 }
818 defer res.Body.Close()
819 got, err := io.ReadAll(res.Body)
820 if string(got) != respBody || err != nil {
821 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
822 }
823 pw.Close()
824 break
825 }
826 }
827
828 func TestServerNoReadTimeout(t *testing.T) { run(t, testServerNoReadTimeout) }
829 func testServerNoReadTimeout(t *testing.T, mode testMode) {
830 reqBody := "Hello, Gophers!"
831 resBody := "Hi, Gophers!"
832 for _, timeout := range []time.Duration{0, -1} {
833 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
834 ctl := NewResponseController(res)
835 ctl.EnableFullDuplex()
836 res.WriteHeader(StatusOK)
837
838
839 if err := ctl.Flush(); err != nil {
840 t.Errorf("server flush response: %v", err)
841 return
842 }
843 got, err := io.ReadAll(req.Body)
844 if string(got) != reqBody || err != nil {
845 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
846 }
847 res.Write([]byte(resBody))
848 }), func(ts *httptest.Server) {
849 ts.Config.ReadTimeout = timeout
850 t.Logf("Server.Config.ReadTimeout = %d", timeout)
851 })
852
853 pr, pw := io.Pipe()
854 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
855 if err != nil {
856 t.Fatal(err)
857 }
858 defer res.Body.Close()
859
860
861 time.Sleep(10 * time.Millisecond)
862 pw.Write([]byte(reqBody))
863 pw.Close()
864
865 got, err := io.ReadAll(res.Body)
866 if string(got) != resBody || err != nil {
867 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
868 }
869 }
870 }
871
872 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
873 func testServerWriteTimeout(t *testing.T, mode testMode) {
874 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
875 errc := make(chan error, 2)
876 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
877 errc <- nil
878 _, err := io.Copy(res, neverEnding('a'))
879 errc <- err
880 }), func(ts *httptest.Server) {
881 ts.Config.WriteTimeout = timeout
882 t.Logf("Server.Config.WriteTimeout = %v", timeout)
883 })
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902 var retries atomic.Int32
903 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
904 if retries.Add(1) != 1 {
905 return nil, errors.New("too many retries")
906 }
907 return nil, nil
908 }
909
910 res, err := cst.c.Get(cst.ts.URL)
911 if err != nil {
912
913 t.Logf("Get error, retrying: %v", err)
914 cst.close()
915 continue
916 }
917 defer res.Body.Close()
918 _, err = io.Copy(io.Discard, res.Body)
919 if err == nil {
920 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
921 }
922 select {
923 case <-errc:
924 err = <-errc
925 if !errors.Is(err, os.ErrDeadlineExceeded) {
926 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
927 }
928 return
929 default:
930
931 t.Logf("handler didn't run, retrying")
932 cst.close()
933 }
934 }
935 }
936
937 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
938 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
939 for _, timeout := range []time.Duration{0, -1} {
940 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
941 _, err := io.Copy(res, neverEnding('a'))
942 t.Logf("server write response: %v", err)
943 }), func(ts *httptest.Server) {
944 ts.Config.WriteTimeout = timeout
945 t.Logf("Server.Config.WriteTimeout = %d", timeout)
946 })
947
948 res, err := cst.c.Get(cst.ts.URL)
949 if err != nil {
950 t.Fatal(err)
951 }
952 defer res.Body.Close()
953 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
954 if n != 1<<20 || err != nil {
955 t.Errorf("client read response body: %d, %v", n, err)
956 }
957
958
959 res.Body.Close()
960 cst.ts.Config.Shutdown(context.Background())
961 }
962 }
963
964
965 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
966 run(t, testWriteDeadlineExtendedOnNewRequest)
967 }
968 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
969 if testing.Short() {
970 t.Skip("skipping in short mode")
971 }
972 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
973 func(ts *httptest.Server) {
974 ts.Config.WriteTimeout = 250 * time.Millisecond
975 },
976 ).ts
977
978 c := ts.Client()
979
980 for i := 1; i <= 3; i++ {
981 req, err := NewRequest("GET", ts.URL, nil)
982 if err != nil {
983 t.Fatal(err)
984 }
985
986 r, err := c.Do(req)
987 if err != nil {
988 t.Fatalf("http2 Get #%d: %v", i, err)
989 }
990 r.Body.Close()
991 time.Sleep(ts.Config.WriteTimeout / 2)
992 }
993 }
994
995
996
997 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
998 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
999 for i, timeout := range tries {
1000 err := testFunc(timeout)
1001 if err == nil {
1002 return
1003 }
1004 t.Logf("failed at %v: %v", timeout, err)
1005 if i != len(tries)-1 {
1006 t.Logf("retrying at %v ...", tries[i+1])
1007 }
1008 }
1009 t.Fatal("all attempts failed")
1010 }
1011
1012
1013 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
1014 if testing.Short() {
1015 t.Skip("skipping in short mode")
1016 }
1017 setParallel(t)
1018 run(t, func(t *testing.T, mode testMode) {
1019 tryTimeouts(t, func(timeout time.Duration) error {
1020 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1021 })
1022 })
1023 }
1024
1025 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1026 firstRequest := make(chan bool, 1)
1027 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1028 select {
1029 case firstRequest <- true:
1030
1031 default:
1032
1033 time.Sleep(timeout)
1034 }
1035 }), func(ts *httptest.Server) {
1036 ts.Config.WriteTimeout = timeout / 2
1037 })
1038 defer cst.close()
1039 ts := cst.ts
1040
1041 c := ts.Client()
1042
1043 req, err := NewRequest("GET", ts.URL, nil)
1044 if err != nil {
1045 return fmt.Errorf("NewRequest: %v", err)
1046 }
1047 r, err := c.Do(req)
1048 if err != nil {
1049 return fmt.Errorf("Get #1: %v", err)
1050 }
1051 r.Body.Close()
1052
1053 req, err = NewRequest("GET", ts.URL, nil)
1054 if err != nil {
1055 return fmt.Errorf("NewRequest: %v", err)
1056 }
1057 r, err = c.Do(req)
1058 if err == nil {
1059 r.Body.Close()
1060 return fmt.Errorf("Get #2 expected error, got nil")
1061 }
1062 if mode == http2Mode {
1063 expected := "stream ID 3; INTERNAL_ERROR"
1064 if !strings.Contains(err.Error(), expected) {
1065 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1066 }
1067 }
1068 return nil
1069 }
1070
1071
1072 func TestNoWriteDeadline(t *testing.T) {
1073 if testing.Short() {
1074 t.Skip("skipping in short mode")
1075 }
1076 setParallel(t)
1077 defer afterTest(t)
1078 run(t, func(t *testing.T, mode testMode) {
1079 tryTimeouts(t, func(timeout time.Duration) error {
1080 return testNoWriteDeadline(t, mode, timeout)
1081 })
1082 })
1083 }
1084
1085 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1086 firstRequest := make(chan bool, 1)
1087 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1088 select {
1089 case firstRequest <- true:
1090
1091 default:
1092
1093 time.Sleep(timeout)
1094 }
1095 }))
1096 defer cst.close()
1097 ts := cst.ts
1098
1099 c := ts.Client()
1100
1101 for i := 0; i < 2; i++ {
1102 req, err := NewRequest("GET", ts.URL, nil)
1103 if err != nil {
1104 return fmt.Errorf("NewRequest: %v", err)
1105 }
1106 r, err := c.Do(req)
1107 if err != nil {
1108 return fmt.Errorf("Get #%d: %v", i, err)
1109 }
1110 r.Body.Close()
1111 }
1112 return nil
1113 }
1114
1115
1116
1117
1118 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1119 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1120 var (
1121 mu sync.RWMutex
1122 conn net.Conn
1123 )
1124 var afterTimeoutErrc = make(chan error, 1)
1125 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1126 buf := make([]byte, 512<<10)
1127 _, err := w.Write(buf)
1128 if err != nil {
1129 t.Errorf("handler Write error: %v", err)
1130 return
1131 }
1132 mu.RLock()
1133 defer mu.RUnlock()
1134 if conn == nil {
1135 t.Error("no established connection found")
1136 return
1137 }
1138 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1139 _, err = w.Write(buf)
1140 afterTimeoutErrc <- err
1141 }), func(ts *httptest.Server) {
1142 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1143 }).ts
1144
1145 c := ts.Client()
1146
1147 err := func() error {
1148 res, err := c.Get(ts.URL)
1149 if err != nil {
1150 return err
1151 }
1152 _, err = io.Copy(io.Discard, res.Body)
1153 res.Body.Close()
1154 return err
1155 }()
1156 if err == nil {
1157 t.Errorf("expected an error copying body from Get request")
1158 }
1159
1160 if err := <-afterTimeoutErrc; err == nil {
1161 t.Error("expected write error after timeout")
1162 }
1163 }
1164
1165
1166 type trackLastConnListener struct {
1167 net.Listener
1168
1169 mu *sync.RWMutex
1170 last *net.Conn
1171 }
1172
1173 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1174 c, err = l.Listener.Accept()
1175 if err == nil {
1176 l.mu.Lock()
1177 *l.last = c
1178 l.mu.Unlock()
1179 }
1180 return
1181 }
1182
1183
1184 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1185 func testIdentityResponse(t *testing.T, mode testMode) {
1186 if mode == http2Mode {
1187 t.Skip("https://go.dev/issue/56019")
1188 }
1189
1190 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1191 rw.Header().Set("Content-Length", "3")
1192 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1193 switch {
1194 case req.FormValue("overwrite") == "1":
1195 _, err := rw.Write([]byte("foo TOO LONG"))
1196 if err != ErrContentLength {
1197 t.Errorf("expected ErrContentLength; got %v", err)
1198 }
1199 case req.FormValue("underwrite") == "1":
1200 rw.Header().Set("Content-Length", "500")
1201 rw.Write([]byte("too short"))
1202 default:
1203 rw.Write([]byte("foo"))
1204 }
1205 })
1206
1207 ts := newClientServerTest(t, mode, handler).ts
1208 c := ts.Client()
1209
1210
1211
1212
1213
1214 for _, te := range []string{"", "identity"} {
1215 url := ts.URL + "/?te=" + te
1216 res, err := c.Get(url)
1217 if err != nil {
1218 t.Fatalf("error with Get of %s: %v", url, err)
1219 }
1220 if cl, expected := res.ContentLength, int64(3); cl != expected {
1221 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1222 }
1223 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1224 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1225 }
1226 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1227 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1228 url, expected, tl, res.TransferEncoding)
1229 }
1230 res.Body.Close()
1231 }
1232
1233
1234 url := ts.URL + "/?overwrite=1"
1235 res, err := c.Get(url)
1236 if err != nil {
1237 t.Fatalf("error with Get of %s: %v", url, err)
1238 }
1239 res.Body.Close()
1240
1241 if mode != http1Mode {
1242 return
1243 }
1244
1245
1246
1247 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1248 if err != nil {
1249 t.Fatalf("error dialing: %v", err)
1250 }
1251 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1252 if err != nil {
1253 t.Fatalf("error writing: %v", err)
1254 }
1255
1256
1257 got, _ := io.ReadAll(conn)
1258 expectedSuffix := "\r\n\r\ntoo short"
1259 if !strings.HasSuffix(string(got), expectedSuffix) {
1260 t.Errorf("Expected output to end with %q; got response body %q",
1261 expectedSuffix, string(got))
1262 }
1263 }
1264
1265 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1266 setParallel(t)
1267 s := newClientServerTest(t, http1Mode, h).ts
1268
1269 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1270 if err != nil {
1271 t.Fatal("dial error:", err)
1272 }
1273 defer conn.Close()
1274
1275 _, err = fmt.Fprint(conn, req)
1276 if err != nil {
1277 t.Fatal("print error:", err)
1278 }
1279
1280 r := bufio.NewReader(conn)
1281 res, err := ReadResponse(r, &Request{Method: "GET"})
1282 if err != nil {
1283 t.Fatal("ReadResponse error:", err)
1284 }
1285
1286 _, err = io.ReadAll(r)
1287 if err != nil {
1288 t.Fatal("read error:", err)
1289 }
1290
1291 if !res.Close {
1292 t.Errorf("Response.Close = false; want true")
1293 }
1294 }
1295
1296 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1297 setParallel(t)
1298 ts := newClientServerTest(t, http1Mode, handler).ts
1299 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1300 if err != nil {
1301 t.Fatal(err)
1302 }
1303 defer conn.Close()
1304 br := bufio.NewReader(conn)
1305 for i := 0; i < 2; i++ {
1306 if _, err := io.WriteString(conn, req); err != nil {
1307 t.Fatal(err)
1308 }
1309 res, err := ReadResponse(br, nil)
1310 if err != nil {
1311 t.Fatalf("res %d: %v", i+1, err)
1312 }
1313 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1314 t.Fatalf("res %d body copy: %v", i+1, err)
1315 }
1316 res.Body.Close()
1317 }
1318 }
1319
1320
1321 func TestServeHTTP10Close(t *testing.T) {
1322 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1323 ServeFile(w, r, "testdata/file")
1324 }))
1325 }
1326
1327
1328 func TestClientCanClose(t *testing.T) {
1329 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1330
1331 }))
1332 }
1333
1334
1335
1336 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1337 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1338 w.Header().Set("Connection", "close")
1339 }))
1340 }
1341
1342 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1343 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1344 w.Header().Set("Connection", "close")
1345 }))
1346 }
1347
1348 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1349 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1350
1351
1352 }))
1353 }
1354
1355 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1356 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1357
1358
1359 func TestHTTP10KeepAlive204Response(t *testing.T) {
1360 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1361 }
1362
1363 func TestHTTP11KeepAlive204Response(t *testing.T) {
1364 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1365 }
1366
1367 func TestHTTP10KeepAlive304Response(t *testing.T) {
1368 testTCPConnectionStaysOpen(t,
1369 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1370 HandlerFunc(send304))
1371 }
1372
1373
1374 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1375 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1376 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1377 w.(Flusher).Flush()
1378 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1379 }))
1380 type data struct {
1381 Addr string
1382 }
1383 var addrs [2]data
1384 for i := range addrs {
1385 res, err := cst.c.Get(cst.ts.URL)
1386 if err != nil {
1387 t.Fatal(err)
1388 }
1389 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1390 t.Fatal(err)
1391 }
1392 if addrs[i].Addr == "" {
1393 t.Fatal("no address")
1394 }
1395 res.Body.Close()
1396 }
1397 if addrs[0] != addrs[1] {
1398 t.Fatalf("connection not reused")
1399 }
1400 }
1401
1402 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1403 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1404 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1405 fmt.Fprintf(w, "%s", r.RemoteAddr)
1406 }))
1407
1408 res, err := cst.c.Get(cst.ts.URL)
1409 if err != nil {
1410 t.Fatalf("Get error: %v", err)
1411 }
1412 body, err := io.ReadAll(res.Body)
1413 if err != nil {
1414 t.Fatalf("ReadAll error: %v", err)
1415 }
1416 ip := string(body)
1417 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1418 t.Fatalf("Expected local addr; got %q", ip)
1419 }
1420 }
1421
1422 type blockingRemoteAddrListener struct {
1423 net.Listener
1424 conns chan<- net.Conn
1425 }
1426
1427 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1428 c, err := l.Listener.Accept()
1429 if err != nil {
1430 return nil, err
1431 }
1432 brac := &blockingRemoteAddrConn{
1433 Conn: c,
1434 addrs: make(chan net.Addr, 1),
1435 }
1436 l.conns <- brac
1437 return brac, nil
1438 }
1439
1440 type blockingRemoteAddrConn struct {
1441 net.Conn
1442 addrs chan net.Addr
1443 }
1444
1445 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1446 return <-c.addrs
1447 }
1448
1449
1450 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1451 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1452 }
1453 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1454 conns := make(chan net.Conn)
1455 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1456 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1457 }), func(ts *httptest.Server) {
1458 ts.Listener = &blockingRemoteAddrListener{
1459 Listener: ts.Listener,
1460 conns: conns,
1461 }
1462 }).ts
1463
1464 c := ts.Client()
1465
1466 c.Transport.(*Transport).DisableKeepAlives = true
1467
1468 fetch := func(num int, response chan<- string) {
1469 resp, err := c.Get(ts.URL)
1470 if err != nil {
1471 t.Errorf("Request %d: %v", num, err)
1472 response <- ""
1473 return
1474 }
1475 defer resp.Body.Close()
1476 body, err := io.ReadAll(resp.Body)
1477 if err != nil {
1478 t.Errorf("Request %d: %v", num, err)
1479 response <- ""
1480 return
1481 }
1482 response <- string(body)
1483 }
1484
1485
1486 response1c := make(chan string, 1)
1487 go fetch(1, response1c)
1488
1489
1490 conn1 := <-conns
1491
1492
1493 response2c := make(chan string, 1)
1494 go fetch(2, response2c)
1495 conn2 := <-conns
1496
1497
1498 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1499 IP: net.ParseIP("12.12.12.12"), Port: 12}
1500
1501
1502 response2 := <-response2c
1503 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1504 t.Fatalf("response 2 addr = %q; want %q", g, e)
1505 }
1506
1507
1508 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1509 IP: net.ParseIP("21.21.21.21"), Port: 21}
1510
1511
1512 response1 := <-response1c
1513 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1514 t.Fatalf("response 1 addr = %q; want %q", g, e)
1515 }
1516 }
1517
1518
1519
1520 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1521 func testHeadResponses(t *testing.T, mode testMode) {
1522 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1523 _, err := w.Write([]byte("<html>"))
1524 if err != nil {
1525 t.Errorf("ResponseWriter.Write: %v", err)
1526 }
1527
1528
1529 _, err = io.Copy(w, struct{ io.Reader }{strings.NewReader("789a")})
1530 if err != nil {
1531 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1532 }
1533 }))
1534 res, err := cst.c.Head(cst.ts.URL)
1535 if err != nil {
1536 t.Error(err)
1537 }
1538 if len(res.TransferEncoding) > 0 {
1539 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1540 }
1541 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1542 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1543 }
1544 if v := res.ContentLength; v != 10 {
1545 t.Errorf("Content-Length: %d; want 10", v)
1546 }
1547 body, err := io.ReadAll(res.Body)
1548 if err != nil {
1549 t.Error(err)
1550 }
1551 if len(body) > 0 {
1552 t.Errorf("got unexpected body %q", string(body))
1553 }
1554 }
1555
1556
1557
1558 func TestHeadReaderFrom(t *testing.T) { run(t, testHeadReaderFrom, []testMode{http1Mode}) }
1559 func testHeadReaderFrom(t *testing.T, mode testMode) {
1560
1561 wantBody := strings.Repeat("a", 4096)
1562 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1563 w.(io.ReaderFrom).ReadFrom(strings.NewReader(wantBody))
1564 }))
1565 res, err := cst.c.Head(cst.ts.URL)
1566 if err != nil {
1567 t.Fatal(err)
1568 }
1569 res.Body.Close()
1570 res, err = cst.c.Get(cst.ts.URL)
1571 if err != nil {
1572 t.Fatal(err)
1573 }
1574 gotBody, err := io.ReadAll(res.Body)
1575 res.Body.Close()
1576 if err != nil {
1577 t.Fatal(err)
1578 }
1579 if string(gotBody) != wantBody {
1580 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), len(wantBody))
1581 }
1582 }
1583
1584 func TestTLSHandshakeTimeout(t *testing.T) {
1585 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1586 }
1587 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1588 errLog := new(strings.Builder)
1589 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1590 func(ts *httptest.Server) {
1591 ts.Config.ReadTimeout = 250 * time.Millisecond
1592 ts.Config.ErrorLog = log.New(errLog, "", 0)
1593 },
1594 )
1595 ts := cst.ts
1596
1597 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1598 if err != nil {
1599 t.Fatalf("Dial: %v", err)
1600 }
1601 var buf [1]byte
1602 n, err := conn.Read(buf[:])
1603 if err == nil || n != 0 {
1604 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1605 }
1606 conn.Close()
1607
1608 cst.close()
1609 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1610 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1611 }
1612 }
1613
1614 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1615 func testTLSServer(t *testing.T, mode testMode) {
1616 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1617 if r.TLS != nil {
1618 w.Header().Set("X-TLS-Set", "true")
1619 if r.TLS.HandshakeComplete {
1620 w.Header().Set("X-TLS-HandshakeComplete", "true")
1621 }
1622 }
1623 }), func(ts *httptest.Server) {
1624 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1625 }).ts
1626
1627
1628
1629
1630
1631
1632 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1633 if err != nil {
1634 t.Fatalf("Dial: %v", err)
1635 }
1636 defer idleConn.Close()
1637
1638 if !strings.HasPrefix(ts.URL, "https://") {
1639 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1640 return
1641 }
1642 client := ts.Client()
1643 res, err := client.Get(ts.URL)
1644 if err != nil {
1645 t.Error(err)
1646 return
1647 }
1648 if res == nil {
1649 t.Errorf("got nil Response")
1650 return
1651 }
1652 defer res.Body.Close()
1653 if res.Header.Get("X-TLS-Set") != "true" {
1654 t.Errorf("expected X-TLS-Set response header")
1655 return
1656 }
1657 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1658 t.Errorf("expected X-TLS-HandshakeComplete header")
1659 }
1660 }
1661
1662 type fakeConnectionStateConn struct {
1663 net.Conn
1664 }
1665
1666 func (fcsc *fakeConnectionStateConn) ConnectionState() tls.ConnectionState {
1667 return tls.ConnectionState{
1668 ServerName: "example.com",
1669 }
1670 }
1671
1672 func TestTLSServerWithoutTLSConn(t *testing.T) {
1673
1674 pr, pw := net.Pipe()
1675 c := make(chan int)
1676 listener := &oneConnListener{&fakeConnectionStateConn{pr}}
1677 server := &Server{
1678 Handler: HandlerFunc(func(writer ResponseWriter, request *Request) {
1679 if request.TLS == nil {
1680 t.Fatal("request.TLS is nil, expected not nil")
1681 }
1682 if request.TLS.ServerName != "example.com" {
1683 t.Fatalf("request.TLS.ServerName is %s, expected %s", request.TLS.ServerName, "example.com")
1684 }
1685 writer.Header().Set("X-TLS-ServerName", "example.com")
1686 }),
1687 }
1688
1689
1690 go func() {
1691 req, _ := NewRequest(MethodGet, "https://example.com", nil)
1692 req.Write(pw)
1693
1694 resp, _ := ReadResponse(bufio.NewReader(pw), req)
1695 if hdr := resp.Header.Get("X-TLS-ServerName"); hdr != "example.com" {
1696 t.Errorf("response header X-TLS-ServerName is %s, expected %s", hdr, "example.com")
1697 }
1698 close(c)
1699 pw.Close()
1700 }()
1701
1702 server.Serve(listener)
1703
1704
1705 <-c
1706 pr.Close()
1707 }
1708
1709 func TestServeTLS(t *testing.T) {
1710 CondSkipHTTP2(t)
1711
1712 defer afterTest(t)
1713 defer SetTestHookServerServe(nil)
1714
1715 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1716 if err != nil {
1717 t.Fatal(err)
1718 }
1719 tlsConf := &tls.Config{
1720 Certificates: []tls.Certificate{cert},
1721 }
1722
1723 ln := newLocalListener(t)
1724 defer ln.Close()
1725 addr := ln.Addr().String()
1726
1727 serving := make(chan bool, 1)
1728 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1729 serving <- true
1730 })
1731 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1732 s := &Server{
1733 Addr: addr,
1734 TLSConfig: tlsConf,
1735 Handler: handler,
1736 }
1737 errc := make(chan error, 1)
1738 go func() { errc <- s.ServeTLS(ln, "", "") }()
1739 select {
1740 case err := <-errc:
1741 t.Fatalf("ServeTLS: %v", err)
1742 case <-serving:
1743 }
1744
1745 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1746 InsecureSkipVerify: true,
1747 NextProtos: []string{"h2", "http/1.1"},
1748 })
1749 if err != nil {
1750 t.Fatal(err)
1751 }
1752 defer c.Close()
1753 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1754 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1755 }
1756 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1757 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1758 }
1759 }
1760
1761
1762 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1763 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1764 }
1765 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1766 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1767 t.Error("unexpected HTTPS request")
1768 }), func(ts *httptest.Server) {
1769 var errBuf bytes.Buffer
1770 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1771 }).ts
1772 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1773 if err != nil {
1774 t.Fatal(err)
1775 }
1776 defer conn.Close()
1777 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1778 slurp, err := io.ReadAll(conn)
1779 if err != nil {
1780 t.Fatal(err)
1781 }
1782 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1783 if !strings.HasPrefix(string(slurp), wantPrefix) {
1784 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1785 }
1786 }
1787
1788
1789 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1790 testAutomaticHTTP2_Serve(t, nil, true)
1791 }
1792
1793 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1794 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1795 }
1796
1797 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1798 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1799 }
1800
1801 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1802 setParallel(t)
1803 defer afterTest(t)
1804 ln := newLocalListener(t)
1805 ln.Close()
1806 var s Server
1807 s.TLSConfig = tlsConf
1808 if err := s.Serve(ln); err == nil {
1809 t.Fatal("expected an error")
1810 }
1811 gotH2 := s.TLSNextProto["h2"] != nil
1812 if gotH2 != wantH2 {
1813 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1814 }
1815 }
1816
1817 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1818 setParallel(t)
1819 defer afterTest(t)
1820 ln := newLocalListener(t)
1821 ln.Close()
1822 var s Server
1823
1824
1825 s.TLSConfig = &tls.Config{
1826 NextProtos: []string{"h2"},
1827 }
1828 if err := s.Serve(ln); err == nil {
1829 t.Fatal("expected an error")
1830 }
1831 on := s.TLSNextProto["h2"] != nil
1832 if !on {
1833 t.Errorf("http2 wasn't automatically enabled")
1834 }
1835 }
1836
1837 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1838 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1839 if err != nil {
1840 t.Fatal(err)
1841 }
1842 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1843 Certificates: []tls.Certificate{cert},
1844 })
1845 }
1846
1847 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1848 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1849 if err != nil {
1850 t.Fatal(err)
1851 }
1852 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1853 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1854 return &cert, nil
1855 },
1856 })
1857 }
1858
1859 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1860 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1861 if err != nil {
1862 t.Fatal(err)
1863 }
1864 conf := &tls.Config{
1865
1866
1867 NextProtos: []string{"h2"},
1868 Certificates: []tls.Certificate{cert},
1869 }
1870 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1871 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1872 return conf, nil
1873 },
1874 })
1875 }
1876
1877 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1878 CondSkipHTTP2(t)
1879
1880 defer afterTest(t)
1881 defer SetTestHookServerServe(nil)
1882 var ok bool
1883 var s *Server
1884 const maxTries = 5
1885 var ln net.Listener
1886 Try:
1887 for try := 0; try < maxTries; try++ {
1888 ln = newLocalListener(t)
1889 addr := ln.Addr().String()
1890 ln.Close()
1891 t.Logf("Got %v", addr)
1892 lnc := make(chan net.Listener, 1)
1893 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1894 lnc <- ln
1895 })
1896 s = &Server{
1897 Addr: addr,
1898 TLSConfig: tlsConf,
1899 }
1900 errc := make(chan error, 1)
1901 go func() { errc <- s.ListenAndServeTLS("", "") }()
1902 select {
1903 case err := <-errc:
1904 t.Logf("On try #%v: %v", try+1, err)
1905 continue
1906 case ln = <-lnc:
1907 ok = true
1908 t.Logf("Listening on %v", ln.Addr().String())
1909 break Try
1910 }
1911 }
1912 if !ok {
1913 t.Fatalf("Failed to start up after %d tries", maxTries)
1914 }
1915 defer ln.Close()
1916 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1917 InsecureSkipVerify: true,
1918 NextProtos: []string{"h2", "http/1.1"},
1919 })
1920 if err != nil {
1921 t.Fatal(err)
1922 }
1923 defer c.Close()
1924 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1925 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1926 }
1927 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1928 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1929 }
1930 }
1931
1932 type serverExpectTest struct {
1933 contentLength int
1934 chunked bool
1935 expectation string
1936 readBody bool
1937 expectedResponse string
1938 }
1939
1940 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1941 return serverExpectTest{
1942 contentLength: contentLength,
1943 expectation: expectation,
1944 readBody: readBody,
1945 expectedResponse: expectedResponse,
1946 }
1947 }
1948
1949 var serverExpectTests = []serverExpectTest{
1950
1951 expectTest(100, "100-continue", true, "100 Continue"),
1952 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1953
1954
1955 expectTest(100, "", true, "200 OK"),
1956
1957
1958
1959 expectTest(100, "100-continue", false, "401 Unauthorized"),
1960
1961 expectTest(100, "", false, "401 Unauthorized"),
1962
1963
1964 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1965
1966
1967 expectTest(0, "100-continue", true, "200 OK"),
1968
1969 expectTest(0, "100-continue", false, "401 Unauthorized"),
1970
1971 {
1972 expectation: "100-continue",
1973 readBody: true,
1974 chunked: true,
1975 expectedResponse: "100 Continue",
1976 },
1977 }
1978
1979
1980
1981 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
1982 func testServerExpect(t *testing.T, mode testMode) {
1983 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1984
1985
1986
1987 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1988 io.ReadAll(r.Body)
1989 w.Write([]byte("Hi"))
1990 } else {
1991 w.WriteHeader(StatusUnauthorized)
1992 }
1993 })).ts
1994
1995 runTest := func(test serverExpectTest) {
1996 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1997 if err != nil {
1998 t.Fatalf("Dial: %v", err)
1999 }
2000 defer conn.Close()
2001
2002
2003
2004 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
2005
2006 wg := sync.WaitGroup{}
2007 wg.Add(1)
2008 defer wg.Wait()
2009
2010 go func() {
2011 defer wg.Done()
2012
2013 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
2014 if test.chunked {
2015 contentLen = "Transfer-Encoding: chunked"
2016 }
2017 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
2018 "Connection: close\r\n"+
2019 "%s\r\n"+
2020 "Expect: %s\r\nHost: foo\r\n\r\n",
2021 test.readBody, contentLen, test.expectation)
2022 if err != nil {
2023 t.Errorf("On test %#v, error writing request headers: %v", test, err)
2024 return
2025 }
2026 if writeBody {
2027 var targ io.WriteCloser = struct {
2028 io.Writer
2029 io.Closer
2030 }{
2031 conn,
2032 io.NopCloser(nil),
2033 }
2034 if test.chunked {
2035 targ = httputil.NewChunkedWriter(conn)
2036 }
2037 body := strings.Repeat("A", test.contentLength)
2038 _, err = fmt.Fprint(targ, body)
2039 if err == nil {
2040 err = targ.Close()
2041 }
2042 if err != nil {
2043 if !test.readBody {
2044
2045
2046 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
2047 return
2048 }
2049 t.Errorf("On test %#v, error writing request body: %v", test, err)
2050 }
2051 }
2052 }()
2053 bufr := bufio.NewReader(conn)
2054 line, err := bufr.ReadString('\n')
2055 if err != nil {
2056 if writeBody && !test.readBody {
2057
2058
2059
2060
2061
2062 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
2063 return
2064 }
2065 t.Fatalf("On test %#v, ReadString: %v", test, err)
2066 }
2067 if !strings.Contains(line, test.expectedResponse) {
2068 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
2069 }
2070 }
2071
2072 for _, test := range serverExpectTests {
2073 runTest(test)
2074 }
2075 }
2076
2077
2078
2079 func TestServerUnreadRequestBodyLittle(t *testing.T) {
2080 setParallel(t)
2081 defer afterTest(t)
2082 conn := new(testConn)
2083 body := strings.Repeat("x", 100<<10)
2084 conn.readBuf.Write([]byte(fmt.Sprintf(
2085 "POST / HTTP/1.1\r\n"+
2086 "Host: test\r\n"+
2087 "Content-Length: %d\r\n"+
2088 "\r\n", len(body))))
2089 conn.readBuf.Write([]byte(body))
2090
2091 done := make(chan bool)
2092
2093 readBufLen := func() int {
2094 conn.readMu.Lock()
2095 defer conn.readMu.Unlock()
2096 return conn.readBuf.Len()
2097 }
2098
2099 ls := &oneConnListener{conn}
2100 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2101 defer close(done)
2102 if bufLen := readBufLen(); bufLen < len(body)/2 {
2103 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2104 }
2105 rw.WriteHeader(200)
2106 rw.(Flusher).Flush()
2107 if g, e := readBufLen(), 0; g != e {
2108 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2109 }
2110 if c := rw.Header().Get("Connection"); c != "" {
2111 t.Errorf(`Connection header = %q; want ""`, c)
2112 }
2113 }))
2114 <-done
2115 }
2116
2117
2118
2119
2120 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2121 setParallel(t)
2122 if testing.Short() && testenv.Builder() == "" {
2123 t.Log("skipping in short mode")
2124 }
2125 conn := new(testConn)
2126 body := strings.Repeat("x", 1<<20)
2127 conn.readBuf.Write([]byte(fmt.Sprintf(
2128 "POST / HTTP/1.1\r\n"+
2129 "Host: test\r\n"+
2130 "Content-Length: %d\r\n"+
2131 "\r\n", len(body))))
2132 conn.readBuf.Write([]byte(body))
2133 conn.closec = make(chan bool, 1)
2134
2135 ls := &oneConnListener{conn}
2136 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2137 if conn.readBuf.Len() < len(body)/2 {
2138 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2139 }
2140 rw.WriteHeader(200)
2141 rw.(Flusher).Flush()
2142 if conn.readBuf.Len() < len(body)/2 {
2143 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2144 }
2145 }))
2146 <-conn.closec
2147
2148 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2149 t.Errorf("Expected a Connection: close header; got response: %s", res)
2150 }
2151 }
2152
2153 type handlerBodyCloseTest struct {
2154 bodySize int
2155 bodyChunked bool
2156 reqConnClose bool
2157
2158 wantEOFSearch bool
2159 wantNextReq bool
2160 }
2161
2162 func (t handlerBodyCloseTest) connectionHeader() string {
2163 if t.reqConnClose {
2164 return "Connection: close\r\n"
2165 }
2166 return ""
2167 }
2168
2169 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2170
2171
2172 0: {
2173 bodySize: 20 << 10,
2174 bodyChunked: false,
2175 reqConnClose: false,
2176 wantEOFSearch: true,
2177 wantNextReq: true,
2178 },
2179
2180
2181
2182 1: {
2183 bodySize: 20 << 10,
2184 bodyChunked: true,
2185 reqConnClose: false,
2186 wantEOFSearch: true,
2187 wantNextReq: true,
2188 },
2189
2190
2191
2192
2193 2: {
2194 bodySize: 20 << 10,
2195 bodyChunked: false,
2196 reqConnClose: true,
2197 wantEOFSearch: false,
2198 wantNextReq: false,
2199 },
2200
2201
2202
2203
2204
2205
2206 3: {
2207 bodySize: 20 << 10,
2208 bodyChunked: true,
2209 reqConnClose: true,
2210 wantEOFSearch: true,
2211 wantNextReq: false,
2212 },
2213
2214
2215 4: {
2216 bodySize: 1 << 20,
2217 bodyChunked: false,
2218 reqConnClose: false,
2219 wantEOFSearch: false,
2220 wantNextReq: false,
2221 },
2222
2223
2224 5: {
2225 bodySize: 1 << 20,
2226 bodyChunked: true,
2227 reqConnClose: false,
2228 wantEOFSearch: true,
2229 wantNextReq: false,
2230 },
2231
2232
2233
2234
2235 6: {
2236 bodySize: 1 << 20,
2237 bodyChunked: true,
2238 reqConnClose: true,
2239 wantEOFSearch: true,
2240 wantNextReq: false,
2241 },
2242
2243
2244
2245 7: {
2246 bodySize: 1 << 20,
2247 bodyChunked: false,
2248 reqConnClose: true,
2249 wantEOFSearch: false,
2250 wantNextReq: false,
2251 },
2252 }
2253
2254 func TestHandlerBodyClose(t *testing.T) {
2255 setParallel(t)
2256 if testing.Short() && testenv.Builder() == "" {
2257 t.Skip("skipping in -short mode")
2258 }
2259 for i, tt := range handlerBodyCloseTests {
2260 testHandlerBodyClose(t, i, tt)
2261 }
2262 }
2263
2264 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2265 conn := new(testConn)
2266 body := strings.Repeat("x", tt.bodySize)
2267 if tt.bodyChunked {
2268 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2269 "Host: test\r\n" +
2270 tt.connectionHeader() +
2271 "Transfer-Encoding: chunked\r\n" +
2272 "\r\n")
2273 cw := internal.NewChunkedWriter(&conn.readBuf)
2274 io.WriteString(cw, body)
2275 cw.Close()
2276 conn.readBuf.WriteString("\r\n")
2277 } else {
2278 conn.readBuf.Write([]byte(fmt.Sprintf(
2279 "POST / HTTP/1.1\r\n"+
2280 "Host: test\r\n"+
2281 tt.connectionHeader()+
2282 "Content-Length: %d\r\n"+
2283 "\r\n", len(body))))
2284 conn.readBuf.Write([]byte(body))
2285 }
2286 if !tt.reqConnClose {
2287 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2288 }
2289 conn.closec = make(chan bool, 1)
2290
2291 readBufLen := func() int {
2292 conn.readMu.Lock()
2293 defer conn.readMu.Unlock()
2294 return conn.readBuf.Len()
2295 }
2296
2297 ls := &oneConnListener{conn}
2298 var numReqs int
2299 var size0, size1 int
2300 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2301 numReqs++
2302 if numReqs == 1 {
2303 size0 = readBufLen()
2304 req.Body.Close()
2305 size1 = readBufLen()
2306 }
2307 }))
2308 <-conn.closec
2309 if numReqs < 1 || numReqs > 2 {
2310 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2311 }
2312 didSearch := size0 != size1
2313 if didSearch != tt.wantEOFSearch {
2314 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2315 }
2316 if tt.wantNextReq && numReqs != 2 {
2317 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2318 }
2319 }
2320
2321
2322
2323 type testHandlerBodyConsumer struct {
2324 name string
2325 f func(io.ReadCloser)
2326 }
2327
2328 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2329 {"nil", func(io.ReadCloser) {}},
2330 {"close", func(r io.ReadCloser) { r.Close() }},
2331 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2332 }
2333
2334 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2335 setParallel(t)
2336 defer afterTest(t)
2337 for _, handler := range testHandlerBodyConsumers {
2338 conn := new(testConn)
2339 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2340 "Host: test\r\n" +
2341 "Transfer-Encoding: chunked\r\n" +
2342 "\r\n" +
2343 "hax\r\n" +
2344 "GET /secret HTTP/1.1\r\n" +
2345 "Host: test\r\n" +
2346 "\r\n")
2347
2348 conn.closec = make(chan bool, 1)
2349 ls := &oneConnListener{conn}
2350 var numReqs int
2351 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2352 numReqs++
2353 if strings.Contains(req.URL.Path, "secret") {
2354 t.Error("Request for /secret encountered, should not have happened.")
2355 }
2356 handler.f(req.Body)
2357 }))
2358 <-conn.closec
2359 if numReqs != 1 {
2360 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2361 }
2362 }
2363 }
2364
2365 func TestInvalidTrailerClosesConnection(t *testing.T) {
2366 setParallel(t)
2367 defer afterTest(t)
2368 for _, handler := range testHandlerBodyConsumers {
2369 conn := new(testConn)
2370 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2371 "Host: test\r\n" +
2372 "Trailer: hack\r\n" +
2373 "Transfer-Encoding: chunked\r\n" +
2374 "\r\n" +
2375 "3\r\n" +
2376 "hax\r\n" +
2377 "0\r\n" +
2378 "I'm not a valid trailer\r\n" +
2379 "GET /secret HTTP/1.1\r\n" +
2380 "Host: test\r\n" +
2381 "\r\n")
2382
2383 conn.closec = make(chan bool, 1)
2384 ln := &oneConnListener{conn}
2385 var numReqs int
2386 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2387 numReqs++
2388 if strings.Contains(req.URL.Path, "secret") {
2389 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2390 }
2391 handler.f(req.Body)
2392 }))
2393 <-conn.closec
2394 if numReqs != 1 {
2395 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2396 }
2397 }
2398 }
2399
2400
2401
2402
2403 type slowTestConn struct {
2404
2405 script []any
2406 closec chan bool
2407
2408 mu sync.Mutex
2409 rd, wd time.Time
2410 noopConn
2411 }
2412
2413 func (c *slowTestConn) SetDeadline(t time.Time) error {
2414 c.SetReadDeadline(t)
2415 c.SetWriteDeadline(t)
2416 return nil
2417 }
2418
2419 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2420 c.mu.Lock()
2421 defer c.mu.Unlock()
2422 c.rd = t
2423 return nil
2424 }
2425
2426 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2427 c.mu.Lock()
2428 defer c.mu.Unlock()
2429 c.wd = t
2430 return nil
2431 }
2432
2433 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2434 c.mu.Lock()
2435 defer c.mu.Unlock()
2436 restart:
2437 if !c.rd.IsZero() && time.Now().After(c.rd) {
2438 return 0, syscall.ETIMEDOUT
2439 }
2440 if len(c.script) == 0 {
2441 return 0, io.EOF
2442 }
2443
2444 switch cue := c.script[0].(type) {
2445 case time.Duration:
2446 if !c.rd.IsZero() {
2447
2448
2449 if remaining := time.Until(c.rd); remaining < cue {
2450 c.script[0] = cue - remaining
2451 time.Sleep(remaining)
2452 return 0, syscall.ETIMEDOUT
2453 }
2454 }
2455 c.script = c.script[1:]
2456 time.Sleep(cue)
2457 goto restart
2458
2459 case string:
2460 n = copy(b, cue)
2461
2462 if len(cue) > n {
2463 c.script[0] = cue[n:]
2464 } else {
2465 c.script = c.script[1:]
2466 }
2467
2468 default:
2469 panic("unknown cue in slowTestConn script")
2470 }
2471
2472 return
2473 }
2474
2475 func (c *slowTestConn) Close() error {
2476 select {
2477 case c.closec <- true:
2478 default:
2479 }
2480 return nil
2481 }
2482
2483 func (c *slowTestConn) Write(b []byte) (int, error) {
2484 if !c.wd.IsZero() && time.Now().After(c.wd) {
2485 return 0, syscall.ETIMEDOUT
2486 }
2487 return len(b), nil
2488 }
2489
2490 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2491 if testing.Short() {
2492 t.Skip("skipping in -short mode")
2493 }
2494 defer afterTest(t)
2495 for _, handler := range testHandlerBodyConsumers {
2496 conn := &slowTestConn{
2497 script: []any{
2498 "POST /public HTTP/1.1\r\n" +
2499 "Host: test\r\n" +
2500 "Content-Length: 10000\r\n" +
2501 "\r\n",
2502 "foo bar baz",
2503 600 * time.Millisecond,
2504 "GET /secret HTTP/1.1\r\n" +
2505 "Host: test\r\n" +
2506 "\r\n",
2507 },
2508 closec: make(chan bool, 1),
2509 }
2510 ls := &oneConnListener{conn}
2511
2512 var numReqs int
2513 s := Server{
2514 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2515 numReqs++
2516 if strings.Contains(req.URL.Path, "secret") {
2517 t.Error("Request for /secret encountered, should not have happened.")
2518 }
2519 handler.f(req.Body)
2520 }),
2521 ReadTimeout: 400 * time.Millisecond,
2522 }
2523 go s.Serve(ls)
2524 <-conn.closec
2525
2526 if numReqs != 1 {
2527 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2528 }
2529 }
2530 }
2531
2532
2533 type cancelableTimeoutContext struct {
2534 context.Context
2535 }
2536
2537 func (c cancelableTimeoutContext) Err() error {
2538 if c.Context.Err() != nil {
2539 return context.DeadlineExceeded
2540 }
2541 return nil
2542 }
2543
2544 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2545 func testTimeoutHandler(t *testing.T, mode testMode) {
2546 sendHi := make(chan bool, 1)
2547 writeErrors := make(chan error, 1)
2548 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2549 <-sendHi
2550 _, werr := w.Write([]byte("hi"))
2551 writeErrors <- werr
2552 })
2553 ctx, cancel := context.WithCancel(context.Background())
2554 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2555 cst := newClientServerTest(t, mode, h)
2556
2557
2558 sendHi <- true
2559 res, err := cst.c.Get(cst.ts.URL)
2560 if err != nil {
2561 t.Error(err)
2562 }
2563 if g, e := res.StatusCode, StatusOK; g != e {
2564 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2565 }
2566 body, _ := io.ReadAll(res.Body)
2567 if g, e := string(body), "hi"; g != e {
2568 t.Errorf("got body %q; expected %q", g, e)
2569 }
2570 if g := <-writeErrors; g != nil {
2571 t.Errorf("got unexpected Write error on first request: %v", g)
2572 }
2573
2574
2575 cancel()
2576
2577 res, err = cst.c.Get(cst.ts.URL)
2578 if err != nil {
2579 t.Error(err)
2580 }
2581 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2582 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2583 }
2584 body, _ = io.ReadAll(res.Body)
2585 if !strings.Contains(string(body), "<title>Timeout</title>") {
2586 t.Errorf("expected timeout body; got %q", string(body))
2587 }
2588 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2589 t.Errorf("response content-type = %q; want %q", g, w)
2590 }
2591
2592
2593
2594 sendHi <- true
2595 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2596 t.Errorf("expected Write error of %v; got %v", e, g)
2597 }
2598 }
2599
2600
2601 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2602 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2603 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2604 ms, _ := strconv.Atoi(r.URL.Path[1:])
2605 if ms == 0 {
2606 ms = 1
2607 }
2608 for i := 0; i < ms; i++ {
2609 w.Write([]byte("hi"))
2610 time.Sleep(time.Millisecond)
2611 }
2612 })
2613
2614 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2615
2616 c := ts.Client()
2617
2618 var wg sync.WaitGroup
2619 gate := make(chan bool, 10)
2620 n := 50
2621 if testing.Short() {
2622 n = 10
2623 gate = make(chan bool, 3)
2624 }
2625 for i := 0; i < n; i++ {
2626 gate <- true
2627 wg.Add(1)
2628 go func() {
2629 defer wg.Done()
2630 defer func() { <-gate }()
2631 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2632 if err == nil {
2633 io.Copy(io.Discard, res.Body)
2634 res.Body.Close()
2635 }
2636 }()
2637 }
2638 wg.Wait()
2639 }
2640
2641
2642
2643 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2644 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2645 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2646 w.WriteHeader(204)
2647 })
2648
2649 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2650
2651 var wg sync.WaitGroup
2652 gate := make(chan bool, 50)
2653 n := 500
2654 if testing.Short() {
2655 n = 10
2656 }
2657
2658 c := ts.Client()
2659 for i := 0; i < n; i++ {
2660 gate <- true
2661 wg.Add(1)
2662 go func() {
2663 defer wg.Done()
2664 defer func() { <-gate }()
2665 res, err := c.Get(ts.URL)
2666 if err != nil {
2667
2668
2669 t.Log(err)
2670 return
2671 }
2672 defer res.Body.Close()
2673 io.Copy(io.Discard, res.Body)
2674 }()
2675 }
2676 wg.Wait()
2677 }
2678
2679
2680 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2681 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2682 sendHi := make(chan bool, 1)
2683 writeErrors := make(chan error, 1)
2684 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2685 w.Header().Set("Content-Type", "text/plain")
2686 <-sendHi
2687 _, werr := w.Write([]byte("hi"))
2688 writeErrors <- werr
2689 })
2690 ctx, cancel := context.WithCancel(context.Background())
2691 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2692 cst := newClientServerTest(t, mode, h)
2693
2694
2695 sendHi <- true
2696 res, err := cst.c.Get(cst.ts.URL)
2697 if err != nil {
2698 t.Error(err)
2699 }
2700 if g, e := res.StatusCode, StatusOK; g != e {
2701 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2702 }
2703 body, _ := io.ReadAll(res.Body)
2704 if g, e := string(body), "hi"; g != e {
2705 t.Errorf("got body %q; expected %q", g, e)
2706 }
2707 if g := <-writeErrors; g != nil {
2708 t.Errorf("got unexpected Write error on first request: %v", g)
2709 }
2710
2711
2712 cancel()
2713
2714 res, err = cst.c.Get(cst.ts.URL)
2715 if err != nil {
2716 t.Error(err)
2717 }
2718 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2719 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2720 }
2721 body, _ = io.ReadAll(res.Body)
2722 if !strings.Contains(string(body), "<title>Timeout</title>") {
2723 t.Errorf("expected timeout body; got %q", string(body))
2724 }
2725
2726
2727
2728 sendHi <- true
2729 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2730 t.Errorf("expected Write error of %v; got %v", e, g)
2731 }
2732 }
2733
2734
2735 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2736 run(t, testTimeoutHandlerStartTimerWhenServing)
2737 }
2738 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2739 if testing.Short() {
2740 t.Skip("skipping sleeping test in -short mode")
2741 }
2742 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2743 w.WriteHeader(StatusNoContent)
2744 }
2745 timeout := 300 * time.Millisecond
2746 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2747 defer ts.Close()
2748
2749 c := ts.Client()
2750
2751
2752
2753
2754 time.Sleep(2 * timeout)
2755 res, err := c.Get(ts.URL)
2756 if err != nil {
2757 t.Fatal(err)
2758 }
2759 defer res.Body.Close()
2760 if res.StatusCode != StatusNoContent {
2761 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2762 }
2763 }
2764
2765 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2766 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2767 writeErrors := make(chan error, 1)
2768 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2769 w.Header().Set("Content-Type", "text/plain")
2770 var err error
2771
2772
2773
2774 for i := 0; i < 100; i++ {
2775 _, err = w.Write([]byte("a"))
2776 if err != nil {
2777 break
2778 }
2779 time.Sleep(1 * time.Millisecond)
2780 }
2781 writeErrors <- err
2782 })
2783 ctx, cancel := context.WithCancel(context.Background())
2784 cancel()
2785 h := NewTestTimeoutHandler(sayHi, ctx)
2786 cst := newClientServerTest(t, mode, h)
2787 defer cst.close()
2788
2789 res, err := cst.c.Get(cst.ts.URL)
2790 if err != nil {
2791 t.Error(err)
2792 }
2793 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2794 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2795 }
2796 body, _ := io.ReadAll(res.Body)
2797 if g, e := string(body), ""; g != e {
2798 t.Errorf("got body %q; expected %q", g, e)
2799 }
2800 if g, e := <-writeErrors, context.Canceled; g != e {
2801 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2802 }
2803 }
2804
2805
2806 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2807 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2808 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2809
2810 }
2811 timeout := 300 * time.Millisecond
2812 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2813
2814 c := ts.Client()
2815
2816 res, err := c.Get(ts.URL)
2817 if err != nil {
2818 t.Fatal(err)
2819 }
2820 defer res.Body.Close()
2821 if res.StatusCode != StatusOK {
2822 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2823 }
2824 }
2825
2826
2827 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2828 wrapper := func(h Handler) Handler {
2829 return TimeoutHandler(h, time.Second, "")
2830 }
2831 run(t, func(t *testing.T, mode testMode) {
2832 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2833 }, testNotParallel)
2834 }
2835
2836 func TestRedirectBadPath(t *testing.T) {
2837
2838
2839 rr := httptest.NewRecorder()
2840 req := &Request{
2841 Method: "GET",
2842 URL: &url.URL{
2843 Scheme: "http",
2844 Path: "not-empty-but-no-leading-slash",
2845 },
2846 }
2847 Redirect(rr, req, "", 304)
2848 if rr.Code != 304 {
2849 t.Errorf("Code = %d; want 304", rr.Code)
2850 }
2851 }
2852
2853
2854 func TestRedirect(t *testing.T) {
2855 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2856
2857 var tests = []struct {
2858 in string
2859 want string
2860 }{
2861
2862 {"http://foobar.com/baz", "http://foobar.com/baz"},
2863
2864 {"https://foobar.com/baz", "https://foobar.com/baz"},
2865
2866 {"test://foobar.com/baz", "test://foobar.com/baz"},
2867
2868 {"//foobar.com/baz", "//foobar.com/baz"},
2869
2870 {"/foobar.com/baz", "/foobar.com/baz"},
2871
2872 {"foobar.com/baz", "/qux/foobar.com/baz"},
2873
2874 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2875
2876 {"///foobar.com/baz", "/foobar.com/baz"},
2877
2878
2879 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2880 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2881 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2882
2883 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2884 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2885 }
2886
2887 for _, tt := range tests {
2888 rec := httptest.NewRecorder()
2889 Redirect(rec, req, tt.in, 302)
2890 if got, want := rec.Code, 302; got != want {
2891 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2892 }
2893 if got := rec.Header().Get("Location"); got != tt.want {
2894 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2895 }
2896 }
2897 }
2898
2899
2900
2901 func TestRedirectContentTypeAndBody(t *testing.T) {
2902 type ctHeader struct {
2903 Values []string
2904 }
2905
2906 var tests = []struct {
2907 method string
2908 ct *ctHeader
2909 wantCT string
2910 wantBody string
2911 }{
2912 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2913 {MethodHead, nil, "text/html; charset=utf-8", ""},
2914 {MethodPost, nil, "", ""},
2915 {MethodDelete, nil, "", ""},
2916 {"foo", nil, "", ""},
2917 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2918 {MethodGet, &ctHeader{[]string{}}, "", ""},
2919 {MethodGet, &ctHeader{nil}, "", ""},
2920 }
2921 for _, tt := range tests {
2922 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2923 rec := httptest.NewRecorder()
2924 if tt.ct != nil {
2925 rec.Header()["Content-Type"] = tt.ct.Values
2926 }
2927 Redirect(rec, req, "/foo", 302)
2928 if got, want := rec.Code, 302; got != want {
2929 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2930 }
2931 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2932 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2933 }
2934 resp := rec.Result()
2935 body, err := io.ReadAll(resp.Body)
2936 if err != nil {
2937 t.Fatal(err)
2938 }
2939 if got, want := string(body), tt.wantBody; got != want {
2940 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2941 }
2942 }
2943 }
2944
2945
2946
2947
2948
2949
2950
2951 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
2952
2953 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
2954 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
2955 all, err := io.ReadAll(r.Body)
2956 if err != nil {
2957 t.Fatalf("handler ReadAll: %v", err)
2958 }
2959 if len(all) != 0 {
2960 t.Errorf("handler got %d bytes; expected 0", len(all))
2961 }
2962 rw.Header().Set("Content-Length", "0")
2963 }))
2964
2965 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2966 if err != nil {
2967 t.Fatal(err)
2968 }
2969 req.ContentLength = 0
2970
2971 var resp [5]*Response
2972 for i := range resp {
2973 resp[i], err = cst.c.Do(req)
2974 if err != nil {
2975 t.Fatalf("client post #%d: %v", i, err)
2976 }
2977 }
2978
2979 for i := range resp {
2980 all, err := io.ReadAll(resp[i].Body)
2981 if err != nil {
2982 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2983 }
2984 if len(all) != 0 {
2985 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2986 }
2987 }
2988 }
2989
2990 func TestHandlerPanicNil(t *testing.T) {
2991 run(t, func(t *testing.T, mode testMode) {
2992 testHandlerPanic(t, false, mode, nil, nil)
2993 }, testNotParallel)
2994 }
2995
2996 func TestHandlerPanic(t *testing.T) {
2997 run(t, func(t *testing.T, mode testMode) {
2998 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
2999 }, testNotParallel)
3000 }
3001
3002 func TestHandlerPanicWithHijack(t *testing.T) {
3003
3004 run(t, func(t *testing.T, mode testMode) {
3005 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
3006 }, []testMode{http1Mode})
3007 }
3008
3009 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
3010
3011
3012
3013
3014
3015
3016
3017
3018 pr, pw := io.Pipe()
3019 defer pw.Close()
3020
3021 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
3022 if withHijack {
3023 rwc, _, err := w.(Hijacker).Hijack()
3024 if err != nil {
3025 t.Logf("unexpected error: %v", err)
3026 }
3027 defer rwc.Close()
3028 }
3029 panic(panicValue)
3030 })
3031 if wrapper != nil {
3032 handler = wrapper(handler)
3033 }
3034 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
3035 ts.Config.ErrorLog = log.New(pw, "", 0)
3036 })
3037
3038
3039 done := make(chan bool, 1)
3040 go func() {
3041 buf := make([]byte, 4<<10)
3042 _, err := pr.Read(buf)
3043 pr.Close()
3044 if err != nil && err != io.EOF {
3045 t.Error(err)
3046 }
3047 done <- true
3048 }()
3049
3050 _, err := cst.c.Get(cst.ts.URL)
3051 if err == nil {
3052 t.Logf("expected an error")
3053 }
3054
3055 if panicValue == nil {
3056 return
3057 }
3058
3059 <-done
3060 }
3061
3062 type terrorWriter struct{ t *testing.T }
3063
3064 func (w terrorWriter) Write(p []byte) (int, error) {
3065 w.t.Errorf("%s", p)
3066 return len(p), nil
3067 }
3068
3069
3070
3071 func TestServerWriteHijackZeroBytes(t *testing.T) {
3072 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
3073 }
3074 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
3075 done := make(chan struct{})
3076 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3077 defer close(done)
3078 w.(Flusher).Flush()
3079 conn, _, err := w.(Hijacker).Hijack()
3080 if err != nil {
3081 t.Errorf("Hijack: %v", err)
3082 return
3083 }
3084 defer conn.Close()
3085 _, err = w.Write(nil)
3086 if err != ErrHijacked {
3087 t.Errorf("Write error = %v; want ErrHijacked", err)
3088 }
3089 }), func(ts *httptest.Server) {
3090 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
3091 }).ts
3092
3093 c := ts.Client()
3094 res, err := c.Get(ts.URL)
3095 if err != nil {
3096 t.Fatal(err)
3097 }
3098 res.Body.Close()
3099 <-done
3100 }
3101
3102 func TestServerNoDate(t *testing.T) {
3103 run(t, func(t *testing.T, mode testMode) {
3104 testServerNoHeader(t, mode, "Date")
3105 })
3106 }
3107
3108 func TestServerContentType(t *testing.T) {
3109 run(t, func(t *testing.T, mode testMode) {
3110 testServerNoHeader(t, mode, "Content-Type")
3111 })
3112 }
3113
3114 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3115 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3116 w.Header()[header] = nil
3117 io.WriteString(w, "<html>foo</html>")
3118 }))
3119 res, err := cst.c.Get(cst.ts.URL)
3120 if err != nil {
3121 t.Fatal(err)
3122 }
3123 res.Body.Close()
3124 if got, ok := res.Header[header]; ok {
3125 t.Fatalf("Expected no %s header; got %q", header, got)
3126 }
3127 }
3128
3129 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3130 func testStripPrefix(t *testing.T, mode testMode) {
3131 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3132 w.Header().Set("X-Path", r.URL.Path)
3133 w.Header().Set("X-RawPath", r.URL.RawPath)
3134 })
3135 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3136
3137 c := ts.Client()
3138
3139 cases := []struct {
3140 reqPath string
3141 path string
3142 rawPath string
3143 }{
3144 {"/foo/bar/qux", "/qux", ""},
3145 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3146 {"/foo%2Fbar/qux", "", ""},
3147 {"/bar", "", ""},
3148 }
3149 for _, tc := range cases {
3150 t.Run(tc.reqPath, func(t *testing.T) {
3151 res, err := c.Get(ts.URL + tc.reqPath)
3152 if err != nil {
3153 t.Fatal(err)
3154 }
3155 res.Body.Close()
3156 if tc.path == "" {
3157 if res.StatusCode != StatusNotFound {
3158 t.Errorf("got %q, want 404 Not Found", res.Status)
3159 }
3160 return
3161 }
3162 if res.StatusCode != StatusOK {
3163 t.Fatalf("got %q, want 200 OK", res.Status)
3164 }
3165 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3166 t.Errorf("got Path %q, want %q", g, w)
3167 }
3168 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3169 t.Errorf("got RawPath %q, want %q", g, w)
3170 }
3171 })
3172 }
3173 }
3174
3175
3176 func TestStripPrefixNotModifyRequest(t *testing.T) {
3177 h := StripPrefix("/foo", NotFoundHandler())
3178 req := httptest.NewRequest("GET", "/foo/bar", nil)
3179 h.ServeHTTP(httptest.NewRecorder(), req)
3180 if req.URL.Path != "/foo/bar" {
3181 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3182 }
3183 }
3184
3185 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
3186 func testRequestLimit(t *testing.T, mode testMode) {
3187 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3188 t.Fatalf("didn't expect to get request in Handler")
3189 }), optQuietLog)
3190 req, _ := NewRequest("GET", cst.ts.URL, nil)
3191 var bytesPerHeader = len("header12345: val12345\r\n")
3192 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3193 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3194 }
3195 res, err := cst.c.Do(req)
3196 if res != nil {
3197 defer res.Body.Close()
3198 }
3199 if mode == http2Mode {
3200
3201
3202
3203
3204 if err == nil && res.StatusCode != 431 {
3205 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3206 }
3207 } else {
3208
3209
3210
3211
3212 if err != nil {
3213 t.Fatalf("Do: %v", err)
3214 }
3215 if res.StatusCode != 431 {
3216 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3217 }
3218 }
3219 }
3220
3221 type neverEnding byte
3222
3223 func (b neverEnding) Read(p []byte) (n int, err error) {
3224 for i := range p {
3225 p[i] = byte(b)
3226 }
3227 return len(p), nil
3228 }
3229
3230 type bodyLimitReader struct {
3231 mu sync.Mutex
3232 count int
3233 limit int
3234 closed chan struct{}
3235 }
3236
3237 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3238 r.mu.Lock()
3239 defer r.mu.Unlock()
3240 select {
3241 case <-r.closed:
3242 return 0, errors.New("closed")
3243 default:
3244 }
3245 if r.count > r.limit {
3246 return 0, errors.New("at limit")
3247 }
3248 r.count += len(p)
3249 for i := range p {
3250 p[i] = 'a'
3251 }
3252 return len(p), nil
3253 }
3254
3255 func (r *bodyLimitReader) Close() error {
3256 r.mu.Lock()
3257 defer r.mu.Unlock()
3258 close(r.closed)
3259 return nil
3260 }
3261
3262 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3263 func testRequestBodyLimit(t *testing.T, mode testMode) {
3264 const limit = 1 << 20
3265 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3266 r.Body = MaxBytesReader(w, r.Body, limit)
3267 n, err := io.Copy(io.Discard, r.Body)
3268 if err == nil {
3269 t.Errorf("expected error from io.Copy")
3270 }
3271 if n != limit {
3272 t.Errorf("io.Copy = %d, want %d", n, limit)
3273 }
3274 mbErr, ok := err.(*MaxBytesError)
3275 if !ok {
3276 t.Errorf("expected MaxBytesError, got %T", err)
3277 }
3278 if mbErr.Limit != limit {
3279 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3280 }
3281 }))
3282
3283 body := &bodyLimitReader{
3284 closed: make(chan struct{}),
3285 limit: limit * 200,
3286 }
3287 req, _ := NewRequest("POST", cst.ts.URL, body)
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298 resp, err := cst.c.Do(req)
3299 if err == nil {
3300 resp.Body.Close()
3301 }
3302
3303
3304 <-body.closed
3305
3306 if body.count > limit*100 {
3307 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3308 limit, body.count)
3309 }
3310 }
3311
3312
3313
3314 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3315 func testClientWriteShutdown(t *testing.T, mode testMode) {
3316 if runtime.GOOS == "plan9" {
3317 t.Skip("skipping test; see https://golang.org/issue/17906")
3318 }
3319 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3320 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3321 if err != nil {
3322 t.Fatalf("Dial: %v", err)
3323 }
3324 err = conn.(*net.TCPConn).CloseWrite()
3325 if err != nil {
3326 t.Fatalf("CloseWrite: %v", err)
3327 }
3328
3329 bs, err := io.ReadAll(conn)
3330 if err != nil {
3331 t.Errorf("ReadAll: %v", err)
3332 }
3333 got := string(bs)
3334 if got != "" {
3335 t.Errorf("read %q from server; want nothing", got)
3336 }
3337 }
3338
3339
3340
3341 func TestServerBufferedChunking(t *testing.T) {
3342 conn := new(testConn)
3343 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3344 conn.closec = make(chan bool, 1)
3345 ls := &oneConnListener{conn}
3346 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3347 rw.(Flusher).Flush()
3348 rw.Write([]byte{'x'})
3349 rw.Write([]byte{'y'})
3350 rw.Write([]byte{'z'})
3351 }))
3352 <-conn.closec
3353 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3354 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3355 conn.writeBuf.Bytes())
3356 }
3357 }
3358
3359
3360
3361
3362
3363 func TestServerGracefulClose(t *testing.T) {
3364
3365 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3366 }
3367 func testServerGracefulClose(t *testing.T, mode testMode) {
3368 runTimeSensitiveTest(t, []time.Duration{
3369 1 * time.Millisecond,
3370 5 * time.Millisecond,
3371 10 * time.Millisecond,
3372 50 * time.Millisecond,
3373 100 * time.Millisecond,
3374 500 * time.Millisecond,
3375 time.Second,
3376 5 * time.Second,
3377 }, func(t *testing.T, timeout time.Duration) error {
3378 SetRSTAvoidanceDelay(t, timeout)
3379 t.Logf("set RST avoidance delay to %v", timeout)
3380
3381 const bodySize = 5 << 20
3382 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3383 for i := 0; i < bodySize; i++ {
3384 req = append(req, 'x')
3385 }
3386
3387 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3388 Error(w, "bye", StatusUnauthorized)
3389 }))
3390
3391
3392 defer cst.close()
3393 ts := cst.ts
3394
3395 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3396 if err != nil {
3397 return err
3398 }
3399 writeErr := make(chan error)
3400 go func() {
3401 _, err := conn.Write(req)
3402 writeErr <- err
3403 }()
3404 defer func() {
3405 conn.Close()
3406
3407
3408
3409 <-writeErr
3410 }()
3411
3412 br := bufio.NewReader(conn)
3413 lineNum := 0
3414 for {
3415 line, err := br.ReadString('\n')
3416 if err == io.EOF {
3417 break
3418 }
3419 if err != nil {
3420 return fmt.Errorf("ReadLine: %v", err)
3421 }
3422 lineNum++
3423 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3424 t.Errorf("Response line = %q; want a 401", line)
3425 }
3426 }
3427 return nil
3428 })
3429 }
3430
3431 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3432 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3433 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3434 if r.Method != "get" {
3435 t.Errorf(`Got method %q; want "get"`, r.Method)
3436 }
3437 }))
3438 defer cst.close()
3439 req, _ := NewRequest("get", cst.ts.URL, nil)
3440 res, err := cst.c.Do(req)
3441 if err != nil {
3442 t.Error(err)
3443 return
3444 }
3445
3446 res.Body.Close()
3447 }
3448
3449
3450
3451
3452
3453 func TestContentLengthZero(t *testing.T) {
3454 run(t, testContentLengthZero, []testMode{http1Mode})
3455 }
3456 func testContentLengthZero(t *testing.T, mode testMode) {
3457 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3458
3459 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3460 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3461 if err != nil {
3462 t.Fatalf("error dialing: %v", err)
3463 }
3464 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3465 if err != nil {
3466 t.Fatalf("error writing: %v", err)
3467 }
3468 req, _ := NewRequest("GET", "/", nil)
3469 res, err := ReadResponse(bufio.NewReader(conn), req)
3470 if err != nil {
3471 t.Fatalf("error reading response: %v", err)
3472 }
3473 if te := res.TransferEncoding; len(te) > 0 {
3474 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3475 }
3476 if cl := res.ContentLength; cl != 0 {
3477 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3478 }
3479 conn.Close()
3480 }
3481 }
3482
3483 func TestCloseNotifier(t *testing.T) {
3484 run(t, testCloseNotifier, []testMode{http1Mode})
3485 }
3486 func testCloseNotifier(t *testing.T, mode testMode) {
3487 gotReq := make(chan bool, 1)
3488 sawClose := make(chan bool, 1)
3489 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3490 gotReq <- true
3491 cc := rw.(CloseNotifier).CloseNotify()
3492 <-cc
3493 sawClose <- true
3494 })).ts
3495 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3496 if err != nil {
3497 t.Fatalf("error dialing: %v", err)
3498 }
3499 diec := make(chan bool)
3500 go func() {
3501 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3502 if err != nil {
3503 t.Error(err)
3504 return
3505 }
3506 <-diec
3507 conn.Close()
3508 }()
3509 For:
3510 for {
3511 select {
3512 case <-gotReq:
3513 diec <- true
3514 case <-sawClose:
3515 break For
3516 }
3517 }
3518 ts.Close()
3519 }
3520
3521
3522
3523
3524
3525 func TestCloseNotifierPipelined(t *testing.T) {
3526 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3527 }
3528 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3529 gotReq := make(chan bool, 2)
3530 sawClose := make(chan bool, 2)
3531 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3532 gotReq <- true
3533 cc := rw.(CloseNotifier).CloseNotify()
3534 select {
3535 case <-cc:
3536 t.Error("unexpected CloseNotify")
3537 case <-time.After(100 * time.Millisecond):
3538 }
3539 sawClose <- true
3540 })).ts
3541 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3542 if err != nil {
3543 t.Fatalf("error dialing: %v", err)
3544 }
3545 diec := make(chan bool, 1)
3546 defer close(diec)
3547 go func() {
3548 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3549 _, err = io.WriteString(conn, req+req)
3550 if err != nil {
3551 t.Error(err)
3552 return
3553 }
3554 <-diec
3555 conn.Close()
3556 }()
3557 reqs := 0
3558 closes := 0
3559 for {
3560 select {
3561 case <-gotReq:
3562 reqs++
3563 if reqs > 2 {
3564 t.Fatal("too many requests")
3565 }
3566 case <-sawClose:
3567 closes++
3568 if closes > 1 {
3569 return
3570 }
3571 }
3572 }
3573 }
3574
3575 func TestCloseNotifierChanLeak(t *testing.T) {
3576 defer afterTest(t)
3577 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3578 for i := 0; i < 20; i++ {
3579 var output bytes.Buffer
3580 conn := &rwTestConn{
3581 Reader: bytes.NewReader(req),
3582 Writer: &output,
3583 closec: make(chan bool, 1),
3584 }
3585 ln := &oneConnListener{conn: conn}
3586 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3587
3588
3589
3590 _ = rw.(CloseNotifier).CloseNotify()
3591 })
3592 go Serve(ln, handler)
3593 <-conn.closec
3594 }
3595 }
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606 func TestHijackAfterCloseNotifier(t *testing.T) {
3607 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3608 }
3609 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3610 script := make(chan string, 2)
3611 script <- "closenotify"
3612 script <- "hijack"
3613 close(script)
3614 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3615 plan := <-script
3616 switch plan {
3617 default:
3618 panic("bogus plan; too many requests")
3619 case "closenotify":
3620 w.(CloseNotifier).CloseNotify()
3621 w.Header().Set("X-Addr", r.RemoteAddr)
3622 case "hijack":
3623 c, _, err := w.(Hijacker).Hijack()
3624 if err != nil {
3625 t.Errorf("Hijack in Handler: %v", err)
3626 return
3627 }
3628 if _, ok := c.(*net.TCPConn); !ok {
3629
3630
3631 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3632 }
3633 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3634 c.Close()
3635 return
3636 }
3637 })).ts
3638 res1, err := ts.Client().Get(ts.URL)
3639 if err != nil {
3640 log.Fatal(err)
3641 }
3642 res2, err := ts.Client().Get(ts.URL)
3643 if err != nil {
3644 log.Fatal(err)
3645 }
3646 addr1 := res1.Header.Get("X-Addr")
3647 addr2 := res2.Header.Get("X-Addr")
3648 if addr1 == "" || addr1 != addr2 {
3649 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3650 }
3651 }
3652
3653 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3654 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3655 }
3656 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3657 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3658 bodyOkay := make(chan bool, 1)
3659 gotCloseNotify := make(chan bool, 1)
3660 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3661 defer close(bodyOkay)
3662
3663 reqBody := r.Body
3664 r.Body = nil
3665
3666 gone := w.(CloseNotifier).CloseNotify()
3667 slurp, err := io.ReadAll(reqBody)
3668 if err != nil {
3669 t.Errorf("Body read: %v", err)
3670 return
3671 }
3672 if len(slurp) != len(requestBody) {
3673 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3674 return
3675 }
3676 if !bytes.Equal(slurp, requestBody) {
3677 t.Error("Backend read wrong request body.")
3678 return
3679 }
3680 bodyOkay <- true
3681 <-gone
3682 gotCloseNotify <- true
3683 })).ts
3684
3685 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3686 if err != nil {
3687 t.Fatal(err)
3688 }
3689 defer conn.Close()
3690
3691 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3692 len(requestBody), requestBody)
3693 if !<-bodyOkay {
3694
3695 return
3696 }
3697 conn.Close()
3698 <-gotCloseNotify
3699 }
3700
3701 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3702 func testOptions(t *testing.T, mode testMode) {
3703 uric := make(chan string, 2)
3704 mux := NewServeMux()
3705 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3706 uric <- r.RequestURI
3707 })
3708 ts := newClientServerTest(t, mode, mux).ts
3709
3710 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3711 if err != nil {
3712 t.Fatal(err)
3713 }
3714 defer conn.Close()
3715
3716
3717 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3718 if err != nil {
3719 t.Fatal(err)
3720 }
3721 br := bufio.NewReader(conn)
3722 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3723 if err != nil {
3724 t.Fatal(err)
3725 }
3726 if res.StatusCode != 200 {
3727 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3728 }
3729
3730
3731 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3732 if err != nil {
3733 t.Fatal(err)
3734 }
3735 res, err = ReadResponse(br, &Request{Method: "GET"})
3736 if err != nil {
3737 t.Fatal(err)
3738 }
3739 if res.StatusCode != 400 {
3740 t.Errorf("Got non-400 response to GET *: %#v", res)
3741 }
3742
3743 res, err = Get(ts.URL + "/second")
3744 if err != nil {
3745 t.Fatal(err)
3746 }
3747 res.Body.Close()
3748 if got := <-uric; got != "/second" {
3749 t.Errorf("Handler saw request for %q; want /second", got)
3750 }
3751 }
3752
3753 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3754 func testOptionsHandler(t *testing.T, mode testMode) {
3755 rc := make(chan *Request, 1)
3756
3757 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3758 rc <- r
3759 }), func(ts *httptest.Server) {
3760 ts.Config.DisableGeneralOptionsHandler = true
3761 }).ts
3762
3763 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3764 if err != nil {
3765 t.Fatal(err)
3766 }
3767 defer conn.Close()
3768
3769 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3770 if err != nil {
3771 t.Fatal(err)
3772 }
3773
3774 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3775 t.Errorf("Expected OPTIONS * request, got %v", got)
3776 }
3777 }
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788 func TestHeaderToWire(t *testing.T) {
3789 tests := []struct {
3790 name string
3791 handler func(ResponseWriter, *Request)
3792 check func(got, logs string) error
3793 }{
3794 {
3795 name: "write without Header",
3796 handler: func(rw ResponseWriter, r *Request) {
3797 rw.Write([]byte("hello world"))
3798 },
3799 check: func(got, logs string) error {
3800 if !strings.Contains(got, "Content-Length:") {
3801 return errors.New("no content-length")
3802 }
3803 if !strings.Contains(got, "Content-Type: text/plain") {
3804 return errors.New("no content-type")
3805 }
3806 return nil
3807 },
3808 },
3809 {
3810 name: "Header mutation before write",
3811 handler: func(rw ResponseWriter, r *Request) {
3812 h := rw.Header()
3813 h.Set("Content-Type", "some/type")
3814 rw.Write([]byte("hello world"))
3815 h.Set("Too-Late", "bogus")
3816 },
3817 check: func(got, logs string) error {
3818 if !strings.Contains(got, "Content-Length:") {
3819 return errors.New("no content-length")
3820 }
3821 if !strings.Contains(got, "Content-Type: some/type") {
3822 return errors.New("wrong content-type")
3823 }
3824 if strings.Contains(got, "Too-Late") {
3825 return errors.New("don't want too-late header")
3826 }
3827 return nil
3828 },
3829 },
3830 {
3831 name: "write then useless Header mutation",
3832 handler: func(rw ResponseWriter, r *Request) {
3833 rw.Write([]byte("hello world"))
3834 rw.Header().Set("Too-Late", "Write already wrote headers")
3835 },
3836 check: func(got, logs string) error {
3837 if strings.Contains(got, "Too-Late") {
3838 return errors.New("header appeared from after WriteHeader")
3839 }
3840 return nil
3841 },
3842 },
3843 {
3844 name: "flush then write",
3845 handler: func(rw ResponseWriter, r *Request) {
3846 rw.(Flusher).Flush()
3847 rw.Write([]byte("post-flush"))
3848 rw.Header().Set("Too-Late", "Write already wrote headers")
3849 },
3850 check: func(got, logs string) error {
3851 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3852 return errors.New("not chunked")
3853 }
3854 if strings.Contains(got, "Too-Late") {
3855 return errors.New("header appeared from after WriteHeader")
3856 }
3857 return nil
3858 },
3859 },
3860 {
3861 name: "header then flush",
3862 handler: func(rw ResponseWriter, r *Request) {
3863 rw.Header().Set("Content-Type", "some/type")
3864 rw.(Flusher).Flush()
3865 rw.Write([]byte("post-flush"))
3866 rw.Header().Set("Too-Late", "Write already wrote headers")
3867 },
3868 check: func(got, logs string) error {
3869 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3870 return errors.New("not chunked")
3871 }
3872 if strings.Contains(got, "Too-Late") {
3873 return errors.New("header appeared from after WriteHeader")
3874 }
3875 if !strings.Contains(got, "Content-Type: some/type") {
3876 return errors.New("wrong content-type")
3877 }
3878 return nil
3879 },
3880 },
3881 {
3882 name: "sniff-on-first-write content-type",
3883 handler: func(rw ResponseWriter, r *Request) {
3884 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3885 rw.Header().Set("Content-Type", "x/wrong")
3886 },
3887 check: func(got, logs string) error {
3888 if !strings.Contains(got, "Content-Type: text/html") {
3889 return errors.New("wrong content-type; want html")
3890 }
3891 return nil
3892 },
3893 },
3894 {
3895 name: "explicit content-type wins",
3896 handler: func(rw ResponseWriter, r *Request) {
3897 rw.Header().Set("Content-Type", "some/type")
3898 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3899 },
3900 check: func(got, logs string) error {
3901 if !strings.Contains(got, "Content-Type: some/type") {
3902 return errors.New("wrong content-type; want html")
3903 }
3904 return nil
3905 },
3906 },
3907 {
3908 name: "empty handler",
3909 handler: func(rw ResponseWriter, r *Request) {
3910 },
3911 check: func(got, logs string) error {
3912 if !strings.Contains(got, "Content-Length: 0") {
3913 return errors.New("want 0 content-length")
3914 }
3915 return nil
3916 },
3917 },
3918 {
3919 name: "only Header, no write",
3920 handler: func(rw ResponseWriter, r *Request) {
3921 rw.Header().Set("Some-Header", "some-value")
3922 },
3923 check: func(got, logs string) error {
3924 if !strings.Contains(got, "Some-Header") {
3925 return errors.New("didn't get header")
3926 }
3927 return nil
3928 },
3929 },
3930 {
3931 name: "WriteHeader call",
3932 handler: func(rw ResponseWriter, r *Request) {
3933 rw.WriteHeader(404)
3934 rw.Header().Set("Too-Late", "some-value")
3935 },
3936 check: func(got, logs string) error {
3937 if !strings.Contains(got, "404") {
3938 return errors.New("wrong status")
3939 }
3940 if strings.Contains(got, "Too-Late") {
3941 return errors.New("shouldn't have seen Too-Late")
3942 }
3943 return nil
3944 },
3945 },
3946 }
3947 for _, tc := range tests {
3948 ht := newHandlerTest(HandlerFunc(tc.handler))
3949 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3950 logs := ht.logbuf.String()
3951 if err := tc.check(got, logs); err != nil {
3952 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3953 }
3954 }
3955 }
3956
3957 type errorListener struct {
3958 errs []error
3959 }
3960
3961 func (l *errorListener) Accept() (c net.Conn, err error) {
3962 if len(l.errs) == 0 {
3963 return nil, io.EOF
3964 }
3965 err = l.errs[0]
3966 l.errs = l.errs[1:]
3967 return
3968 }
3969
3970 func (l *errorListener) Close() error {
3971 return nil
3972 }
3973
3974 func (l *errorListener) Addr() net.Addr {
3975 return dummyAddr("test-address")
3976 }
3977
3978 func TestAcceptMaxFds(t *testing.T) {
3979 setParallel(t)
3980
3981 ln := &errorListener{[]error{
3982 &net.OpError{
3983 Op: "accept",
3984 Err: syscall.EMFILE,
3985 }}}
3986 server := &Server{
3987 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3988 ErrorLog: log.New(io.Discard, "", 0),
3989 }
3990 err := server.Serve(ln)
3991 if err != io.EOF {
3992 t.Errorf("got error %v, want EOF", err)
3993 }
3994 }
3995
3996 func TestWriteAfterHijack(t *testing.T) {
3997 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3998 var buf strings.Builder
3999 wrotec := make(chan bool, 1)
4000 conn := &rwTestConn{
4001 Reader: bytes.NewReader(req),
4002 Writer: &buf,
4003 closec: make(chan bool, 1),
4004 }
4005 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4006 conn, bufrw, err := rw.(Hijacker).Hijack()
4007 if err != nil {
4008 t.Error(err)
4009 return
4010 }
4011 go func() {
4012 bufrw.Write([]byte("[hijack-to-bufw]"))
4013 bufrw.Flush()
4014 conn.Write([]byte("[hijack-to-conn]"))
4015 conn.Close()
4016 wrotec <- true
4017 }()
4018 })
4019 ln := &oneConnListener{conn: conn}
4020 go Serve(ln, handler)
4021 <-conn.closec
4022 <-wrotec
4023 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
4024 t.Errorf("wrote %q; want %q", g, w)
4025 }
4026 }
4027
4028 func TestDoubleHijack(t *testing.T) {
4029 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
4030 var buf bytes.Buffer
4031 conn := &rwTestConn{
4032 Reader: bytes.NewReader(req),
4033 Writer: &buf,
4034 closec: make(chan bool, 1),
4035 }
4036 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4037 conn, _, err := rw.(Hijacker).Hijack()
4038 if err != nil {
4039 t.Error(err)
4040 return
4041 }
4042 _, _, err = rw.(Hijacker).Hijack()
4043 if err == nil {
4044 t.Errorf("got err = nil; want err != nil")
4045 }
4046 conn.Close()
4047 })
4048 ln := &oneConnListener{conn: conn}
4049 go Serve(ln, handler)
4050 <-conn.closec
4051 }
4052
4053
4054
4055
4056
4057
4058
4059 func TestHTTP10ConnectionHeader(t *testing.T) {
4060 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
4061 }
4062 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
4063 mux := NewServeMux()
4064 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
4065 ts := newClientServerTest(t, mode, mux).ts
4066
4067
4068 tests := []struct {
4069 req string
4070 expect []string
4071 }{
4072 {
4073 req: "GET / HTTP/1.0\r\n\r\n",
4074 expect: nil,
4075 },
4076 {
4077 req: "OPTIONS * HTTP/1.0\r\n\r\n",
4078 expect: nil,
4079 },
4080 {
4081 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
4082 expect: []string{"keep-alive"},
4083 },
4084 }
4085
4086 for _, tt := range tests {
4087 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
4088 if err != nil {
4089 t.Fatal("dial err:", err)
4090 }
4091
4092 _, err = fmt.Fprint(conn, tt.req)
4093 if err != nil {
4094 t.Fatal("conn write err:", err)
4095 }
4096
4097 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4098 if err != nil {
4099 t.Fatal("ReadResponse err:", err)
4100 }
4101 conn.Close()
4102 resp.Body.Close()
4103
4104 got := resp.Header["Connection"]
4105 if !slices.Equal(got, tt.expect) {
4106 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4107 }
4108 }
4109 }
4110
4111
4112 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4113 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4114 pr, pw := io.Pipe()
4115 const size = 3 << 20
4116 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4117 rw.Header().Set("Content-Type", "text/plain")
4118 done := make(chan bool)
4119 go func() {
4120 io.Copy(rw, pr)
4121 close(done)
4122 }()
4123 time.Sleep(25 * time.Millisecond)
4124 n, err := io.Copy(io.Discard, req.Body)
4125 if err != nil {
4126 t.Errorf("handler Copy: %v", err)
4127 return
4128 }
4129 if n != size {
4130 t.Errorf("handler Copy = %d; want %d", n, size)
4131 }
4132 pw.Write([]byte("hi"))
4133 pw.Close()
4134 <-done
4135 }))
4136
4137 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4138 if err != nil {
4139 t.Fatal(err)
4140 }
4141 res, err := cst.c.Do(req)
4142 if err != nil {
4143 t.Fatal(err)
4144 }
4145 all, err := io.ReadAll(res.Body)
4146 if err != nil {
4147 t.Fatal(err)
4148 }
4149 res.Body.Close()
4150 if string(all) != "hi" {
4151 t.Errorf("Body = %q; want hi", all)
4152 }
4153 }
4154
4155
4156 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4157 for _, code := range []int{StatusNotModified, StatusNoContent} {
4158 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4159 if r.URL.Path == "/header" {
4160 w.Header().Set("Content-Length", "123")
4161 }
4162 w.WriteHeader(code)
4163 if r.URL.Path == "/more" {
4164 w.Write([]byte("stuff"))
4165 }
4166 }))
4167 for _, req := range []string{
4168 "GET / HTTP/1.0",
4169 "GET /header HTTP/1.0",
4170 "GET /more HTTP/1.0",
4171 "GET / HTTP/1.1\nHost: foo",
4172 "GET /header HTTP/1.1\nHost: foo",
4173 "GET /more HTTP/1.1\nHost: foo",
4174 } {
4175 got := ht.rawResponse(req)
4176 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4177 if !strings.Contains(got, wantStatus) {
4178 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4179 } else if strings.Contains(got, "Content-Length") {
4180 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4181 } else if strings.Contains(got, "stuff") {
4182 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4183 }
4184 }
4185 }
4186 }
4187
4188 func TestContentTypeOkayOn204(t *testing.T) {
4189 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4190 w.Header().Set("Content-Length", "123")
4191 w.Header().Set("Content-Type", "foo/bar")
4192 w.WriteHeader(204)
4193 }))
4194 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4195 if !strings.Contains(got, "Content-Type: foo/bar") {
4196 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4197 }
4198 if strings.Contains(got, "Content-Length: 123") {
4199 t.Errorf("Response = %q; don't want a Content-Length", got)
4200 }
4201 }
4202
4203
4204
4205
4206
4207
4208
4209 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4210 run(t, testTransportAndServerSharedBodyRace, testNotParallel)
4211 }
4212 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4213
4214
4215
4216
4217 runTimeSensitiveTest(t, []time.Duration{
4218 1 * time.Millisecond,
4219 5 * time.Millisecond,
4220 10 * time.Millisecond,
4221 50 * time.Millisecond,
4222 100 * time.Millisecond,
4223 500 * time.Millisecond,
4224 time.Second,
4225 5 * time.Second,
4226 }, func(t *testing.T, timeout time.Duration) error {
4227 SetRSTAvoidanceDelay(t, timeout)
4228 t.Logf("set RST avoidance delay to %v", timeout)
4229
4230 const bodySize = 1 << 20
4231
4232 var wg sync.WaitGroup
4233 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4234
4235
4236
4237
4238
4239
4240
4241
4242 wg.Add(1)
4243 defer wg.Done()
4244
4245 n, err := io.CopyN(rw, req.Body, bodySize)
4246 t.Logf("backend CopyN: %v, %v", n, err)
4247 <-req.Context().Done()
4248 }))
4249
4250
4251 defer func() {
4252 wg.Wait()
4253 backend.close()
4254 }()
4255
4256 var proxy *clientServerTest
4257 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4258 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4259 req2.ContentLength = bodySize
4260 cancel := make(chan struct{})
4261 req2.Cancel = cancel
4262
4263 bresp, err := proxy.c.Do(req2)
4264 if err != nil {
4265 t.Errorf("Proxy outbound request: %v", err)
4266 return
4267 }
4268 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4269 if err != nil {
4270 t.Errorf("Proxy copy error: %v", err)
4271 return
4272 }
4273 t.Cleanup(func() { bresp.Body.Close() })
4274
4275
4276
4277
4278
4279
4280 if mode == http2Mode {
4281 close(cancel)
4282 } else {
4283 proxy.c.Transport.(*Transport).CancelRequest(req2)
4284 }
4285 rw.Write([]byte("OK"))
4286 }))
4287 defer proxy.close()
4288
4289 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4290 res, err := proxy.c.Do(req)
4291 if err != nil {
4292 return fmt.Errorf("original request: %v", err)
4293 }
4294 res.Body.Close()
4295 return nil
4296 })
4297 }
4298
4299
4300
4301
4302 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4303 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4304 }
4305 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4306 if testing.Short() {
4307 t.Skip("skipping in -short mode")
4308 }
4309
4310 readErrCh := make(chan error, 1)
4311 errCh := make(chan error, 2)
4312
4313 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4314 go func(body io.Reader) {
4315 _, err := body.Read(make([]byte, 100))
4316 readErrCh <- err
4317 }(req.Body)
4318 time.Sleep(500 * time.Millisecond)
4319 })).ts
4320
4321 closeConn := make(chan bool)
4322 defer close(closeConn)
4323 go func() {
4324 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4325 if err != nil {
4326 errCh <- err
4327 return
4328 }
4329 defer conn.Close()
4330 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4331 if err != nil {
4332 errCh <- err
4333 return
4334 }
4335
4336
4337 <-closeConn
4338 }()
4339 select {
4340 case err := <-readErrCh:
4341 if err == nil {
4342 t.Error("Read was nil. Expected error.")
4343 }
4344 case err := <-errCh:
4345 t.Error(err)
4346 }
4347 }
4348
4349
4350 func TestResponseWriterWriteString(t *testing.T) {
4351 okc := make(chan bool, 1)
4352 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4353 _, ok := w.(io.StringWriter)
4354 okc <- ok
4355 }))
4356 ht.rawResponse("GET / HTTP/1.0")
4357 select {
4358 case ok := <-okc:
4359 if !ok {
4360 t.Error("ResponseWriter did not implement io.StringWriter")
4361 }
4362 default:
4363 t.Error("handler was never called")
4364 }
4365 }
4366
4367 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4368 func testServerConnState(t *testing.T, mode testMode) {
4369 handler := map[string]func(w ResponseWriter, r *Request){
4370 "/": func(w ResponseWriter, r *Request) {
4371 fmt.Fprintf(w, "Hello.")
4372 },
4373 "/close": func(w ResponseWriter, r *Request) {
4374 w.Header().Set("Connection", "close")
4375 fmt.Fprintf(w, "Hello.")
4376 },
4377 "/hijack": func(w ResponseWriter, r *Request) {
4378 c, _, _ := w.(Hijacker).Hijack()
4379 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4380 c.Close()
4381 },
4382 "/hijack-panic": func(w ResponseWriter, r *Request) {
4383 c, _, _ := w.(Hijacker).Hijack()
4384 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4385 c.Close()
4386 panic("intentional panic")
4387 },
4388 }
4389
4390
4391 type stateLog struct {
4392 active net.Conn
4393 got []ConnState
4394 want []ConnState
4395 complete chan<- struct{}
4396 }
4397 activeLog := make(chan *stateLog, 1)
4398
4399
4400
4401
4402 wantLog := func(doRequests func(), want ...ConnState) {
4403 t.Helper()
4404 complete := make(chan struct{})
4405 activeLog <- &stateLog{want: want, complete: complete}
4406
4407 doRequests()
4408
4409 <-complete
4410 sl := <-activeLog
4411 if !slices.Equal(sl.got, sl.want) {
4412 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4413 }
4414
4415
4416
4417 }
4418
4419 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4420 handler[r.URL.Path](w, r)
4421 }), func(ts *httptest.Server) {
4422 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4423 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4424 if c == nil {
4425 t.Errorf("nil conn seen in state %s", state)
4426 return
4427 }
4428 sl := <-activeLog
4429 if sl.active == nil && state == StateNew {
4430 sl.active = c
4431 } else if sl.active != c {
4432 t.Errorf("unexpected conn in state %s", state)
4433 activeLog <- sl
4434 return
4435 }
4436 sl.got = append(sl.got, state)
4437 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !slices.Equal(sl.got, sl.want[:len(sl.got)])) {
4438 close(sl.complete)
4439 sl.complete = nil
4440 }
4441 activeLog <- sl
4442 }
4443 }).ts
4444 defer func() {
4445 activeLog <- &stateLog{}
4446 ts.Close()
4447 }()
4448
4449 c := ts.Client()
4450
4451 mustGet := func(url string, headers ...string) {
4452 t.Helper()
4453 req, err := NewRequest("GET", url, nil)
4454 if err != nil {
4455 t.Fatal(err)
4456 }
4457 for len(headers) > 0 {
4458 req.Header.Add(headers[0], headers[1])
4459 headers = headers[2:]
4460 }
4461 res, err := c.Do(req)
4462 if err != nil {
4463 t.Errorf("Error fetching %s: %v", url, err)
4464 return
4465 }
4466 _, err = io.ReadAll(res.Body)
4467 defer res.Body.Close()
4468 if err != nil {
4469 t.Errorf("Error reading %s: %v", url, err)
4470 }
4471 }
4472
4473 wantLog(func() {
4474 mustGet(ts.URL + "/")
4475 mustGet(ts.URL + "/close")
4476 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4477
4478 wantLog(func() {
4479 mustGet(ts.URL + "/")
4480 mustGet(ts.URL+"/", "Connection", "close")
4481 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4482
4483 wantLog(func() {
4484 mustGet(ts.URL + "/hijack")
4485 }, StateNew, StateActive, StateHijacked)
4486
4487 wantLog(func() {
4488 mustGet(ts.URL + "/hijack-panic")
4489 }, StateNew, StateActive, StateHijacked)
4490
4491 wantLog(func() {
4492 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4493 if err != nil {
4494 t.Fatal(err)
4495 }
4496 c.Close()
4497 }, StateNew, StateClosed)
4498
4499 wantLog(func() {
4500 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4501 if err != nil {
4502 t.Fatal(err)
4503 }
4504 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4505 t.Fatal(err)
4506 }
4507 c.Read(make([]byte, 1))
4508 c.Close()
4509 }, StateNew, StateActive, StateClosed)
4510
4511 wantLog(func() {
4512 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4513 if err != nil {
4514 t.Fatal(err)
4515 }
4516 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4517 t.Fatal(err)
4518 }
4519 res, err := ReadResponse(bufio.NewReader(c), nil)
4520 if err != nil {
4521 t.Fatal(err)
4522 }
4523 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4524 t.Fatal(err)
4525 }
4526 c.Close()
4527 }, StateNew, StateActive, StateIdle, StateClosed)
4528 }
4529
4530 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4531 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4532 }
4533 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4534 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4535 }), func(ts *httptest.Server) {
4536 ts.Config.SetKeepAlivesEnabled(false)
4537 }).ts
4538 res, err := ts.Client().Get(ts.URL)
4539 if err != nil {
4540 t.Fatal(err)
4541 }
4542 defer res.Body.Close()
4543 if !res.Close {
4544 t.Errorf("Body.Close == false; want true")
4545 }
4546 }
4547
4548
4549 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4550 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4551 var n int32
4552 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4553 atomic.AddInt32(&n, 1)
4554 }), optQuietLog)
4555 var wg sync.WaitGroup
4556 const reqs = 20
4557 for i := 0; i < reqs; i++ {
4558 wg.Add(1)
4559 go func() {
4560 defer wg.Done()
4561 res, err := cst.c.Get(cst.ts.URL)
4562 if err != nil {
4563
4564
4565 time.Sleep(10 * time.Millisecond)
4566 res, err = cst.c.Get(cst.ts.URL)
4567 if err != nil {
4568 t.Error(err)
4569 return
4570 }
4571 }
4572 defer res.Body.Close()
4573 _, err = io.Copy(io.Discard, res.Body)
4574 if err != nil {
4575 t.Error(err)
4576 return
4577 }
4578 }()
4579 }
4580 wg.Wait()
4581 if got := atomic.LoadInt32(&n); got != reqs {
4582 t.Errorf("handler ran %d times; want %d", got, reqs)
4583 }
4584 }
4585
4586 func TestServerConnStateNew(t *testing.T) {
4587 sawNew := false
4588 srv := &Server{
4589 ConnState: func(c net.Conn, state ConnState) {
4590 if state == StateNew {
4591 sawNew = true
4592 }
4593 },
4594 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4595 }
4596 srv.Serve(&oneConnListener{
4597 conn: &rwTestConn{
4598 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4599 Writer: io.Discard,
4600 },
4601 })
4602 if !sawNew {
4603 t.Error("StateNew not seen")
4604 }
4605 }
4606
4607 type closeWriteTestConn struct {
4608 rwTestConn
4609 didCloseWrite bool
4610 }
4611
4612 func (c *closeWriteTestConn) CloseWrite() error {
4613 c.didCloseWrite = true
4614 return nil
4615 }
4616
4617 func TestCloseWrite(t *testing.T) {
4618 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4619
4620 var srv Server
4621 var testConn closeWriteTestConn
4622 c := ExportServerNewConn(&srv, &testConn)
4623 ExportCloseWriteAndWait(c)
4624 if !testConn.didCloseWrite {
4625 t.Error("didn't see CloseWrite call")
4626 }
4627 }
4628
4629
4630
4631
4632
4633
4634
4635
4636 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4637 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4638 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4639 io.WriteString(w, "Hello, ")
4640 w.(Flusher).Flush()
4641 conn, buf, _ := w.(Hijacker).Hijack()
4642 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4643 if err := buf.Flush(); err != nil {
4644 t.Error(err)
4645 }
4646 if err := conn.Close(); err != nil {
4647 t.Error(err)
4648 }
4649 })).ts
4650 res, err := Get(ts.URL)
4651 if err != nil {
4652 t.Fatal(err)
4653 }
4654 defer res.Body.Close()
4655 all, err := io.ReadAll(res.Body)
4656 if err != nil {
4657 t.Fatal(err)
4658 }
4659 if want := "Hello, world!"; string(all) != want {
4660 t.Errorf("Got %q; want %q", all, want)
4661 }
4662 }
4663
4664
4665
4666
4667
4668
4669
4670 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4671 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4672 }
4673 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4674 if testing.Short() {
4675 t.Skip("skipping in -short mode")
4676 }
4677 const numReq = 3
4678 addrc := make(chan string, numReq)
4679 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4680 addrc <- r.RemoteAddr
4681 time.Sleep(500 * time.Millisecond)
4682 w.(Flusher).Flush()
4683 }), func(ts *httptest.Server) {
4684 ts.Config.WriteTimeout = 250 * time.Millisecond
4685 }).ts
4686
4687 errc := make(chan error, numReq)
4688 go func() {
4689 defer close(errc)
4690 for i := 0; i < numReq; i++ {
4691 res, err := Get(ts.URL)
4692 if res != nil {
4693 res.Body.Close()
4694 }
4695 errc <- err
4696 }
4697 }()
4698
4699 addrSeen := map[string]bool{}
4700 numOkay := 0
4701 for {
4702 select {
4703 case v := <-addrc:
4704 addrSeen[v] = true
4705 case err, ok := <-errc:
4706 if !ok {
4707 if len(addrSeen) != numReq {
4708 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4709 }
4710 if numOkay != 0 {
4711 t.Errorf("got %d successful client requests; want 0", numOkay)
4712 }
4713 return
4714 }
4715 if err == nil {
4716 numOkay++
4717 }
4718 }
4719 }
4720 }
4721
4722
4723
4724 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4725 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4726 }
4727 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4728 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4729 w.Header().Set("Transfer-Encoding", "foo")
4730 io.WriteString(w, "<html>")
4731 })).ts
4732 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4733 if err != nil {
4734 t.Fatalf("Dial: %v", err)
4735 }
4736 defer c.Close()
4737 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4738 t.Fatal(err)
4739 }
4740 bs := bufio.NewScanner(c)
4741 var got strings.Builder
4742 for bs.Scan() {
4743 if strings.TrimSpace(bs.Text()) == "" {
4744 break
4745 }
4746 got.WriteString(bs.Text())
4747 got.WriteByte('\n')
4748 }
4749 if err := bs.Err(); err != nil {
4750 t.Fatal(err)
4751 }
4752 if strings.Contains(got.String(), "Content-Length") {
4753 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4754 }
4755 if strings.Contains(got.String(), "Content-Type") {
4756 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4757 }
4758 }
4759
4760
4761
4762 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4763 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4764 "\r\n\r\n" +
4765 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4766 var buf bytes.Buffer
4767 conn := &rwTestConn{
4768 Reader: bytes.NewReader(req),
4769 Writer: &buf,
4770 closec: make(chan bool, 1),
4771 }
4772 ln := &oneConnListener{conn: conn}
4773 numReq := 0
4774 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4775 numReq++
4776 }))
4777 <-conn.closec
4778 if numReq != 2 {
4779 t.Errorf("num requests = %d; want 2", numReq)
4780 t.Logf("Res: %s", buf.Bytes())
4781 }
4782 }
4783
4784 func TestIssue13893_Expect100(t *testing.T) {
4785
4786 req := reqBytes(`PUT /readbody HTTP/1.1
4787 User-Agent: PycURL/7.22.0
4788 Host: 127.0.0.1:9000
4789 Accept: */*
4790 Expect: 100-continue
4791 Content-Length: 10
4792
4793 HelloWorld
4794
4795 `)
4796 var buf bytes.Buffer
4797 conn := &rwTestConn{
4798 Reader: bytes.NewReader(req),
4799 Writer: &buf,
4800 closec: make(chan bool, 1),
4801 }
4802 ln := &oneConnListener{conn: conn}
4803 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4804 if _, ok := r.Header["Expect"]; !ok {
4805 t.Error("Expect header should not be filtered out")
4806 }
4807 }))
4808 <-conn.closec
4809 }
4810
4811 func TestIssue11549_Expect100(t *testing.T) {
4812 req := reqBytes(`PUT /readbody HTTP/1.1
4813 User-Agent: PycURL/7.22.0
4814 Host: 127.0.0.1:9000
4815 Accept: */*
4816 Expect: 100-continue
4817 Content-Length: 10
4818
4819 HelloWorldPUT /noreadbody HTTP/1.1
4820 User-Agent: PycURL/7.22.0
4821 Host: 127.0.0.1:9000
4822 Accept: */*
4823 Expect: 100-continue
4824 Content-Length: 10
4825
4826 GET /should-be-ignored HTTP/1.1
4827 Host: foo
4828
4829 `)
4830 var buf strings.Builder
4831 conn := &rwTestConn{
4832 Reader: bytes.NewReader(req),
4833 Writer: &buf,
4834 closec: make(chan bool, 1),
4835 }
4836 ln := &oneConnListener{conn: conn}
4837 numReq := 0
4838 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4839 numReq++
4840 if r.URL.Path == "/readbody" {
4841 io.ReadAll(r.Body)
4842 }
4843 io.WriteString(w, "Hello world!")
4844 }))
4845 <-conn.closec
4846 if numReq != 2 {
4847 t.Errorf("num requests = %d; want 2", numReq)
4848 }
4849 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4850 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4851 }
4852 }
4853
4854
4855
4856 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4857 setParallel(t)
4858 conn := newTestConn()
4859 conn.readBuf.WriteString(
4860 "POST / HTTP/1.1\r\n" +
4861 "Host: test\r\n" +
4862 "Content-Length: 9999999999\r\n" +
4863 "\r\n" + strings.Repeat("a", 1<<20))
4864
4865 ls := &oneConnListener{conn}
4866 var inHandlerLen int
4867 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4868 inHandlerLen = conn.readBuf.Len()
4869 rw.WriteHeader(404)
4870 }))
4871 <-conn.closec
4872 afterHandlerLen := conn.readBuf.Len()
4873
4874 if afterHandlerLen != inHandlerLen {
4875 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4876 }
4877 }
4878
4879 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4880 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4881 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4882 r.Body = nil
4883 fmt.Fprintf(w, "%v", r.RemoteAddr)
4884 }))
4885 get := func() string {
4886 res, err := cst.c.Get(cst.ts.URL)
4887 if err != nil {
4888 t.Fatal(err)
4889 }
4890 defer res.Body.Close()
4891 slurp, err := io.ReadAll(res.Body)
4892 if err != nil {
4893 t.Fatal(err)
4894 }
4895 return string(slurp)
4896 }
4897 a, b := get(), get()
4898 if a != b {
4899 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4900 }
4901 }
4902
4903
4904
4905 func TestServerValidatesHostHeader(t *testing.T) {
4906 tests := []struct {
4907 proto string
4908 host string
4909 want int
4910 }{
4911 {"HTTP/0.9", "", 505},
4912
4913 {"HTTP/1.1", "", 400},
4914 {"HTTP/1.1", "Host: \r\n", 200},
4915 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4916 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4917 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4918 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4919 {"HTTP/1.1", "Host: ::1\r\n", 200},
4920 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4921 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4922 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4923 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4924 {"HTTP/1.1", "Host: \x06\r\n", 400},
4925 {"HTTP/1.1", "Host: \xff\r\n", 400},
4926 {"HTTP/1.1", "Host: {\r\n", 400},
4927 {"HTTP/1.1", "Host: }\r\n", 400},
4928 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4929
4930
4931
4932 {"HTTP/1.0", "", 200},
4933 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4934 {"HTTP/1.0", "Host: \xff\r\n", 400},
4935
4936
4937 {"PRI * HTTP/2.0", "", 200},
4938
4939
4940 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4941
4942
4943 {"PRI / HTTP/2.0", "", 505},
4944 {"GET / HTTP/2.0", "", 505},
4945 {"GET / HTTP/3.0", "", 505},
4946 }
4947 for _, tt := range tests {
4948 conn := newTestConn()
4949 methodTarget := "GET / "
4950 if !strings.HasPrefix(tt.proto, "HTTP/") {
4951 methodTarget = ""
4952 }
4953 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4954
4955 ln := &oneConnListener{conn}
4956 srv := Server{
4957 ErrorLog: quietLog,
4958 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4959 }
4960 go srv.Serve(ln)
4961 <-conn.closec
4962 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4963 if err != nil {
4964 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4965 continue
4966 }
4967 if res.StatusCode != tt.want {
4968 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4969 }
4970 }
4971 }
4972
4973 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4974 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
4975 }
4976 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
4977 const upgradeResponse = "upgrade here"
4978 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4979 conn, br, err := w.(Hijacker).Hijack()
4980 if err != nil {
4981 t.Error(err)
4982 return
4983 }
4984 defer conn.Close()
4985 if r.Method != "PRI" || r.RequestURI != "*" {
4986 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4987 return
4988 }
4989 if !r.Close {
4990 t.Errorf("Request.Close = true; want false")
4991 }
4992 const want = "SM\r\n\r\n"
4993 buf := make([]byte, len(want))
4994 n, err := io.ReadFull(br, buf)
4995 if err != nil || string(buf[:n]) != want {
4996 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4997 return
4998 }
4999 io.WriteString(conn, upgradeResponse)
5000 })).ts
5001
5002 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5003 if err != nil {
5004 t.Fatalf("Dial: %v", err)
5005 }
5006 defer c.Close()
5007 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
5008 slurp, err := io.ReadAll(c)
5009 if err != nil {
5010 t.Fatal(err)
5011 }
5012 if string(slurp) != upgradeResponse {
5013 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
5014 }
5015 }
5016
5017
5018
5019 func TestServerValidatesHeaders(t *testing.T) {
5020 setParallel(t)
5021 tests := []struct {
5022 header string
5023 want int
5024 }{
5025 {"", 200},
5026 {"Foo: bar\r\n", 200},
5027 {"X-Foo: bar\r\n", 200},
5028 {"Foo: a space\r\n", 200},
5029
5030 {"A space: foo\r\n", 400},
5031 {"foo\xffbar: foo\r\n", 400},
5032 {"foo\x00bar: foo\r\n", 400},
5033 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
5034
5035
5036 {"Foo : bar\r\n", 400},
5037 {"Foo\t: bar\r\n", 400},
5038
5039
5040
5041 {": empty key\r\n", 400},
5042
5043
5044
5045
5046 {"Content-Length: notdigits\r\n", 400},
5047 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
5048
5049 {"foo: foo foo\r\n", 200},
5050 {"foo: foo\tfoo\r\n", 200},
5051 {"foo: foo\x00foo\r\n", 400},
5052 {"foo: foo\x7ffoo\r\n", 400},
5053 {"foo: foo\xfffoo\r\n", 200},
5054 }
5055 for _, tt := range tests {
5056 conn := newTestConn()
5057 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
5058
5059 ln := &oneConnListener{conn}
5060 srv := Server{
5061 ErrorLog: quietLog,
5062 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5063 }
5064 go srv.Serve(ln)
5065 <-conn.closec
5066 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5067 if err != nil {
5068 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
5069 continue
5070 }
5071 if res.StatusCode != tt.want {
5072 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
5073 }
5074 }
5075 }
5076
5077 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
5078 run(t, testServerRequestContextCancel_ServeHTTPDone)
5079 }
5080 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5081 ctxc := make(chan context.Context, 1)
5082 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5083 ctx := r.Context()
5084 select {
5085 case <-ctx.Done():
5086 t.Error("should not be Done in ServeHTTP")
5087 default:
5088 }
5089 ctxc <- ctx
5090 }))
5091 res, err := cst.c.Get(cst.ts.URL)
5092 if err != nil {
5093 t.Fatal(err)
5094 }
5095 res.Body.Close()
5096 ctx := <-ctxc
5097 select {
5098 case <-ctx.Done():
5099 default:
5100 t.Error("context should be done after ServeHTTP completes")
5101 }
5102 }
5103
5104
5105
5106
5107
5108 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5109 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5110 }
5111 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5112 inHandler := make(chan struct{})
5113 handlerDone := make(chan struct{})
5114 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5115 close(inHandler)
5116 <-r.Context().Done()
5117 close(handlerDone)
5118 })).ts
5119 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5120 if err != nil {
5121 t.Fatal(err)
5122 }
5123 defer c.Close()
5124 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5125 <-inHandler
5126 c.Close()
5127 <-handlerDone
5128 }
5129
5130 func TestServerContext_ServerContextKey(t *testing.T) {
5131 run(t, testServerContext_ServerContextKey)
5132 }
5133 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5134 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5135 ctx := r.Context()
5136 got := ctx.Value(ServerContextKey)
5137 if _, ok := got.(*Server); !ok {
5138 t.Errorf("context value = %T; want *http.Server", got)
5139 }
5140 }))
5141 res, err := cst.c.Get(cst.ts.URL)
5142 if err != nil {
5143 t.Fatal(err)
5144 }
5145 res.Body.Close()
5146 }
5147
5148 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5149 run(t, testServerContext_LocalAddrContextKey)
5150 }
5151 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5152 ch := make(chan any, 1)
5153 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5154 ch <- r.Context().Value(LocalAddrContextKey)
5155 }))
5156 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5157 t.Fatal(err)
5158 }
5159
5160 host := cst.ts.Listener.Addr().String()
5161 got := <-ch
5162 if addr, ok := got.(net.Addr); !ok {
5163 t.Errorf("local addr value = %T; want net.Addr", got)
5164 } else if fmt.Sprint(addr) != host {
5165 t.Errorf("local addr = %v; want %v", addr, host)
5166 }
5167 }
5168
5169
5170 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5171 setParallel(t)
5172 defer afterTest(t)
5173 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5174 w.Header().Set("Transfer-Encoding", "chunked")
5175 w.Write([]byte("hello"))
5176 }))
5177 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5178 const hdr = "Transfer-Encoding: chunked"
5179 if n := strings.Count(resp, hdr); n != 1 {
5180 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5181 }
5182 }
5183
5184
5185 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5186 setParallel(t)
5187 defer afterTest(t)
5188 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5189 w.Header().Set("Transfer-Encoding", "gzip")
5190 gz := gzip.NewWriter(w)
5191 gz.Write([]byte("hello"))
5192 gz.Close()
5193 }))
5194 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5195 for _, v := range []string{"gzip", "chunked"} {
5196 hdr := "Transfer-Encoding: " + v
5197 if n := strings.Count(resp, hdr); n != 1 {
5198 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5199 }
5200 }
5201 }
5202
5203 func BenchmarkClientServer(b *testing.B) {
5204 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5205 }
5206 func benchmarkClientServer(b *testing.B, mode testMode) {
5207 b.ReportAllocs()
5208 b.StopTimer()
5209 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5210 fmt.Fprintf(rw, "Hello world.\n")
5211 })).ts
5212 b.StartTimer()
5213
5214 c := ts.Client()
5215 for i := 0; i < b.N; i++ {
5216 res, err := c.Get(ts.URL)
5217 if err != nil {
5218 b.Fatal("Get:", err)
5219 }
5220 all, err := io.ReadAll(res.Body)
5221 res.Body.Close()
5222 if err != nil {
5223 b.Fatal("ReadAll:", err)
5224 }
5225 body := string(all)
5226 if body != "Hello world.\n" {
5227 b.Fatal("Got body:", body)
5228 }
5229 }
5230
5231 b.StopTimer()
5232 }
5233
5234 func BenchmarkClientServerParallel(b *testing.B) {
5235 for _, parallelism := range []int{4, 64} {
5236 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5237 run(b, func(b *testing.B, mode testMode) {
5238 benchmarkClientServerParallel(b, parallelism, mode)
5239 }, []testMode{http1Mode, https1Mode, http2Mode})
5240 })
5241 }
5242 }
5243
5244 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5245 b.ReportAllocs()
5246 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5247 fmt.Fprintf(rw, "Hello world.\n")
5248 })).ts
5249 b.ResetTimer()
5250 b.SetParallelism(parallelism)
5251 b.RunParallel(func(pb *testing.PB) {
5252 c := ts.Client()
5253 for pb.Next() {
5254 res, err := c.Get(ts.URL)
5255 if err != nil {
5256 b.Logf("Get: %v", err)
5257 continue
5258 }
5259 all, err := io.ReadAll(res.Body)
5260 res.Body.Close()
5261 if err != nil {
5262 b.Logf("ReadAll: %v", err)
5263 continue
5264 }
5265 body := string(all)
5266 if body != "Hello world.\n" {
5267 panic("Got body: " + body)
5268 }
5269 }
5270 })
5271 }
5272
5273
5274
5275
5276
5277
5278
5279
5280
5281
5282 func BenchmarkServer(b *testing.B) {
5283 b.ReportAllocs()
5284
5285 if url := os.Getenv("GO_TEST_BENCH_SERVER_URL"); url != "" {
5286 n, err := strconv.Atoi(os.Getenv("GO_TEST_BENCH_CLIENT_N"))
5287 if err != nil {
5288 panic(err)
5289 }
5290 for i := 0; i < n; i++ {
5291 res, err := Get(url)
5292 if err != nil {
5293 log.Panicf("Get: %v", err)
5294 }
5295 all, err := io.ReadAll(res.Body)
5296 res.Body.Close()
5297 if err != nil {
5298 log.Panicf("ReadAll: %v", err)
5299 }
5300 body := string(all)
5301 if body != "Hello world.\n" {
5302 log.Panicf("Got body: %q", body)
5303 }
5304 }
5305 os.Exit(0)
5306 return
5307 }
5308
5309 var res = []byte("Hello world.\n")
5310 b.StopTimer()
5311 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5312 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5313 rw.Write(res)
5314 }))
5315 defer ts.Close()
5316 b.StartTimer()
5317
5318 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5319 cmd.Env = append([]string{
5320 fmt.Sprintf("GO_TEST_BENCH_CLIENT_N=%d", b.N),
5321 fmt.Sprintf("GO_TEST_BENCH_SERVER_URL=%s", ts.URL),
5322 }, os.Environ()...)
5323 out, err := cmd.CombinedOutput()
5324 if err != nil {
5325 b.Errorf("Test failure: %v, with output: %s", err, out)
5326 }
5327 }
5328
5329
5330 func getNoBody(urlStr string) (*Response, error) {
5331 res, err := Get(urlStr)
5332 if err != nil {
5333 return nil, err
5334 }
5335 res.Body.Close()
5336 return res, nil
5337 }
5338
5339
5340
5341 func BenchmarkClient(b *testing.B) {
5342 var data = []byte("Hello world.\n")
5343
5344 url := startClientBenchmarkServer(b, HandlerFunc(func(w ResponseWriter, _ *Request) {
5345 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5346 w.Write(data)
5347 }))
5348
5349
5350 b.StartTimer()
5351 for i := 0; i < b.N; i++ {
5352 res, err := Get(url)
5353 if err != nil {
5354 b.Fatalf("Get: %v", err)
5355 }
5356 body, err := io.ReadAll(res.Body)
5357 res.Body.Close()
5358 if err != nil {
5359 b.Fatalf("ReadAll: %v", err)
5360 }
5361 if !bytes.Equal(body, data) {
5362 b.Fatalf("Got body: %q", body)
5363 }
5364 }
5365 b.StopTimer()
5366 }
5367
5368 func startClientBenchmarkServer(b *testing.B, handler Handler) string {
5369 b.ReportAllocs()
5370 b.StopTimer()
5371
5372 if server := os.Getenv("GO_TEST_BENCH_SERVER"); server != "" {
5373
5374 port := os.Getenv("GO_TEST_BENCH_SERVER_PORT")
5375 if port == "" {
5376 port = "0"
5377 }
5378 ln, err := net.Listen("tcp", "localhost:"+port)
5379 if err != nil {
5380 log.Fatal(err)
5381 }
5382 fmt.Println(ln.Addr().String())
5383
5384 HandleFunc("/", func(w ResponseWriter, r *Request) {
5385 r.ParseForm()
5386 if r.Form.Get("stop") != "" {
5387 os.Exit(0)
5388 }
5389 handler.ServeHTTP(w, r)
5390 })
5391 var srv Server
5392 log.Fatal(srv.Serve(ln))
5393 }
5394
5395
5396 ctx, cancel := context.WithCancel(context.Background())
5397 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^"+b.Name()+"$")
5398 cmd.Env = append(cmd.Environ(), "GO_TEST_BENCH_SERVER=yes")
5399 cmd.Stderr = os.Stderr
5400 stdout, err := cmd.StdoutPipe()
5401 if err != nil {
5402 b.Fatal(err)
5403 }
5404 if err := cmd.Start(); err != nil {
5405 b.Fatalf("subprocess failed to start: %v", err)
5406 }
5407
5408 done := make(chan error, 1)
5409 go func() {
5410 done <- cmd.Wait()
5411 close(done)
5412 }()
5413
5414
5415
5416 bs := bufio.NewScanner(stdout)
5417 if !bs.Scan() {
5418 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5419 }
5420 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5421 if _, err := getNoBody(url); err != nil {
5422 b.Fatalf("initial probe of child process failed: %v", err)
5423 }
5424
5425
5426 b.Cleanup(func() {
5427 getNoBody(url + "?stop=yes")
5428 if err := <-done; err != nil {
5429 b.Fatalf("subprocess failed: %v", err)
5430 }
5431
5432 cancel()
5433 <-done
5434
5435 afterTest(b)
5436 })
5437
5438 return url
5439 }
5440
5441 func BenchmarkClientGzip(b *testing.B) {
5442 const responseSize = 1024 * 1024
5443
5444 var buf bytes.Buffer
5445 gz := gzip.NewWriter(&buf)
5446 if _, err := io.CopyN(gz, crand.Reader, responseSize); err != nil {
5447 b.Fatal(err)
5448 }
5449 gz.Close()
5450
5451 data := buf.Bytes()
5452
5453 url := startClientBenchmarkServer(b, HandlerFunc(func(w ResponseWriter, _ *Request) {
5454 w.Header().Set("Content-Encoding", "gzip")
5455 w.Write(data)
5456 }))
5457
5458
5459 b.StartTimer()
5460 for i := 0; i < b.N; i++ {
5461 res, err := Get(url)
5462 if err != nil {
5463 b.Fatalf("Get: %v", err)
5464 }
5465 n, err := io.Copy(io.Discard, res.Body)
5466 res.Body.Close()
5467 if err != nil {
5468 b.Fatalf("ReadAll: %v", err)
5469 }
5470 if n != responseSize {
5471 b.Fatalf("ReadAll: expected %d bytes, got %d", responseSize, n)
5472 }
5473 }
5474 b.StopTimer()
5475 }
5476
5477 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5478 b.ReportAllocs()
5479 req := reqBytes(`GET / HTTP/1.0
5480 Host: golang.org
5481 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5482 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5483 Accept-Encoding: gzip,deflate,sdch
5484 Accept-Language: en-US,en;q=0.8
5485 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5486 `)
5487 res := []byte("Hello world!\n")
5488
5489 conn := newTestConn()
5490 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5491 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5492 rw.Write(res)
5493 })
5494 ln := new(oneConnListener)
5495 for i := 0; i < b.N; i++ {
5496 conn.readBuf.Reset()
5497 conn.writeBuf.Reset()
5498 conn.readBuf.Write(req)
5499 ln.conn = conn
5500 Serve(ln, handler)
5501 <-conn.closec
5502 }
5503 }
5504
5505
5506 type repeatReader struct {
5507 content []byte
5508 count int
5509 off int
5510 }
5511
5512 func (r *repeatReader) Read(p []byte) (n int, err error) {
5513 if r.count <= 0 {
5514 return 0, io.EOF
5515 }
5516 n = copy(p, r.content[r.off:])
5517 r.off += n
5518 if r.off == len(r.content) {
5519 r.count--
5520 r.off = 0
5521 }
5522 return
5523 }
5524
5525 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5526 b.ReportAllocs()
5527
5528 req := reqBytes(`GET / HTTP/1.1
5529 Host: golang.org
5530 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5531 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5532 Accept-Encoding: gzip,deflate,sdch
5533 Accept-Language: en-US,en;q=0.8
5534 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5535 `)
5536 res := []byte("Hello world!\n")
5537
5538 conn := &rwTestConn{
5539 Reader: &repeatReader{content: req, count: b.N},
5540 Writer: io.Discard,
5541 closec: make(chan bool, 1),
5542 }
5543 handled := 0
5544 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5545 handled++
5546 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5547 rw.Write(res)
5548 })
5549 ln := &oneConnListener{conn: conn}
5550 go Serve(ln, handler)
5551 <-conn.closec
5552 if b.N != handled {
5553 b.Errorf("b.N=%d but handled %d", b.N, handled)
5554 }
5555 }
5556
5557
5558
5559 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5560 b.ReportAllocs()
5561
5562 req := reqBytes(`GET / HTTP/1.1
5563 Host: golang.org
5564 `)
5565 res := []byte("Hello world!\n")
5566
5567 conn := &rwTestConn{
5568 Reader: &repeatReader{content: req, count: b.N},
5569 Writer: io.Discard,
5570 closec: make(chan bool, 1),
5571 }
5572 handled := 0
5573 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5574 handled++
5575 rw.Write(res)
5576 })
5577 ln := &oneConnListener{conn: conn}
5578 go Serve(ln, handler)
5579 <-conn.closec
5580 if b.N != handled {
5581 b.Errorf("b.N=%d but handled %d", b.N, handled)
5582 }
5583 }
5584
5585 const someResponse = "<html>some response</html>"
5586
5587
5588 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5589
5590
5591 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5592 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5593 w.Header().Set("Content-Type", "text/html")
5594 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5595 w.Write(response)
5596 }))
5597 }
5598
5599
5600 func BenchmarkServerHandlerNoLen(b *testing.B) {
5601 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5602 w.Header().Set("Content-Type", "text/html")
5603 w.Write(response)
5604 }))
5605 }
5606
5607
5608 func BenchmarkServerHandlerNoType(b *testing.B) {
5609 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5610 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5611 w.Write(response)
5612 }))
5613 }
5614
5615
5616 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5617 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5618 w.Write(response)
5619 }))
5620 }
5621
5622 func benchmarkHandler(b *testing.B, h Handler) {
5623 b.ReportAllocs()
5624 req := reqBytes(`GET / HTTP/1.1
5625 Host: golang.org
5626 `)
5627 conn := &rwTestConn{
5628 Reader: &repeatReader{content: req, count: b.N},
5629 Writer: io.Discard,
5630 closec: make(chan bool, 1),
5631 }
5632 handled := 0
5633 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5634 handled++
5635 h.ServeHTTP(rw, r)
5636 })
5637 ln := &oneConnListener{conn: conn}
5638 go Serve(ln, handler)
5639 <-conn.closec
5640 if b.N != handled {
5641 b.Errorf("b.N=%d but handled %d", b.N, handled)
5642 }
5643 }
5644
5645 func BenchmarkServerHijack(b *testing.B) {
5646 b.ReportAllocs()
5647 req := reqBytes(`GET / HTTP/1.1
5648 Host: golang.org
5649 `)
5650 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5651 conn, _, err := w.(Hijacker).Hijack()
5652 if err != nil {
5653 panic(err)
5654 }
5655 conn.Close()
5656 })
5657 conn := &rwTestConn{
5658 Writer: io.Discard,
5659 closec: make(chan bool, 1),
5660 }
5661 ln := &oneConnListener{conn: conn}
5662 for i := 0; i < b.N; i++ {
5663 conn.Reader = bytes.NewReader(req)
5664 ln.conn = conn
5665 Serve(ln, h)
5666 <-conn.closec
5667 }
5668 }
5669
5670 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5671 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5672 b.ReportAllocs()
5673 b.StopTimer()
5674 sawClose := make(chan bool)
5675 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5676 <-rw.(CloseNotifier).CloseNotify()
5677 sawClose <- true
5678 })).ts
5679 b.StartTimer()
5680 for i := 0; i < b.N; i++ {
5681 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5682 if err != nil {
5683 b.Fatalf("error dialing: %v", err)
5684 }
5685 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5686 if err != nil {
5687 b.Fatal(err)
5688 }
5689 conn.Close()
5690 <-sawClose
5691 }
5692 b.StopTimer()
5693 }
5694
5695
5696 func TestConcurrentServerServe(t *testing.T) {
5697 setParallel(t)
5698 for i := 0; i < 100; i++ {
5699 ln1 := &oneConnListener{conn: nil}
5700 ln2 := &oneConnListener{conn: nil}
5701 srv := Server{}
5702 go func() { srv.Serve(ln1) }()
5703 go func() { srv.Serve(ln2) }()
5704 }
5705 }
5706
5707 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5708 func testServerIdleTimeout(t *testing.T, mode testMode) {
5709 if testing.Short() {
5710 t.Skip("skipping in short mode")
5711 }
5712 runTimeSensitiveTest(t, []time.Duration{
5713 10 * time.Millisecond,
5714 100 * time.Millisecond,
5715 1 * time.Second,
5716 10 * time.Second,
5717 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5718 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5719 io.Copy(io.Discard, r.Body)
5720 io.WriteString(w, r.RemoteAddr)
5721 }), func(ts *httptest.Server) {
5722 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5723 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5724 })
5725 defer cst.close()
5726 ts := cst.ts
5727 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5728 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5729 c := ts.Client()
5730
5731 get := func() (string, error) {
5732 res, err := c.Get(ts.URL)
5733 if err != nil {
5734 return "", err
5735 }
5736 defer res.Body.Close()
5737 slurp, err := io.ReadAll(res.Body)
5738 if err != nil {
5739
5740
5741
5742 t.Fatal(err)
5743 }
5744 return string(slurp), nil
5745 }
5746
5747 a1, err := get()
5748 if err != nil {
5749 return err
5750 }
5751 a2, err := get()
5752 if err != nil {
5753 return err
5754 }
5755 if a1 != a2 {
5756 return fmt.Errorf("did requests on different connections")
5757 }
5758 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5759 a3, err := get()
5760 if err != nil {
5761 return err
5762 }
5763 if a2 == a3 {
5764 return fmt.Errorf("request three unexpectedly on same connection")
5765 }
5766
5767
5768 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5769 if err != nil {
5770 return err
5771 }
5772 defer conn.Close()
5773 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5774 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5775 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5776 return fmt.Errorf("copy byte succeeded; want err")
5777 }
5778
5779 return nil
5780 })
5781 }
5782
5783 func get(t *testing.T, c *Client, url string) string {
5784 res, err := c.Get(url)
5785 if err != nil {
5786 t.Fatal(err)
5787 }
5788 defer res.Body.Close()
5789 slurp, err := io.ReadAll(res.Body)
5790 if err != nil {
5791 t.Fatal(err)
5792 }
5793 return string(slurp)
5794 }
5795
5796
5797
5798 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5799 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5800 }
5801 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5802 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5803 io.WriteString(w, r.RemoteAddr)
5804 })).ts
5805
5806 c := ts.Client()
5807 tr := c.Transport.(*Transport)
5808
5809 get := func() string { return get(t, c, ts.URL) }
5810
5811 a1, a2 := get(), get()
5812 if a1 == a2 {
5813 t.Logf("made two requests from a single conn %q (as expected)", a1)
5814 } else {
5815 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5816 }
5817
5818
5819
5820
5821
5822 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5823 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5824 }
5825
5826
5827 ts.Config.SetKeepAlivesEnabled(false)
5828
5829 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5830 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5831 if d > 0 {
5832 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5833 }
5834 return false
5835 }
5836 return true
5837 })
5838
5839
5840
5841
5842 }
5843
5844 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5845 func testServerShutdown(t *testing.T, mode testMode) {
5846 var cst *clientServerTest
5847
5848 var once sync.Once
5849 statesRes := make(chan map[ConnState]int, 1)
5850 shutdownRes := make(chan error, 1)
5851 gotOnShutdown := make(chan struct{})
5852 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5853 first := false
5854 once.Do(func() {
5855 statesRes <- cst.ts.Config.ExportAllConnsByState()
5856 go func() {
5857 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5858 }()
5859 first = true
5860 })
5861
5862 if first {
5863
5864
5865
5866 <-gotOnShutdown
5867
5868
5869 for !t.Failed() {
5870 res, err := cst.c.Get(cst.ts.URL)
5871 if err != nil {
5872 break
5873 }
5874 out, _ := io.ReadAll(res.Body)
5875 res.Body.Close()
5876 if mode == http2Mode {
5877 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5878 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5879 continue
5880 }
5881 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5882 }
5883 }
5884
5885 io.WriteString(w, r.RemoteAddr)
5886 })
5887
5888 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5889 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5890 })
5891
5892 out := get(t, cst.c, cst.ts.URL)
5893 t.Logf("%v: %q", cst.ts.URL, out)
5894
5895 if err := <-shutdownRes; err != nil {
5896 t.Fatalf("Shutdown: %v", err)
5897 }
5898 <-gotOnShutdown
5899
5900 if states := <-statesRes; states[StateActive] != 1 {
5901 t.Errorf("connection in wrong state, %v", states)
5902 }
5903 }
5904
5905 func TestServerShutdownStateNew(t *testing.T) { runSynctest(t, testServerShutdownStateNew) }
5906 func testServerShutdownStateNew(t *testing.T, mode testMode) {
5907 if testing.Short() {
5908 t.Skip("test takes 5-6 seconds; skipping in short mode")
5909 }
5910
5911 listener := fakeNetListen()
5912 defer listener.Close()
5913
5914 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5915
5916 }), func(ts *httptest.Server) {
5917 ts.Listener.Close()
5918 ts.Listener = listener
5919
5920 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
5921 }).ts
5922
5923
5924 c := listener.connect()
5925 defer c.Close()
5926 synctest.Wait()
5927
5928 shutdownRes := runAsync(func() (struct{}, error) {
5929 return struct{}{}, ts.Config.Shutdown(context.Background())
5930 })
5931
5932
5933
5934
5935 const expectTimeout = 5 * time.Second
5936
5937
5938 time.Sleep(expectTimeout - 1)
5939 synctest.Wait()
5940 if shutdownRes.done() {
5941 t.Fatal("shutdown too soon")
5942 }
5943 if c.IsClosedByPeer() {
5944 t.Fatal("connection was closed by server too soon")
5945 }
5946
5947
5948
5949
5950
5951 time.Sleep(2 * time.Second)
5952 synctest.Wait()
5953 if _, err := shutdownRes.result(); err != nil {
5954 t.Fatalf("Shutdown() = %v, want complete", err)
5955 }
5956 if !c.IsClosedByPeer() {
5957 t.Fatalf("connection was not closed by server after shutdown")
5958 }
5959 }
5960
5961
5962 func TestServerCloseDeadlock(t *testing.T) {
5963 var s Server
5964 s.Close()
5965 s.Close()
5966 }
5967
5968
5969
5970 func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
5971 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
5972 if mode == http2Mode {
5973 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5974 defer restore()
5975 }
5976
5977 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
5978 defer cst.close()
5979 srv := cst.ts.Config
5980 srv.SetKeepAlivesEnabled(false)
5981 for try := 0; try < 2; try++ {
5982 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5983 if !srv.ExportAllConnsIdle() {
5984 if d > 0 {
5985 t.Logf("test server still has active conns after %v", d)
5986 }
5987 return false
5988 }
5989 return true
5990 })
5991 conns := 0
5992 var info httptrace.GotConnInfo
5993 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5994 GotConn: func(v httptrace.GotConnInfo) {
5995 conns++
5996 info = v
5997 },
5998 })
5999 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6000 if err != nil {
6001 t.Fatal(err)
6002 }
6003 res, err := cst.c.Do(req)
6004 if err != nil {
6005 t.Fatal(err)
6006 }
6007 res.Body.Close()
6008 if conns != 1 {
6009 t.Fatalf("request %v: got %v conns, want 1", try, conns)
6010 }
6011 if info.Reused || info.WasIdle {
6012 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
6013 }
6014 }
6015 }
6016
6017
6018
6019
6020 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
6021 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
6022 runTimeSensitiveTest(t, []time.Duration{
6023 10 * time.Millisecond,
6024 50 * time.Millisecond,
6025 250 * time.Millisecond,
6026 time.Second,
6027 2 * time.Second,
6028 }, func(t *testing.T, timeout time.Duration) error {
6029 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6030 select {
6031 case <-time.After(2 * timeout):
6032 fmt.Fprint(w, "ok")
6033 case <-r.Context().Done():
6034 fmt.Fprint(w, r.Context().Err())
6035 }
6036 }), func(ts *httptest.Server) {
6037 ts.Config.ReadTimeout = timeout
6038 t.Logf("Server.Config.ReadTimeout = %v", timeout)
6039 })
6040 defer cst.close()
6041 ts := cst.ts
6042
6043 var retries atomic.Int32
6044 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
6045 if retries.Add(1) != 1 {
6046 return nil, errors.New("too many retries")
6047 }
6048 return nil, nil
6049 }
6050
6051 c := ts.Client()
6052
6053 res, err := c.Get(ts.URL)
6054 if err != nil {
6055 return fmt.Errorf("Get: %v", err)
6056 }
6057 slurp, err := io.ReadAll(res.Body)
6058 res.Body.Close()
6059 if err != nil {
6060 return fmt.Errorf("Body ReadAll: %v", err)
6061 }
6062 if string(slurp) != "ok" {
6063 return fmt.Errorf("got: %q, want ok", slurp)
6064 }
6065 return nil
6066 })
6067 }
6068
6069
6070
6071
6072 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
6073 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
6074 }
6075 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
6076 runTimeSensitiveTest(t, []time.Duration{
6077 10 * time.Millisecond,
6078 50 * time.Millisecond,
6079 250 * time.Millisecond,
6080 time.Second,
6081 2 * time.Second,
6082 }, func(t *testing.T, timeout time.Duration) error {
6083 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
6084 ts.Config.ReadHeaderTimeout = timeout
6085 ts.Config.IdleTimeout = 0
6086 })
6087 defer cst.close()
6088 ts := cst.ts
6089
6090
6091
6092 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6093 if err != nil {
6094 t.Fatalf("dial failed: %v", err)
6095 }
6096 br := bufio.NewReader(conn)
6097 defer conn.Close()
6098
6099 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6100 return fmt.Errorf("writing first request failed: %v", err)
6101 }
6102
6103 if _, err := ReadResponse(br, nil); err != nil {
6104 return fmt.Errorf("first response (before timeout) failed: %v", err)
6105 }
6106
6107
6108
6109 time.Sleep(timeout * 3 / 2)
6110
6111 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6112 return fmt.Errorf("writing second request failed: %v", err)
6113 }
6114
6115 if _, err := ReadResponse(br, nil); err != nil {
6116 return fmt.Errorf("second response (after timeout) failed: %v", err)
6117 }
6118
6119 return nil
6120 })
6121 }
6122
6123
6124
6125 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6126 for i, d := range durations {
6127 err := test(t, d)
6128 if err == nil {
6129 return
6130 }
6131 if i == len(durations)-1 || t.Failed() {
6132 t.Fatalf("failed with duration %v: %v", d, err)
6133 }
6134 t.Logf("retrying after error with duration %v: %v", d, err)
6135 }
6136 }
6137
6138
6139
6140 func TestServerDuplicateBackgroundRead(t *testing.T) {
6141 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6142 }
6143 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6144 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6145 testenv.SkipFlaky(t, 24826)
6146 }
6147
6148 goroutines := 5
6149 requests := 2000
6150 if testing.Short() {
6151 goroutines = 3
6152 requests = 100
6153 }
6154
6155 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6156
6157 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6158
6159 var wg sync.WaitGroup
6160 for i := 0; i < goroutines; i++ {
6161 wg.Add(1)
6162 go func() {
6163 defer wg.Done()
6164 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6165 if err != nil {
6166 t.Error(err)
6167 return
6168 }
6169 defer cn.Close()
6170
6171 wg.Add(1)
6172 go func() {
6173 defer wg.Done()
6174 io.Copy(io.Discard, cn)
6175 }()
6176
6177 for j := 0; j < requests; j++ {
6178 if t.Failed() {
6179 return
6180 }
6181 _, err := cn.Write(reqBytes)
6182 if err != nil {
6183 t.Error(err)
6184 return
6185 }
6186 }
6187 }()
6188 }
6189 wg.Wait()
6190 }
6191
6192
6193
6194
6195
6196
6197 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6198 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6199 }
6200 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6201 if runtime.GOOS == "plan9" {
6202 t.Skip("skipping test; see https://golang.org/issue/18657")
6203 }
6204 done := make(chan struct{})
6205 inHandler := make(chan bool, 1)
6206 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6207 defer close(done)
6208
6209
6210 inHandler <- true
6211
6212 conn, buf, err := w.(Hijacker).Hijack()
6213 if err != nil {
6214 t.Error(err)
6215 return
6216 }
6217 defer conn.Close()
6218
6219 peek, err := buf.Reader.Peek(3)
6220 if string(peek) != "foo" || err != nil {
6221 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6222 }
6223
6224 select {
6225 case <-r.Context().Done():
6226 t.Error("context unexpectedly canceled")
6227 default:
6228 }
6229 })).ts
6230
6231 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6232 if err != nil {
6233 t.Fatal(err)
6234 }
6235 defer cn.Close()
6236 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6237 t.Fatal(err)
6238 }
6239 <-inHandler
6240 if _, err := cn.Write([]byte("foo")); err != nil {
6241 t.Fatal(err)
6242 }
6243
6244 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6245 t.Fatal(err)
6246 }
6247 <-done
6248 }
6249
6250
6251 func TestServerHijackGetsFullBody(t *testing.T) {
6252 run(t, testServerHijackGetsFullBody, []testMode{http1Mode})
6253 }
6254 func testServerHijackGetsFullBody(t *testing.T, mode testMode) {
6255 if runtime.GOOS == "plan9" {
6256 t.Skip("skipping test; see https://golang.org/issue/18657")
6257 }
6258 done := make(chan struct{})
6259 needle := strings.Repeat("x", 100*1024)
6260 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6261 defer close(done)
6262
6263 conn, buf, err := w.(Hijacker).Hijack()
6264 if err != nil {
6265 t.Error(err)
6266 return
6267 }
6268 defer conn.Close()
6269
6270 got := make([]byte, len(needle))
6271 n, err := io.ReadFull(buf.Reader, got)
6272 if n != len(needle) || string(got) != needle || err != nil {
6273 t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err)
6274 }
6275 })).ts
6276
6277 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6278 if err != nil {
6279 t.Fatal(err)
6280 }
6281 defer cn.Close()
6282 buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6283 buf = append(buf, []byte(needle)...)
6284 if _, err := cn.Write(buf); err != nil {
6285 t.Fatal(err)
6286 }
6287
6288 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6289 t.Fatal(err)
6290 }
6291 <-done
6292 }
6293
6294
6295
6296
6297 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6298 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6299 }
6300 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6301 if runtime.GOOS == "plan9" {
6302 t.Skip("skipping test; see https://golang.org/issue/18657")
6303 }
6304 done := make(chan struct{})
6305 const size = 8 << 10
6306 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6307 defer close(done)
6308
6309 conn, buf, err := w.(Hijacker).Hijack()
6310 if err != nil {
6311 t.Error(err)
6312 return
6313 }
6314 defer conn.Close()
6315 slurp, err := io.ReadAll(buf.Reader)
6316 if err != nil {
6317 t.Errorf("Copy: %v", err)
6318 }
6319 allX := true
6320 for _, v := range slurp {
6321 if v != 'x' {
6322 allX = false
6323 }
6324 }
6325 if len(slurp) != size {
6326 t.Errorf("read %d; want %d", len(slurp), size)
6327 } else if !allX {
6328 t.Errorf("read %q; want %d 'x'", slurp, size)
6329 }
6330 })).ts
6331
6332 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6333 if err != nil {
6334 t.Fatal(err)
6335 }
6336 defer cn.Close()
6337 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6338 strings.Repeat("x", size)); err != nil {
6339 t.Fatal(err)
6340 }
6341 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6342 t.Fatal(err)
6343 }
6344
6345 <-done
6346 }
6347
6348
6349 func TestServerValidatesMethod(t *testing.T) {
6350 tests := []struct {
6351 method string
6352 want int
6353 }{
6354 {"GET", 200},
6355 {"GE(T", 400},
6356 }
6357 for _, tt := range tests {
6358 conn := newTestConn()
6359 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6360
6361 ln := &oneConnListener{conn}
6362 go Serve(ln, serve(200))
6363 <-conn.closec
6364 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6365 if err != nil {
6366 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6367 continue
6368 }
6369 if res.StatusCode != tt.want {
6370 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6371 }
6372 }
6373 }
6374
6375
6376 type eofListenerNotComparable []int
6377
6378 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6379 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6380 func (eofListenerNotComparable) Close() error { return nil }
6381
6382
6383 func TestServerListenNotComparableListener(t *testing.T) {
6384 var s Server
6385 s.Serve(make(eofListenerNotComparable, 1))
6386 }
6387
6388
6389 type countCloseListener struct {
6390 net.Listener
6391 closes int32
6392 }
6393
6394 func (p *countCloseListener) Close() error {
6395 var err error
6396 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6397 err = p.Listener.Close()
6398 }
6399 return err
6400 }
6401
6402
6403 func TestServerCloseListenerOnce(t *testing.T) {
6404 setParallel(t)
6405 defer afterTest(t)
6406
6407 ln := newLocalListener(t)
6408 defer ln.Close()
6409
6410 cl := &countCloseListener{Listener: ln}
6411 server := &Server{}
6412 sdone := make(chan bool, 1)
6413
6414 go func() {
6415 server.Serve(cl)
6416 sdone <- true
6417 }()
6418 time.Sleep(10 * time.Millisecond)
6419 server.Shutdown(context.Background())
6420 ln.Close()
6421 <-sdone
6422
6423 nclose := atomic.LoadInt32(&cl.closes)
6424 if nclose != 1 {
6425 t.Errorf("Close calls = %v; want 1", nclose)
6426 }
6427 }
6428
6429
6430 func TestServerShutdownThenServe(t *testing.T) {
6431 var srv Server
6432 cl := &countCloseListener{Listener: nil}
6433 srv.Shutdown(context.Background())
6434 got := srv.Serve(cl)
6435 if got != ErrServerClosed {
6436 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6437 }
6438 nclose := atomic.LoadInt32(&cl.closes)
6439 if nclose != 1 {
6440 t.Errorf("Close calls = %v; want 1", nclose)
6441 }
6442 }
6443
6444
6445 func TestStripPortFromHost(t *testing.T) {
6446 mux := NewServeMux()
6447
6448 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6449 fmt.Fprintf(w, "OK")
6450 })
6451 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6452 fmt.Fprintf(w, "uh-oh!")
6453 })
6454
6455 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6456 rw := httptest.NewRecorder()
6457
6458 mux.ServeHTTP(rw, req)
6459
6460 response := rw.Body.String()
6461 if response != "OK" {
6462 t.Errorf("Response gotten was %q", response)
6463 }
6464 }
6465
6466 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6467 func testServerContexts(t *testing.T, mode testMode) {
6468 type baseKey struct{}
6469 type connKey struct{}
6470 ch := make(chan context.Context, 1)
6471 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6472 ch <- r.Context()
6473 }), func(ts *httptest.Server) {
6474 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6475 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6476 t.Errorf("unexpected onceClose listener type %T", ln)
6477 }
6478 return context.WithValue(context.Background(), baseKey{}, "base")
6479 }
6480 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6481 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6482 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6483 }
6484 return context.WithValue(ctx, connKey{}, "conn")
6485 }
6486 }).ts
6487 res, err := ts.Client().Get(ts.URL)
6488 if err != nil {
6489 t.Fatal(err)
6490 }
6491 res.Body.Close()
6492 ctx := <-ch
6493 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6494 t.Errorf("base context key = %#v; want %q", got, want)
6495 }
6496 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6497 t.Errorf("conn context key = %#v; want %q", got, want)
6498 }
6499 }
6500
6501
6502 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6503 run(t, testConnContextNotModifyingAllContexts)
6504 }
6505 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6506 type connKey struct{}
6507 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6508 rw.Header().Set("Connection", "close")
6509 }), func(ts *httptest.Server) {
6510 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6511 if got := ctx.Value(connKey{}); got != nil {
6512 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6513 }
6514 return context.WithValue(ctx, connKey{}, "conn")
6515 }
6516 }).ts
6517
6518 var res *Response
6519 var err error
6520
6521 res, err = ts.Client().Get(ts.URL)
6522 if err != nil {
6523 t.Fatal(err)
6524 }
6525 res.Body.Close()
6526
6527 res, err = ts.Client().Get(ts.URL)
6528 if err != nil {
6529 t.Fatal(err)
6530 }
6531 res.Body.Close()
6532 }
6533
6534
6535
6536 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6537 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6538 }
6539 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6540 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6541 w.Write([]byte("Hello, World!"))
6542 })).ts
6543
6544 serverURL, err := url.Parse(cst.URL)
6545 if err != nil {
6546 t.Fatalf("Failed to parse server URL: %v", err)
6547 }
6548
6549 unsupportedTEs := []string{
6550 "fugazi",
6551 "foo-bar",
6552 "unknown",
6553 `" chunked"`,
6554 }
6555
6556 for _, badTE := range unsupportedTEs {
6557 http1ReqBody := fmt.Sprintf(""+
6558 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6559 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6560
6561 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6562 if err != nil {
6563 t.Errorf("%q. unexpected error: %v", badTE, err)
6564 continue
6565 }
6566
6567 wantBody := fmt.Sprintf("" +
6568 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6569 "Connection: close\r\n\r\nUnsupported transfer encoding")
6570
6571 if string(gotBody) != wantBody {
6572 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6573 }
6574 }
6575 }
6576
6577
6578 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6579 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6580 type setting struct {
6581 name string
6582 body []byte
6583
6584
6585
6586
6587 contentEncoding any
6588 wantContentType string
6589 }
6590
6591 settings := []*setting{
6592 {
6593 name: "gzip content-encoding, gzipped",
6594 contentEncoding: "application/gzip",
6595 wantContentType: "",
6596 body: func() []byte {
6597 buf := new(bytes.Buffer)
6598 gzw := gzip.NewWriter(buf)
6599 gzw.Write([]byte("doctype html><p>Hello</p>"))
6600 gzw.Close()
6601 return buf.Bytes()
6602 }(),
6603 },
6604 {
6605 name: "zlib content-encoding, zlibbed",
6606 contentEncoding: "application/zlib",
6607 wantContentType: "",
6608 body: func() []byte {
6609 buf := new(bytes.Buffer)
6610 zw := zlib.NewWriter(buf)
6611 zw.Write([]byte("doctype html><p>Hello</p>"))
6612 zw.Close()
6613 return buf.Bytes()
6614 }(),
6615 },
6616 {
6617 name: "no content-encoding",
6618 wantContentType: "application/x-gzip",
6619 body: func() []byte {
6620 buf := new(bytes.Buffer)
6621 gzw := gzip.NewWriter(buf)
6622 gzw.Write([]byte("doctype html><p>Hello</p>"))
6623 gzw.Close()
6624 return buf.Bytes()
6625 }(),
6626 },
6627 {
6628 name: "phony content-encoding",
6629 contentEncoding: "foo/bar",
6630 body: []byte("doctype html><p>Hello</p>"),
6631 },
6632 {
6633 name: "empty but set content-encoding",
6634 contentEncoding: "",
6635 wantContentType: "audio/mpeg",
6636 body: []byte("ID3"),
6637 },
6638 }
6639
6640 for _, tt := range settings {
6641 t.Run(tt.name, func(t *testing.T) {
6642 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6643 if tt.contentEncoding != nil {
6644 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6645 }
6646 rw.Write(tt.body)
6647 }))
6648
6649 res, err := cst.c.Get(cst.ts.URL)
6650 if err != nil {
6651 t.Fatalf("Failed to fetch URL: %v", err)
6652 }
6653 defer res.Body.Close()
6654
6655 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6656 if w != nil {
6657 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6658 } else if g != "" {
6659 t.Errorf("Unexpected Content-Encoding %q", g)
6660 }
6661 }
6662
6663 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6664 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6665 }
6666 })
6667 }
6668 }
6669
6670
6671
6672 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6673 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6674 }
6675 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6676 if testing.Short() {
6677 t.Skip("skipping in short mode")
6678 }
6679
6680 pc, curFile, _, _ := runtime.Caller(0)
6681 curFileBaseName := filepath.Base(curFile)
6682 testFuncName := runtime.FuncForPC(pc).Name()
6683
6684 timeoutMsg := "timed out here!"
6685
6686 tests := []struct {
6687 name string
6688 mustTimeout bool
6689 wantResp string
6690 }{
6691 {
6692 name: "return before timeout",
6693 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6694 },
6695 {
6696 name: "return after timeout",
6697 mustTimeout: true,
6698 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6699 len(timeoutMsg), timeoutMsg),
6700 },
6701 }
6702
6703 for _, tt := range tests {
6704 tt := tt
6705 t.Run(tt.name, func(t *testing.T) {
6706 exitHandler := make(chan bool, 1)
6707 defer close(exitHandler)
6708 lastLine := make(chan int, 1)
6709
6710 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6711 w.WriteHeader(404)
6712 w.WriteHeader(404)
6713 w.WriteHeader(404)
6714 w.WriteHeader(404)
6715 _, _, line, _ := runtime.Caller(0)
6716 lastLine <- line
6717 <-exitHandler
6718 })
6719
6720 if !tt.mustTimeout {
6721 exitHandler <- true
6722 }
6723
6724 logBuf := new(strings.Builder)
6725 srvLog := log.New(logBuf, "", 0)
6726
6727 dur := 20 * time.Millisecond
6728 if !tt.mustTimeout {
6729
6730 dur = 10 * time.Second
6731 }
6732 th := TimeoutHandler(sh, dur, timeoutMsg)
6733 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6734 defer cst.close()
6735
6736 res, err := cst.c.Get(cst.ts.URL)
6737 if err != nil {
6738 t.Fatalf("Unexpected error: %v", err)
6739 }
6740
6741
6742
6743 res.Header.Del("Date")
6744 res.Header.Del("Content-Type")
6745
6746
6747 blob, _ := httputil.DumpResponse(res, true)
6748 if g, w := string(blob), tt.wantResp; g != w {
6749 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6750 }
6751
6752
6753
6754 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6755 if g, w := len(logEntries), 3; g != w {
6756 blob, _ := json.MarshalIndent(logEntries, "", " ")
6757 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6758 }
6759
6760 lastSpuriousLine := <-lastLine
6761 firstSpuriousLine := lastSpuriousLine - 3
6762
6763
6764 for i, logEntry := range logEntries {
6765 wantLine := firstSpuriousLine + i
6766 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6767 testFuncName, curFileBaseName, wantLine)
6768 re := regexp.MustCompile(pat)
6769 if !re.MatchString(logEntry) {
6770 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6771 }
6772 }
6773 })
6774 }
6775 }
6776
6777
6778
6779
6780 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6781 conn, err := net.Dial("tcp", host)
6782 if err != nil {
6783 return nil, err
6784 }
6785 defer conn.Close()
6786
6787 if _, err := conn.Write(http1ReqBody); err != nil {
6788 return nil, err
6789 }
6790 return io.ReadAll(conn)
6791 }
6792
6793 func BenchmarkResponseStatusLine(b *testing.B) {
6794 b.ReportAllocs()
6795 b.RunParallel(func(pb *testing.PB) {
6796 bw := bufio.NewWriter(io.Discard)
6797 var buf3 [3]byte
6798 for pb.Next() {
6799 Export_writeStatusLine(bw, true, 200, buf3[:])
6800 }
6801 })
6802 }
6803
6804 func TestDisableKeepAliveUpgrade(t *testing.T) {
6805 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6806 }
6807 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6808 if testing.Short() {
6809 t.Skip("skipping in short mode")
6810 }
6811
6812 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6813 w.Header().Set("Connection", "Upgrade")
6814 w.Header().Set("Upgrade", "someProto")
6815 w.WriteHeader(StatusSwitchingProtocols)
6816 c, buf, err := w.(Hijacker).Hijack()
6817 if err != nil {
6818 return
6819 }
6820 defer c.Close()
6821
6822
6823
6824 io.Copy(c, buf)
6825 }), func(ts *httptest.Server) {
6826 ts.Config.SetKeepAlivesEnabled(false)
6827 }).ts
6828
6829 cl := s.Client()
6830 cl.Transport.(*Transport).DisableKeepAlives = true
6831
6832 resp, err := cl.Get(s.URL)
6833 if err != nil {
6834 t.Fatalf("failed to perform request: %v", err)
6835 }
6836 defer resp.Body.Close()
6837
6838 if resp.StatusCode != StatusSwitchingProtocols {
6839 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6840 }
6841
6842 rwc, ok := resp.Body.(io.ReadWriteCloser)
6843 if !ok {
6844 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6845 }
6846
6847 _, err = rwc.Write([]byte("hello"))
6848 if err != nil {
6849 t.Fatalf("failed to write to body: %v", err)
6850 }
6851
6852 b := make([]byte, 5)
6853 _, err = io.ReadFull(rwc, b)
6854 if err != nil {
6855 t.Fatalf("failed to read from body: %v", err)
6856 }
6857
6858 if string(b) != "hello" {
6859 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6860 }
6861 }
6862
6863 type tlogWriter struct{ t *testing.T }
6864
6865 func (w tlogWriter) Write(p []byte) (int, error) {
6866 w.t.Log(string(p))
6867 return len(p), nil
6868 }
6869
6870 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6871 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6872 }
6873 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6874 const wantBody = "want"
6875 const wantUpgrade = "someProto"
6876 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6877 w.Header().Set("Connection", "Upgrade")
6878 w.Header().Set("Upgrade", wantUpgrade)
6879 w.WriteHeader(StatusSwitchingProtocols)
6880 NewResponseController(w).Flush()
6881
6882
6883 w.WriteHeader(200)
6884 if _, err := w.Write([]byte("x")); err == nil {
6885 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6886 }
6887
6888 c, _, err := NewResponseController(w).Hijack()
6889 if err != nil {
6890 t.Errorf("Hijack: %v", err)
6891 return
6892 }
6893 defer c.Close()
6894 if _, err := c.Write([]byte(wantBody)); err != nil {
6895 t.Errorf("Write to hijacked body: %v", err)
6896 }
6897 }), func(ts *httptest.Server) {
6898
6899 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6900 }).ts
6901
6902 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6903 if err != nil {
6904 t.Fatalf("net.Dial: %v", err)
6905 }
6906 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6907 if err != nil {
6908 t.Fatalf("conn.Write: %v", err)
6909 }
6910 defer conn.Close()
6911
6912 r := bufio.NewReader(conn)
6913 res, err := ReadResponse(r, &Request{Method: "GET"})
6914 if err != nil {
6915 t.Fatal("ReadResponse error:", err)
6916 }
6917 if res.StatusCode != StatusSwitchingProtocols {
6918 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
6919 }
6920 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
6921 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
6922 }
6923 body, err := io.ReadAll(r)
6924 if err != nil {
6925 t.Error(err)
6926 }
6927 if string(body) != wantBody {
6928 t.Errorf("Response body = %q, want %q", string(body), wantBody)
6929 }
6930 }
6931
6932 func TestMuxRedirectRelative(t *testing.T) {
6933 setParallel(t)
6934 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6935 if err != nil {
6936 t.Errorf("%s", err)
6937 }
6938 mux := NewServeMux()
6939 resp := httptest.NewRecorder()
6940 mux.ServeHTTP(resp, req)
6941 if got, want := resp.Header().Get("Location"), "/"; got != want {
6942 t.Errorf("Location header expected %q; got %q", want, got)
6943 }
6944 if got, want := resp.Code, StatusMovedPermanently; got != want {
6945 t.Errorf("Expected response code %d; got %d", want, got)
6946 }
6947 }
6948
6949
6950 func TestQuerySemicolon(t *testing.T) {
6951 t.Cleanup(func() { afterTest(t) })
6952
6953 tests := []struct {
6954 query string
6955 xNoSemicolons string
6956 xWithSemicolons string
6957 expectParseFormErr bool
6958 }{
6959 {"?a=1;x=bad&x=good", "good", "bad", true},
6960 {"?a=1;b=bad&x=good", "good", "good", true},
6961 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6962 {"?a=1;x=good;x=bad", "", "good", true},
6963 }
6964
6965 run(t, func(t *testing.T, mode testMode) {
6966 for _, tt := range tests {
6967 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6968 allowSemicolons := false
6969 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
6970 })
6971 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6972 allowSemicolons, expectParseFormErr := true, false
6973 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
6974 })
6975 }
6976 })
6977 }
6978
6979 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
6980 writeBackX := func(w ResponseWriter, r *Request) {
6981 x := r.URL.Query().Get("x")
6982 if expectParseFormErr {
6983 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6984 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6985 }
6986 } else {
6987 if err := r.ParseForm(); err != nil {
6988 t.Errorf("expected no error from ParseForm, got %v", err)
6989 }
6990 }
6991 if got := r.FormValue("x"); x != got {
6992 t.Errorf("got %q from FormValue, want %q", got, x)
6993 }
6994 fmt.Fprintf(w, "%s", x)
6995 }
6996
6997 h := Handler(HandlerFunc(writeBackX))
6998 if allowSemicolons {
6999 h = AllowQuerySemicolons(h)
7000 }
7001
7002 logBuf := &strings.Builder{}
7003 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
7004 ts.Config.ErrorLog = log.New(logBuf, "", 0)
7005 }).ts
7006
7007 req, _ := NewRequest("GET", ts.URL+query, nil)
7008 res, err := ts.Client().Do(req)
7009 if err != nil {
7010 t.Fatal(err)
7011 }
7012 slurp, _ := io.ReadAll(res.Body)
7013 res.Body.Close()
7014 if got, want := res.StatusCode, 200; got != want {
7015 t.Errorf("Status = %d; want = %d", got, want)
7016 }
7017 if got, want := string(slurp), wantX; got != want {
7018 t.Errorf("Body = %q; want = %q", got, want)
7019 }
7020 }
7021
7022 func TestMaxBytesHandler(t *testing.T) {
7023
7024 defer afterTest(t)
7025
7026 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
7027 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
7028 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
7029 func(t *testing.T) {
7030 run(t, func(t *testing.T, mode testMode) {
7031 testMaxBytesHandler(t, mode, maxSize, requestSize)
7032 }, testNotParallel)
7033 })
7034 }
7035 }
7036 }
7037
7038 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
7039 runTimeSensitiveTest(t, []time.Duration{
7040 1 * time.Millisecond,
7041 5 * time.Millisecond,
7042 10 * time.Millisecond,
7043 50 * time.Millisecond,
7044 100 * time.Millisecond,
7045 500 * time.Millisecond,
7046 time.Second,
7047 5 * time.Second,
7048 }, func(t *testing.T, timeout time.Duration) error {
7049 SetRSTAvoidanceDelay(t, timeout)
7050 t.Logf("set RST avoidance delay to %v", timeout)
7051
7052 var (
7053 handlerN int64
7054 handlerErr error
7055 )
7056 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
7057 var buf bytes.Buffer
7058 handlerN, handlerErr = io.Copy(&buf, r.Body)
7059 io.Copy(w, &buf)
7060 })
7061
7062 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
7063
7064
7065 defer cst.close()
7066 ts := cst.ts
7067 c := ts.Client()
7068
7069 body := strings.Repeat("a", int(requestSize))
7070 var wg sync.WaitGroup
7071 defer wg.Wait()
7072 getBody := func() (io.ReadCloser, error) {
7073 wg.Add(1)
7074 body := &wgReadCloser{
7075 Reader: strings.NewReader(body),
7076 wg: &wg,
7077 }
7078 return body, nil
7079 }
7080 reqBody, _ := getBody()
7081 req, err := NewRequest("POST", ts.URL, reqBody)
7082 if err != nil {
7083 reqBody.Close()
7084 t.Fatal(err)
7085 }
7086 req.ContentLength = int64(len(body))
7087 req.GetBody = getBody
7088 req.Header.Set("Content-Type", "text/plain")
7089
7090 var buf strings.Builder
7091 res, err := c.Do(req)
7092 if err != nil {
7093 return fmt.Errorf("unexpected connection error: %v", err)
7094 } else {
7095 _, err = io.Copy(&buf, res.Body)
7096 res.Body.Close()
7097 if err != nil {
7098 return fmt.Errorf("unexpected read error: %v", err)
7099 }
7100 }
7101
7102
7103
7104
7105 if handlerN > maxSize {
7106 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
7107 }
7108 if requestSize > maxSize && handlerErr == nil {
7109 t.Error("expected error on handler side; got nil")
7110 }
7111 if requestSize <= maxSize {
7112 if handlerErr != nil {
7113 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
7114 }
7115 if handlerN != requestSize {
7116 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
7117 }
7118 }
7119 if buf.Len() != int(handlerN) {
7120 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
7121 }
7122
7123 return nil
7124 })
7125 }
7126
7127 func TestEarlyHints(t *testing.T) {
7128 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7129 h := w.Header()
7130 h.Add("Link", "</style.css>; rel=preload; as=style")
7131 h.Add("Link", "</script.js>; rel=preload; as=script")
7132 w.WriteHeader(StatusEarlyHints)
7133
7134 h.Add("Link", "</foo.js>; rel=preload; as=script")
7135 w.WriteHeader(StatusEarlyHints)
7136
7137 w.Write([]byte("stuff"))
7138 }))
7139
7140 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7141 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
7142 if !strings.Contains(got, expected) {
7143 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7144 }
7145 }
7146 func TestProcessing(t *testing.T) {
7147 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7148 w.WriteHeader(StatusProcessing)
7149 w.Write([]byte("stuff"))
7150 }))
7151
7152 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7153 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
7154 if !strings.Contains(got, expected) {
7155 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7156 }
7157 }
7158
7159 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
7160 func testParseFormCleanup(t *testing.T, mode testMode) {
7161 if mode == http2Mode {
7162 t.Skip("https://go.dev/issue/20253")
7163 }
7164
7165 const maxMemory = 1024
7166 const key = "file"
7167
7168 if runtime.GOOS == "windows" {
7169
7170 t.Skip("https://go.dev/issue/25965")
7171 }
7172
7173 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7174 r.ParseMultipartForm(maxMemory)
7175 f, _, err := r.FormFile(key)
7176 if err != nil {
7177 t.Errorf("r.FormFile(%q) = %v", key, err)
7178 return
7179 }
7180 of, ok := f.(*os.File)
7181 if !ok {
7182 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7183 return
7184 }
7185 w.Write([]byte(of.Name()))
7186 }))
7187
7188 fBuf := new(bytes.Buffer)
7189 mw := multipart.NewWriter(fBuf)
7190 mf, err := mw.CreateFormFile(key, "myfile.txt")
7191 if err != nil {
7192 t.Fatal(err)
7193 }
7194 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7195 t.Fatal(err)
7196 }
7197 if err := mw.Close(); err != nil {
7198 t.Fatal(err)
7199 }
7200 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7201 if err != nil {
7202 t.Fatal(err)
7203 }
7204 req.Header.Set("Content-Type", mw.FormDataContentType())
7205 res, err := cst.c.Do(req)
7206 if err != nil {
7207 t.Fatal(err)
7208 }
7209 defer res.Body.Close()
7210 fname, err := io.ReadAll(res.Body)
7211 if err != nil {
7212 t.Fatal(err)
7213 }
7214 cst.close()
7215 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7216 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7217 }
7218 }
7219
7220 func TestHeadBody(t *testing.T) {
7221 const identityMode = false
7222 const chunkedMode = true
7223 run(t, func(t *testing.T, mode testMode) {
7224 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7225 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7226 })
7227 }
7228
7229 func TestGetBody(t *testing.T) {
7230 const identityMode = false
7231 const chunkedMode = true
7232 run(t, func(t *testing.T, mode testMode) {
7233 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7234 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7235 })
7236 }
7237
7238 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7239 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7240 b, err := io.ReadAll(r.Body)
7241 if err != nil {
7242 t.Errorf("server reading body: %v", err)
7243 return
7244 }
7245 w.Header().Set("X-Request-Body", string(b))
7246 w.Header().Set("Content-Length", "0")
7247 }))
7248 defer cst.close()
7249 for _, reqBody := range []string{
7250 "",
7251 "",
7252 "request_body",
7253 "",
7254 } {
7255 var bodyReader io.Reader
7256 if reqBody != "" {
7257 bodyReader = strings.NewReader(reqBody)
7258 if chunked {
7259 bodyReader = bufio.NewReader(bodyReader)
7260 }
7261 }
7262 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7263 if err != nil {
7264 t.Fatal(err)
7265 }
7266 res, err := cst.c.Do(req)
7267 if err != nil {
7268 t.Fatal(err)
7269 }
7270 res.Body.Close()
7271 if got, want := res.StatusCode, 200; got != want {
7272 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7273 }
7274 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7275 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7276 }
7277 }
7278 }
7279
7280
7281
7282 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7283 func testDisableContentLength(t *testing.T, mode testMode) {
7284 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7285 w.Header()["Content-Length"] = nil
7286 fmt.Fprintf(w, "OK")
7287 }))
7288
7289 res, err := noCL.c.Get(noCL.ts.URL)
7290 if err != nil {
7291 t.Fatal(err)
7292 }
7293 if got, haveCL := res.Header["Content-Length"]; haveCL {
7294 t.Errorf("Unexpected Content-Length: %q", got)
7295 }
7296 if err := res.Body.Close(); err != nil {
7297 t.Fatal(err)
7298 }
7299
7300 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7301 fmt.Fprintf(w, "OK")
7302 }))
7303
7304 res, err = withCL.c.Get(withCL.ts.URL)
7305 if err != nil {
7306 t.Fatal(err)
7307 }
7308 if got := res.Header.Get("Content-Length"); got != "2" {
7309 t.Errorf("Content-Length: %q; want 2", got)
7310 }
7311 if err := res.Body.Close(); err != nil {
7312 t.Fatal(err)
7313 }
7314 }
7315
7316 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7317 func testErrorContentLength(t *testing.T, mode testMode) {
7318 const errorBody = "an error occurred"
7319 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7320 w.Header().Set("Content-Length", "1000")
7321 Error(w, errorBody, 400)
7322 }))
7323 res, err := cst.c.Get(cst.ts.URL)
7324 if err != nil {
7325 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7326 }
7327 defer res.Body.Close()
7328 body, err := io.ReadAll(res.Body)
7329 if err != nil {
7330 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7331 }
7332 if string(body) != errorBody+"\n" {
7333 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7334 }
7335 }
7336
7337 func TestError(t *testing.T) {
7338 w := httptest.NewRecorder()
7339 w.Header().Set("Content-Length", "1")
7340 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7341 w.Header().Set("Other", "foo")
7342 Error(w, "oops", 432)
7343
7344 h := w.Header()
7345 for _, hdr := range []string{"Content-Length"} {
7346 if v, ok := h[hdr]; ok {
7347 t.Errorf("%s: %q, want not present", hdr, v)
7348 }
7349 }
7350 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7351 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7352 }
7353 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7354 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7355 }
7356 }
7357
7358 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7359 run(t, testServerReadAfterWriteHeader100Continue)
7360 }
7361 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7362 t.Skip("https://go.dev/issue/67555")
7363 body := []byte("body")
7364 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7365 w.WriteHeader(200)
7366 NewResponseController(w).Flush()
7367 io.ReadAll(r.Body)
7368 w.Write(body)
7369 }), func(tr *Transport) {
7370 tr.ExpectContinueTimeout = 24 * time.Hour
7371 })
7372
7373 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7374 req.Header.Set("Expect", "100-continue")
7375 res, err := cst.c.Do(req)
7376 if err != nil {
7377 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7378 }
7379 defer res.Body.Close()
7380 got, err := io.ReadAll(res.Body)
7381 if err != nil {
7382 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7383 }
7384 if !bytes.Equal(got, body) {
7385 t.Fatalf("response body = %q, want %q", got, body)
7386 }
7387 }
7388
7389 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7390 run(t, testServerReadAfterHandlerDone100Continue)
7391 }
7392 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7393 t.Skip("https://go.dev/issue/67555")
7394 readyc := make(chan struct{})
7395 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7396 go func() {
7397 <-readyc
7398 io.ReadAll(r.Body)
7399 <-readyc
7400 }()
7401 }), func(tr *Transport) {
7402 tr.ExpectContinueTimeout = 24 * time.Hour
7403 })
7404
7405 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7406 req.Header.Set("Expect", "100-continue")
7407 res, err := cst.c.Do(req)
7408 if err != nil {
7409 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7410 }
7411 res.Body.Close()
7412 readyc <- struct{}{}
7413 readyc <- struct{}{}
7414 }
7415
7416 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7417 run(t, testServerReadAfterHandlerAbort100Continue)
7418 }
7419 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7420 t.Skip("https://go.dev/issue/67555")
7421 readyc := make(chan struct{})
7422 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7423 go func() {
7424 <-readyc
7425 io.ReadAll(r.Body)
7426 <-readyc
7427 }()
7428 panic(ErrAbortHandler)
7429 }), func(tr *Transport) {
7430 tr.ExpectContinueTimeout = 24 * time.Hour
7431 })
7432
7433 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7434 req.Header.Set("Expect", "100-continue")
7435 res, err := cst.c.Do(req)
7436 if err == nil {
7437 res.Body.Close()
7438 }
7439 readyc <- struct{}{}
7440 readyc <- struct{}{}
7441 }
7442
7443 func TestInvalidChunkedBodies(t *testing.T) {
7444 for _, test := range []struct {
7445 name string
7446 b string
7447 }{{
7448 name: "bare LF in chunk size",
7449 b: "1\na\r\n0\r\n\r\n",
7450 }, {
7451 name: "bare LF at body end",
7452 b: "1\r\na\r\n0\r\n\n",
7453 }} {
7454 t.Run(test.name, func(t *testing.T) {
7455 reqc := make(chan error)
7456 ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7457 got, err := io.ReadAll(r.Body)
7458 if err == nil {
7459 t.Logf("read body: %q", got)
7460 }
7461 reqc <- err
7462 })).ts
7463
7464 serverURL, err := url.Parse(ts.URL)
7465 if err != nil {
7466 t.Fatal(err)
7467 }
7468
7469 conn, err := net.Dial("tcp", serverURL.Host)
7470 if err != nil {
7471 t.Fatal(err)
7472 }
7473
7474 if _, err := conn.Write([]byte(
7475 "POST / HTTP/1.1\r\n" +
7476 "Host: localhost\r\n" +
7477 "Transfer-Encoding: chunked\r\n" +
7478 "Connection: close\r\n" +
7479 "\r\n" +
7480 test.b)); err != nil {
7481 t.Fatal(err)
7482 }
7483 conn.(*net.TCPConn).CloseWrite()
7484
7485 if err := <-reqc; err == nil {
7486 t.Errorf("server handler: io.ReadAll(r.Body) succeeded, want error")
7487 }
7488 })
7489 }
7490 }
7491
7492
7493 func TestServerTLSNextProtos(t *testing.T) {
7494 run(t, testServerTLSNextProtos, []testMode{https1Mode, http2Mode})
7495 }
7496 func testServerTLSNextProtos(t *testing.T, mode testMode) {
7497 CondSkipHTTP2(t)
7498
7499 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7500 if err != nil {
7501 t.Fatal(err)
7502 }
7503 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7504 if err != nil {
7505 t.Fatal(err)
7506 }
7507 certpool := x509.NewCertPool()
7508 certpool.AddCert(leafCert)
7509
7510 protos := new(Protocols)
7511 switch mode {
7512 case https1Mode:
7513 protos.SetHTTP1(true)
7514 case http2Mode:
7515 protos.SetHTTP2(true)
7516 }
7517
7518 wantNextProtos := []string{"http/1.1", "h2", "other"}
7519 nextProtos := slices.Clone(wantNextProtos)
7520
7521
7522 srv := &Server{
7523 TLSConfig: &tls.Config{
7524 Certificates: []tls.Certificate{cert},
7525 NextProtos: nextProtos,
7526 },
7527 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {}),
7528 Protocols: protos,
7529 }
7530 tr := &Transport{
7531 TLSClientConfig: &tls.Config{
7532 RootCAs: certpool,
7533 NextProtos: nextProtos,
7534 },
7535 Protocols: protos,
7536 }
7537
7538 listener := newLocalListener(t)
7539 srvc := make(chan error, 1)
7540 go func() {
7541 srvc <- srv.ServeTLS(listener, "", "")
7542 }()
7543 t.Cleanup(func() {
7544 srv.Close()
7545 <-srvc
7546 })
7547
7548 client := &Client{Transport: tr}
7549 resp, err := client.Get("https://" + listener.Addr().String())
7550 if err != nil {
7551 t.Fatal(err)
7552 }
7553 resp.Body.Close()
7554
7555 if !slices.Equal(nextProtos, wantNextProtos) {
7556 t.Fatalf("after running test: original NextProtos slice = %v, want %v", nextProtos, wantNextProtos)
7557 }
7558 }
7559
View as plain text