Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 "crypto/tls"
16 "encoding/json"
17 "errors"
18 "fmt"
19 "internal/testenv"
20 "io"
21 "log"
22 "math/rand"
23 "mime/multipart"
24 "net"
25 . "net/http"
26 "net/http/httptest"
27 "net/http/httptrace"
28 "net/http/httputil"
29 "net/http/internal"
30 "net/http/internal/testcert"
31 "net/url"
32 "os"
33 "path/filepath"
34 "reflect"
35 "regexp"
36 "runtime"
37 "strconv"
38 "strings"
39 "sync"
40 "sync/atomic"
41 "syscall"
42 "testing"
43 "time"
44 )
45
46 type dummyAddr string
47 type oneConnListener struct {
48 conn net.Conn
49 }
50
51 func (l *oneConnListener) Accept() (c net.Conn, err error) {
52 c = l.conn
53 if c == nil {
54 err = io.EOF
55 return
56 }
57 err = nil
58 l.conn = nil
59 return
60 }
61
62 func (l *oneConnListener) Close() error {
63 return nil
64 }
65
66 func (l *oneConnListener) Addr() net.Addr {
67 return dummyAddr("test-address")
68 }
69
70 func (a dummyAddr) Network() string {
71 return string(a)
72 }
73
74 func (a dummyAddr) String() string {
75 return string(a)
76 }
77
78 type noopConn struct{}
79
80 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
81 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
82 func (noopConn) SetDeadline(t time.Time) error { return nil }
83 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
84 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
85
86 type rwTestConn struct {
87 io.Reader
88 io.Writer
89 noopConn
90
91 closeFunc func() error
92 closec chan bool
93 }
94
95 func (c *rwTestConn) Close() error {
96 if c.closeFunc != nil {
97 return c.closeFunc()
98 }
99 select {
100 case c.closec <- true:
101 default:
102 }
103 return nil
104 }
105
106 type testConn struct {
107 readMu sync.Mutex
108 readBuf bytes.Buffer
109 writeBuf bytes.Buffer
110 closec chan bool
111 noopConn
112 }
113
114 func newTestConn() *testConn {
115 return &testConn{closec: make(chan bool, 1)}
116 }
117
118 func (c *testConn) Read(b []byte) (int, error) {
119 c.readMu.Lock()
120 defer c.readMu.Unlock()
121 return c.readBuf.Read(b)
122 }
123
124 func (c *testConn) Write(b []byte) (int, error) {
125 return c.writeBuf.Write(b)
126 }
127
128 func (c *testConn) Close() error {
129 select {
130 case c.closec <- true:
131 default:
132 }
133 return nil
134 }
135
136
137
138 func reqBytes(req string) []byte {
139 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
140 }
141
142 type handlerTest struct {
143 logbuf bytes.Buffer
144 handler Handler
145 }
146
147 func newHandlerTest(h Handler) handlerTest {
148 return handlerTest{handler: h}
149 }
150
151 func (ht *handlerTest) rawResponse(req string) string {
152 reqb := reqBytes(req)
153 var output strings.Builder
154 conn := &rwTestConn{
155 Reader: bytes.NewReader(reqb),
156 Writer: &output,
157 closec: make(chan bool, 1),
158 }
159 ln := &oneConnListener{conn: conn}
160 srv := &Server{
161 ErrorLog: log.New(&ht.logbuf, "", 0),
162 Handler: ht.handler,
163 }
164 go srv.Serve(ln)
165 <-conn.closec
166 return output.String()
167 }
168
169 func TestConsumingBodyOnNextConn(t *testing.T) {
170 t.Parallel()
171 defer afterTest(t)
172 conn := new(testConn)
173 for i := 0; i < 2; i++ {
174 conn.readBuf.Write([]byte(
175 "POST / HTTP/1.1\r\n" +
176 "Host: test\r\n" +
177 "Content-Length: 11\r\n" +
178 "\r\n" +
179 "foo=1&bar=1"))
180 }
181
182 reqNum := 0
183 ch := make(chan *Request)
184 servech := make(chan error)
185 listener := &oneConnListener{conn}
186 handler := func(res ResponseWriter, req *Request) {
187 reqNum++
188 ch <- req
189 }
190
191 go func() {
192 servech <- Serve(listener, HandlerFunc(handler))
193 }()
194
195 var req *Request
196 req = <-ch
197 if req == nil {
198 t.Fatal("Got nil first request.")
199 }
200 if req.Method != "POST" {
201 t.Errorf("For request #1's method, got %q; expected %q",
202 req.Method, "POST")
203 }
204
205 req = <-ch
206 if req == nil {
207 t.Fatal("Got nil first request.")
208 }
209 if req.Method != "POST" {
210 t.Errorf("For request #2's method, got %q; expected %q",
211 req.Method, "POST")
212 }
213
214 if serveerr := <-servech; serveerr != io.EOF {
215 t.Errorf("Serve returned %q; expected EOF", serveerr)
216 }
217 }
218
219 type stringHandler string
220
221 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
222 w.Header().Set("Result", string(s))
223 }
224
225 var handlers = []struct {
226 pattern string
227 msg string
228 }{
229 {"/", "Default"},
230 {"/someDir/", "someDir"},
231 {"/#/", "hash"},
232 {"someHost.com/someDir/", "someHost.com/someDir"},
233 }
234
235 var vtests = []struct {
236 url string
237 expected string
238 }{
239 {"http://localhost/someDir/apage", "someDir"},
240 {"http://localhost/%23/apage", "hash"},
241 {"http://localhost/otherDir/apage", "Default"},
242 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
243 {"http://otherHost.com/someDir/apage", "someDir"},
244 {"http://otherHost.com/aDir/apage", "Default"},
245
246 {"http://localhost/someDir", "/someDir/"},
247 {"http://localhost/%23", "/%23/"},
248 {"http://someHost.com/someDir", "/someDir/"},
249 }
250
251 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
252 func testHostHandlers(t *testing.T, mode testMode) {
253 mux := NewServeMux()
254 for _, h := range handlers {
255 mux.Handle(h.pattern, stringHandler(h.msg))
256 }
257 ts := newClientServerTest(t, mode, mux).ts
258
259 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
260 if err != nil {
261 t.Fatal(err)
262 }
263 defer conn.Close()
264 cc := httputil.NewClientConn(conn, nil)
265 for _, vt := range vtests {
266 var r *Response
267 var req Request
268 if req.URL, err = url.Parse(vt.url); err != nil {
269 t.Errorf("cannot parse url: %v", err)
270 continue
271 }
272 if err := cc.Write(&req); err != nil {
273 t.Errorf("writing request: %v", err)
274 continue
275 }
276 r, err := cc.Read(&req)
277 if err != nil {
278 t.Errorf("reading response: %v", err)
279 continue
280 }
281 switch r.StatusCode {
282 case StatusOK:
283 s := r.Header.Get("Result")
284 if s != vt.expected {
285 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
286 }
287 case StatusMovedPermanently:
288 s := r.Header.Get("Location")
289 if s != vt.expected {
290 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
291 }
292 default:
293 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
294 }
295 }
296 }
297
298 var serveMuxRegister = []struct {
299 pattern string
300 h Handler
301 }{
302 {"/dir/", serve(200)},
303 {"/search", serve(201)},
304 {"codesearch.google.com/search", serve(202)},
305 {"codesearch.google.com/", serve(203)},
306 {"example.com/", HandlerFunc(checkQueryStringHandler)},
307 }
308
309
310 func serve(code int) HandlerFunc {
311 return func(w ResponseWriter, r *Request) {
312 w.WriteHeader(code)
313 }
314 }
315
316
317
318
319 func checkQueryStringHandler(w ResponseWriter, r *Request) {
320 u := *r.URL
321 u.Scheme = "http"
322 u.Host = r.Host
323 u.RawQuery = ""
324 if "http://"+r.URL.RawQuery == u.String() {
325 w.WriteHeader(200)
326 } else {
327 w.WriteHeader(500)
328 }
329 }
330
331 var serveMuxTests = []struct {
332 method string
333 host string
334 path string
335 code int
336 pattern string
337 }{
338 {"GET", "google.com", "/", 404, ""},
339 {"GET", "google.com", "/dir", 301, "/dir/"},
340 {"GET", "google.com", "/dir/", 200, "/dir/"},
341 {"GET", "google.com", "/dir/file", 200, "/dir/"},
342 {"GET", "google.com", "/search", 201, "/search"},
343 {"GET", "google.com", "/search/", 404, ""},
344 {"GET", "google.com", "/search/foo", 404, ""},
345 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
346 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
347 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
348 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
349 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
350 {"GET", "images.google.com", "/search", 201, "/search"},
351 {"GET", "images.google.com", "/search/", 404, ""},
352 {"GET", "images.google.com", "/search/foo", 404, ""},
353 {"GET", "google.com", "/../search", 301, "/search"},
354 {"GET", "google.com", "/dir/..", 301, ""},
355 {"GET", "google.com", "/dir/..", 301, ""},
356 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
357
358
359
360 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
361 {"CONNECT", "google.com", "/../search", 404, ""},
362 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
363 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
364 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
365 }
366
367 func TestServeMuxHandler(t *testing.T) {
368 setParallel(t)
369 mux := NewServeMux()
370 for _, e := range serveMuxRegister {
371 mux.Handle(e.pattern, e.h)
372 }
373
374 for _, tt := range serveMuxTests {
375 r := &Request{
376 Method: tt.method,
377 Host: tt.host,
378 URL: &url.URL{
379 Path: tt.path,
380 },
381 }
382 h, pattern := mux.Handler(r)
383 rr := httptest.NewRecorder()
384 h.ServeHTTP(rr, r)
385 if pattern != tt.pattern || rr.Code != tt.code {
386 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
387 }
388 }
389 }
390
391
392 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
393 setParallel(t)
394 defer func() {
395 if err := recover(); err == nil {
396 t.Error("expected call to mux.HandleFunc to panic")
397 }
398 }()
399 mux := NewServeMux()
400 mux.HandleFunc("/", nil)
401 }
402
403 var serveMuxTests2 = []struct {
404 method string
405 host string
406 url string
407 code int
408 redirOk bool
409 }{
410 {"GET", "google.com", "/", 404, false},
411 {"GET", "example.com", "/test/?example.com/test/", 200, false},
412 {"GET", "example.com", "test/?example.com/test/", 200, true},
413 }
414
415
416
417 func TestServeMuxHandlerRedirects(t *testing.T) {
418 setParallel(t)
419 mux := NewServeMux()
420 for _, e := range serveMuxRegister {
421 mux.Handle(e.pattern, e.h)
422 }
423
424 for _, tt := range serveMuxTests2 {
425 tries := 1
426 turl := tt.url
427 for {
428 u, e := url.Parse(turl)
429 if e != nil {
430 t.Fatal(e)
431 }
432 r := &Request{
433 Method: tt.method,
434 Host: tt.host,
435 URL: u,
436 }
437 h, _ := mux.Handler(r)
438 rr := httptest.NewRecorder()
439 h.ServeHTTP(rr, r)
440 if rr.Code != 301 {
441 if rr.Code != tt.code {
442 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
443 }
444 break
445 }
446 if !tt.redirOk {
447 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
448 break
449 }
450 turl = rr.HeaderMap.Get("Location")
451 tries--
452 }
453 if tries < 0 {
454 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
455 }
456 }
457 }
458
459
460 func TestMuxRedirectLeadingSlashes(t *testing.T) {
461 setParallel(t)
462 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
463 for _, path := range paths {
464 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
465 if err != nil {
466 t.Errorf("%s", err)
467 }
468 mux := NewServeMux()
469 resp := httptest.NewRecorder()
470
471 mux.ServeHTTP(resp, req)
472
473 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
474 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
475 return
476 }
477
478 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
479 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
480 return
481 }
482 }
483 }
484
485
486
487
488
489 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
490 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
491 }
492 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
493 writeBackQuery := func(w ResponseWriter, r *Request) {
494 fmt.Fprintf(w, "%s", r.URL.RawQuery)
495 }
496
497 mux := NewServeMux()
498 mux.HandleFunc("/testOne", writeBackQuery)
499 mux.HandleFunc("/testTwo/", writeBackQuery)
500 mux.HandleFunc("/testThree", writeBackQuery)
501 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
502 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
503 })
504
505 ts := newClientServerTest(t, mode, mux).ts
506
507 tests := [...]struct {
508 path string
509 method string
510 want string
511 statusOk bool
512 }{
513 0: {"/testOne?this=that", "GET", "this=that", true},
514 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
515 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
516 3: {"/testTwo?", "GET", "", true},
517 4: {"/testThree?foo", "GET", "foo", true},
518 5: {"/testThree/?foo", "GET", "foo:bar", true},
519 6: {"/testThree?foo", "CONNECT", "foo", true},
520 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
521
522
523 8: {"/testOne/foo/..?foo", "GET", "foo", true},
524 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
525 }
526
527 for i, tt := range tests {
528 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
529 res, err := ts.Client().Do(req)
530 if err != nil {
531 continue
532 }
533 slurp, _ := io.ReadAll(res.Body)
534 res.Body.Close()
535 if !tt.statusOk {
536 if got, want := res.StatusCode, 404; got != want {
537 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
538 }
539 }
540 if got, want := string(slurp), tt.want; got != want {
541 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
542 }
543 }
544 }
545
546 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
547 setParallel(t)
548
549 mux := NewServeMux()
550 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
551 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
552 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
553 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
554 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
555 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
556
557 tests := []struct {
558 method string
559 url string
560 code int
561 loc string
562 want string
563 }{
564 {"GET", "http://example.com/", 404, "", ""},
565 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
566 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
567 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
568 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
569 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
570 {"CONNECT", "http://example.com/", 404, "", ""},
571 {"CONNECT", "http://example.com:3000/", 404, "", ""},
572 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
573 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
574 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
575 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
576 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
577 }
578
579 for i, tt := range tests {
580 req, _ := NewRequest(tt.method, tt.url, nil)
581 w := httptest.NewRecorder()
582 mux.ServeHTTP(w, req)
583
584 if got, want := w.Code, tt.code; got != want {
585 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
586 }
587
588 if tt.code == 301 {
589 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
590 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
591 }
592 } else {
593 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
594 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
595 }
596 }
597 }
598 }
599
600
601
602
603 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
604 mux := NewServeMux()
605 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
606 fmt.Fprintln(w, "ok")
607 })
608 w := httptest.NewRecorder()
609 req, _ := NewRequest("GET", "/", nil)
610 mux.ServeHTTP(w, req)
611 if g, w := w.Code, 404; g != w {
612 t.Errorf("got %d, want %d", g, w)
613 }
614 }
615
616
617
618
619 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
620 mux := NewServeMux()
621 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
622 fmt.Fprintln(w, "ok")
623 })
624 w := httptest.NewRecorder()
625 req, _ := NewRequest("GET", "/", nil)
626 mux.ServeHTTP(w, req)
627 if g, w := w.Code, 404; g != w {
628 t.Errorf("got %d, want %d", g, w)
629 }
630 }
631
632 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
633 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
634 mux := NewServeMux()
635 newClientServerTest(t, mode, mux)
636 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
637 }
638
639 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
640 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
641 func benchmarkServeMux(b *testing.B, runHandler bool) {
642 type test struct {
643 path string
644 code int
645 req *Request
646 }
647
648
649 var tests []test
650 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
651 for _, e := range endpoints {
652 for i := 200; i < 230; i++ {
653 p := fmt.Sprintf("/%s/%d/", e, i)
654 tests = append(tests, test{
655 path: p,
656 code: i,
657 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
658 })
659 }
660 }
661 mux := NewServeMux()
662 for _, tt := range tests {
663 mux.Handle(tt.path, serve(tt.code))
664 }
665
666 rw := httptest.NewRecorder()
667 b.ReportAllocs()
668 b.ResetTimer()
669 for i := 0; i < b.N; i++ {
670 for _, tt := range tests {
671 *rw = httptest.ResponseRecorder{}
672 h, pattern := mux.Handler(tt.req)
673 if runHandler {
674 h.ServeHTTP(rw, tt.req)
675 if pattern != tt.path || rw.Code != tt.code {
676 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
677 }
678 }
679 }
680 }
681 }
682
683 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
684 func testServerTimeouts(t *testing.T, mode testMode) {
685 runTimeSensitiveTest(t, []time.Duration{
686 10 * time.Millisecond,
687 50 * time.Millisecond,
688 100 * time.Millisecond,
689 500 * time.Millisecond,
690 1 * time.Second,
691 }, func(t *testing.T, timeout time.Duration) error {
692 return testServerTimeoutsWithTimeout(t, timeout, mode)
693 })
694 }
695
696 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
697 var reqNum atomic.Int32
698 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
699 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
700 }), func(ts *httptest.Server) {
701 ts.Config.ReadTimeout = timeout
702 ts.Config.WriteTimeout = timeout
703 })
704 defer cst.close()
705 ts := cst.ts
706
707
708 c := ts.Client()
709 r, err := c.Get(ts.URL)
710 if err != nil {
711 return fmt.Errorf("http Get #1: %v", err)
712 }
713 got, err := io.ReadAll(r.Body)
714 expected := "req=1"
715 if string(got) != expected || err != nil {
716 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
717 string(got), err, expected)
718 }
719
720
721 t1 := time.Now()
722 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
723 if err != nil {
724 return fmt.Errorf("Dial: %v", err)
725 }
726 buf := make([]byte, 1)
727 n, err := conn.Read(buf)
728 conn.Close()
729 latency := time.Since(t1)
730 if n != 0 || err != io.EOF {
731 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
732 }
733 minLatency := timeout / 5 * 4
734 if latency < minLatency {
735 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
736 }
737
738
739
740
741 r, err = c.Get(ts.URL)
742 if err != nil {
743 return fmt.Errorf("http Get #2: %v", err)
744 }
745 got, err = io.ReadAll(r.Body)
746 r.Body.Close()
747 expected = "req=2"
748 if string(got) != expected || err != nil {
749 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
750 }
751
752 if !testing.Short() {
753 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
754 if err != nil {
755 return fmt.Errorf("long Dial: %v", err)
756 }
757 defer conn.Close()
758 go io.Copy(io.Discard, conn)
759 for i := 0; i < 5; i++ {
760 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
761 if err != nil {
762 return fmt.Errorf("on write %d: %v", i, err)
763 }
764 time.Sleep(timeout / 2)
765 }
766 }
767 return nil
768 }
769
770 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
771 func testServerReadTimeout(t *testing.T, mode testMode) {
772 respBody := "response body"
773 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
774 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
775 _, err := io.Copy(io.Discard, req.Body)
776 if !errors.Is(err, os.ErrDeadlineExceeded) {
777 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
778 }
779 res.Write([]byte(respBody))
780 }), func(ts *httptest.Server) {
781 ts.Config.ReadHeaderTimeout = -1
782 ts.Config.ReadTimeout = timeout
783 t.Logf("Server.Config.ReadTimeout = %v", timeout)
784 })
785
786 var retries atomic.Int32
787 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
788 if retries.Add(1) != 1 {
789 return nil, errors.New("too many retries")
790 }
791 return nil, nil
792 }
793
794 pr, pw := io.Pipe()
795 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
796 if err != nil {
797 t.Logf("Get error, retrying: %v", err)
798 cst.close()
799 continue
800 }
801 defer res.Body.Close()
802 got, err := io.ReadAll(res.Body)
803 if string(got) != respBody || err != nil {
804 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
805 }
806 pw.Close()
807 break
808 }
809 }
810
811 func TestServerNoReadTimeout(t *testing.T) { run(t, testServerNoReadTimeout) }
812 func testServerNoReadTimeout(t *testing.T, mode testMode) {
813 reqBody := "Hello, Gophers!"
814 resBody := "Hi, Gophers!"
815 for _, timeout := range []time.Duration{0, -1} {
816 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
817 ctl := NewResponseController(res)
818 ctl.EnableFullDuplex()
819 res.WriteHeader(StatusOK)
820
821
822 if err := ctl.Flush(); err != nil {
823 t.Errorf("server flush response: %v", err)
824 return
825 }
826 got, err := io.ReadAll(req.Body)
827 if string(got) != reqBody || err != nil {
828 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
829 }
830 res.Write([]byte(resBody))
831 }), func(ts *httptest.Server) {
832 ts.Config.ReadTimeout = timeout
833 t.Logf("Server.Config.ReadTimeout = %d", timeout)
834 })
835
836 pr, pw := io.Pipe()
837 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
838 if err != nil {
839 t.Fatal(err)
840 }
841 defer res.Body.Close()
842
843
844 time.Sleep(10 * time.Millisecond)
845 pw.Write([]byte(reqBody))
846 pw.Close()
847
848 got, err := io.ReadAll(res.Body)
849 if string(got) != resBody || err != nil {
850 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
851 }
852 }
853 }
854
855 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
856 func testServerWriteTimeout(t *testing.T, mode testMode) {
857 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
858 errc := make(chan error, 2)
859 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
860 errc <- nil
861 _, err := io.Copy(res, neverEnding('a'))
862 errc <- err
863 }), func(ts *httptest.Server) {
864 ts.Config.WriteTimeout = timeout
865 t.Logf("Server.Config.WriteTimeout = %v", timeout)
866 })
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885 var retries atomic.Int32
886 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
887 if retries.Add(1) != 1 {
888 return nil, errors.New("too many retries")
889 }
890 return nil, nil
891 }
892
893 res, err := cst.c.Get(cst.ts.URL)
894 if err != nil {
895
896 t.Logf("Get error, retrying: %v", err)
897 cst.close()
898 continue
899 }
900 defer res.Body.Close()
901 _, err = io.Copy(io.Discard, res.Body)
902 if err == nil {
903 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
904 }
905 select {
906 case <-errc:
907 err = <-errc
908 if !errors.Is(err, os.ErrDeadlineExceeded) {
909 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
910 }
911 return
912 default:
913
914 t.Logf("handler didn't run, retrying")
915 cst.close()
916 }
917 }
918 }
919
920 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
921 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
922 for _, timeout := range []time.Duration{0, -1} {
923 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
924 _, err := io.Copy(res, neverEnding('a'))
925 t.Logf("server write response: %v", err)
926 }), func(ts *httptest.Server) {
927 ts.Config.WriteTimeout = timeout
928 t.Logf("Server.Config.WriteTimeout = %d", timeout)
929 })
930
931 res, err := cst.c.Get(cst.ts.URL)
932 if err != nil {
933 t.Fatal(err)
934 }
935 defer res.Body.Close()
936 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
937 if n != 1<<20 || err != nil {
938 t.Errorf("client read response body: %d, %v", n, err)
939 }
940
941
942 res.Body.Close()
943 cst.ts.Config.Shutdown(context.Background())
944 }
945 }
946
947
948 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
949 run(t, testWriteDeadlineExtendedOnNewRequest)
950 }
951 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
952 if testing.Short() {
953 t.Skip("skipping in short mode")
954 }
955 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
956 func(ts *httptest.Server) {
957 ts.Config.WriteTimeout = 250 * time.Millisecond
958 },
959 ).ts
960
961 c := ts.Client()
962
963 for i := 1; i <= 3; i++ {
964 req, err := NewRequest("GET", ts.URL, nil)
965 if err != nil {
966 t.Fatal(err)
967 }
968
969 r, err := c.Do(req)
970 if err != nil {
971 t.Fatalf("http2 Get #%d: %v", i, err)
972 }
973 r.Body.Close()
974 time.Sleep(ts.Config.WriteTimeout / 2)
975 }
976 }
977
978
979
980 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
981 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
982 for i, timeout := range tries {
983 err := testFunc(timeout)
984 if err == nil {
985 return
986 }
987 t.Logf("failed at %v: %v", timeout, err)
988 if i != len(tries)-1 {
989 t.Logf("retrying at %v ...", tries[i+1])
990 }
991 }
992 t.Fatal("all attempts failed")
993 }
994
995
996 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
997 if testing.Short() {
998 t.Skip("skipping in short mode")
999 }
1000 setParallel(t)
1001 run(t, func(t *testing.T, mode testMode) {
1002 tryTimeouts(t, func(timeout time.Duration) error {
1003 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1004 })
1005 })
1006 }
1007
1008 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1009 firstRequest := make(chan bool, 1)
1010 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1011 select {
1012 case firstRequest <- true:
1013
1014 default:
1015
1016 time.Sleep(timeout)
1017 }
1018 }), func(ts *httptest.Server) {
1019 ts.Config.WriteTimeout = timeout / 2
1020 })
1021 defer cst.close()
1022 ts := cst.ts
1023
1024 c := ts.Client()
1025
1026 req, err := NewRequest("GET", ts.URL, nil)
1027 if err != nil {
1028 return fmt.Errorf("NewRequest: %v", err)
1029 }
1030 r, err := c.Do(req)
1031 if err != nil {
1032 return fmt.Errorf("Get #1: %v", err)
1033 }
1034 r.Body.Close()
1035
1036 req, err = NewRequest("GET", ts.URL, nil)
1037 if err != nil {
1038 return fmt.Errorf("NewRequest: %v", err)
1039 }
1040 r, err = c.Do(req)
1041 if err == nil {
1042 r.Body.Close()
1043 return fmt.Errorf("Get #2 expected error, got nil")
1044 }
1045 if mode == http2Mode {
1046 expected := "stream ID 3; INTERNAL_ERROR"
1047 if !strings.Contains(err.Error(), expected) {
1048 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1049 }
1050 }
1051 return nil
1052 }
1053
1054
1055 func TestNoWriteDeadline(t *testing.T) {
1056 if testing.Short() {
1057 t.Skip("skipping in short mode")
1058 }
1059 setParallel(t)
1060 defer afterTest(t)
1061 run(t, func(t *testing.T, mode testMode) {
1062 tryTimeouts(t, func(timeout time.Duration) error {
1063 return testNoWriteDeadline(t, mode, timeout)
1064 })
1065 })
1066 }
1067
1068 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1069 firstRequest := make(chan bool, 1)
1070 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1071 select {
1072 case firstRequest <- true:
1073
1074 default:
1075
1076 time.Sleep(timeout)
1077 }
1078 }))
1079 defer cst.close()
1080 ts := cst.ts
1081
1082 c := ts.Client()
1083
1084 for i := 0; i < 2; i++ {
1085 req, err := NewRequest("GET", ts.URL, nil)
1086 if err != nil {
1087 return fmt.Errorf("NewRequest: %v", err)
1088 }
1089 r, err := c.Do(req)
1090 if err != nil {
1091 return fmt.Errorf("Get #%d: %v", i, err)
1092 }
1093 r.Body.Close()
1094 }
1095 return nil
1096 }
1097
1098
1099
1100
1101 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1102 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1103 var (
1104 mu sync.RWMutex
1105 conn net.Conn
1106 )
1107 var afterTimeoutErrc = make(chan error, 1)
1108 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1109 buf := make([]byte, 512<<10)
1110 _, err := w.Write(buf)
1111 if err != nil {
1112 t.Errorf("handler Write error: %v", err)
1113 return
1114 }
1115 mu.RLock()
1116 defer mu.RUnlock()
1117 if conn == nil {
1118 t.Error("no established connection found")
1119 return
1120 }
1121 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1122 _, err = w.Write(buf)
1123 afterTimeoutErrc <- err
1124 }), func(ts *httptest.Server) {
1125 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1126 }).ts
1127
1128 c := ts.Client()
1129
1130 err := func() error {
1131 res, err := c.Get(ts.URL)
1132 if err != nil {
1133 return err
1134 }
1135 _, err = io.Copy(io.Discard, res.Body)
1136 res.Body.Close()
1137 return err
1138 }()
1139 if err == nil {
1140 t.Errorf("expected an error copying body from Get request")
1141 }
1142
1143 if err := <-afterTimeoutErrc; err == nil {
1144 t.Error("expected write error after timeout")
1145 }
1146 }
1147
1148
1149 type trackLastConnListener struct {
1150 net.Listener
1151
1152 mu *sync.RWMutex
1153 last *net.Conn
1154 }
1155
1156 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1157 c, err = l.Listener.Accept()
1158 if err == nil {
1159 l.mu.Lock()
1160 *l.last = c
1161 l.mu.Unlock()
1162 }
1163 return
1164 }
1165
1166
1167 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1168 func testIdentityResponse(t *testing.T, mode testMode) {
1169 if mode == http2Mode {
1170 t.Skip("https://go.dev/issue/56019")
1171 }
1172
1173 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1174 rw.Header().Set("Content-Length", "3")
1175 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1176 switch {
1177 case req.FormValue("overwrite") == "1":
1178 _, err := rw.Write([]byte("foo TOO LONG"))
1179 if err != ErrContentLength {
1180 t.Errorf("expected ErrContentLength; got %v", err)
1181 }
1182 case req.FormValue("underwrite") == "1":
1183 rw.Header().Set("Content-Length", "500")
1184 rw.Write([]byte("too short"))
1185 default:
1186 rw.Write([]byte("foo"))
1187 }
1188 })
1189
1190 ts := newClientServerTest(t, mode, handler).ts
1191 c := ts.Client()
1192
1193
1194
1195
1196
1197 for _, te := range []string{"", "identity"} {
1198 url := ts.URL + "/?te=" + te
1199 res, err := c.Get(url)
1200 if err != nil {
1201 t.Fatalf("error with Get of %s: %v", url, err)
1202 }
1203 if cl, expected := res.ContentLength, int64(3); cl != expected {
1204 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1205 }
1206 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1207 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1208 }
1209 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1210 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1211 url, expected, tl, res.TransferEncoding)
1212 }
1213 res.Body.Close()
1214 }
1215
1216
1217 url := ts.URL + "/?overwrite=1"
1218 res, err := c.Get(url)
1219 if err != nil {
1220 t.Fatalf("error with Get of %s: %v", url, err)
1221 }
1222 res.Body.Close()
1223
1224 if mode != http1Mode {
1225 return
1226 }
1227
1228
1229
1230 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1231 if err != nil {
1232 t.Fatalf("error dialing: %v", err)
1233 }
1234 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1235 if err != nil {
1236 t.Fatalf("error writing: %v", err)
1237 }
1238
1239
1240 got, _ := io.ReadAll(conn)
1241 expectedSuffix := "\r\n\r\ntoo short"
1242 if !strings.HasSuffix(string(got), expectedSuffix) {
1243 t.Errorf("Expected output to end with %q; got response body %q",
1244 expectedSuffix, string(got))
1245 }
1246 }
1247
1248 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1249 setParallel(t)
1250 s := newClientServerTest(t, http1Mode, h).ts
1251
1252 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1253 if err != nil {
1254 t.Fatal("dial error:", err)
1255 }
1256 defer conn.Close()
1257
1258 _, err = fmt.Fprint(conn, req)
1259 if err != nil {
1260 t.Fatal("print error:", err)
1261 }
1262
1263 r := bufio.NewReader(conn)
1264 res, err := ReadResponse(r, &Request{Method: "GET"})
1265 if err != nil {
1266 t.Fatal("ReadResponse error:", err)
1267 }
1268
1269 _, err = io.ReadAll(r)
1270 if err != nil {
1271 t.Fatal("read error:", err)
1272 }
1273
1274 if !res.Close {
1275 t.Errorf("Response.Close = false; want true")
1276 }
1277 }
1278
1279 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1280 setParallel(t)
1281 ts := newClientServerTest(t, http1Mode, handler).ts
1282 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1283 if err != nil {
1284 t.Fatal(err)
1285 }
1286 defer conn.Close()
1287 br := bufio.NewReader(conn)
1288 for i := 0; i < 2; i++ {
1289 if _, err := io.WriteString(conn, req); err != nil {
1290 t.Fatal(err)
1291 }
1292 res, err := ReadResponse(br, nil)
1293 if err != nil {
1294 t.Fatalf("res %d: %v", i+1, err)
1295 }
1296 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1297 t.Fatalf("res %d body copy: %v", i+1, err)
1298 }
1299 res.Body.Close()
1300 }
1301 }
1302
1303
1304 func TestServeHTTP10Close(t *testing.T) {
1305 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1306 ServeFile(w, r, "testdata/file")
1307 }))
1308 }
1309
1310
1311 func TestClientCanClose(t *testing.T) {
1312 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1313
1314 }))
1315 }
1316
1317
1318
1319 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1320 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1321 w.Header().Set("Connection", "close")
1322 }))
1323 }
1324
1325 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1326 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1327 w.Header().Set("Connection", "close")
1328 }))
1329 }
1330
1331 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1332 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1333
1334
1335 }))
1336 }
1337
1338 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1339 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1340
1341
1342 func TestHTTP10KeepAlive204Response(t *testing.T) {
1343 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1344 }
1345
1346 func TestHTTP11KeepAlive204Response(t *testing.T) {
1347 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1348 }
1349
1350 func TestHTTP10KeepAlive304Response(t *testing.T) {
1351 testTCPConnectionStaysOpen(t,
1352 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1353 HandlerFunc(send304))
1354 }
1355
1356
1357 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1358 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1359 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1360 w.(Flusher).Flush()
1361 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1362 }))
1363 type data struct {
1364 Addr string
1365 }
1366 var addrs [2]data
1367 for i := range addrs {
1368 res, err := cst.c.Get(cst.ts.URL)
1369 if err != nil {
1370 t.Fatal(err)
1371 }
1372 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1373 t.Fatal(err)
1374 }
1375 if addrs[i].Addr == "" {
1376 t.Fatal("no address")
1377 }
1378 res.Body.Close()
1379 }
1380 if addrs[0] != addrs[1] {
1381 t.Fatalf("connection not reused")
1382 }
1383 }
1384
1385 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1386 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1387 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1388 fmt.Fprintf(w, "%s", r.RemoteAddr)
1389 }))
1390
1391 res, err := cst.c.Get(cst.ts.URL)
1392 if err != nil {
1393 t.Fatalf("Get error: %v", err)
1394 }
1395 body, err := io.ReadAll(res.Body)
1396 if err != nil {
1397 t.Fatalf("ReadAll error: %v", err)
1398 }
1399 ip := string(body)
1400 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1401 t.Fatalf("Expected local addr; got %q", ip)
1402 }
1403 }
1404
1405 type blockingRemoteAddrListener struct {
1406 net.Listener
1407 conns chan<- net.Conn
1408 }
1409
1410 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1411 c, err := l.Listener.Accept()
1412 if err != nil {
1413 return nil, err
1414 }
1415 brac := &blockingRemoteAddrConn{
1416 Conn: c,
1417 addrs: make(chan net.Addr, 1),
1418 }
1419 l.conns <- brac
1420 return brac, nil
1421 }
1422
1423 type blockingRemoteAddrConn struct {
1424 net.Conn
1425 addrs chan net.Addr
1426 }
1427
1428 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1429 return <-c.addrs
1430 }
1431
1432
1433 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1434 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1435 }
1436 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1437 conns := make(chan net.Conn)
1438 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1439 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1440 }), func(ts *httptest.Server) {
1441 ts.Listener = &blockingRemoteAddrListener{
1442 Listener: ts.Listener,
1443 conns: conns,
1444 }
1445 }).ts
1446
1447 c := ts.Client()
1448
1449 c.Transport.(*Transport).DisableKeepAlives = true
1450
1451 fetch := func(num int, response chan<- string) {
1452 resp, err := c.Get(ts.URL)
1453 if err != nil {
1454 t.Errorf("Request %d: %v", num, err)
1455 response <- ""
1456 return
1457 }
1458 defer resp.Body.Close()
1459 body, err := io.ReadAll(resp.Body)
1460 if err != nil {
1461 t.Errorf("Request %d: %v", num, err)
1462 response <- ""
1463 return
1464 }
1465 response <- string(body)
1466 }
1467
1468
1469 response1c := make(chan string, 1)
1470 go fetch(1, response1c)
1471
1472
1473 conn1 := <-conns
1474
1475
1476 response2c := make(chan string, 1)
1477 go fetch(2, response2c)
1478 conn2 := <-conns
1479
1480
1481 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1482 IP: net.ParseIP("12.12.12.12"), Port: 12}
1483
1484
1485 response2 := <-response2c
1486 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1487 t.Fatalf("response 2 addr = %q; want %q", g, e)
1488 }
1489
1490
1491 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1492 IP: net.ParseIP("21.21.21.21"), Port: 21}
1493
1494
1495 response1 := <-response1c
1496 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1497 t.Fatalf("response 1 addr = %q; want %q", g, e)
1498 }
1499 }
1500
1501
1502
1503 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1504 func testHeadResponses(t *testing.T, mode testMode) {
1505 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1506 _, err := w.Write([]byte("<html>"))
1507 if err != nil {
1508 t.Errorf("ResponseWriter.Write: %v", err)
1509 }
1510
1511
1512 _, err = io.Copy(w, strings.NewReader("789a"))
1513 if err != nil {
1514 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1515 }
1516 }))
1517 res, err := cst.c.Head(cst.ts.URL)
1518 if err != nil {
1519 t.Error(err)
1520 }
1521 if len(res.TransferEncoding) > 0 {
1522 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1523 }
1524 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1525 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1526 }
1527 if v := res.ContentLength; v != 10 {
1528 t.Errorf("Content-Length: %d; want 10", v)
1529 }
1530 body, err := io.ReadAll(res.Body)
1531 if err != nil {
1532 t.Error(err)
1533 }
1534 if len(body) > 0 {
1535 t.Errorf("got unexpected body %q", string(body))
1536 }
1537 }
1538
1539 func TestTLSHandshakeTimeout(t *testing.T) {
1540 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1541 }
1542 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1543 errLog := new(strings.Builder)
1544 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1545 func(ts *httptest.Server) {
1546 ts.Config.ReadTimeout = 250 * time.Millisecond
1547 ts.Config.ErrorLog = log.New(errLog, "", 0)
1548 },
1549 )
1550 ts := cst.ts
1551
1552 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1553 if err != nil {
1554 t.Fatalf("Dial: %v", err)
1555 }
1556 var buf [1]byte
1557 n, err := conn.Read(buf[:])
1558 if err == nil || n != 0 {
1559 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1560 }
1561 conn.Close()
1562
1563 cst.close()
1564 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1565 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1566 }
1567 }
1568
1569 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1570 func testTLSServer(t *testing.T, mode testMode) {
1571 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1572 if r.TLS != nil {
1573 w.Header().Set("X-TLS-Set", "true")
1574 if r.TLS.HandshakeComplete {
1575 w.Header().Set("X-TLS-HandshakeComplete", "true")
1576 }
1577 }
1578 }), func(ts *httptest.Server) {
1579 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1580 }).ts
1581
1582
1583
1584
1585
1586
1587 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1588 if err != nil {
1589 t.Fatalf("Dial: %v", err)
1590 }
1591 defer idleConn.Close()
1592
1593 if !strings.HasPrefix(ts.URL, "https://") {
1594 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1595 return
1596 }
1597 client := ts.Client()
1598 res, err := client.Get(ts.URL)
1599 if err != nil {
1600 t.Error(err)
1601 return
1602 }
1603 if res == nil {
1604 t.Errorf("got nil Response")
1605 return
1606 }
1607 defer res.Body.Close()
1608 if res.Header.Get("X-TLS-Set") != "true" {
1609 t.Errorf("expected X-TLS-Set response header")
1610 return
1611 }
1612 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1613 t.Errorf("expected X-TLS-HandshakeComplete header")
1614 }
1615 }
1616
1617 func TestServeTLS(t *testing.T) {
1618 CondSkipHTTP2(t)
1619
1620 defer afterTest(t)
1621 defer SetTestHookServerServe(nil)
1622
1623 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1624 if err != nil {
1625 t.Fatal(err)
1626 }
1627 tlsConf := &tls.Config{
1628 Certificates: []tls.Certificate{cert},
1629 }
1630
1631 ln := newLocalListener(t)
1632 defer ln.Close()
1633 addr := ln.Addr().String()
1634
1635 serving := make(chan bool, 1)
1636 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1637 serving <- true
1638 })
1639 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1640 s := &Server{
1641 Addr: addr,
1642 TLSConfig: tlsConf,
1643 Handler: handler,
1644 }
1645 errc := make(chan error, 1)
1646 go func() { errc <- s.ServeTLS(ln, "", "") }()
1647 select {
1648 case err := <-errc:
1649 t.Fatalf("ServeTLS: %v", err)
1650 case <-serving:
1651 }
1652
1653 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1654 InsecureSkipVerify: true,
1655 NextProtos: []string{"h2", "http/1.1"},
1656 })
1657 if err != nil {
1658 t.Fatal(err)
1659 }
1660 defer c.Close()
1661 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1662 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1663 }
1664 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1665 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1666 }
1667 }
1668
1669
1670 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1671 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1672 }
1673 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1674 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1675 t.Error("unexpected HTTPS request")
1676 }), func(ts *httptest.Server) {
1677 var errBuf bytes.Buffer
1678 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1679 }).ts
1680 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1681 if err != nil {
1682 t.Fatal(err)
1683 }
1684 defer conn.Close()
1685 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1686 slurp, err := io.ReadAll(conn)
1687 if err != nil {
1688 t.Fatal(err)
1689 }
1690 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1691 if !strings.HasPrefix(string(slurp), wantPrefix) {
1692 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1693 }
1694 }
1695
1696
1697 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1698 testAutomaticHTTP2_Serve(t, nil, true)
1699 }
1700
1701 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1702 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1703 }
1704
1705 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1706 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1707 }
1708
1709 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1710 setParallel(t)
1711 defer afterTest(t)
1712 ln := newLocalListener(t)
1713 ln.Close()
1714 var s Server
1715 s.TLSConfig = tlsConf
1716 if err := s.Serve(ln); err == nil {
1717 t.Fatal("expected an error")
1718 }
1719 gotH2 := s.TLSNextProto["h2"] != nil
1720 if gotH2 != wantH2 {
1721 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1722 }
1723 }
1724
1725 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1726 setParallel(t)
1727 defer afterTest(t)
1728 ln := newLocalListener(t)
1729 ln.Close()
1730 var s Server
1731
1732
1733 s.TLSConfig = &tls.Config{
1734 NextProtos: []string{"h2"},
1735 }
1736 if err := s.Serve(ln); err == nil {
1737 t.Fatal("expected an error")
1738 }
1739 on := s.TLSNextProto["h2"] != nil
1740 if !on {
1741 t.Errorf("http2 wasn't automatically enabled")
1742 }
1743 }
1744
1745 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1746 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1747 if err != nil {
1748 t.Fatal(err)
1749 }
1750 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1751 Certificates: []tls.Certificate{cert},
1752 })
1753 }
1754
1755 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1756 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1757 if err != nil {
1758 t.Fatal(err)
1759 }
1760 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1761 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1762 return &cert, nil
1763 },
1764 })
1765 }
1766
1767 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1768 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1769 if err != nil {
1770 t.Fatal(err)
1771 }
1772 conf := &tls.Config{
1773
1774
1775 NextProtos: []string{"h2"},
1776 Certificates: []tls.Certificate{cert},
1777 }
1778 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1779 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1780 return conf, nil
1781 },
1782 })
1783 }
1784
1785 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1786 CondSkipHTTP2(t)
1787
1788 defer afterTest(t)
1789 defer SetTestHookServerServe(nil)
1790 var ok bool
1791 var s *Server
1792 const maxTries = 5
1793 var ln net.Listener
1794 Try:
1795 for try := 0; try < maxTries; try++ {
1796 ln = newLocalListener(t)
1797 addr := ln.Addr().String()
1798 ln.Close()
1799 t.Logf("Got %v", addr)
1800 lnc := make(chan net.Listener, 1)
1801 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1802 lnc <- ln
1803 })
1804 s = &Server{
1805 Addr: addr,
1806 TLSConfig: tlsConf,
1807 }
1808 errc := make(chan error, 1)
1809 go func() { errc <- s.ListenAndServeTLS("", "") }()
1810 select {
1811 case err := <-errc:
1812 t.Logf("On try #%v: %v", try+1, err)
1813 continue
1814 case ln = <-lnc:
1815 ok = true
1816 t.Logf("Listening on %v", ln.Addr().String())
1817 break Try
1818 }
1819 }
1820 if !ok {
1821 t.Fatalf("Failed to start up after %d tries", maxTries)
1822 }
1823 defer ln.Close()
1824 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1825 InsecureSkipVerify: true,
1826 NextProtos: []string{"h2", "http/1.1"},
1827 })
1828 if err != nil {
1829 t.Fatal(err)
1830 }
1831 defer c.Close()
1832 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1833 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1834 }
1835 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1836 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1837 }
1838 }
1839
1840 type serverExpectTest struct {
1841 contentLength int
1842 chunked bool
1843 expectation string
1844 readBody bool
1845 expectedResponse string
1846 }
1847
1848 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1849 return serverExpectTest{
1850 contentLength: contentLength,
1851 expectation: expectation,
1852 readBody: readBody,
1853 expectedResponse: expectedResponse,
1854 }
1855 }
1856
1857 var serverExpectTests = []serverExpectTest{
1858
1859 expectTest(100, "100-continue", true, "100 Continue"),
1860 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1861
1862
1863 expectTest(100, "", true, "200 OK"),
1864
1865
1866
1867 expectTest(100, "100-continue", false, "401 Unauthorized"),
1868
1869 expectTest(100, "", false, "401 Unauthorized"),
1870
1871
1872 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1873
1874
1875 expectTest(0, "100-continue", true, "200 OK"),
1876
1877 expectTest(0, "100-continue", false, "401 Unauthorized"),
1878
1879 {
1880 expectation: "100-continue",
1881 readBody: true,
1882 chunked: true,
1883 expectedResponse: "100 Continue",
1884 },
1885 }
1886
1887
1888
1889 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
1890 func testServerExpect(t *testing.T, mode testMode) {
1891 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1892
1893
1894
1895 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1896 io.ReadAll(r.Body)
1897 w.Write([]byte("Hi"))
1898 } else {
1899 w.WriteHeader(StatusUnauthorized)
1900 }
1901 })).ts
1902
1903 runTest := func(test serverExpectTest) {
1904 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1905 if err != nil {
1906 t.Fatalf("Dial: %v", err)
1907 }
1908 defer conn.Close()
1909
1910
1911
1912 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
1913
1914 wg := sync.WaitGroup{}
1915 wg.Add(1)
1916 defer wg.Wait()
1917
1918 go func() {
1919 defer wg.Done()
1920
1921 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
1922 if test.chunked {
1923 contentLen = "Transfer-Encoding: chunked"
1924 }
1925 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
1926 "Connection: close\r\n"+
1927 "%s\r\n"+
1928 "Expect: %s\r\nHost: foo\r\n\r\n",
1929 test.readBody, contentLen, test.expectation)
1930 if err != nil {
1931 t.Errorf("On test %#v, error writing request headers: %v", test, err)
1932 return
1933 }
1934 if writeBody {
1935 var targ io.WriteCloser = struct {
1936 io.Writer
1937 io.Closer
1938 }{
1939 conn,
1940 io.NopCloser(nil),
1941 }
1942 if test.chunked {
1943 targ = httputil.NewChunkedWriter(conn)
1944 }
1945 body := strings.Repeat("A", test.contentLength)
1946 _, err = fmt.Fprint(targ, body)
1947 if err == nil {
1948 err = targ.Close()
1949 }
1950 if err != nil {
1951 if !test.readBody {
1952
1953
1954 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
1955 return
1956 }
1957 t.Errorf("On test %#v, error writing request body: %v", test, err)
1958 }
1959 }
1960 }()
1961 bufr := bufio.NewReader(conn)
1962 line, err := bufr.ReadString('\n')
1963 if err != nil {
1964 if writeBody && !test.readBody {
1965
1966
1967
1968
1969
1970 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
1971 return
1972 }
1973 t.Fatalf("On test %#v, ReadString: %v", test, err)
1974 }
1975 if !strings.Contains(line, test.expectedResponse) {
1976 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
1977 }
1978 }
1979
1980 for _, test := range serverExpectTests {
1981 runTest(test)
1982 }
1983 }
1984
1985
1986
1987 func TestServerUnreadRequestBodyLittle(t *testing.T) {
1988 setParallel(t)
1989 defer afterTest(t)
1990 conn := new(testConn)
1991 body := strings.Repeat("x", 100<<10)
1992 conn.readBuf.Write([]byte(fmt.Sprintf(
1993 "POST / HTTP/1.1\r\n"+
1994 "Host: test\r\n"+
1995 "Content-Length: %d\r\n"+
1996 "\r\n", len(body))))
1997 conn.readBuf.Write([]byte(body))
1998
1999 done := make(chan bool)
2000
2001 readBufLen := func() int {
2002 conn.readMu.Lock()
2003 defer conn.readMu.Unlock()
2004 return conn.readBuf.Len()
2005 }
2006
2007 ls := &oneConnListener{conn}
2008 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2009 defer close(done)
2010 if bufLen := readBufLen(); bufLen < len(body)/2 {
2011 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2012 }
2013 rw.WriteHeader(200)
2014 rw.(Flusher).Flush()
2015 if g, e := readBufLen(), 0; g != e {
2016 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2017 }
2018 if c := rw.Header().Get("Connection"); c != "" {
2019 t.Errorf(`Connection header = %q; want ""`, c)
2020 }
2021 }))
2022 <-done
2023 }
2024
2025
2026
2027
2028 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2029 setParallel(t)
2030 if testing.Short() && testenv.Builder() == "" {
2031 t.Log("skipping in short mode")
2032 }
2033 conn := new(testConn)
2034 body := strings.Repeat("x", 1<<20)
2035 conn.readBuf.Write([]byte(fmt.Sprintf(
2036 "POST / HTTP/1.1\r\n"+
2037 "Host: test\r\n"+
2038 "Content-Length: %d\r\n"+
2039 "\r\n", len(body))))
2040 conn.readBuf.Write([]byte(body))
2041 conn.closec = make(chan bool, 1)
2042
2043 ls := &oneConnListener{conn}
2044 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2045 if conn.readBuf.Len() < len(body)/2 {
2046 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2047 }
2048 rw.WriteHeader(200)
2049 rw.(Flusher).Flush()
2050 if conn.readBuf.Len() < len(body)/2 {
2051 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2052 }
2053 }))
2054 <-conn.closec
2055
2056 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2057 t.Errorf("Expected a Connection: close header; got response: %s", res)
2058 }
2059 }
2060
2061 type handlerBodyCloseTest struct {
2062 bodySize int
2063 bodyChunked bool
2064 reqConnClose bool
2065
2066 wantEOFSearch bool
2067 wantNextReq bool
2068 }
2069
2070 func (t handlerBodyCloseTest) connectionHeader() string {
2071 if t.reqConnClose {
2072 return "Connection: close\r\n"
2073 }
2074 return ""
2075 }
2076
2077 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2078
2079
2080 0: {
2081 bodySize: 20 << 10,
2082 bodyChunked: false,
2083 reqConnClose: false,
2084 wantEOFSearch: true,
2085 wantNextReq: true,
2086 },
2087
2088
2089
2090 1: {
2091 bodySize: 20 << 10,
2092 bodyChunked: true,
2093 reqConnClose: false,
2094 wantEOFSearch: true,
2095 wantNextReq: true,
2096 },
2097
2098
2099
2100
2101 2: {
2102 bodySize: 20 << 10,
2103 bodyChunked: false,
2104 reqConnClose: true,
2105 wantEOFSearch: false,
2106 wantNextReq: false,
2107 },
2108
2109
2110
2111
2112
2113
2114 3: {
2115 bodySize: 20 << 10,
2116 bodyChunked: true,
2117 reqConnClose: true,
2118 wantEOFSearch: true,
2119 wantNextReq: false,
2120 },
2121
2122
2123 4: {
2124 bodySize: 1 << 20,
2125 bodyChunked: false,
2126 reqConnClose: false,
2127 wantEOFSearch: false,
2128 wantNextReq: false,
2129 },
2130
2131
2132 5: {
2133 bodySize: 1 << 20,
2134 bodyChunked: true,
2135 reqConnClose: false,
2136 wantEOFSearch: true,
2137 wantNextReq: false,
2138 },
2139
2140
2141
2142
2143 6: {
2144 bodySize: 1 << 20,
2145 bodyChunked: true,
2146 reqConnClose: true,
2147 wantEOFSearch: true,
2148 wantNextReq: false,
2149 },
2150
2151
2152
2153 7: {
2154 bodySize: 1 << 20,
2155 bodyChunked: false,
2156 reqConnClose: true,
2157 wantEOFSearch: false,
2158 wantNextReq: false,
2159 },
2160 }
2161
2162 func TestHandlerBodyClose(t *testing.T) {
2163 setParallel(t)
2164 if testing.Short() && testenv.Builder() == "" {
2165 t.Skip("skipping in -short mode")
2166 }
2167 for i, tt := range handlerBodyCloseTests {
2168 testHandlerBodyClose(t, i, tt)
2169 }
2170 }
2171
2172 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2173 conn := new(testConn)
2174 body := strings.Repeat("x", tt.bodySize)
2175 if tt.bodyChunked {
2176 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2177 "Host: test\r\n" +
2178 tt.connectionHeader() +
2179 "Transfer-Encoding: chunked\r\n" +
2180 "\r\n")
2181 cw := internal.NewChunkedWriter(&conn.readBuf)
2182 io.WriteString(cw, body)
2183 cw.Close()
2184 conn.readBuf.WriteString("\r\n")
2185 } else {
2186 conn.readBuf.Write([]byte(fmt.Sprintf(
2187 "POST / HTTP/1.1\r\n"+
2188 "Host: test\r\n"+
2189 tt.connectionHeader()+
2190 "Content-Length: %d\r\n"+
2191 "\r\n", len(body))))
2192 conn.readBuf.Write([]byte(body))
2193 }
2194 if !tt.reqConnClose {
2195 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2196 }
2197 conn.closec = make(chan bool, 1)
2198
2199 readBufLen := func() int {
2200 conn.readMu.Lock()
2201 defer conn.readMu.Unlock()
2202 return conn.readBuf.Len()
2203 }
2204
2205 ls := &oneConnListener{conn}
2206 var numReqs int
2207 var size0, size1 int
2208 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2209 numReqs++
2210 if numReqs == 1 {
2211 size0 = readBufLen()
2212 req.Body.Close()
2213 size1 = readBufLen()
2214 }
2215 }))
2216 <-conn.closec
2217 if numReqs < 1 || numReqs > 2 {
2218 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2219 }
2220 didSearch := size0 != size1
2221 if didSearch != tt.wantEOFSearch {
2222 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2223 }
2224 if tt.wantNextReq && numReqs != 2 {
2225 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2226 }
2227 }
2228
2229
2230
2231 type testHandlerBodyConsumer struct {
2232 name string
2233 f func(io.ReadCloser)
2234 }
2235
2236 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2237 {"nil", func(io.ReadCloser) {}},
2238 {"close", func(r io.ReadCloser) { r.Close() }},
2239 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2240 }
2241
2242 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2243 setParallel(t)
2244 defer afterTest(t)
2245 for _, handler := range testHandlerBodyConsumers {
2246 conn := new(testConn)
2247 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2248 "Host: test\r\n" +
2249 "Transfer-Encoding: chunked\r\n" +
2250 "\r\n" +
2251 "hax\r\n" +
2252 "GET /secret HTTP/1.1\r\n" +
2253 "Host: test\r\n" +
2254 "\r\n")
2255
2256 conn.closec = make(chan bool, 1)
2257 ls := &oneConnListener{conn}
2258 var numReqs int
2259 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2260 numReqs++
2261 if strings.Contains(req.URL.Path, "secret") {
2262 t.Error("Request for /secret encountered, should not have happened.")
2263 }
2264 handler.f(req.Body)
2265 }))
2266 <-conn.closec
2267 if numReqs != 1 {
2268 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2269 }
2270 }
2271 }
2272
2273 func TestInvalidTrailerClosesConnection(t *testing.T) {
2274 setParallel(t)
2275 defer afterTest(t)
2276 for _, handler := range testHandlerBodyConsumers {
2277 conn := new(testConn)
2278 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2279 "Host: test\r\n" +
2280 "Trailer: hack\r\n" +
2281 "Transfer-Encoding: chunked\r\n" +
2282 "\r\n" +
2283 "3\r\n" +
2284 "hax\r\n" +
2285 "0\r\n" +
2286 "I'm not a valid trailer\r\n" +
2287 "GET /secret HTTP/1.1\r\n" +
2288 "Host: test\r\n" +
2289 "\r\n")
2290
2291 conn.closec = make(chan bool, 1)
2292 ln := &oneConnListener{conn}
2293 var numReqs int
2294 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2295 numReqs++
2296 if strings.Contains(req.URL.Path, "secret") {
2297 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2298 }
2299 handler.f(req.Body)
2300 }))
2301 <-conn.closec
2302 if numReqs != 1 {
2303 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2304 }
2305 }
2306 }
2307
2308
2309
2310
2311 type slowTestConn struct {
2312
2313 script []any
2314 closec chan bool
2315
2316 mu sync.Mutex
2317 rd, wd time.Time
2318 noopConn
2319 }
2320
2321 func (c *slowTestConn) SetDeadline(t time.Time) error {
2322 c.SetReadDeadline(t)
2323 c.SetWriteDeadline(t)
2324 return nil
2325 }
2326
2327 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2328 c.mu.Lock()
2329 defer c.mu.Unlock()
2330 c.rd = t
2331 return nil
2332 }
2333
2334 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2335 c.mu.Lock()
2336 defer c.mu.Unlock()
2337 c.wd = t
2338 return nil
2339 }
2340
2341 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2342 c.mu.Lock()
2343 defer c.mu.Unlock()
2344 restart:
2345 if !c.rd.IsZero() && time.Now().After(c.rd) {
2346 return 0, syscall.ETIMEDOUT
2347 }
2348 if len(c.script) == 0 {
2349 return 0, io.EOF
2350 }
2351
2352 switch cue := c.script[0].(type) {
2353 case time.Duration:
2354 if !c.rd.IsZero() {
2355
2356
2357 if remaining := time.Until(c.rd); remaining < cue {
2358 c.script[0] = cue - remaining
2359 time.Sleep(remaining)
2360 return 0, syscall.ETIMEDOUT
2361 }
2362 }
2363 c.script = c.script[1:]
2364 time.Sleep(cue)
2365 goto restart
2366
2367 case string:
2368 n = copy(b, cue)
2369
2370 if len(cue) > n {
2371 c.script[0] = cue[n:]
2372 } else {
2373 c.script = c.script[1:]
2374 }
2375
2376 default:
2377 panic("unknown cue in slowTestConn script")
2378 }
2379
2380 return
2381 }
2382
2383 func (c *slowTestConn) Close() error {
2384 select {
2385 case c.closec <- true:
2386 default:
2387 }
2388 return nil
2389 }
2390
2391 func (c *slowTestConn) Write(b []byte) (int, error) {
2392 if !c.wd.IsZero() && time.Now().After(c.wd) {
2393 return 0, syscall.ETIMEDOUT
2394 }
2395 return len(b), nil
2396 }
2397
2398 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2399 if testing.Short() {
2400 t.Skip("skipping in -short mode")
2401 }
2402 defer afterTest(t)
2403 for _, handler := range testHandlerBodyConsumers {
2404 conn := &slowTestConn{
2405 script: []any{
2406 "POST /public HTTP/1.1\r\n" +
2407 "Host: test\r\n" +
2408 "Content-Length: 10000\r\n" +
2409 "\r\n",
2410 "foo bar baz",
2411 600 * time.Millisecond,
2412 "GET /secret HTTP/1.1\r\n" +
2413 "Host: test\r\n" +
2414 "\r\n",
2415 },
2416 closec: make(chan bool, 1),
2417 }
2418 ls := &oneConnListener{conn}
2419
2420 var numReqs int
2421 s := Server{
2422 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2423 numReqs++
2424 if strings.Contains(req.URL.Path, "secret") {
2425 t.Error("Request for /secret encountered, should not have happened.")
2426 }
2427 handler.f(req.Body)
2428 }),
2429 ReadTimeout: 400 * time.Millisecond,
2430 }
2431 go s.Serve(ls)
2432 <-conn.closec
2433
2434 if numReqs != 1 {
2435 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2436 }
2437 }
2438 }
2439
2440
2441 type cancelableTimeoutContext struct {
2442 context.Context
2443 }
2444
2445 func (c cancelableTimeoutContext) Err() error {
2446 if c.Context.Err() != nil {
2447 return context.DeadlineExceeded
2448 }
2449 return nil
2450 }
2451
2452 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2453 func testTimeoutHandler(t *testing.T, mode testMode) {
2454 sendHi := make(chan bool, 1)
2455 writeErrors := make(chan error, 1)
2456 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2457 <-sendHi
2458 _, werr := w.Write([]byte("hi"))
2459 writeErrors <- werr
2460 })
2461 ctx, cancel := context.WithCancel(context.Background())
2462 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2463 cst := newClientServerTest(t, mode, h)
2464
2465
2466 sendHi <- true
2467 res, err := cst.c.Get(cst.ts.URL)
2468 if err != nil {
2469 t.Error(err)
2470 }
2471 if g, e := res.StatusCode, StatusOK; g != e {
2472 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2473 }
2474 body, _ := io.ReadAll(res.Body)
2475 if g, e := string(body), "hi"; g != e {
2476 t.Errorf("got body %q; expected %q", g, e)
2477 }
2478 if g := <-writeErrors; g != nil {
2479 t.Errorf("got unexpected Write error on first request: %v", g)
2480 }
2481
2482
2483 cancel()
2484
2485 res, err = cst.c.Get(cst.ts.URL)
2486 if err != nil {
2487 t.Error(err)
2488 }
2489 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2490 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2491 }
2492 body, _ = io.ReadAll(res.Body)
2493 if !strings.Contains(string(body), "<title>Timeout</title>") {
2494 t.Errorf("expected timeout body; got %q", string(body))
2495 }
2496 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2497 t.Errorf("response content-type = %q; want %q", g, w)
2498 }
2499
2500
2501
2502 sendHi <- true
2503 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2504 t.Errorf("expected Write error of %v; got %v", e, g)
2505 }
2506 }
2507
2508
2509 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2510 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2511 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2512 ms, _ := strconv.Atoi(r.URL.Path[1:])
2513 if ms == 0 {
2514 ms = 1
2515 }
2516 for i := 0; i < ms; i++ {
2517 w.Write([]byte("hi"))
2518 time.Sleep(time.Millisecond)
2519 }
2520 })
2521
2522 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2523
2524 c := ts.Client()
2525
2526 var wg sync.WaitGroup
2527 gate := make(chan bool, 10)
2528 n := 50
2529 if testing.Short() {
2530 n = 10
2531 gate = make(chan bool, 3)
2532 }
2533 for i := 0; i < n; i++ {
2534 gate <- true
2535 wg.Add(1)
2536 go func() {
2537 defer wg.Done()
2538 defer func() { <-gate }()
2539 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2540 if err == nil {
2541 io.Copy(io.Discard, res.Body)
2542 res.Body.Close()
2543 }
2544 }()
2545 }
2546 wg.Wait()
2547 }
2548
2549
2550
2551 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2552 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2553 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2554 w.WriteHeader(204)
2555 })
2556
2557 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2558
2559 var wg sync.WaitGroup
2560 gate := make(chan bool, 50)
2561 n := 500
2562 if testing.Short() {
2563 n = 10
2564 }
2565
2566 c := ts.Client()
2567 for i := 0; i < n; i++ {
2568 gate <- true
2569 wg.Add(1)
2570 go func() {
2571 defer wg.Done()
2572 defer func() { <-gate }()
2573 res, err := c.Get(ts.URL)
2574 if err != nil {
2575
2576
2577 t.Log(err)
2578 return
2579 }
2580 defer res.Body.Close()
2581 io.Copy(io.Discard, res.Body)
2582 }()
2583 }
2584 wg.Wait()
2585 }
2586
2587
2588 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2589 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2590 sendHi := make(chan bool, 1)
2591 writeErrors := make(chan error, 1)
2592 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2593 w.Header().Set("Content-Type", "text/plain")
2594 <-sendHi
2595 _, werr := w.Write([]byte("hi"))
2596 writeErrors <- werr
2597 })
2598 ctx, cancel := context.WithCancel(context.Background())
2599 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2600 cst := newClientServerTest(t, mode, h)
2601
2602
2603 sendHi <- true
2604 res, err := cst.c.Get(cst.ts.URL)
2605 if err != nil {
2606 t.Error(err)
2607 }
2608 if g, e := res.StatusCode, StatusOK; g != e {
2609 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2610 }
2611 body, _ := io.ReadAll(res.Body)
2612 if g, e := string(body), "hi"; g != e {
2613 t.Errorf("got body %q; expected %q", g, e)
2614 }
2615 if g := <-writeErrors; g != nil {
2616 t.Errorf("got unexpected Write error on first request: %v", g)
2617 }
2618
2619
2620 cancel()
2621
2622 res, err = cst.c.Get(cst.ts.URL)
2623 if err != nil {
2624 t.Error(err)
2625 }
2626 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2627 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2628 }
2629 body, _ = io.ReadAll(res.Body)
2630 if !strings.Contains(string(body), "<title>Timeout</title>") {
2631 t.Errorf("expected timeout body; got %q", string(body))
2632 }
2633
2634
2635
2636 sendHi <- true
2637 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2638 t.Errorf("expected Write error of %v; got %v", e, g)
2639 }
2640 }
2641
2642
2643 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2644 run(t, testTimeoutHandlerStartTimerWhenServing)
2645 }
2646 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2647 if testing.Short() {
2648 t.Skip("skipping sleeping test in -short mode")
2649 }
2650 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2651 w.WriteHeader(StatusNoContent)
2652 }
2653 timeout := 300 * time.Millisecond
2654 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2655 defer ts.Close()
2656
2657 c := ts.Client()
2658
2659
2660
2661
2662 time.Sleep(2 * timeout)
2663 res, err := c.Get(ts.URL)
2664 if err != nil {
2665 t.Fatal(err)
2666 }
2667 defer res.Body.Close()
2668 if res.StatusCode != StatusNoContent {
2669 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2670 }
2671 }
2672
2673 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2674 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2675 writeErrors := make(chan error, 1)
2676 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2677 w.Header().Set("Content-Type", "text/plain")
2678 var err error
2679
2680
2681
2682 for i := 0; i < 100; i++ {
2683 _, err = w.Write([]byte("a"))
2684 if err != nil {
2685 break
2686 }
2687 time.Sleep(1 * time.Millisecond)
2688 }
2689 writeErrors <- err
2690 })
2691 ctx, cancel := context.WithCancel(context.Background())
2692 cancel()
2693 h := NewTestTimeoutHandler(sayHi, ctx)
2694 cst := newClientServerTest(t, mode, h)
2695 defer cst.close()
2696
2697 res, err := cst.c.Get(cst.ts.URL)
2698 if err != nil {
2699 t.Error(err)
2700 }
2701 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2702 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2703 }
2704 body, _ := io.ReadAll(res.Body)
2705 if g, e := string(body), ""; g != e {
2706 t.Errorf("got body %q; expected %q", g, e)
2707 }
2708 if g, e := <-writeErrors, context.Canceled; g != e {
2709 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2710 }
2711 }
2712
2713
2714 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2715 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2716 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2717
2718 }
2719 timeout := 300 * time.Millisecond
2720 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2721
2722 c := ts.Client()
2723
2724 res, err := c.Get(ts.URL)
2725 if err != nil {
2726 t.Fatal(err)
2727 }
2728 defer res.Body.Close()
2729 if res.StatusCode != StatusOK {
2730 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2731 }
2732 }
2733
2734
2735 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2736 wrapper := func(h Handler) Handler {
2737 return TimeoutHandler(h, time.Second, "")
2738 }
2739 run(t, func(t *testing.T, mode testMode) {
2740 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2741 }, testNotParallel)
2742 }
2743
2744 func TestRedirectBadPath(t *testing.T) {
2745
2746
2747 rr := httptest.NewRecorder()
2748 req := &Request{
2749 Method: "GET",
2750 URL: &url.URL{
2751 Scheme: "http",
2752 Path: "not-empty-but-no-leading-slash",
2753 },
2754 }
2755 Redirect(rr, req, "", 304)
2756 if rr.Code != 304 {
2757 t.Errorf("Code = %d; want 304", rr.Code)
2758 }
2759 }
2760
2761
2762 func TestRedirect(t *testing.T) {
2763 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2764
2765 var tests = []struct {
2766 in string
2767 want string
2768 }{
2769
2770 {"http://foobar.com/baz", "http://foobar.com/baz"},
2771
2772 {"https://foobar.com/baz", "https://foobar.com/baz"},
2773
2774 {"test://foobar.com/baz", "test://foobar.com/baz"},
2775
2776 {"//foobar.com/baz", "//foobar.com/baz"},
2777
2778 {"/foobar.com/baz", "/foobar.com/baz"},
2779
2780 {"foobar.com/baz", "/qux/foobar.com/baz"},
2781
2782 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2783
2784 {"///foobar.com/baz", "/foobar.com/baz"},
2785
2786
2787 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2788 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2789 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2790
2791 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2792 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2793 }
2794
2795 for _, tt := range tests {
2796 rec := httptest.NewRecorder()
2797 Redirect(rec, req, tt.in, 302)
2798 if got, want := rec.Code, 302; got != want {
2799 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2800 }
2801 if got := rec.Header().Get("Location"); got != tt.want {
2802 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2803 }
2804 }
2805 }
2806
2807
2808
2809 func TestRedirectContentTypeAndBody(t *testing.T) {
2810 type ctHeader struct {
2811 Values []string
2812 }
2813
2814 var tests = []struct {
2815 method string
2816 ct *ctHeader
2817 wantCT string
2818 wantBody string
2819 }{
2820 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2821 {MethodHead, nil, "text/html; charset=utf-8", ""},
2822 {MethodPost, nil, "", ""},
2823 {MethodDelete, nil, "", ""},
2824 {"foo", nil, "", ""},
2825 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2826 {MethodGet, &ctHeader{[]string{}}, "", ""},
2827 {MethodGet, &ctHeader{nil}, "", ""},
2828 }
2829 for _, tt := range tests {
2830 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2831 rec := httptest.NewRecorder()
2832 if tt.ct != nil {
2833 rec.Header()["Content-Type"] = tt.ct.Values
2834 }
2835 Redirect(rec, req, "/foo", 302)
2836 if got, want := rec.Code, 302; got != want {
2837 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2838 }
2839 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2840 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2841 }
2842 resp := rec.Result()
2843 body, err := io.ReadAll(resp.Body)
2844 if err != nil {
2845 t.Fatal(err)
2846 }
2847 if got, want := string(body), tt.wantBody; got != want {
2848 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2849 }
2850 }
2851 }
2852
2853
2854
2855
2856
2857
2858
2859 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
2860
2861 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
2862 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
2863 all, err := io.ReadAll(r.Body)
2864 if err != nil {
2865 t.Fatalf("handler ReadAll: %v", err)
2866 }
2867 if len(all) != 0 {
2868 t.Errorf("handler got %d bytes; expected 0", len(all))
2869 }
2870 rw.Header().Set("Content-Length", "0")
2871 }))
2872
2873 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2874 if err != nil {
2875 t.Fatal(err)
2876 }
2877 req.ContentLength = 0
2878
2879 var resp [5]*Response
2880 for i := range resp {
2881 resp[i], err = cst.c.Do(req)
2882 if err != nil {
2883 t.Fatalf("client post #%d: %v", i, err)
2884 }
2885 }
2886
2887 for i := range resp {
2888 all, err := io.ReadAll(resp[i].Body)
2889 if err != nil {
2890 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2891 }
2892 if len(all) != 0 {
2893 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2894 }
2895 }
2896 }
2897
2898 func TestHandlerPanicNil(t *testing.T) {
2899 run(t, func(t *testing.T, mode testMode) {
2900 testHandlerPanic(t, false, mode, nil, nil)
2901 }, testNotParallel)
2902 }
2903
2904 func TestHandlerPanic(t *testing.T) {
2905 run(t, func(t *testing.T, mode testMode) {
2906 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
2907 }, testNotParallel)
2908 }
2909
2910 func TestHandlerPanicWithHijack(t *testing.T) {
2911
2912 run(t, func(t *testing.T, mode testMode) {
2913 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
2914 }, []testMode{http1Mode})
2915 }
2916
2917 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
2918
2919
2920
2921
2922
2923
2924
2925
2926 pr, pw := io.Pipe()
2927 defer pw.Close()
2928
2929 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
2930 if withHijack {
2931 rwc, _, err := w.(Hijacker).Hijack()
2932 if err != nil {
2933 t.Logf("unexpected error: %v", err)
2934 }
2935 defer rwc.Close()
2936 }
2937 panic(panicValue)
2938 })
2939 if wrapper != nil {
2940 handler = wrapper(handler)
2941 }
2942 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
2943 ts.Config.ErrorLog = log.New(pw, "", 0)
2944 })
2945
2946
2947 done := make(chan bool, 1)
2948 go func() {
2949 buf := make([]byte, 4<<10)
2950 _, err := pr.Read(buf)
2951 pr.Close()
2952 if err != nil && err != io.EOF {
2953 t.Error(err)
2954 }
2955 done <- true
2956 }()
2957
2958 _, err := cst.c.Get(cst.ts.URL)
2959 if err == nil {
2960 t.Logf("expected an error")
2961 }
2962
2963 if panicValue == nil {
2964 return
2965 }
2966
2967 <-done
2968 }
2969
2970 type terrorWriter struct{ t *testing.T }
2971
2972 func (w terrorWriter) Write(p []byte) (int, error) {
2973 w.t.Errorf("%s", p)
2974 return len(p), nil
2975 }
2976
2977
2978
2979 func TestServerWriteHijackZeroBytes(t *testing.T) {
2980 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
2981 }
2982 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
2983 done := make(chan struct{})
2984 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2985 defer close(done)
2986 w.(Flusher).Flush()
2987 conn, _, err := w.(Hijacker).Hijack()
2988 if err != nil {
2989 t.Errorf("Hijack: %v", err)
2990 return
2991 }
2992 defer conn.Close()
2993 _, err = w.Write(nil)
2994 if err != ErrHijacked {
2995 t.Errorf("Write error = %v; want ErrHijacked", err)
2996 }
2997 }), func(ts *httptest.Server) {
2998 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
2999 }).ts
3000
3001 c := ts.Client()
3002 res, err := c.Get(ts.URL)
3003 if err != nil {
3004 t.Fatal(err)
3005 }
3006 res.Body.Close()
3007 <-done
3008 }
3009
3010 func TestServerNoDate(t *testing.T) {
3011 run(t, func(t *testing.T, mode testMode) {
3012 testServerNoHeader(t, mode, "Date")
3013 })
3014 }
3015
3016 func TestServerContentType(t *testing.T) {
3017 run(t, func(t *testing.T, mode testMode) {
3018 testServerNoHeader(t, mode, "Content-Type")
3019 })
3020 }
3021
3022 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3023 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3024 w.Header()[header] = nil
3025 io.WriteString(w, "<html>foo</html>")
3026 }))
3027 res, err := cst.c.Get(cst.ts.URL)
3028 if err != nil {
3029 t.Fatal(err)
3030 }
3031 res.Body.Close()
3032 if got, ok := res.Header[header]; ok {
3033 t.Fatalf("Expected no %s header; got %q", header, got)
3034 }
3035 }
3036
3037 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3038 func testStripPrefix(t *testing.T, mode testMode) {
3039 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3040 w.Header().Set("X-Path", r.URL.Path)
3041 w.Header().Set("X-RawPath", r.URL.RawPath)
3042 })
3043 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3044
3045 c := ts.Client()
3046
3047 cases := []struct {
3048 reqPath string
3049 path string
3050 rawPath string
3051 }{
3052 {"/foo/bar/qux", "/qux", ""},
3053 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3054 {"/foo%2Fbar/qux", "", ""},
3055 {"/bar", "", ""},
3056 }
3057 for _, tc := range cases {
3058 t.Run(tc.reqPath, func(t *testing.T) {
3059 res, err := c.Get(ts.URL + tc.reqPath)
3060 if err != nil {
3061 t.Fatal(err)
3062 }
3063 res.Body.Close()
3064 if tc.path == "" {
3065 if res.StatusCode != StatusNotFound {
3066 t.Errorf("got %q, want 404 Not Found", res.Status)
3067 }
3068 return
3069 }
3070 if res.StatusCode != StatusOK {
3071 t.Fatalf("got %q, want 200 OK", res.Status)
3072 }
3073 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3074 t.Errorf("got Path %q, want %q", g, w)
3075 }
3076 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3077 t.Errorf("got RawPath %q, want %q", g, w)
3078 }
3079 })
3080 }
3081 }
3082
3083
3084 func TestStripPrefixNotModifyRequest(t *testing.T) {
3085 h := StripPrefix("/foo", NotFoundHandler())
3086 req := httptest.NewRequest("GET", "/foo/bar", nil)
3087 h.ServeHTTP(httptest.NewRecorder(), req)
3088 if req.URL.Path != "/foo/bar" {
3089 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3090 }
3091 }
3092
3093 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
3094 func testRequestLimit(t *testing.T, mode testMode) {
3095 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3096 t.Fatalf("didn't expect to get request in Handler")
3097 }), optQuietLog)
3098 req, _ := NewRequest("GET", cst.ts.URL, nil)
3099 var bytesPerHeader = len("header12345: val12345\r\n")
3100 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3101 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3102 }
3103 res, err := cst.c.Do(req)
3104 if res != nil {
3105 defer res.Body.Close()
3106 }
3107 if mode == http2Mode {
3108
3109
3110
3111
3112 if err == nil && res.StatusCode != 431 {
3113 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3114 }
3115 } else {
3116
3117
3118
3119
3120 if err != nil {
3121 t.Fatalf("Do: %v", err)
3122 }
3123 if res.StatusCode != 431 {
3124 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3125 }
3126 }
3127 }
3128
3129 type neverEnding byte
3130
3131 func (b neverEnding) Read(p []byte) (n int, err error) {
3132 for i := range p {
3133 p[i] = byte(b)
3134 }
3135 return len(p), nil
3136 }
3137
3138 type bodyLimitReader struct {
3139 mu sync.Mutex
3140 count int
3141 limit int
3142 closed chan struct{}
3143 }
3144
3145 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3146 r.mu.Lock()
3147 defer r.mu.Unlock()
3148 select {
3149 case <-r.closed:
3150 return 0, errors.New("closed")
3151 default:
3152 }
3153 if r.count > r.limit {
3154 return 0, errors.New("at limit")
3155 }
3156 r.count += len(p)
3157 for i := range p {
3158 p[i] = 'a'
3159 }
3160 return len(p), nil
3161 }
3162
3163 func (r *bodyLimitReader) Close() error {
3164 r.mu.Lock()
3165 defer r.mu.Unlock()
3166 close(r.closed)
3167 return nil
3168 }
3169
3170 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3171 func testRequestBodyLimit(t *testing.T, mode testMode) {
3172 const limit = 1 << 20
3173 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3174 r.Body = MaxBytesReader(w, r.Body, limit)
3175 n, err := io.Copy(io.Discard, r.Body)
3176 if err == nil {
3177 t.Errorf("expected error from io.Copy")
3178 }
3179 if n != limit {
3180 t.Errorf("io.Copy = %d, want %d", n, limit)
3181 }
3182 mbErr, ok := err.(*MaxBytesError)
3183 if !ok {
3184 t.Errorf("expected MaxBytesError, got %T", err)
3185 }
3186 if mbErr.Limit != limit {
3187 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3188 }
3189 }))
3190
3191 body := &bodyLimitReader{
3192 closed: make(chan struct{}),
3193 limit: limit * 200,
3194 }
3195 req, _ := NewRequest("POST", cst.ts.URL, body)
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206 resp, err := cst.c.Do(req)
3207 if err == nil {
3208 resp.Body.Close()
3209 }
3210
3211
3212 <-body.closed
3213
3214 if body.count > limit*100 {
3215 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3216 limit, body.count)
3217 }
3218 }
3219
3220
3221
3222 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3223 func testClientWriteShutdown(t *testing.T, mode testMode) {
3224 if runtime.GOOS == "plan9" {
3225 t.Skip("skipping test; see https://golang.org/issue/17906")
3226 }
3227 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3228 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3229 if err != nil {
3230 t.Fatalf("Dial: %v", err)
3231 }
3232 err = conn.(*net.TCPConn).CloseWrite()
3233 if err != nil {
3234 t.Fatalf("CloseWrite: %v", err)
3235 }
3236
3237 bs, err := io.ReadAll(conn)
3238 if err != nil {
3239 t.Errorf("ReadAll: %v", err)
3240 }
3241 got := string(bs)
3242 if got != "" {
3243 t.Errorf("read %q from server; want nothing", got)
3244 }
3245 }
3246
3247
3248
3249 func TestServerBufferedChunking(t *testing.T) {
3250 conn := new(testConn)
3251 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3252 conn.closec = make(chan bool, 1)
3253 ls := &oneConnListener{conn}
3254 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3255 rw.(Flusher).Flush()
3256 rw.Write([]byte{'x'})
3257 rw.Write([]byte{'y'})
3258 rw.Write([]byte{'z'})
3259 }))
3260 <-conn.closec
3261 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3262 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3263 conn.writeBuf.Bytes())
3264 }
3265 }
3266
3267
3268
3269
3270
3271 func TestServerGracefulClose(t *testing.T) {
3272
3273 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3274 }
3275 func testServerGracefulClose(t *testing.T, mode testMode) {
3276 runTimeSensitiveTest(t, []time.Duration{
3277 1 * time.Millisecond,
3278 5 * time.Millisecond,
3279 10 * time.Millisecond,
3280 50 * time.Millisecond,
3281 100 * time.Millisecond,
3282 500 * time.Millisecond,
3283 time.Second,
3284 5 * time.Second,
3285 }, func(t *testing.T, timeout time.Duration) error {
3286 SetRSTAvoidanceDelay(t, timeout)
3287 t.Logf("set RST avoidance delay to %v", timeout)
3288
3289 const bodySize = 5 << 20
3290 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3291 for i := 0; i < bodySize; i++ {
3292 req = append(req, 'x')
3293 }
3294
3295 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3296 Error(w, "bye", StatusUnauthorized)
3297 }))
3298
3299
3300 defer cst.close()
3301 ts := cst.ts
3302
3303 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3304 if err != nil {
3305 return err
3306 }
3307 writeErr := make(chan error)
3308 go func() {
3309 _, err := conn.Write(req)
3310 writeErr <- err
3311 }()
3312 defer func() {
3313 conn.Close()
3314
3315
3316
3317 <-writeErr
3318 }()
3319
3320 br := bufio.NewReader(conn)
3321 lineNum := 0
3322 for {
3323 line, err := br.ReadString('\n')
3324 if err == io.EOF {
3325 break
3326 }
3327 if err != nil {
3328 return fmt.Errorf("ReadLine: %v", err)
3329 }
3330 lineNum++
3331 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3332 t.Errorf("Response line = %q; want a 401", line)
3333 }
3334 }
3335 return nil
3336 })
3337 }
3338
3339 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3340 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3341 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3342 if r.Method != "get" {
3343 t.Errorf(`Got method %q; want "get"`, r.Method)
3344 }
3345 }))
3346 defer cst.close()
3347 req, _ := NewRequest("get", cst.ts.URL, nil)
3348 res, err := cst.c.Do(req)
3349 if err != nil {
3350 t.Error(err)
3351 return
3352 }
3353
3354 res.Body.Close()
3355 }
3356
3357
3358
3359
3360
3361 func TestContentLengthZero(t *testing.T) {
3362 run(t, testContentLengthZero, []testMode{http1Mode})
3363 }
3364 func testContentLengthZero(t *testing.T, mode testMode) {
3365 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3366
3367 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3368 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3369 if err != nil {
3370 t.Fatalf("error dialing: %v", err)
3371 }
3372 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3373 if err != nil {
3374 t.Fatalf("error writing: %v", err)
3375 }
3376 req, _ := NewRequest("GET", "/", nil)
3377 res, err := ReadResponse(bufio.NewReader(conn), req)
3378 if err != nil {
3379 t.Fatalf("error reading response: %v", err)
3380 }
3381 if te := res.TransferEncoding; len(te) > 0 {
3382 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3383 }
3384 if cl := res.ContentLength; cl != 0 {
3385 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3386 }
3387 conn.Close()
3388 }
3389 }
3390
3391 func TestCloseNotifier(t *testing.T) {
3392 run(t, testCloseNotifier, []testMode{http1Mode})
3393 }
3394 func testCloseNotifier(t *testing.T, mode testMode) {
3395 gotReq := make(chan bool, 1)
3396 sawClose := make(chan bool, 1)
3397 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3398 gotReq <- true
3399 cc := rw.(CloseNotifier).CloseNotify()
3400 <-cc
3401 sawClose <- true
3402 })).ts
3403 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3404 if err != nil {
3405 t.Fatalf("error dialing: %v", err)
3406 }
3407 diec := make(chan bool)
3408 go func() {
3409 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3410 if err != nil {
3411 t.Error(err)
3412 return
3413 }
3414 <-diec
3415 conn.Close()
3416 }()
3417 For:
3418 for {
3419 select {
3420 case <-gotReq:
3421 diec <- true
3422 case <-sawClose:
3423 break For
3424 }
3425 }
3426 ts.Close()
3427 }
3428
3429
3430
3431
3432
3433 func TestCloseNotifierPipelined(t *testing.T) {
3434 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3435 }
3436 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3437 gotReq := make(chan bool, 2)
3438 sawClose := make(chan bool, 2)
3439 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3440 gotReq <- true
3441 cc := rw.(CloseNotifier).CloseNotify()
3442 select {
3443 case <-cc:
3444 t.Error("unexpected CloseNotify")
3445 case <-time.After(100 * time.Millisecond):
3446 }
3447 sawClose <- true
3448 })).ts
3449 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3450 if err != nil {
3451 t.Fatalf("error dialing: %v", err)
3452 }
3453 diec := make(chan bool, 1)
3454 defer close(diec)
3455 go func() {
3456 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3457 _, err = io.WriteString(conn, req+req)
3458 if err != nil {
3459 t.Error(err)
3460 return
3461 }
3462 <-diec
3463 conn.Close()
3464 }()
3465 reqs := 0
3466 closes := 0
3467 for {
3468 select {
3469 case <-gotReq:
3470 reqs++
3471 if reqs > 2 {
3472 t.Fatal("too many requests")
3473 }
3474 case <-sawClose:
3475 closes++
3476 if closes > 1 {
3477 return
3478 }
3479 }
3480 }
3481 }
3482
3483 func TestCloseNotifierChanLeak(t *testing.T) {
3484 defer afterTest(t)
3485 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3486 for i := 0; i < 20; i++ {
3487 var output bytes.Buffer
3488 conn := &rwTestConn{
3489 Reader: bytes.NewReader(req),
3490 Writer: &output,
3491 closec: make(chan bool, 1),
3492 }
3493 ln := &oneConnListener{conn: conn}
3494 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3495
3496
3497
3498 _ = rw.(CloseNotifier).CloseNotify()
3499 })
3500 go Serve(ln, handler)
3501 <-conn.closec
3502 }
3503 }
3504
3505
3506
3507
3508
3509
3510
3511
3512
3513
3514 func TestHijackAfterCloseNotifier(t *testing.T) {
3515 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3516 }
3517 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3518 script := make(chan string, 2)
3519 script <- "closenotify"
3520 script <- "hijack"
3521 close(script)
3522 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3523 plan := <-script
3524 switch plan {
3525 default:
3526 panic("bogus plan; too many requests")
3527 case "closenotify":
3528 w.(CloseNotifier).CloseNotify()
3529 w.Header().Set("X-Addr", r.RemoteAddr)
3530 case "hijack":
3531 c, _, err := w.(Hijacker).Hijack()
3532 if err != nil {
3533 t.Errorf("Hijack in Handler: %v", err)
3534 return
3535 }
3536 if _, ok := c.(*net.TCPConn); !ok {
3537
3538
3539 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3540 }
3541 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3542 c.Close()
3543 return
3544 }
3545 })).ts
3546 res1, err := ts.Client().Get(ts.URL)
3547 if err != nil {
3548 log.Fatal(err)
3549 }
3550 res2, err := ts.Client().Get(ts.URL)
3551 if err != nil {
3552 log.Fatal(err)
3553 }
3554 addr1 := res1.Header.Get("X-Addr")
3555 addr2 := res2.Header.Get("X-Addr")
3556 if addr1 == "" || addr1 != addr2 {
3557 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3558 }
3559 }
3560
3561 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3562 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3563 }
3564 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3565 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3566 bodyOkay := make(chan bool, 1)
3567 gotCloseNotify := make(chan bool, 1)
3568 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3569 defer close(bodyOkay)
3570
3571 reqBody := r.Body
3572 r.Body = nil
3573
3574 gone := w.(CloseNotifier).CloseNotify()
3575 slurp, err := io.ReadAll(reqBody)
3576 if err != nil {
3577 t.Errorf("Body read: %v", err)
3578 return
3579 }
3580 if len(slurp) != len(requestBody) {
3581 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3582 return
3583 }
3584 if !bytes.Equal(slurp, requestBody) {
3585 t.Error("Backend read wrong request body.")
3586 return
3587 }
3588 bodyOkay <- true
3589 <-gone
3590 gotCloseNotify <- true
3591 })).ts
3592
3593 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3594 if err != nil {
3595 t.Fatal(err)
3596 }
3597 defer conn.Close()
3598
3599 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3600 len(requestBody), requestBody)
3601 if !<-bodyOkay {
3602
3603 return
3604 }
3605 conn.Close()
3606 <-gotCloseNotify
3607 }
3608
3609 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3610 func testOptions(t *testing.T, mode testMode) {
3611 uric := make(chan string, 2)
3612 mux := NewServeMux()
3613 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3614 uric <- r.RequestURI
3615 })
3616 ts := newClientServerTest(t, mode, mux).ts
3617
3618 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3619 if err != nil {
3620 t.Fatal(err)
3621 }
3622 defer conn.Close()
3623
3624
3625 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3626 if err != nil {
3627 t.Fatal(err)
3628 }
3629 br := bufio.NewReader(conn)
3630 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3631 if err != nil {
3632 t.Fatal(err)
3633 }
3634 if res.StatusCode != 200 {
3635 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3636 }
3637
3638
3639 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3640 if err != nil {
3641 t.Fatal(err)
3642 }
3643 res, err = ReadResponse(br, &Request{Method: "GET"})
3644 if err != nil {
3645 t.Fatal(err)
3646 }
3647 if res.StatusCode != 400 {
3648 t.Errorf("Got non-400 response to GET *: %#v", res)
3649 }
3650
3651 res, err = Get(ts.URL + "/second")
3652 if err != nil {
3653 t.Fatal(err)
3654 }
3655 res.Body.Close()
3656 if got := <-uric; got != "/second" {
3657 t.Errorf("Handler saw request for %q; want /second", got)
3658 }
3659 }
3660
3661 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3662 func testOptionsHandler(t *testing.T, mode testMode) {
3663 rc := make(chan *Request, 1)
3664
3665 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3666 rc <- r
3667 }), func(ts *httptest.Server) {
3668 ts.Config.DisableGeneralOptionsHandler = true
3669 }).ts
3670
3671 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3672 if err != nil {
3673 t.Fatal(err)
3674 }
3675 defer conn.Close()
3676
3677 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3678 if err != nil {
3679 t.Fatal(err)
3680 }
3681
3682 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3683 t.Errorf("Expected OPTIONS * request, got %v", got)
3684 }
3685 }
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696 func TestHeaderToWire(t *testing.T) {
3697 tests := []struct {
3698 name string
3699 handler func(ResponseWriter, *Request)
3700 check func(got, logs string) error
3701 }{
3702 {
3703 name: "write without Header",
3704 handler: func(rw ResponseWriter, r *Request) {
3705 rw.Write([]byte("hello world"))
3706 },
3707 check: func(got, logs string) error {
3708 if !strings.Contains(got, "Content-Length:") {
3709 return errors.New("no content-length")
3710 }
3711 if !strings.Contains(got, "Content-Type: text/plain") {
3712 return errors.New("no content-type")
3713 }
3714 return nil
3715 },
3716 },
3717 {
3718 name: "Header mutation before write",
3719 handler: func(rw ResponseWriter, r *Request) {
3720 h := rw.Header()
3721 h.Set("Content-Type", "some/type")
3722 rw.Write([]byte("hello world"))
3723 h.Set("Too-Late", "bogus")
3724 },
3725 check: func(got, logs string) error {
3726 if !strings.Contains(got, "Content-Length:") {
3727 return errors.New("no content-length")
3728 }
3729 if !strings.Contains(got, "Content-Type: some/type") {
3730 return errors.New("wrong content-type")
3731 }
3732 if strings.Contains(got, "Too-Late") {
3733 return errors.New("don't want too-late header")
3734 }
3735 return nil
3736 },
3737 },
3738 {
3739 name: "write then useless Header mutation",
3740 handler: func(rw ResponseWriter, r *Request) {
3741 rw.Write([]byte("hello world"))
3742 rw.Header().Set("Too-Late", "Write already wrote headers")
3743 },
3744 check: func(got, logs string) error {
3745 if strings.Contains(got, "Too-Late") {
3746 return errors.New("header appeared from after WriteHeader")
3747 }
3748 return nil
3749 },
3750 },
3751 {
3752 name: "flush then write",
3753 handler: func(rw ResponseWriter, r *Request) {
3754 rw.(Flusher).Flush()
3755 rw.Write([]byte("post-flush"))
3756 rw.Header().Set("Too-Late", "Write already wrote headers")
3757 },
3758 check: func(got, logs string) error {
3759 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3760 return errors.New("not chunked")
3761 }
3762 if strings.Contains(got, "Too-Late") {
3763 return errors.New("header appeared from after WriteHeader")
3764 }
3765 return nil
3766 },
3767 },
3768 {
3769 name: "header then flush",
3770 handler: func(rw ResponseWriter, r *Request) {
3771 rw.Header().Set("Content-Type", "some/type")
3772 rw.(Flusher).Flush()
3773 rw.Write([]byte("post-flush"))
3774 rw.Header().Set("Too-Late", "Write already wrote headers")
3775 },
3776 check: func(got, logs string) error {
3777 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3778 return errors.New("not chunked")
3779 }
3780 if strings.Contains(got, "Too-Late") {
3781 return errors.New("header appeared from after WriteHeader")
3782 }
3783 if !strings.Contains(got, "Content-Type: some/type") {
3784 return errors.New("wrong content-type")
3785 }
3786 return nil
3787 },
3788 },
3789 {
3790 name: "sniff-on-first-write content-type",
3791 handler: func(rw ResponseWriter, r *Request) {
3792 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3793 rw.Header().Set("Content-Type", "x/wrong")
3794 },
3795 check: func(got, logs string) error {
3796 if !strings.Contains(got, "Content-Type: text/html") {
3797 return errors.New("wrong content-type; want html")
3798 }
3799 return nil
3800 },
3801 },
3802 {
3803 name: "explicit content-type wins",
3804 handler: func(rw ResponseWriter, r *Request) {
3805 rw.Header().Set("Content-Type", "some/type")
3806 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3807 },
3808 check: func(got, logs string) error {
3809 if !strings.Contains(got, "Content-Type: some/type") {
3810 return errors.New("wrong content-type; want html")
3811 }
3812 return nil
3813 },
3814 },
3815 {
3816 name: "empty handler",
3817 handler: func(rw ResponseWriter, r *Request) {
3818 },
3819 check: func(got, logs string) error {
3820 if !strings.Contains(got, "Content-Length: 0") {
3821 return errors.New("want 0 content-length")
3822 }
3823 return nil
3824 },
3825 },
3826 {
3827 name: "only Header, no write",
3828 handler: func(rw ResponseWriter, r *Request) {
3829 rw.Header().Set("Some-Header", "some-value")
3830 },
3831 check: func(got, logs string) error {
3832 if !strings.Contains(got, "Some-Header") {
3833 return errors.New("didn't get header")
3834 }
3835 return nil
3836 },
3837 },
3838 {
3839 name: "WriteHeader call",
3840 handler: func(rw ResponseWriter, r *Request) {
3841 rw.WriteHeader(404)
3842 rw.Header().Set("Too-Late", "some-value")
3843 },
3844 check: func(got, logs string) error {
3845 if !strings.Contains(got, "404") {
3846 return errors.New("wrong status")
3847 }
3848 if strings.Contains(got, "Too-Late") {
3849 return errors.New("shouldn't have seen Too-Late")
3850 }
3851 return nil
3852 },
3853 },
3854 }
3855 for _, tc := range tests {
3856 ht := newHandlerTest(HandlerFunc(tc.handler))
3857 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3858 logs := ht.logbuf.String()
3859 if err := tc.check(got, logs); err != nil {
3860 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3861 }
3862 }
3863 }
3864
3865 type errorListener struct {
3866 errs []error
3867 }
3868
3869 func (l *errorListener) Accept() (c net.Conn, err error) {
3870 if len(l.errs) == 0 {
3871 return nil, io.EOF
3872 }
3873 err = l.errs[0]
3874 l.errs = l.errs[1:]
3875 return
3876 }
3877
3878 func (l *errorListener) Close() error {
3879 return nil
3880 }
3881
3882 func (l *errorListener) Addr() net.Addr {
3883 return dummyAddr("test-address")
3884 }
3885
3886 func TestAcceptMaxFds(t *testing.T) {
3887 setParallel(t)
3888
3889 ln := &errorListener{[]error{
3890 &net.OpError{
3891 Op: "accept",
3892 Err: syscall.EMFILE,
3893 }}}
3894 server := &Server{
3895 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3896 ErrorLog: log.New(io.Discard, "", 0),
3897 }
3898 err := server.Serve(ln)
3899 if err != io.EOF {
3900 t.Errorf("got error %v, want EOF", err)
3901 }
3902 }
3903
3904 func TestWriteAfterHijack(t *testing.T) {
3905 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3906 var buf strings.Builder
3907 wrotec := make(chan bool, 1)
3908 conn := &rwTestConn{
3909 Reader: bytes.NewReader(req),
3910 Writer: &buf,
3911 closec: make(chan bool, 1),
3912 }
3913 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3914 conn, bufrw, err := rw.(Hijacker).Hijack()
3915 if err != nil {
3916 t.Error(err)
3917 return
3918 }
3919 go func() {
3920 bufrw.Write([]byte("[hijack-to-bufw]"))
3921 bufrw.Flush()
3922 conn.Write([]byte("[hijack-to-conn]"))
3923 conn.Close()
3924 wrotec <- true
3925 }()
3926 })
3927 ln := &oneConnListener{conn: conn}
3928 go Serve(ln, handler)
3929 <-conn.closec
3930 <-wrotec
3931 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
3932 t.Errorf("wrote %q; want %q", g, w)
3933 }
3934 }
3935
3936 func TestDoubleHijack(t *testing.T) {
3937 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3938 var buf bytes.Buffer
3939 conn := &rwTestConn{
3940 Reader: bytes.NewReader(req),
3941 Writer: &buf,
3942 closec: make(chan bool, 1),
3943 }
3944 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3945 conn, _, err := rw.(Hijacker).Hijack()
3946 if err != nil {
3947 t.Error(err)
3948 return
3949 }
3950 _, _, err = rw.(Hijacker).Hijack()
3951 if err == nil {
3952 t.Errorf("got err = nil; want err != nil")
3953 }
3954 conn.Close()
3955 })
3956 ln := &oneConnListener{conn: conn}
3957 go Serve(ln, handler)
3958 <-conn.closec
3959 }
3960
3961
3962
3963
3964
3965
3966
3967 func TestHTTP10ConnectionHeader(t *testing.T) {
3968 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
3969 }
3970 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
3971 mux := NewServeMux()
3972 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
3973 ts := newClientServerTest(t, mode, mux).ts
3974
3975
3976 tests := []struct {
3977 req string
3978 expect []string
3979 }{
3980 {
3981 req: "GET / HTTP/1.0\r\n\r\n",
3982 expect: nil,
3983 },
3984 {
3985 req: "OPTIONS * HTTP/1.0\r\n\r\n",
3986 expect: nil,
3987 },
3988 {
3989 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
3990 expect: []string{"keep-alive"},
3991 },
3992 }
3993
3994 for _, tt := range tests {
3995 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3996 if err != nil {
3997 t.Fatal("dial err:", err)
3998 }
3999
4000 _, err = fmt.Fprint(conn, tt.req)
4001 if err != nil {
4002 t.Fatal("conn write err:", err)
4003 }
4004
4005 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4006 if err != nil {
4007 t.Fatal("ReadResponse err:", err)
4008 }
4009 conn.Close()
4010 resp.Body.Close()
4011
4012 got := resp.Header["Connection"]
4013 if !reflect.DeepEqual(got, tt.expect) {
4014 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4015 }
4016 }
4017 }
4018
4019
4020 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4021 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4022 pr, pw := io.Pipe()
4023 const size = 3 << 20
4024 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4025 rw.Header().Set("Content-Type", "text/plain")
4026 done := make(chan bool)
4027 go func() {
4028 io.Copy(rw, pr)
4029 close(done)
4030 }()
4031 time.Sleep(25 * time.Millisecond)
4032 n, err := io.Copy(io.Discard, req.Body)
4033 if err != nil {
4034 t.Errorf("handler Copy: %v", err)
4035 return
4036 }
4037 if n != size {
4038 t.Errorf("handler Copy = %d; want %d", n, size)
4039 }
4040 pw.Write([]byte("hi"))
4041 pw.Close()
4042 <-done
4043 }))
4044
4045 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4046 if err != nil {
4047 t.Fatal(err)
4048 }
4049 res, err := cst.c.Do(req)
4050 if err != nil {
4051 t.Fatal(err)
4052 }
4053 all, err := io.ReadAll(res.Body)
4054 if err != nil {
4055 t.Fatal(err)
4056 }
4057 res.Body.Close()
4058 if string(all) != "hi" {
4059 t.Errorf("Body = %q; want hi", all)
4060 }
4061 }
4062
4063
4064 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4065 for _, code := range []int{StatusNotModified, StatusNoContent} {
4066 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4067 if r.URL.Path == "/header" {
4068 w.Header().Set("Content-Length", "123")
4069 }
4070 w.WriteHeader(code)
4071 if r.URL.Path == "/more" {
4072 w.Write([]byte("stuff"))
4073 }
4074 }))
4075 for _, req := range []string{
4076 "GET / HTTP/1.0",
4077 "GET /header HTTP/1.0",
4078 "GET /more HTTP/1.0",
4079 "GET / HTTP/1.1\nHost: foo",
4080 "GET /header HTTP/1.1\nHost: foo",
4081 "GET /more HTTP/1.1\nHost: foo",
4082 } {
4083 got := ht.rawResponse(req)
4084 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4085 if !strings.Contains(got, wantStatus) {
4086 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4087 } else if strings.Contains(got, "Content-Length") {
4088 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4089 } else if strings.Contains(got, "stuff") {
4090 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4091 }
4092 }
4093 }
4094 }
4095
4096 func TestContentTypeOkayOn204(t *testing.T) {
4097 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4098 w.Header().Set("Content-Length", "123")
4099 w.Header().Set("Content-Type", "foo/bar")
4100 w.WriteHeader(204)
4101 }))
4102 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4103 if !strings.Contains(got, "Content-Type: foo/bar") {
4104 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4105 }
4106 if strings.Contains(got, "Content-Length: 123") {
4107 t.Errorf("Response = %q; don't want a Content-Length", got)
4108 }
4109 }
4110
4111
4112
4113
4114
4115
4116
4117 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4118 run(t, testTransportAndServerSharedBodyRace, testNotParallel)
4119 }
4120 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4121
4122
4123
4124
4125 runTimeSensitiveTest(t, []time.Duration{
4126 1 * time.Millisecond,
4127 5 * time.Millisecond,
4128 10 * time.Millisecond,
4129 50 * time.Millisecond,
4130 100 * time.Millisecond,
4131 500 * time.Millisecond,
4132 time.Second,
4133 5 * time.Second,
4134 }, func(t *testing.T, timeout time.Duration) error {
4135 SetRSTAvoidanceDelay(t, timeout)
4136 t.Logf("set RST avoidance delay to %v", timeout)
4137
4138 const bodySize = 1 << 20
4139
4140 var wg sync.WaitGroup
4141 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4142
4143
4144
4145
4146
4147
4148
4149
4150 wg.Add(1)
4151 defer wg.Done()
4152
4153 n, err := io.CopyN(rw, req.Body, bodySize)
4154 t.Logf("backend CopyN: %v, %v", n, err)
4155 <-req.Context().Done()
4156 }))
4157
4158
4159 defer func() {
4160 wg.Wait()
4161 backend.close()
4162 }()
4163
4164 var proxy *clientServerTest
4165 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4166 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4167 req2.ContentLength = bodySize
4168 cancel := make(chan struct{})
4169 req2.Cancel = cancel
4170
4171 bresp, err := proxy.c.Do(req2)
4172 if err != nil {
4173 t.Errorf("Proxy outbound request: %v", err)
4174 return
4175 }
4176 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4177 if err != nil {
4178 t.Errorf("Proxy copy error: %v", err)
4179 return
4180 }
4181 t.Cleanup(func() { bresp.Body.Close() })
4182
4183
4184
4185
4186
4187
4188 if mode == http2Mode {
4189 close(cancel)
4190 } else {
4191 proxy.c.Transport.(*Transport).CancelRequest(req2)
4192 }
4193 rw.Write([]byte("OK"))
4194 }))
4195 defer proxy.close()
4196
4197 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4198 res, err := proxy.c.Do(req)
4199 if err != nil {
4200 return fmt.Errorf("original request: %v", err)
4201 }
4202 res.Body.Close()
4203 return nil
4204 })
4205 }
4206
4207
4208
4209
4210 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4211 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4212 }
4213 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4214 if testing.Short() {
4215 t.Skip("skipping in -short mode")
4216 }
4217
4218 readErrCh := make(chan error, 1)
4219 errCh := make(chan error, 2)
4220
4221 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4222 go func(body io.Reader) {
4223 _, err := body.Read(make([]byte, 100))
4224 readErrCh <- err
4225 }(req.Body)
4226 time.Sleep(500 * time.Millisecond)
4227 })).ts
4228
4229 closeConn := make(chan bool)
4230 defer close(closeConn)
4231 go func() {
4232 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4233 if err != nil {
4234 errCh <- err
4235 return
4236 }
4237 defer conn.Close()
4238 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4239 if err != nil {
4240 errCh <- err
4241 return
4242 }
4243
4244
4245 <-closeConn
4246 }()
4247 select {
4248 case err := <-readErrCh:
4249 if err == nil {
4250 t.Error("Read was nil. Expected error.")
4251 }
4252 case err := <-errCh:
4253 t.Error(err)
4254 }
4255 }
4256
4257
4258 func TestResponseWriterWriteString(t *testing.T) {
4259 okc := make(chan bool, 1)
4260 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4261 _, ok := w.(io.StringWriter)
4262 okc <- ok
4263 }))
4264 ht.rawResponse("GET / HTTP/1.0")
4265 select {
4266 case ok := <-okc:
4267 if !ok {
4268 t.Error("ResponseWriter did not implement io.StringWriter")
4269 }
4270 default:
4271 t.Error("handler was never called")
4272 }
4273 }
4274
4275 func TestAppendTime(t *testing.T) {
4276 var b [len(TimeFormat)]byte
4277 t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60))
4278 res := ExportAppendTime(b[:0], t1)
4279 t2, err := ParseTime(string(res))
4280 if err != nil {
4281 t.Fatalf("Error parsing time: %s", err)
4282 }
4283 if !t1.Equal(t2) {
4284 t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res))
4285 }
4286 }
4287
4288 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4289 func testServerConnState(t *testing.T, mode testMode) {
4290 handler := map[string]func(w ResponseWriter, r *Request){
4291 "/": func(w ResponseWriter, r *Request) {
4292 fmt.Fprintf(w, "Hello.")
4293 },
4294 "/close": func(w ResponseWriter, r *Request) {
4295 w.Header().Set("Connection", "close")
4296 fmt.Fprintf(w, "Hello.")
4297 },
4298 "/hijack": func(w ResponseWriter, r *Request) {
4299 c, _, _ := w.(Hijacker).Hijack()
4300 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4301 c.Close()
4302 },
4303 "/hijack-panic": func(w ResponseWriter, r *Request) {
4304 c, _, _ := w.(Hijacker).Hijack()
4305 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4306 c.Close()
4307 panic("intentional panic")
4308 },
4309 }
4310
4311
4312 type stateLog struct {
4313 active net.Conn
4314 got []ConnState
4315 want []ConnState
4316 complete chan<- struct{}
4317 }
4318 activeLog := make(chan *stateLog, 1)
4319
4320
4321
4322
4323 wantLog := func(doRequests func(), want ...ConnState) {
4324 t.Helper()
4325 complete := make(chan struct{})
4326 activeLog <- &stateLog{want: want, complete: complete}
4327
4328 doRequests()
4329
4330 <-complete
4331 sl := <-activeLog
4332 if !reflect.DeepEqual(sl.got, sl.want) {
4333 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4334 }
4335
4336
4337
4338 }
4339
4340 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4341 handler[r.URL.Path](w, r)
4342 }), func(ts *httptest.Server) {
4343 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4344 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4345 if c == nil {
4346 t.Errorf("nil conn seen in state %s", state)
4347 return
4348 }
4349 sl := <-activeLog
4350 if sl.active == nil && state == StateNew {
4351 sl.active = c
4352 } else if sl.active != c {
4353 t.Errorf("unexpected conn in state %s", state)
4354 activeLog <- sl
4355 return
4356 }
4357 sl.got = append(sl.got, state)
4358 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) {
4359 close(sl.complete)
4360 sl.complete = nil
4361 }
4362 activeLog <- sl
4363 }
4364 }).ts
4365 defer func() {
4366 activeLog <- &stateLog{}
4367 ts.Close()
4368 }()
4369
4370 c := ts.Client()
4371
4372 mustGet := func(url string, headers ...string) {
4373 t.Helper()
4374 req, err := NewRequest("GET", url, nil)
4375 if err != nil {
4376 t.Fatal(err)
4377 }
4378 for len(headers) > 0 {
4379 req.Header.Add(headers[0], headers[1])
4380 headers = headers[2:]
4381 }
4382 res, err := c.Do(req)
4383 if err != nil {
4384 t.Errorf("Error fetching %s: %v", url, err)
4385 return
4386 }
4387 _, err = io.ReadAll(res.Body)
4388 defer res.Body.Close()
4389 if err != nil {
4390 t.Errorf("Error reading %s: %v", url, err)
4391 }
4392 }
4393
4394 wantLog(func() {
4395 mustGet(ts.URL + "/")
4396 mustGet(ts.URL + "/close")
4397 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4398
4399 wantLog(func() {
4400 mustGet(ts.URL + "/")
4401 mustGet(ts.URL+"/", "Connection", "close")
4402 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4403
4404 wantLog(func() {
4405 mustGet(ts.URL + "/hijack")
4406 }, StateNew, StateActive, StateHijacked)
4407
4408 wantLog(func() {
4409 mustGet(ts.URL + "/hijack-panic")
4410 }, StateNew, StateActive, StateHijacked)
4411
4412 wantLog(func() {
4413 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4414 if err != nil {
4415 t.Fatal(err)
4416 }
4417 c.Close()
4418 }, StateNew, StateClosed)
4419
4420 wantLog(func() {
4421 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4422 if err != nil {
4423 t.Fatal(err)
4424 }
4425 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4426 t.Fatal(err)
4427 }
4428 c.Read(make([]byte, 1))
4429 c.Close()
4430 }, StateNew, StateActive, StateClosed)
4431
4432 wantLog(func() {
4433 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4434 if err != nil {
4435 t.Fatal(err)
4436 }
4437 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4438 t.Fatal(err)
4439 }
4440 res, err := ReadResponse(bufio.NewReader(c), nil)
4441 if err != nil {
4442 t.Fatal(err)
4443 }
4444 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4445 t.Fatal(err)
4446 }
4447 c.Close()
4448 }, StateNew, StateActive, StateIdle, StateClosed)
4449 }
4450
4451 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4452 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4453 }
4454 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4455 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4456 }), func(ts *httptest.Server) {
4457 ts.Config.SetKeepAlivesEnabled(false)
4458 }).ts
4459 res, err := ts.Client().Get(ts.URL)
4460 if err != nil {
4461 t.Fatal(err)
4462 }
4463 defer res.Body.Close()
4464 if !res.Close {
4465 t.Errorf("Body.Close == false; want true")
4466 }
4467 }
4468
4469
4470 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4471 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4472 var n int32
4473 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4474 atomic.AddInt32(&n, 1)
4475 }), optQuietLog)
4476 var wg sync.WaitGroup
4477 const reqs = 20
4478 for i := 0; i < reqs; i++ {
4479 wg.Add(1)
4480 go func() {
4481 defer wg.Done()
4482 res, err := cst.c.Get(cst.ts.URL)
4483 if err != nil {
4484
4485
4486 time.Sleep(10 * time.Millisecond)
4487 res, err = cst.c.Get(cst.ts.URL)
4488 if err != nil {
4489 t.Error(err)
4490 return
4491 }
4492 }
4493 defer res.Body.Close()
4494 _, err = io.Copy(io.Discard, res.Body)
4495 if err != nil {
4496 t.Error(err)
4497 return
4498 }
4499 }()
4500 }
4501 wg.Wait()
4502 if got := atomic.LoadInt32(&n); got != reqs {
4503 t.Errorf("handler ran %d times; want %d", got, reqs)
4504 }
4505 }
4506
4507 func TestServerConnStateNew(t *testing.T) {
4508 sawNew := false
4509 srv := &Server{
4510 ConnState: func(c net.Conn, state ConnState) {
4511 if state == StateNew {
4512 sawNew = true
4513 }
4514 },
4515 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4516 }
4517 srv.Serve(&oneConnListener{
4518 conn: &rwTestConn{
4519 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4520 Writer: io.Discard,
4521 },
4522 })
4523 if !sawNew {
4524 t.Error("StateNew not seen")
4525 }
4526 }
4527
4528 type closeWriteTestConn struct {
4529 rwTestConn
4530 didCloseWrite bool
4531 }
4532
4533 func (c *closeWriteTestConn) CloseWrite() error {
4534 c.didCloseWrite = true
4535 return nil
4536 }
4537
4538 func TestCloseWrite(t *testing.T) {
4539 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4540
4541 var srv Server
4542 var testConn closeWriteTestConn
4543 c := ExportServerNewConn(&srv, &testConn)
4544 ExportCloseWriteAndWait(c)
4545 if !testConn.didCloseWrite {
4546 t.Error("didn't see CloseWrite call")
4547 }
4548 }
4549
4550
4551
4552
4553
4554
4555
4556
4557 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4558 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4559 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4560 io.WriteString(w, "Hello, ")
4561 w.(Flusher).Flush()
4562 conn, buf, _ := w.(Hijacker).Hijack()
4563 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4564 if err := buf.Flush(); err != nil {
4565 t.Error(err)
4566 }
4567 if err := conn.Close(); err != nil {
4568 t.Error(err)
4569 }
4570 })).ts
4571 res, err := Get(ts.URL)
4572 if err != nil {
4573 t.Fatal(err)
4574 }
4575 defer res.Body.Close()
4576 all, err := io.ReadAll(res.Body)
4577 if err != nil {
4578 t.Fatal(err)
4579 }
4580 if want := "Hello, world!"; string(all) != want {
4581 t.Errorf("Got %q; want %q", all, want)
4582 }
4583 }
4584
4585
4586
4587
4588
4589
4590
4591 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4592 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4593 }
4594 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4595 if testing.Short() {
4596 t.Skip("skipping in -short mode")
4597 }
4598 const numReq = 3
4599 addrc := make(chan string, numReq)
4600 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4601 addrc <- r.RemoteAddr
4602 time.Sleep(500 * time.Millisecond)
4603 w.(Flusher).Flush()
4604 }), func(ts *httptest.Server) {
4605 ts.Config.WriteTimeout = 250 * time.Millisecond
4606 }).ts
4607
4608 errc := make(chan error, numReq)
4609 go func() {
4610 defer close(errc)
4611 for i := 0; i < numReq; i++ {
4612 res, err := Get(ts.URL)
4613 if res != nil {
4614 res.Body.Close()
4615 }
4616 errc <- err
4617 }
4618 }()
4619
4620 addrSeen := map[string]bool{}
4621 numOkay := 0
4622 for {
4623 select {
4624 case v := <-addrc:
4625 addrSeen[v] = true
4626 case err, ok := <-errc:
4627 if !ok {
4628 if len(addrSeen) != numReq {
4629 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4630 }
4631 if numOkay != 0 {
4632 t.Errorf("got %d successful client requests; want 0", numOkay)
4633 }
4634 return
4635 }
4636 if err == nil {
4637 numOkay++
4638 }
4639 }
4640 }
4641 }
4642
4643
4644
4645 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4646 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4647 }
4648 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4649 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4650 w.Header().Set("Transfer-Encoding", "foo")
4651 io.WriteString(w, "<html>")
4652 })).ts
4653 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4654 if err != nil {
4655 t.Fatalf("Dial: %v", err)
4656 }
4657 defer c.Close()
4658 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4659 t.Fatal(err)
4660 }
4661 bs := bufio.NewScanner(c)
4662 var got strings.Builder
4663 for bs.Scan() {
4664 if strings.TrimSpace(bs.Text()) == "" {
4665 break
4666 }
4667 got.WriteString(bs.Text())
4668 got.WriteByte('\n')
4669 }
4670 if err := bs.Err(); err != nil {
4671 t.Fatal(err)
4672 }
4673 if strings.Contains(got.String(), "Content-Length") {
4674 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4675 }
4676 if strings.Contains(got.String(), "Content-Type") {
4677 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4678 }
4679 }
4680
4681
4682
4683 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4684 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4685 "\r\n\r\n" +
4686 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4687 var buf bytes.Buffer
4688 conn := &rwTestConn{
4689 Reader: bytes.NewReader(req),
4690 Writer: &buf,
4691 closec: make(chan bool, 1),
4692 }
4693 ln := &oneConnListener{conn: conn}
4694 numReq := 0
4695 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4696 numReq++
4697 }))
4698 <-conn.closec
4699 if numReq != 2 {
4700 t.Errorf("num requests = %d; want 2", numReq)
4701 t.Logf("Res: %s", buf.Bytes())
4702 }
4703 }
4704
4705 func TestIssue13893_Expect100(t *testing.T) {
4706
4707 req := reqBytes(`PUT /readbody HTTP/1.1
4708 User-Agent: PycURL/7.22.0
4709 Host: 127.0.0.1:9000
4710 Accept: */*
4711 Expect: 100-continue
4712 Content-Length: 10
4713
4714 HelloWorld
4715
4716 `)
4717 var buf bytes.Buffer
4718 conn := &rwTestConn{
4719 Reader: bytes.NewReader(req),
4720 Writer: &buf,
4721 closec: make(chan bool, 1),
4722 }
4723 ln := &oneConnListener{conn: conn}
4724 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4725 if _, ok := r.Header["Expect"]; !ok {
4726 t.Error("Expect header should not be filtered out")
4727 }
4728 }))
4729 <-conn.closec
4730 }
4731
4732 func TestIssue11549_Expect100(t *testing.T) {
4733 req := reqBytes(`PUT /readbody HTTP/1.1
4734 User-Agent: PycURL/7.22.0
4735 Host: 127.0.0.1:9000
4736 Accept: */*
4737 Expect: 100-continue
4738 Content-Length: 10
4739
4740 HelloWorldPUT /noreadbody HTTP/1.1
4741 User-Agent: PycURL/7.22.0
4742 Host: 127.0.0.1:9000
4743 Accept: */*
4744 Expect: 100-continue
4745 Content-Length: 10
4746
4747 GET /should-be-ignored HTTP/1.1
4748 Host: foo
4749
4750 `)
4751 var buf strings.Builder
4752 conn := &rwTestConn{
4753 Reader: bytes.NewReader(req),
4754 Writer: &buf,
4755 closec: make(chan bool, 1),
4756 }
4757 ln := &oneConnListener{conn: conn}
4758 numReq := 0
4759 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4760 numReq++
4761 if r.URL.Path == "/readbody" {
4762 io.ReadAll(r.Body)
4763 }
4764 io.WriteString(w, "Hello world!")
4765 }))
4766 <-conn.closec
4767 if numReq != 2 {
4768 t.Errorf("num requests = %d; want 2", numReq)
4769 }
4770 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4771 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4772 }
4773 }
4774
4775
4776
4777 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4778 setParallel(t)
4779 conn := newTestConn()
4780 conn.readBuf.WriteString(
4781 "POST / HTTP/1.1\r\n" +
4782 "Host: test\r\n" +
4783 "Content-Length: 9999999999\r\n" +
4784 "\r\n" + strings.Repeat("a", 1<<20))
4785
4786 ls := &oneConnListener{conn}
4787 var inHandlerLen int
4788 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4789 inHandlerLen = conn.readBuf.Len()
4790 rw.WriteHeader(404)
4791 }))
4792 <-conn.closec
4793 afterHandlerLen := conn.readBuf.Len()
4794
4795 if afterHandlerLen != inHandlerLen {
4796 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4797 }
4798 }
4799
4800 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4801 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4802 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4803 r.Body = nil
4804 fmt.Fprintf(w, "%v", r.RemoteAddr)
4805 }))
4806 get := func() string {
4807 res, err := cst.c.Get(cst.ts.URL)
4808 if err != nil {
4809 t.Fatal(err)
4810 }
4811 defer res.Body.Close()
4812 slurp, err := io.ReadAll(res.Body)
4813 if err != nil {
4814 t.Fatal(err)
4815 }
4816 return string(slurp)
4817 }
4818 a, b := get(), get()
4819 if a != b {
4820 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4821 }
4822 }
4823
4824
4825
4826 func TestServerValidatesHostHeader(t *testing.T) {
4827 tests := []struct {
4828 proto string
4829 host string
4830 want int
4831 }{
4832 {"HTTP/0.9", "", 505},
4833
4834 {"HTTP/1.1", "", 400},
4835 {"HTTP/1.1", "Host: \r\n", 200},
4836 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4837 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4838 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4839 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4840 {"HTTP/1.1", "Host: ::1\r\n", 200},
4841 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4842 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4843 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4844 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4845 {"HTTP/1.1", "Host: \x06\r\n", 400},
4846 {"HTTP/1.1", "Host: \xff\r\n", 400},
4847 {"HTTP/1.1", "Host: {\r\n", 400},
4848 {"HTTP/1.1", "Host: }\r\n", 400},
4849 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4850
4851
4852
4853 {"HTTP/1.0", "", 200},
4854 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4855 {"HTTP/1.0", "Host: \xff\r\n", 400},
4856
4857
4858 {"PRI * HTTP/2.0", "", 200},
4859
4860
4861 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4862
4863
4864 {"PRI / HTTP/2.0", "", 505},
4865 {"GET / HTTP/2.0", "", 505},
4866 {"GET / HTTP/3.0", "", 505},
4867 }
4868 for _, tt := range tests {
4869 conn := newTestConn()
4870 methodTarget := "GET / "
4871 if !strings.HasPrefix(tt.proto, "HTTP/") {
4872 methodTarget = ""
4873 }
4874 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4875
4876 ln := &oneConnListener{conn}
4877 srv := Server{
4878 ErrorLog: quietLog,
4879 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4880 }
4881 go srv.Serve(ln)
4882 <-conn.closec
4883 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4884 if err != nil {
4885 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4886 continue
4887 }
4888 if res.StatusCode != tt.want {
4889 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4890 }
4891 }
4892 }
4893
4894 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4895 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
4896 }
4897 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
4898 const upgradeResponse = "upgrade here"
4899 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4900 conn, br, err := w.(Hijacker).Hijack()
4901 if err != nil {
4902 t.Error(err)
4903 return
4904 }
4905 defer conn.Close()
4906 if r.Method != "PRI" || r.RequestURI != "*" {
4907 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4908 return
4909 }
4910 if !r.Close {
4911 t.Errorf("Request.Close = true; want false")
4912 }
4913 const want = "SM\r\n\r\n"
4914 buf := make([]byte, len(want))
4915 n, err := io.ReadFull(br, buf)
4916 if err != nil || string(buf[:n]) != want {
4917 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4918 return
4919 }
4920 io.WriteString(conn, upgradeResponse)
4921 })).ts
4922
4923 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4924 if err != nil {
4925 t.Fatalf("Dial: %v", err)
4926 }
4927 defer c.Close()
4928 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
4929 slurp, err := io.ReadAll(c)
4930 if err != nil {
4931 t.Fatal(err)
4932 }
4933 if string(slurp) != upgradeResponse {
4934 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
4935 }
4936 }
4937
4938
4939
4940 func TestServerValidatesHeaders(t *testing.T) {
4941 setParallel(t)
4942 tests := []struct {
4943 header string
4944 want int
4945 }{
4946 {"", 200},
4947 {"Foo: bar\r\n", 200},
4948 {"X-Foo: bar\r\n", 200},
4949 {"Foo: a space\r\n", 200},
4950
4951 {"A space: foo\r\n", 400},
4952 {"foo\xffbar: foo\r\n", 400},
4953 {"foo\x00bar: foo\r\n", 400},
4954 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
4955
4956
4957 {"Foo : bar\r\n", 400},
4958 {"Foo\t: bar\r\n", 400},
4959
4960
4961
4962 {": empty key\r\n", 400},
4963
4964
4965
4966
4967 {"Content-Length: notdigits\r\n", 400},
4968 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
4969
4970 {"foo: foo foo\r\n", 200},
4971 {"foo: foo\tfoo\r\n", 200},
4972 {"foo: foo\x00foo\r\n", 400},
4973 {"foo: foo\x7ffoo\r\n", 400},
4974 {"foo: foo\xfffoo\r\n", 200},
4975 }
4976 for _, tt := range tests {
4977 conn := newTestConn()
4978 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
4979
4980 ln := &oneConnListener{conn}
4981 srv := Server{
4982 ErrorLog: quietLog,
4983 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4984 }
4985 go srv.Serve(ln)
4986 <-conn.closec
4987 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4988 if err != nil {
4989 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
4990 continue
4991 }
4992 if res.StatusCode != tt.want {
4993 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
4994 }
4995 }
4996 }
4997
4998 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
4999 run(t, testServerRequestContextCancel_ServeHTTPDone)
5000 }
5001 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5002 ctxc := make(chan context.Context, 1)
5003 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5004 ctx := r.Context()
5005 select {
5006 case <-ctx.Done():
5007 t.Error("should not be Done in ServeHTTP")
5008 default:
5009 }
5010 ctxc <- ctx
5011 }))
5012 res, err := cst.c.Get(cst.ts.URL)
5013 if err != nil {
5014 t.Fatal(err)
5015 }
5016 res.Body.Close()
5017 ctx := <-ctxc
5018 select {
5019 case <-ctx.Done():
5020 default:
5021 t.Error("context should be done after ServeHTTP completes")
5022 }
5023 }
5024
5025
5026
5027
5028
5029 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5030 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5031 }
5032 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5033 inHandler := make(chan struct{})
5034 handlerDone := make(chan struct{})
5035 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5036 close(inHandler)
5037 <-r.Context().Done()
5038 close(handlerDone)
5039 })).ts
5040 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5041 if err != nil {
5042 t.Fatal(err)
5043 }
5044 defer c.Close()
5045 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5046 <-inHandler
5047 c.Close()
5048 <-handlerDone
5049 }
5050
5051 func TestServerContext_ServerContextKey(t *testing.T) {
5052 run(t, testServerContext_ServerContextKey)
5053 }
5054 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5055 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5056 ctx := r.Context()
5057 got := ctx.Value(ServerContextKey)
5058 if _, ok := got.(*Server); !ok {
5059 t.Errorf("context value = %T; want *http.Server", got)
5060 }
5061 }))
5062 res, err := cst.c.Get(cst.ts.URL)
5063 if err != nil {
5064 t.Fatal(err)
5065 }
5066 res.Body.Close()
5067 }
5068
5069 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5070 run(t, testServerContext_LocalAddrContextKey)
5071 }
5072 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5073 ch := make(chan any, 1)
5074 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5075 ch <- r.Context().Value(LocalAddrContextKey)
5076 }))
5077 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5078 t.Fatal(err)
5079 }
5080
5081 host := cst.ts.Listener.Addr().String()
5082 got := <-ch
5083 if addr, ok := got.(net.Addr); !ok {
5084 t.Errorf("local addr value = %T; want net.Addr", got)
5085 } else if fmt.Sprint(addr) != host {
5086 t.Errorf("local addr = %v; want %v", addr, host)
5087 }
5088 }
5089
5090
5091 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5092 setParallel(t)
5093 defer afterTest(t)
5094 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5095 w.Header().Set("Transfer-Encoding", "chunked")
5096 w.Write([]byte("hello"))
5097 }))
5098 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5099 const hdr = "Transfer-Encoding: chunked"
5100 if n := strings.Count(resp, hdr); n != 1 {
5101 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5102 }
5103 }
5104
5105
5106 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5107 setParallel(t)
5108 defer afterTest(t)
5109 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5110 w.Header().Set("Transfer-Encoding", "gzip")
5111 gz := gzip.NewWriter(w)
5112 gz.Write([]byte("hello"))
5113 gz.Close()
5114 }))
5115 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5116 for _, v := range []string{"gzip", "chunked"} {
5117 hdr := "Transfer-Encoding: " + v
5118 if n := strings.Count(resp, hdr); n != 1 {
5119 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5120 }
5121 }
5122 }
5123
5124 func BenchmarkClientServer(b *testing.B) {
5125 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5126 }
5127 func benchmarkClientServer(b *testing.B, mode testMode) {
5128 b.ReportAllocs()
5129 b.StopTimer()
5130 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5131 fmt.Fprintf(rw, "Hello world.\n")
5132 })).ts
5133 b.StartTimer()
5134
5135 c := ts.Client()
5136 for i := 0; i < b.N; i++ {
5137 res, err := c.Get(ts.URL)
5138 if err != nil {
5139 b.Fatal("Get:", err)
5140 }
5141 all, err := io.ReadAll(res.Body)
5142 res.Body.Close()
5143 if err != nil {
5144 b.Fatal("ReadAll:", err)
5145 }
5146 body := string(all)
5147 if body != "Hello world.\n" {
5148 b.Fatal("Got body:", body)
5149 }
5150 }
5151
5152 b.StopTimer()
5153 }
5154
5155 func BenchmarkClientServerParallel(b *testing.B) {
5156 for _, parallelism := range []int{4, 64} {
5157 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5158 run(b, func(b *testing.B, mode testMode) {
5159 benchmarkClientServerParallel(b, parallelism, mode)
5160 }, []testMode{http1Mode, https1Mode, http2Mode})
5161 })
5162 }
5163 }
5164
5165 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5166 b.ReportAllocs()
5167 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5168 fmt.Fprintf(rw, "Hello world.\n")
5169 })).ts
5170 b.ResetTimer()
5171 b.SetParallelism(parallelism)
5172 b.RunParallel(func(pb *testing.PB) {
5173 c := ts.Client()
5174 for pb.Next() {
5175 res, err := c.Get(ts.URL)
5176 if err != nil {
5177 b.Logf("Get: %v", err)
5178 continue
5179 }
5180 all, err := io.ReadAll(res.Body)
5181 res.Body.Close()
5182 if err != nil {
5183 b.Logf("ReadAll: %v", err)
5184 continue
5185 }
5186 body := string(all)
5187 if body != "Hello world.\n" {
5188 panic("Got body: " + body)
5189 }
5190 }
5191 })
5192 }
5193
5194
5195
5196
5197
5198
5199
5200
5201
5202
5203 func BenchmarkServer(b *testing.B) {
5204 b.ReportAllocs()
5205
5206 if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
5207 n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
5208 if err != nil {
5209 panic(err)
5210 }
5211 for i := 0; i < n; i++ {
5212 res, err := Get(url)
5213 if err != nil {
5214 log.Panicf("Get: %v", err)
5215 }
5216 all, err := io.ReadAll(res.Body)
5217 res.Body.Close()
5218 if err != nil {
5219 log.Panicf("ReadAll: %v", err)
5220 }
5221 body := string(all)
5222 if body != "Hello world.\n" {
5223 log.Panicf("Got body: %q", body)
5224 }
5225 }
5226 os.Exit(0)
5227 return
5228 }
5229
5230 var res = []byte("Hello world.\n")
5231 b.StopTimer()
5232 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5233 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5234 rw.Write(res)
5235 }))
5236 defer ts.Close()
5237 b.StartTimer()
5238
5239 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5240 cmd.Env = append([]string{
5241 fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
5242 fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
5243 }, os.Environ()...)
5244 out, err := cmd.CombinedOutput()
5245 if err != nil {
5246 b.Errorf("Test failure: %v, with output: %s", err, out)
5247 }
5248 }
5249
5250
5251 func getNoBody(urlStr string) (*Response, error) {
5252 res, err := Get(urlStr)
5253 if err != nil {
5254 return nil, err
5255 }
5256 res.Body.Close()
5257 return res, nil
5258 }
5259
5260
5261
5262 func BenchmarkClient(b *testing.B) {
5263 b.ReportAllocs()
5264 b.StopTimer()
5265 defer afterTest(b)
5266
5267 var data = []byte("Hello world.\n")
5268 if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
5269
5270 port := os.Getenv("TEST_BENCH_SERVER_PORT")
5271 if port == "" {
5272 port = "0"
5273 }
5274 ln, err := net.Listen("tcp", "localhost:"+port)
5275 if err != nil {
5276 fmt.Fprintln(os.Stderr, err.Error())
5277 os.Exit(1)
5278 }
5279 fmt.Println(ln.Addr().String())
5280 HandleFunc("/", func(w ResponseWriter, r *Request) {
5281 r.ParseForm()
5282 if r.Form.Get("stop") != "" {
5283 os.Exit(0)
5284 }
5285 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5286 w.Write(data)
5287 })
5288 var srv Server
5289 log.Fatal(srv.Serve(ln))
5290 }
5291
5292
5293 ctx, cancel := context.WithCancel(context.Background())
5294 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkClient$")
5295 cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes")
5296 cmd.Stderr = os.Stderr
5297 stdout, err := cmd.StdoutPipe()
5298 if err != nil {
5299 b.Fatal(err)
5300 }
5301 if err := cmd.Start(); err != nil {
5302 b.Fatalf("subprocess failed to start: %v", err)
5303 }
5304
5305 done := make(chan error, 1)
5306 go func() {
5307 done <- cmd.Wait()
5308 close(done)
5309 }()
5310 defer func() {
5311 cancel()
5312 <-done
5313 }()
5314
5315
5316
5317 bs := bufio.NewScanner(stdout)
5318 if !bs.Scan() {
5319 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5320 }
5321 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5322 if _, err := getNoBody(url); err != nil {
5323 b.Fatalf("initial probe of child process failed: %v", err)
5324 }
5325
5326
5327 b.StartTimer()
5328 for i := 0; i < b.N; i++ {
5329 res, err := Get(url)
5330 if err != nil {
5331 b.Fatalf("Get: %v", err)
5332 }
5333 body, err := io.ReadAll(res.Body)
5334 res.Body.Close()
5335 if err != nil {
5336 b.Fatalf("ReadAll: %v", err)
5337 }
5338 if !bytes.Equal(body, data) {
5339 b.Fatalf("Got body: %q", body)
5340 }
5341 }
5342 b.StopTimer()
5343
5344
5345 getNoBody(url + "?stop=yes")
5346 if err := <-done; err != nil {
5347 b.Fatalf("subprocess failed: %v", err)
5348 }
5349 }
5350
5351 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5352 b.ReportAllocs()
5353 req := reqBytes(`GET / HTTP/1.0
5354 Host: golang.org
5355 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5356 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5357 Accept-Encoding: gzip,deflate,sdch
5358 Accept-Language: en-US,en;q=0.8
5359 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5360 `)
5361 res := []byte("Hello world!\n")
5362
5363 conn := newTestConn()
5364 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5365 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5366 rw.Write(res)
5367 })
5368 ln := new(oneConnListener)
5369 for i := 0; i < b.N; i++ {
5370 conn.readBuf.Reset()
5371 conn.writeBuf.Reset()
5372 conn.readBuf.Write(req)
5373 ln.conn = conn
5374 Serve(ln, handler)
5375 <-conn.closec
5376 }
5377 }
5378
5379
5380 type repeatReader struct {
5381 content []byte
5382 count int
5383 off int
5384 }
5385
5386 func (r *repeatReader) Read(p []byte) (n int, err error) {
5387 if r.count <= 0 {
5388 return 0, io.EOF
5389 }
5390 n = copy(p, r.content[r.off:])
5391 r.off += n
5392 if r.off == len(r.content) {
5393 r.count--
5394 r.off = 0
5395 }
5396 return
5397 }
5398
5399 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5400 b.ReportAllocs()
5401
5402 req := reqBytes(`GET / HTTP/1.1
5403 Host: golang.org
5404 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5405 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5406 Accept-Encoding: gzip,deflate,sdch
5407 Accept-Language: en-US,en;q=0.8
5408 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5409 `)
5410 res := []byte("Hello world!\n")
5411
5412 conn := &rwTestConn{
5413 Reader: &repeatReader{content: req, count: b.N},
5414 Writer: io.Discard,
5415 closec: make(chan bool, 1),
5416 }
5417 handled := 0
5418 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5419 handled++
5420 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5421 rw.Write(res)
5422 })
5423 ln := &oneConnListener{conn: conn}
5424 go Serve(ln, handler)
5425 <-conn.closec
5426 if b.N != handled {
5427 b.Errorf("b.N=%d but handled %d", b.N, handled)
5428 }
5429 }
5430
5431
5432
5433 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5434 b.ReportAllocs()
5435
5436 req := reqBytes(`GET / HTTP/1.1
5437 Host: golang.org
5438 `)
5439 res := []byte("Hello world!\n")
5440
5441 conn := &rwTestConn{
5442 Reader: &repeatReader{content: req, count: b.N},
5443 Writer: io.Discard,
5444 closec: make(chan bool, 1),
5445 }
5446 handled := 0
5447 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5448 handled++
5449 rw.Write(res)
5450 })
5451 ln := &oneConnListener{conn: conn}
5452 go Serve(ln, handler)
5453 <-conn.closec
5454 if b.N != handled {
5455 b.Errorf("b.N=%d but handled %d", b.N, handled)
5456 }
5457 }
5458
5459 const someResponse = "<html>some response</html>"
5460
5461
5462 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5463
5464
5465 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5466 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5467 w.Header().Set("Content-Type", "text/html")
5468 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5469 w.Write(response)
5470 }))
5471 }
5472
5473
5474 func BenchmarkServerHandlerNoLen(b *testing.B) {
5475 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5476 w.Header().Set("Content-Type", "text/html")
5477 w.Write(response)
5478 }))
5479 }
5480
5481
5482 func BenchmarkServerHandlerNoType(b *testing.B) {
5483 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5484 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5485 w.Write(response)
5486 }))
5487 }
5488
5489
5490 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5491 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5492 w.Write(response)
5493 }))
5494 }
5495
5496 func benchmarkHandler(b *testing.B, h Handler) {
5497 b.ReportAllocs()
5498 req := reqBytes(`GET / HTTP/1.1
5499 Host: golang.org
5500 `)
5501 conn := &rwTestConn{
5502 Reader: &repeatReader{content: req, count: b.N},
5503 Writer: io.Discard,
5504 closec: make(chan bool, 1),
5505 }
5506 handled := 0
5507 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5508 handled++
5509 h.ServeHTTP(rw, r)
5510 })
5511 ln := &oneConnListener{conn: conn}
5512 go Serve(ln, handler)
5513 <-conn.closec
5514 if b.N != handled {
5515 b.Errorf("b.N=%d but handled %d", b.N, handled)
5516 }
5517 }
5518
5519 func BenchmarkServerHijack(b *testing.B) {
5520 b.ReportAllocs()
5521 req := reqBytes(`GET / HTTP/1.1
5522 Host: golang.org
5523 `)
5524 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5525 conn, _, err := w.(Hijacker).Hijack()
5526 if err != nil {
5527 panic(err)
5528 }
5529 conn.Close()
5530 })
5531 conn := &rwTestConn{
5532 Writer: io.Discard,
5533 closec: make(chan bool, 1),
5534 }
5535 ln := &oneConnListener{conn: conn}
5536 for i := 0; i < b.N; i++ {
5537 conn.Reader = bytes.NewReader(req)
5538 ln.conn = conn
5539 Serve(ln, h)
5540 <-conn.closec
5541 }
5542 }
5543
5544 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5545 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5546 b.ReportAllocs()
5547 b.StopTimer()
5548 sawClose := make(chan bool)
5549 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5550 <-rw.(CloseNotifier).CloseNotify()
5551 sawClose <- true
5552 })).ts
5553 b.StartTimer()
5554 for i := 0; i < b.N; i++ {
5555 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5556 if err != nil {
5557 b.Fatalf("error dialing: %v", err)
5558 }
5559 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5560 if err != nil {
5561 b.Fatal(err)
5562 }
5563 conn.Close()
5564 <-sawClose
5565 }
5566 b.StopTimer()
5567 }
5568
5569
5570 func TestConcurrentServerServe(t *testing.T) {
5571 setParallel(t)
5572 for i := 0; i < 100; i++ {
5573 ln1 := &oneConnListener{conn: nil}
5574 ln2 := &oneConnListener{conn: nil}
5575 srv := Server{}
5576 go func() { srv.Serve(ln1) }()
5577 go func() { srv.Serve(ln2) }()
5578 }
5579 }
5580
5581 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5582 func testServerIdleTimeout(t *testing.T, mode testMode) {
5583 if testing.Short() {
5584 t.Skip("skipping in short mode")
5585 }
5586 runTimeSensitiveTest(t, []time.Duration{
5587 10 * time.Millisecond,
5588 100 * time.Millisecond,
5589 1 * time.Second,
5590 10 * time.Second,
5591 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5592 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5593 io.Copy(io.Discard, r.Body)
5594 io.WriteString(w, r.RemoteAddr)
5595 }), func(ts *httptest.Server) {
5596 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5597 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5598 })
5599 defer cst.close()
5600 ts := cst.ts
5601 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5602 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5603 c := ts.Client()
5604
5605 get := func() (string, error) {
5606 res, err := c.Get(ts.URL)
5607 if err != nil {
5608 return "", err
5609 }
5610 defer res.Body.Close()
5611 slurp, err := io.ReadAll(res.Body)
5612 if err != nil {
5613
5614
5615
5616 t.Fatal(err)
5617 }
5618 return string(slurp), nil
5619 }
5620
5621 a1, err := get()
5622 if err != nil {
5623 return err
5624 }
5625 a2, err := get()
5626 if err != nil {
5627 return err
5628 }
5629 if a1 != a2 {
5630 return fmt.Errorf("did requests on different connections")
5631 }
5632 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5633 a3, err := get()
5634 if err != nil {
5635 return err
5636 }
5637 if a2 == a3 {
5638 return fmt.Errorf("request three unexpectedly on same connection")
5639 }
5640
5641
5642 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5643 if err != nil {
5644 return err
5645 }
5646 defer conn.Close()
5647 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5648 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5649 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5650 return fmt.Errorf("copy byte succeeded; want err")
5651 }
5652
5653 return nil
5654 })
5655 }
5656
5657 func get(t *testing.T, c *Client, url string) string {
5658 res, err := c.Get(url)
5659 if err != nil {
5660 t.Fatal(err)
5661 }
5662 defer res.Body.Close()
5663 slurp, err := io.ReadAll(res.Body)
5664 if err != nil {
5665 t.Fatal(err)
5666 }
5667 return string(slurp)
5668 }
5669
5670
5671
5672 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5673 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5674 }
5675 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5676 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5677 io.WriteString(w, r.RemoteAddr)
5678 })).ts
5679
5680 c := ts.Client()
5681 tr := c.Transport.(*Transport)
5682
5683 get := func() string { return get(t, c, ts.URL) }
5684
5685 a1, a2 := get(), get()
5686 if a1 == a2 {
5687 t.Logf("made two requests from a single conn %q (as expected)", a1)
5688 } else {
5689 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5690 }
5691
5692
5693
5694
5695
5696 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5697 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5698 }
5699
5700
5701 ts.Config.SetKeepAlivesEnabled(false)
5702
5703 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5704 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5705 if d > 0 {
5706 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5707 }
5708 return false
5709 }
5710 return true
5711 })
5712
5713
5714
5715
5716 }
5717
5718 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5719 func testServerShutdown(t *testing.T, mode testMode) {
5720 var cst *clientServerTest
5721
5722 var once sync.Once
5723 statesRes := make(chan map[ConnState]int, 1)
5724 shutdownRes := make(chan error, 1)
5725 gotOnShutdown := make(chan struct{})
5726 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5727 first := false
5728 once.Do(func() {
5729 statesRes <- cst.ts.Config.ExportAllConnsByState()
5730 go func() {
5731 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5732 }()
5733 first = true
5734 })
5735
5736 if first {
5737
5738
5739
5740 <-gotOnShutdown
5741
5742
5743 for !t.Failed() {
5744 res, err := cst.c.Get(cst.ts.URL)
5745 if err != nil {
5746 break
5747 }
5748 out, _ := io.ReadAll(res.Body)
5749 res.Body.Close()
5750 if mode == http2Mode {
5751 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5752 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5753 continue
5754 }
5755 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5756 }
5757 }
5758
5759 io.WriteString(w, r.RemoteAddr)
5760 })
5761
5762 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5763 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5764 })
5765
5766 out := get(t, cst.c, cst.ts.URL)
5767 t.Logf("%v: %q", cst.ts.URL, out)
5768
5769 if err := <-shutdownRes; err != nil {
5770 t.Fatalf("Shutdown: %v", err)
5771 }
5772 <-gotOnShutdown
5773
5774 if states := <-statesRes; states[StateActive] != 1 {
5775 t.Errorf("connection in wrong state, %v", states)
5776 }
5777 }
5778
5779 func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) }
5780 func testServerShutdownStateNew(t *testing.T, mode testMode) {
5781 if testing.Short() {
5782 t.Skip("test takes 5-6 seconds; skipping in short mode")
5783 }
5784
5785 var connAccepted sync.WaitGroup
5786 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5787
5788 }), func(ts *httptest.Server) {
5789 ts.Config.ConnState = func(conn net.Conn, state ConnState) {
5790 if state == StateNew {
5791 connAccepted.Done()
5792 }
5793 }
5794 }).ts
5795
5796
5797 connAccepted.Add(1)
5798 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5799 if err != nil {
5800 t.Fatal(err)
5801 }
5802 defer c.Close()
5803
5804
5805
5806
5807
5808 connAccepted.Wait()
5809
5810 shutdownRes := make(chan error, 1)
5811 go func() {
5812 shutdownRes <- ts.Config.Shutdown(context.Background())
5813 }()
5814 readRes := make(chan error, 1)
5815 go func() {
5816 _, err := c.Read([]byte{0})
5817 readRes <- err
5818 }()
5819
5820
5821
5822
5823 const expectTimeout = 5 * time.Second
5824
5825 t0 := time.Now()
5826 select {
5827 case got := <-shutdownRes:
5828 d := time.Since(t0)
5829 if got != nil {
5830 t.Fatalf("shutdown error after %v: %v", d, err)
5831 }
5832 if d < expectTimeout/2 {
5833 t.Errorf("shutdown too soon after %v", d)
5834 }
5835 case <-time.After(expectTimeout * 3 / 2):
5836 t.Fatalf("timeout waiting for shutdown")
5837 }
5838
5839
5840
5841 if err := <-readRes; err == nil {
5842 t.Error("expected error from Read")
5843 }
5844 }
5845
5846
5847 func TestServerCloseDeadlock(t *testing.T) {
5848 var s Server
5849 s.Close()
5850 s.Close()
5851 }
5852
5853
5854
5855 func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
5856 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
5857 if mode == http2Mode {
5858 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5859 defer restore()
5860 }
5861
5862 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
5863 defer cst.close()
5864 srv := cst.ts.Config
5865 srv.SetKeepAlivesEnabled(false)
5866 for try := 0; try < 2; try++ {
5867 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5868 if !srv.ExportAllConnsIdle() {
5869 if d > 0 {
5870 t.Logf("test server still has active conns after %v", d)
5871 }
5872 return false
5873 }
5874 return true
5875 })
5876 conns := 0
5877 var info httptrace.GotConnInfo
5878 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5879 GotConn: func(v httptrace.GotConnInfo) {
5880 conns++
5881 info = v
5882 },
5883 })
5884 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
5885 if err != nil {
5886 t.Fatal(err)
5887 }
5888 res, err := cst.c.Do(req)
5889 if err != nil {
5890 t.Fatal(err)
5891 }
5892 res.Body.Close()
5893 if conns != 1 {
5894 t.Fatalf("request %v: got %v conns, want 1", try, conns)
5895 }
5896 if info.Reused || info.WasIdle {
5897 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
5898 }
5899 }
5900 }
5901
5902
5903
5904
5905 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
5906 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
5907 runTimeSensitiveTest(t, []time.Duration{
5908 10 * time.Millisecond,
5909 50 * time.Millisecond,
5910 250 * time.Millisecond,
5911 time.Second,
5912 2 * time.Second,
5913 }, func(t *testing.T, timeout time.Duration) error {
5914 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5915 select {
5916 case <-time.After(2 * timeout):
5917 fmt.Fprint(w, "ok")
5918 case <-r.Context().Done():
5919 fmt.Fprint(w, r.Context().Err())
5920 }
5921 }), func(ts *httptest.Server) {
5922 ts.Config.ReadTimeout = timeout
5923 t.Logf("Server.Config.ReadTimeout = %v", timeout)
5924 })
5925 defer cst.close()
5926 ts := cst.ts
5927
5928 var retries atomic.Int32
5929 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
5930 if retries.Add(1) != 1 {
5931 return nil, errors.New("too many retries")
5932 }
5933 return nil, nil
5934 }
5935
5936 c := ts.Client()
5937
5938 res, err := c.Get(ts.URL)
5939 if err != nil {
5940 return fmt.Errorf("Get: %v", err)
5941 }
5942 slurp, err := io.ReadAll(res.Body)
5943 res.Body.Close()
5944 if err != nil {
5945 return fmt.Errorf("Body ReadAll: %v", err)
5946 }
5947 if string(slurp) != "ok" {
5948 return fmt.Errorf("got: %q, want ok", slurp)
5949 }
5950 return nil
5951 })
5952 }
5953
5954
5955
5956
5957 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
5958 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
5959 }
5960 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
5961 runTimeSensitiveTest(t, []time.Duration{
5962 10 * time.Millisecond,
5963 50 * time.Millisecond,
5964 250 * time.Millisecond,
5965 time.Second,
5966 2 * time.Second,
5967 }, func(t *testing.T, timeout time.Duration) error {
5968 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
5969 ts.Config.ReadHeaderTimeout = timeout
5970 ts.Config.IdleTimeout = 0
5971 })
5972 defer cst.close()
5973 ts := cst.ts
5974
5975
5976
5977 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5978 if err != nil {
5979 t.Fatalf("dial failed: %v", err)
5980 }
5981 br := bufio.NewReader(conn)
5982 defer conn.Close()
5983
5984 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
5985 return fmt.Errorf("writing first request failed: %v", err)
5986 }
5987
5988 if _, err := ReadResponse(br, nil); err != nil {
5989 return fmt.Errorf("first response (before timeout) failed: %v", err)
5990 }
5991
5992
5993
5994 time.Sleep(timeout * 3 / 2)
5995
5996 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
5997 return fmt.Errorf("writing second request failed: %v", err)
5998 }
5999
6000 if _, err := ReadResponse(br, nil); err != nil {
6001 return fmt.Errorf("second response (after timeout) failed: %v", err)
6002 }
6003
6004 return nil
6005 })
6006 }
6007
6008
6009
6010 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6011 for i, d := range durations {
6012 err := test(t, d)
6013 if err == nil {
6014 return
6015 }
6016 if i == len(durations)-1 || t.Failed() {
6017 t.Fatalf("failed with duration %v: %v", d, err)
6018 }
6019 t.Logf("retrying after error with duration %v: %v", d, err)
6020 }
6021 }
6022
6023
6024
6025 func TestServerDuplicateBackgroundRead(t *testing.T) {
6026 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6027 }
6028 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6029 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6030 testenv.SkipFlaky(t, 24826)
6031 }
6032
6033 goroutines := 5
6034 requests := 2000
6035 if testing.Short() {
6036 goroutines = 3
6037 requests = 100
6038 }
6039
6040 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6041
6042 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6043
6044 var wg sync.WaitGroup
6045 for i := 0; i < goroutines; i++ {
6046 wg.Add(1)
6047 go func() {
6048 defer wg.Done()
6049 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6050 if err != nil {
6051 t.Error(err)
6052 return
6053 }
6054 defer cn.Close()
6055
6056 wg.Add(1)
6057 go func() {
6058 defer wg.Done()
6059 io.Copy(io.Discard, cn)
6060 }()
6061
6062 for j := 0; j < requests; j++ {
6063 if t.Failed() {
6064 return
6065 }
6066 _, err := cn.Write(reqBytes)
6067 if err != nil {
6068 t.Error(err)
6069 return
6070 }
6071 }
6072 }()
6073 }
6074 wg.Wait()
6075 }
6076
6077
6078
6079
6080
6081
6082 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6083 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6084 }
6085 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6086 if runtime.GOOS == "plan9" {
6087 t.Skip("skipping test; see https://golang.org/issue/18657")
6088 }
6089 done := make(chan struct{})
6090 inHandler := make(chan bool, 1)
6091 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6092 defer close(done)
6093
6094
6095 inHandler <- true
6096
6097 conn, buf, err := w.(Hijacker).Hijack()
6098 if err != nil {
6099 t.Error(err)
6100 return
6101 }
6102 defer conn.Close()
6103
6104 peek, err := buf.Reader.Peek(3)
6105 if string(peek) != "foo" || err != nil {
6106 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6107 }
6108
6109 select {
6110 case <-r.Context().Done():
6111 t.Error("context unexpectedly canceled")
6112 default:
6113 }
6114 })).ts
6115
6116 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6117 if err != nil {
6118 t.Fatal(err)
6119 }
6120 defer cn.Close()
6121 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6122 t.Fatal(err)
6123 }
6124 <-inHandler
6125 if _, err := cn.Write([]byte("foo")); err != nil {
6126 t.Fatal(err)
6127 }
6128
6129 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6130 t.Fatal(err)
6131 }
6132 <-done
6133 }
6134
6135
6136
6137
6138 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6139 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6140 }
6141 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6142 if runtime.GOOS == "plan9" {
6143 t.Skip("skipping test; see https://golang.org/issue/18657")
6144 }
6145 done := make(chan struct{})
6146 const size = 8 << 10
6147 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6148 defer close(done)
6149
6150 conn, buf, err := w.(Hijacker).Hijack()
6151 if err != nil {
6152 t.Error(err)
6153 return
6154 }
6155 defer conn.Close()
6156 slurp, err := io.ReadAll(buf.Reader)
6157 if err != nil {
6158 t.Errorf("Copy: %v", err)
6159 }
6160 allX := true
6161 for _, v := range slurp {
6162 if v != 'x' {
6163 allX = false
6164 }
6165 }
6166 if len(slurp) != size {
6167 t.Errorf("read %d; want %d", len(slurp), size)
6168 } else if !allX {
6169 t.Errorf("read %q; want %d 'x'", slurp, size)
6170 }
6171 })).ts
6172
6173 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6174 if err != nil {
6175 t.Fatal(err)
6176 }
6177 defer cn.Close()
6178 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6179 strings.Repeat("x", size)); err != nil {
6180 t.Fatal(err)
6181 }
6182 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6183 t.Fatal(err)
6184 }
6185
6186 <-done
6187 }
6188
6189
6190 func TestServerValidatesMethod(t *testing.T) {
6191 tests := []struct {
6192 method string
6193 want int
6194 }{
6195 {"GET", 200},
6196 {"GE(T", 400},
6197 }
6198 for _, tt := range tests {
6199 conn := newTestConn()
6200 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6201
6202 ln := &oneConnListener{conn}
6203 go Serve(ln, serve(200))
6204 <-conn.closec
6205 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6206 if err != nil {
6207 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6208 continue
6209 }
6210 if res.StatusCode != tt.want {
6211 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6212 }
6213 }
6214 }
6215
6216
6217 type eofListenerNotComparable []int
6218
6219 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6220 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6221 func (eofListenerNotComparable) Close() error { return nil }
6222
6223
6224 func TestServerListenNotComparableListener(t *testing.T) {
6225 var s Server
6226 s.Serve(make(eofListenerNotComparable, 1))
6227 }
6228
6229
6230 type countCloseListener struct {
6231 net.Listener
6232 closes int32
6233 }
6234
6235 func (p *countCloseListener) Close() error {
6236 var err error
6237 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6238 err = p.Listener.Close()
6239 }
6240 return err
6241 }
6242
6243
6244 func TestServerCloseListenerOnce(t *testing.T) {
6245 setParallel(t)
6246 defer afterTest(t)
6247
6248 ln := newLocalListener(t)
6249 defer ln.Close()
6250
6251 cl := &countCloseListener{Listener: ln}
6252 server := &Server{}
6253 sdone := make(chan bool, 1)
6254
6255 go func() {
6256 server.Serve(cl)
6257 sdone <- true
6258 }()
6259 time.Sleep(10 * time.Millisecond)
6260 server.Shutdown(context.Background())
6261 ln.Close()
6262 <-sdone
6263
6264 nclose := atomic.LoadInt32(&cl.closes)
6265 if nclose != 1 {
6266 t.Errorf("Close calls = %v; want 1", nclose)
6267 }
6268 }
6269
6270
6271 func TestServerShutdownThenServe(t *testing.T) {
6272 var srv Server
6273 cl := &countCloseListener{Listener: nil}
6274 srv.Shutdown(context.Background())
6275 got := srv.Serve(cl)
6276 if got != ErrServerClosed {
6277 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6278 }
6279 nclose := atomic.LoadInt32(&cl.closes)
6280 if nclose != 1 {
6281 t.Errorf("Close calls = %v; want 1", nclose)
6282 }
6283 }
6284
6285
6286 func TestStripPortFromHost(t *testing.T) {
6287 mux := NewServeMux()
6288
6289 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6290 fmt.Fprintf(w, "OK")
6291 })
6292 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6293 fmt.Fprintf(w, "uh-oh!")
6294 })
6295
6296 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6297 rw := httptest.NewRecorder()
6298
6299 mux.ServeHTTP(rw, req)
6300
6301 response := rw.Body.String()
6302 if response != "OK" {
6303 t.Errorf("Response gotten was %q", response)
6304 }
6305 }
6306
6307 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6308 func testServerContexts(t *testing.T, mode testMode) {
6309 type baseKey struct{}
6310 type connKey struct{}
6311 ch := make(chan context.Context, 1)
6312 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6313 ch <- r.Context()
6314 }), func(ts *httptest.Server) {
6315 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6316 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6317 t.Errorf("unexpected onceClose listener type %T", ln)
6318 }
6319 return context.WithValue(context.Background(), baseKey{}, "base")
6320 }
6321 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6322 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6323 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6324 }
6325 return context.WithValue(ctx, connKey{}, "conn")
6326 }
6327 }).ts
6328 res, err := ts.Client().Get(ts.URL)
6329 if err != nil {
6330 t.Fatal(err)
6331 }
6332 res.Body.Close()
6333 ctx := <-ch
6334 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6335 t.Errorf("base context key = %#v; want %q", got, want)
6336 }
6337 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6338 t.Errorf("conn context key = %#v; want %q", got, want)
6339 }
6340 }
6341
6342
6343 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6344 run(t, testConnContextNotModifyingAllContexts)
6345 }
6346 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6347 type connKey struct{}
6348 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6349 rw.Header().Set("Connection", "close")
6350 }), func(ts *httptest.Server) {
6351 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6352 if got := ctx.Value(connKey{}); got != nil {
6353 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6354 }
6355 return context.WithValue(ctx, connKey{}, "conn")
6356 }
6357 }).ts
6358
6359 var res *Response
6360 var err error
6361
6362 res, err = ts.Client().Get(ts.URL)
6363 if err != nil {
6364 t.Fatal(err)
6365 }
6366 res.Body.Close()
6367
6368 res, err = ts.Client().Get(ts.URL)
6369 if err != nil {
6370 t.Fatal(err)
6371 }
6372 res.Body.Close()
6373 }
6374
6375
6376
6377 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6378 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6379 }
6380 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6381 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6382 w.Write([]byte("Hello, World!"))
6383 })).ts
6384
6385 serverURL, err := url.Parse(cst.URL)
6386 if err != nil {
6387 t.Fatalf("Failed to parse server URL: %v", err)
6388 }
6389
6390 unsupportedTEs := []string{
6391 "fugazi",
6392 "foo-bar",
6393 "unknown",
6394 `" chunked"`,
6395 }
6396
6397 for _, badTE := range unsupportedTEs {
6398 http1ReqBody := fmt.Sprintf(""+
6399 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6400 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6401
6402 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6403 if err != nil {
6404 t.Errorf("%q. unexpected error: %v", badTE, err)
6405 continue
6406 }
6407
6408 wantBody := fmt.Sprintf("" +
6409 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6410 "Connection: close\r\n\r\nUnsupported transfer encoding")
6411
6412 if string(gotBody) != wantBody {
6413 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6414 }
6415 }
6416 }
6417
6418
6419 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6420 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6421 type setting struct {
6422 name string
6423 body []byte
6424
6425
6426
6427
6428 contentEncoding any
6429 wantContentType string
6430 }
6431
6432 settings := []*setting{
6433 {
6434 name: "gzip content-encoding, gzipped",
6435 contentEncoding: "application/gzip",
6436 wantContentType: "",
6437 body: func() []byte {
6438 buf := new(bytes.Buffer)
6439 gzw := gzip.NewWriter(buf)
6440 gzw.Write([]byte("doctype html><p>Hello</p>"))
6441 gzw.Close()
6442 return buf.Bytes()
6443 }(),
6444 },
6445 {
6446 name: "zlib content-encoding, zlibbed",
6447 contentEncoding: "application/zlib",
6448 wantContentType: "",
6449 body: func() []byte {
6450 buf := new(bytes.Buffer)
6451 zw := zlib.NewWriter(buf)
6452 zw.Write([]byte("doctype html><p>Hello</p>"))
6453 zw.Close()
6454 return buf.Bytes()
6455 }(),
6456 },
6457 {
6458 name: "no content-encoding",
6459 wantContentType: "application/x-gzip",
6460 body: func() []byte {
6461 buf := new(bytes.Buffer)
6462 gzw := gzip.NewWriter(buf)
6463 gzw.Write([]byte("doctype html><p>Hello</p>"))
6464 gzw.Close()
6465 return buf.Bytes()
6466 }(),
6467 },
6468 {
6469 name: "phony content-encoding",
6470 contentEncoding: "foo/bar",
6471 body: []byte("doctype html><p>Hello</p>"),
6472 },
6473 {
6474 name: "empty but set content-encoding",
6475 contentEncoding: "",
6476 wantContentType: "audio/mpeg",
6477 body: []byte("ID3"),
6478 },
6479 }
6480
6481 for _, tt := range settings {
6482 t.Run(tt.name, func(t *testing.T) {
6483 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6484 if tt.contentEncoding != nil {
6485 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6486 }
6487 rw.Write(tt.body)
6488 }))
6489
6490 res, err := cst.c.Get(cst.ts.URL)
6491 if err != nil {
6492 t.Fatalf("Failed to fetch URL: %v", err)
6493 }
6494 defer res.Body.Close()
6495
6496 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6497 if w != nil {
6498 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6499 } else if g != "" {
6500 t.Errorf("Unexpected Content-Encoding %q", g)
6501 }
6502 }
6503
6504 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6505 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6506 }
6507 })
6508 }
6509 }
6510
6511
6512
6513 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6514 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6515 }
6516 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6517 if testing.Short() {
6518 t.Skip("skipping in short mode")
6519 }
6520
6521 pc, curFile, _, _ := runtime.Caller(0)
6522 curFileBaseName := filepath.Base(curFile)
6523 testFuncName := runtime.FuncForPC(pc).Name()
6524
6525 timeoutMsg := "timed out here!"
6526
6527 tests := []struct {
6528 name string
6529 mustTimeout bool
6530 wantResp string
6531 }{
6532 {
6533 name: "return before timeout",
6534 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6535 },
6536 {
6537 name: "return after timeout",
6538 mustTimeout: true,
6539 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6540 len(timeoutMsg), timeoutMsg),
6541 },
6542 }
6543
6544 for _, tt := range tests {
6545 tt := tt
6546 t.Run(tt.name, func(t *testing.T) {
6547 exitHandler := make(chan bool, 1)
6548 defer close(exitHandler)
6549 lastLine := make(chan int, 1)
6550
6551 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6552 w.WriteHeader(404)
6553 w.WriteHeader(404)
6554 w.WriteHeader(404)
6555 w.WriteHeader(404)
6556 _, _, line, _ := runtime.Caller(0)
6557 lastLine <- line
6558 <-exitHandler
6559 })
6560
6561 if !tt.mustTimeout {
6562 exitHandler <- true
6563 }
6564
6565 logBuf := new(strings.Builder)
6566 srvLog := log.New(logBuf, "", 0)
6567
6568 dur := 20 * time.Millisecond
6569 if !tt.mustTimeout {
6570
6571 dur = 10 * time.Second
6572 }
6573 th := TimeoutHandler(sh, dur, timeoutMsg)
6574 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6575 defer cst.close()
6576
6577 res, err := cst.c.Get(cst.ts.URL)
6578 if err != nil {
6579 t.Fatalf("Unexpected error: %v", err)
6580 }
6581
6582
6583
6584 res.Header.Del("Date")
6585 res.Header.Del("Content-Type")
6586
6587
6588 blob, _ := httputil.DumpResponse(res, true)
6589 if g, w := string(blob), tt.wantResp; g != w {
6590 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6591 }
6592
6593
6594
6595 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6596 if g, w := len(logEntries), 3; g != w {
6597 blob, _ := json.MarshalIndent(logEntries, "", " ")
6598 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6599 }
6600
6601 lastSpuriousLine := <-lastLine
6602 firstSpuriousLine := lastSpuriousLine - 3
6603
6604
6605 for i, logEntry := range logEntries {
6606 wantLine := firstSpuriousLine + i
6607 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6608 testFuncName, curFileBaseName, wantLine)
6609 re := regexp.MustCompile(pat)
6610 if !re.MatchString(logEntry) {
6611 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6612 }
6613 }
6614 })
6615 }
6616 }
6617
6618
6619
6620
6621 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6622 conn, err := net.Dial("tcp", host)
6623 if err != nil {
6624 return nil, err
6625 }
6626 defer conn.Close()
6627
6628 if _, err := conn.Write(http1ReqBody); err != nil {
6629 return nil, err
6630 }
6631 return io.ReadAll(conn)
6632 }
6633
6634 func BenchmarkResponseStatusLine(b *testing.B) {
6635 b.ReportAllocs()
6636 b.RunParallel(func(pb *testing.PB) {
6637 bw := bufio.NewWriter(io.Discard)
6638 var buf3 [3]byte
6639 for pb.Next() {
6640 Export_writeStatusLine(bw, true, 200, buf3[:])
6641 }
6642 })
6643 }
6644
6645 func TestDisableKeepAliveUpgrade(t *testing.T) {
6646 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6647 }
6648 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6649 if testing.Short() {
6650 t.Skip("skipping in short mode")
6651 }
6652
6653 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6654 w.Header().Set("Connection", "Upgrade")
6655 w.Header().Set("Upgrade", "someProto")
6656 w.WriteHeader(StatusSwitchingProtocols)
6657 c, buf, err := w.(Hijacker).Hijack()
6658 if err != nil {
6659 return
6660 }
6661 defer c.Close()
6662
6663
6664
6665 io.Copy(c, buf)
6666 }), func(ts *httptest.Server) {
6667 ts.Config.SetKeepAlivesEnabled(false)
6668 }).ts
6669
6670 cl := s.Client()
6671 cl.Transport.(*Transport).DisableKeepAlives = true
6672
6673 resp, err := cl.Get(s.URL)
6674 if err != nil {
6675 t.Fatalf("failed to perform request: %v", err)
6676 }
6677 defer resp.Body.Close()
6678
6679 if resp.StatusCode != StatusSwitchingProtocols {
6680 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6681 }
6682
6683 rwc, ok := resp.Body.(io.ReadWriteCloser)
6684 if !ok {
6685 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6686 }
6687
6688 _, err = rwc.Write([]byte("hello"))
6689 if err != nil {
6690 t.Fatalf("failed to write to body: %v", err)
6691 }
6692
6693 b := make([]byte, 5)
6694 _, err = io.ReadFull(rwc, b)
6695 if err != nil {
6696 t.Fatalf("failed to read from body: %v", err)
6697 }
6698
6699 if string(b) != "hello" {
6700 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6701 }
6702 }
6703
6704 type tlogWriter struct{ t *testing.T }
6705
6706 func (w tlogWriter) Write(p []byte) (int, error) {
6707 w.t.Log(string(p))
6708 return len(p), nil
6709 }
6710
6711 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6712 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6713 }
6714 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6715 const wantBody = "want"
6716 const wantUpgrade = "someProto"
6717 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6718 w.Header().Set("Connection", "Upgrade")
6719 w.Header().Set("Upgrade", wantUpgrade)
6720 w.WriteHeader(StatusSwitchingProtocols)
6721 NewResponseController(w).Flush()
6722
6723
6724 w.WriteHeader(200)
6725 if _, err := w.Write([]byte("x")); err == nil {
6726 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6727 }
6728
6729 c, _, err := NewResponseController(w).Hijack()
6730 if err != nil {
6731 t.Errorf("Hijack: %v", err)
6732 return
6733 }
6734 defer c.Close()
6735 if _, err := c.Write([]byte(wantBody)); err != nil {
6736 t.Errorf("Write to hijacked body: %v", err)
6737 }
6738 }), func(ts *httptest.Server) {
6739
6740 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6741 }).ts
6742
6743 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6744 if err != nil {
6745 t.Fatalf("net.Dial: %v", err)
6746 }
6747 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6748 if err != nil {
6749 t.Fatalf("conn.Write: %v", err)
6750 }
6751 defer conn.Close()
6752
6753 r := bufio.NewReader(conn)
6754 res, err := ReadResponse(r, &Request{Method: "GET"})
6755 if err != nil {
6756 t.Fatal("ReadResponse error:", err)
6757 }
6758 if res.StatusCode != StatusSwitchingProtocols {
6759 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
6760 }
6761 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
6762 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
6763 }
6764 body, err := io.ReadAll(r)
6765 if err != nil {
6766 t.Error(err)
6767 }
6768 if string(body) != wantBody {
6769 t.Errorf("Response body = %q, want %q", string(body), wantBody)
6770 }
6771 }
6772
6773 func TestMuxRedirectRelative(t *testing.T) {
6774 setParallel(t)
6775 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6776 if err != nil {
6777 t.Errorf("%s", err)
6778 }
6779 mux := NewServeMux()
6780 resp := httptest.NewRecorder()
6781 mux.ServeHTTP(resp, req)
6782 if got, want := resp.Header().Get("Location"), "/"; got != want {
6783 t.Errorf("Location header expected %q; got %q", want, got)
6784 }
6785 if got, want := resp.Code, StatusMovedPermanently; got != want {
6786 t.Errorf("Expected response code %d; got %d", want, got)
6787 }
6788 }
6789
6790
6791 func TestQuerySemicolon(t *testing.T) {
6792 t.Cleanup(func() { afterTest(t) })
6793
6794 tests := []struct {
6795 query string
6796 xNoSemicolons string
6797 xWithSemicolons string
6798 expectParseFormErr bool
6799 }{
6800 {"?a=1;x=bad&x=good", "good", "bad", true},
6801 {"?a=1;b=bad&x=good", "good", "good", true},
6802 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6803 {"?a=1;x=good;x=bad", "", "good", true},
6804 }
6805
6806 run(t, func(t *testing.T, mode testMode) {
6807 for _, tt := range tests {
6808 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6809 allowSemicolons := false
6810 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
6811 })
6812 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6813 allowSemicolons, expectParseFormErr := true, false
6814 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
6815 })
6816 }
6817 })
6818 }
6819
6820 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
6821 writeBackX := func(w ResponseWriter, r *Request) {
6822 x := r.URL.Query().Get("x")
6823 if expectParseFormErr {
6824 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6825 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6826 }
6827 } else {
6828 if err := r.ParseForm(); err != nil {
6829 t.Errorf("expected no error from ParseForm, got %v", err)
6830 }
6831 }
6832 if got := r.FormValue("x"); x != got {
6833 t.Errorf("got %q from FormValue, want %q", got, x)
6834 }
6835 fmt.Fprintf(w, "%s", x)
6836 }
6837
6838 h := Handler(HandlerFunc(writeBackX))
6839 if allowSemicolons {
6840 h = AllowQuerySemicolons(h)
6841 }
6842
6843 logBuf := &strings.Builder{}
6844 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
6845 ts.Config.ErrorLog = log.New(logBuf, "", 0)
6846 }).ts
6847
6848 req, _ := NewRequest("GET", ts.URL+query, nil)
6849 res, err := ts.Client().Do(req)
6850 if err != nil {
6851 t.Fatal(err)
6852 }
6853 slurp, _ := io.ReadAll(res.Body)
6854 res.Body.Close()
6855 if got, want := res.StatusCode, 200; got != want {
6856 t.Errorf("Status = %d; want = %d", got, want)
6857 }
6858 if got, want := string(slurp), wantX; got != want {
6859 t.Errorf("Body = %q; want = %q", got, want)
6860 }
6861 }
6862
6863 func TestMaxBytesHandler(t *testing.T) {
6864
6865 defer afterTest(t)
6866
6867 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
6868 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
6869 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
6870 func(t *testing.T) {
6871 run(t, func(t *testing.T, mode testMode) {
6872 testMaxBytesHandler(t, mode, maxSize, requestSize)
6873 }, testNotParallel)
6874 })
6875 }
6876 }
6877 }
6878
6879 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
6880 runTimeSensitiveTest(t, []time.Duration{
6881 1 * time.Millisecond,
6882 5 * time.Millisecond,
6883 10 * time.Millisecond,
6884 50 * time.Millisecond,
6885 100 * time.Millisecond,
6886 500 * time.Millisecond,
6887 time.Second,
6888 5 * time.Second,
6889 }, func(t *testing.T, timeout time.Duration) error {
6890 SetRSTAvoidanceDelay(t, timeout)
6891 t.Logf("set RST avoidance delay to %v", timeout)
6892
6893 var (
6894 handlerN int64
6895 handlerErr error
6896 )
6897 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
6898 var buf bytes.Buffer
6899 handlerN, handlerErr = io.Copy(&buf, r.Body)
6900 io.Copy(w, &buf)
6901 })
6902
6903 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
6904
6905
6906 defer cst.close()
6907 ts := cst.ts
6908 c := ts.Client()
6909
6910 body := strings.Repeat("a", int(requestSize))
6911 var wg sync.WaitGroup
6912 defer wg.Wait()
6913 getBody := func() (io.ReadCloser, error) {
6914 wg.Add(1)
6915 body := &wgReadCloser{
6916 Reader: strings.NewReader(body),
6917 wg: &wg,
6918 }
6919 return body, nil
6920 }
6921 reqBody, _ := getBody()
6922 req, err := NewRequest("POST", ts.URL, reqBody)
6923 if err != nil {
6924 reqBody.Close()
6925 t.Fatal(err)
6926 }
6927 req.ContentLength = int64(len(body))
6928 req.GetBody = getBody
6929 req.Header.Set("Content-Type", "text/plain")
6930
6931 var buf strings.Builder
6932 res, err := c.Do(req)
6933 if err != nil {
6934 return fmt.Errorf("unexpected connection error: %v", err)
6935 } else {
6936 _, err = io.Copy(&buf, res.Body)
6937 res.Body.Close()
6938 if err != nil {
6939 return fmt.Errorf("unexpected read error: %v", err)
6940 }
6941 }
6942
6943
6944
6945
6946 if handlerN > maxSize {
6947 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
6948 }
6949 if requestSize > maxSize && handlerErr == nil {
6950 t.Error("expected error on handler side; got nil")
6951 }
6952 if requestSize <= maxSize {
6953 if handlerErr != nil {
6954 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
6955 }
6956 if handlerN != requestSize {
6957 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
6958 }
6959 }
6960 if buf.Len() != int(handlerN) {
6961 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
6962 }
6963
6964 return nil
6965 })
6966 }
6967
6968 func TestEarlyHints(t *testing.T) {
6969 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
6970 h := w.Header()
6971 h.Add("Link", "</style.css>; rel=preload; as=style")
6972 h.Add("Link", "</script.js>; rel=preload; as=script")
6973 w.WriteHeader(StatusEarlyHints)
6974
6975 h.Add("Link", "</foo.js>; rel=preload; as=script")
6976 w.WriteHeader(StatusEarlyHints)
6977
6978 w.Write([]byte("stuff"))
6979 }))
6980
6981 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
6982 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
6983 if !strings.Contains(got, expected) {
6984 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
6985 }
6986 }
6987 func TestProcessing(t *testing.T) {
6988 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
6989 w.WriteHeader(StatusProcessing)
6990 w.Write([]byte("stuff"))
6991 }))
6992
6993 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
6994 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
6995 if !strings.Contains(got, expected) {
6996 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
6997 }
6998 }
6999
7000 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
7001 func testParseFormCleanup(t *testing.T, mode testMode) {
7002 if mode == http2Mode {
7003 t.Skip("https://go.dev/issue/20253")
7004 }
7005
7006 const maxMemory = 1024
7007 const key = "file"
7008
7009 if runtime.GOOS == "windows" {
7010
7011 t.Skip("https://go.dev/issue/25965")
7012 }
7013
7014 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7015 r.ParseMultipartForm(maxMemory)
7016 f, _, err := r.FormFile(key)
7017 if err != nil {
7018 t.Errorf("r.FormFile(%q) = %v", key, err)
7019 return
7020 }
7021 of, ok := f.(*os.File)
7022 if !ok {
7023 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7024 return
7025 }
7026 w.Write([]byte(of.Name()))
7027 }))
7028
7029 fBuf := new(bytes.Buffer)
7030 mw := multipart.NewWriter(fBuf)
7031 mf, err := mw.CreateFormFile(key, "myfile.txt")
7032 if err != nil {
7033 t.Fatal(err)
7034 }
7035 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7036 t.Fatal(err)
7037 }
7038 if err := mw.Close(); err != nil {
7039 t.Fatal(err)
7040 }
7041 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7042 if err != nil {
7043 t.Fatal(err)
7044 }
7045 req.Header.Set("Content-Type", mw.FormDataContentType())
7046 res, err := cst.c.Do(req)
7047 if err != nil {
7048 t.Fatal(err)
7049 }
7050 defer res.Body.Close()
7051 fname, err := io.ReadAll(res.Body)
7052 if err != nil {
7053 t.Fatal(err)
7054 }
7055 cst.close()
7056 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7057 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7058 }
7059 }
7060
7061 func TestHeadBody(t *testing.T) {
7062 const identityMode = false
7063 const chunkedMode = true
7064 run(t, func(t *testing.T, mode testMode) {
7065 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7066 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7067 })
7068 }
7069
7070 func TestGetBody(t *testing.T) {
7071 const identityMode = false
7072 const chunkedMode = true
7073 run(t, func(t *testing.T, mode testMode) {
7074 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7075 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7076 })
7077 }
7078
7079 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7080 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7081 b, err := io.ReadAll(r.Body)
7082 if err != nil {
7083 t.Errorf("server reading body: %v", err)
7084 return
7085 }
7086 w.Header().Set("X-Request-Body", string(b))
7087 w.Header().Set("Content-Length", "0")
7088 }))
7089 defer cst.close()
7090 for _, reqBody := range []string{
7091 "",
7092 "",
7093 "request_body",
7094 "",
7095 } {
7096 var bodyReader io.Reader
7097 if reqBody != "" {
7098 bodyReader = strings.NewReader(reqBody)
7099 if chunked {
7100 bodyReader = bufio.NewReader(bodyReader)
7101 }
7102 }
7103 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7104 if err != nil {
7105 t.Fatal(err)
7106 }
7107 res, err := cst.c.Do(req)
7108 if err != nil {
7109 t.Fatal(err)
7110 }
7111 res.Body.Close()
7112 if got, want := res.StatusCode, 200; got != want {
7113 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7114 }
7115 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7116 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7117 }
7118 }
7119 }
7120
7121
7122
7123 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7124 func testDisableContentLength(t *testing.T, mode testMode) {
7125 if mode == http2Mode {
7126 t.Skip("skipping until h2_bundle.go is updated; see https://go-review.googlesource.com/c/net/+/471535")
7127 }
7128
7129 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7130 w.Header()["Content-Length"] = nil
7131 fmt.Fprintf(w, "OK")
7132 }))
7133
7134 res, err := noCL.c.Get(noCL.ts.URL)
7135 if err != nil {
7136 t.Fatal(err)
7137 }
7138 if got, haveCL := res.Header["Content-Length"]; haveCL {
7139 t.Errorf("Unexpected Content-Length: %q", got)
7140 }
7141 if err := res.Body.Close(); err != nil {
7142 t.Fatal(err)
7143 }
7144
7145 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7146 fmt.Fprintf(w, "OK")
7147 }))
7148
7149 res, err = withCL.c.Get(withCL.ts.URL)
7150 if err != nil {
7151 t.Fatal(err)
7152 }
7153 if got := res.Header.Get("Content-Length"); got != "2" {
7154 t.Errorf("Content-Length: %q; want 2", got)
7155 }
7156 if err := res.Body.Close(); err != nil {
7157 t.Fatal(err)
7158 }
7159 }
7160
7161 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7162 func testErrorContentLength(t *testing.T, mode testMode) {
7163 const errorBody = "an error occurred"
7164 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7165 w.Header().Set("Content-Length", "1000")
7166 Error(w, errorBody, 400)
7167 }))
7168 res, err := cst.c.Get(cst.ts.URL)
7169 if err != nil {
7170 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7171 }
7172 defer res.Body.Close()
7173 body, err := io.ReadAll(res.Body)
7174 if err != nil {
7175 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7176 }
7177 if string(body) != errorBody+"\n" {
7178 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7179 }
7180 }
7181
7182 func TestError(t *testing.T) {
7183 w := httptest.NewRecorder()
7184 w.Header().Set("Content-Length", "1")
7185 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7186 w.Header().Set("Other", "foo")
7187 Error(w, "oops", 432)
7188
7189 h := w.Header()
7190 for _, hdr := range []string{"Content-Length"} {
7191 if v, ok := h[hdr]; ok {
7192 t.Errorf("%s: %q, want not present", hdr, v)
7193 }
7194 }
7195 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7196 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7197 }
7198 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7199 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7200 }
7201 }
7202
7203 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7204 run(t, testServerReadAfterWriteHeader100Continue)
7205 }
7206 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7207 t.Skip("https://go.dev/issue/67555")
7208 body := []byte("body")
7209 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7210 w.WriteHeader(200)
7211 NewResponseController(w).Flush()
7212 io.ReadAll(r.Body)
7213 w.Write(body)
7214 }), func(tr *Transport) {
7215 tr.ExpectContinueTimeout = 24 * time.Hour
7216 })
7217
7218 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7219 req.Header.Set("Expect", "100-continue")
7220 res, err := cst.c.Do(req)
7221 if err != nil {
7222 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7223 }
7224 defer res.Body.Close()
7225 got, err := io.ReadAll(res.Body)
7226 if err != nil {
7227 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7228 }
7229 if !bytes.Equal(got, body) {
7230 t.Fatalf("response body = %q, want %q", got, body)
7231 }
7232 }
7233
7234 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7235 run(t, testServerReadAfterHandlerDone100Continue)
7236 }
7237 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7238 t.Skip("https://go.dev/issue/67555")
7239 readyc := make(chan struct{})
7240 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7241 go func() {
7242 <-readyc
7243 io.ReadAll(r.Body)
7244 <-readyc
7245 }()
7246 }), func(tr *Transport) {
7247 tr.ExpectContinueTimeout = 24 * time.Hour
7248 })
7249
7250 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7251 req.Header.Set("Expect", "100-continue")
7252 res, err := cst.c.Do(req)
7253 if err != nil {
7254 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7255 }
7256 res.Body.Close()
7257 readyc <- struct{}{}
7258 readyc <- struct{}{}
7259 }
7260
7261 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7262 run(t, testServerReadAfterHandlerAbort100Continue)
7263 }
7264 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7265 t.Skip("https://go.dev/issue/67555")
7266 readyc := make(chan struct{})
7267 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7268 go func() {
7269 <-readyc
7270 io.ReadAll(r.Body)
7271 <-readyc
7272 }()
7273 panic(ErrAbortHandler)
7274 }), func(tr *Transport) {
7275 tr.ExpectContinueTimeout = 24 * time.Hour
7276 })
7277
7278 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7279 req.Header.Set("Expect", "100-continue")
7280 res, err := cst.c.Do(req)
7281 if err == nil {
7282 res.Body.Close()
7283 }
7284 readyc <- struct{}{}
7285 readyc <- struct{}{}
7286 }
7287
View as plain text