Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 "crypto/tls"
16 "encoding/json"
17 "errors"
18 "fmt"
19 "internal/testenv"
20 "io"
21 "log"
22 "math/rand"
23 "mime/multipart"
24 "net"
25 . "net/http"
26 "net/http/httptest"
27 "net/http/httptrace"
28 "net/http/httputil"
29 "net/http/internal"
30 "net/http/internal/testcert"
31 "net/url"
32 "os"
33 "path/filepath"
34 "reflect"
35 "regexp"
36 "runtime"
37 "slices"
38 "strconv"
39 "strings"
40 "sync"
41 "sync/atomic"
42 "syscall"
43 "testing"
44 "time"
45 )
46
47 type dummyAddr string
48 type oneConnListener struct {
49 conn net.Conn
50 }
51
52 func (l *oneConnListener) Accept() (c net.Conn, err error) {
53 c = l.conn
54 if c == nil {
55 err = io.EOF
56 return
57 }
58 err = nil
59 l.conn = nil
60 return
61 }
62
63 func (l *oneConnListener) Close() error {
64 return nil
65 }
66
67 func (l *oneConnListener) Addr() net.Addr {
68 return dummyAddr("test-address")
69 }
70
71 func (a dummyAddr) Network() string {
72 return string(a)
73 }
74
75 func (a dummyAddr) String() string {
76 return string(a)
77 }
78
79 type noopConn struct{}
80
81 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
82 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
83 func (noopConn) SetDeadline(t time.Time) error { return nil }
84 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
85 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
86
87 type rwTestConn struct {
88 io.Reader
89 io.Writer
90 noopConn
91
92 closeFunc func() error
93 closec chan bool
94 }
95
96 func (c *rwTestConn) Close() error {
97 if c.closeFunc != nil {
98 return c.closeFunc()
99 }
100 select {
101 case c.closec <- true:
102 default:
103 }
104 return nil
105 }
106
107 type testConn struct {
108 readMu sync.Mutex
109 readBuf bytes.Buffer
110 writeBuf bytes.Buffer
111 closec chan bool
112 noopConn
113 }
114
115 func newTestConn() *testConn {
116 return &testConn{closec: make(chan bool, 1)}
117 }
118
119 func (c *testConn) Read(b []byte) (int, error) {
120 c.readMu.Lock()
121 defer c.readMu.Unlock()
122 return c.readBuf.Read(b)
123 }
124
125 func (c *testConn) Write(b []byte) (int, error) {
126 return c.writeBuf.Write(b)
127 }
128
129 func (c *testConn) Close() error {
130 select {
131 case c.closec <- true:
132 default:
133 }
134 return nil
135 }
136
137
138
139 func reqBytes(req string) []byte {
140 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
141 }
142
143 type handlerTest struct {
144 logbuf bytes.Buffer
145 handler Handler
146 }
147
148 func newHandlerTest(h Handler) handlerTest {
149 return handlerTest{handler: h}
150 }
151
152 func (ht *handlerTest) rawResponse(req string) string {
153 reqb := reqBytes(req)
154 var output strings.Builder
155 conn := &rwTestConn{
156 Reader: bytes.NewReader(reqb),
157 Writer: &output,
158 closec: make(chan bool, 1),
159 }
160 ln := &oneConnListener{conn: conn}
161 srv := &Server{
162 ErrorLog: log.New(&ht.logbuf, "", 0),
163 Handler: ht.handler,
164 }
165 go srv.Serve(ln)
166 <-conn.closec
167 return output.String()
168 }
169
170 func TestConsumingBodyOnNextConn(t *testing.T) {
171 t.Parallel()
172 defer afterTest(t)
173 conn := new(testConn)
174 for i := 0; i < 2; i++ {
175 conn.readBuf.Write([]byte(
176 "POST / HTTP/1.1\r\n" +
177 "Host: test\r\n" +
178 "Content-Length: 11\r\n" +
179 "\r\n" +
180 "foo=1&bar=1"))
181 }
182
183 reqNum := 0
184 ch := make(chan *Request)
185 servech := make(chan error)
186 listener := &oneConnListener{conn}
187 handler := func(res ResponseWriter, req *Request) {
188 reqNum++
189 ch <- req
190 }
191
192 go func() {
193 servech <- Serve(listener, HandlerFunc(handler))
194 }()
195
196 var req *Request
197 req = <-ch
198 if req == nil {
199 t.Fatal("Got nil first request.")
200 }
201 if req.Method != "POST" {
202 t.Errorf("For request #1's method, got %q; expected %q",
203 req.Method, "POST")
204 }
205
206 req = <-ch
207 if req == nil {
208 t.Fatal("Got nil first request.")
209 }
210 if req.Method != "POST" {
211 t.Errorf("For request #2's method, got %q; expected %q",
212 req.Method, "POST")
213 }
214
215 if serveerr := <-servech; serveerr != io.EOF {
216 t.Errorf("Serve returned %q; expected EOF", serveerr)
217 }
218 }
219
220 type stringHandler string
221
222 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
223 w.Header().Set("Result", string(s))
224 }
225
226 var handlers = []struct {
227 pattern string
228 msg string
229 }{
230 {"/", "Default"},
231 {"/someDir/", "someDir"},
232 {"/#/", "hash"},
233 {"someHost.com/someDir/", "someHost.com/someDir"},
234 }
235
236 var vtests = []struct {
237 url string
238 expected string
239 }{
240 {"http://localhost/someDir/apage", "someDir"},
241 {"http://localhost/%23/apage", "hash"},
242 {"http://localhost/otherDir/apage", "Default"},
243 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
244 {"http://otherHost.com/someDir/apage", "someDir"},
245 {"http://otherHost.com/aDir/apage", "Default"},
246
247 {"http://localhost/someDir", "/someDir/"},
248 {"http://localhost/%23", "/%23/"},
249 {"http://someHost.com/someDir", "/someDir/"},
250 }
251
252 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
253 func testHostHandlers(t *testing.T, mode testMode) {
254 mux := NewServeMux()
255 for _, h := range handlers {
256 mux.Handle(h.pattern, stringHandler(h.msg))
257 }
258 ts := newClientServerTest(t, mode, mux).ts
259
260 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
261 if err != nil {
262 t.Fatal(err)
263 }
264 defer conn.Close()
265 cc := httputil.NewClientConn(conn, nil)
266 for _, vt := range vtests {
267 var r *Response
268 var req Request
269 if req.URL, err = url.Parse(vt.url); err != nil {
270 t.Errorf("cannot parse url: %v", err)
271 continue
272 }
273 if err := cc.Write(&req); err != nil {
274 t.Errorf("writing request: %v", err)
275 continue
276 }
277 r, err := cc.Read(&req)
278 if err != nil {
279 t.Errorf("reading response: %v", err)
280 continue
281 }
282 switch r.StatusCode {
283 case StatusOK:
284 s := r.Header.Get("Result")
285 if s != vt.expected {
286 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
287 }
288 case StatusMovedPermanently:
289 s := r.Header.Get("Location")
290 if s != vt.expected {
291 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
292 }
293 default:
294 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
295 }
296 }
297 }
298
299 var serveMuxRegister = []struct {
300 pattern string
301 h Handler
302 }{
303 {"/dir/", serve(200)},
304 {"/search", serve(201)},
305 {"codesearch.google.com/search", serve(202)},
306 {"codesearch.google.com/", serve(203)},
307 {"example.com/", HandlerFunc(checkQueryStringHandler)},
308 }
309
310
311 func serve(code int) HandlerFunc {
312 return func(w ResponseWriter, r *Request) {
313 w.WriteHeader(code)
314 }
315 }
316
317
318
319
320 func checkQueryStringHandler(w ResponseWriter, r *Request) {
321 u := *r.URL
322 u.Scheme = "http"
323 u.Host = r.Host
324 u.RawQuery = ""
325 if "http://"+r.URL.RawQuery == u.String() {
326 w.WriteHeader(200)
327 } else {
328 w.WriteHeader(500)
329 }
330 }
331
332 var serveMuxTests = []struct {
333 method string
334 host string
335 path string
336 code int
337 pattern string
338 }{
339 {"GET", "google.com", "/", 404, ""},
340 {"GET", "google.com", "/dir", 301, "/dir/"},
341 {"GET", "google.com", "/dir/", 200, "/dir/"},
342 {"GET", "google.com", "/dir/file", 200, "/dir/"},
343 {"GET", "google.com", "/search", 201, "/search"},
344 {"GET", "google.com", "/search/", 404, ""},
345 {"GET", "google.com", "/search/foo", 404, ""},
346 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
347 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
348 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
349 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
350 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
351 {"GET", "images.google.com", "/search", 201, "/search"},
352 {"GET", "images.google.com", "/search/", 404, ""},
353 {"GET", "images.google.com", "/search/foo", 404, ""},
354 {"GET", "google.com", "/../search", 301, "/search"},
355 {"GET", "google.com", "/dir/..", 301, ""},
356 {"GET", "google.com", "/dir/..", 301, ""},
357 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
358
359
360
361 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
362 {"CONNECT", "google.com", "/../search", 404, ""},
363 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
364 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
365 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
366 }
367
368 func TestServeMuxHandler(t *testing.T) {
369 setParallel(t)
370 mux := NewServeMux()
371 for _, e := range serveMuxRegister {
372 mux.Handle(e.pattern, e.h)
373 }
374
375 for _, tt := range serveMuxTests {
376 r := &Request{
377 Method: tt.method,
378 Host: tt.host,
379 URL: &url.URL{
380 Path: tt.path,
381 },
382 }
383 h, pattern := mux.Handler(r)
384 rr := httptest.NewRecorder()
385 h.ServeHTTP(rr, r)
386 if pattern != tt.pattern || rr.Code != tt.code {
387 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
388 }
389 }
390 }
391
392
393 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
394 setParallel(t)
395 defer func() {
396 if err := recover(); err == nil {
397 t.Error("expected call to mux.HandleFunc to panic")
398 }
399 }()
400 mux := NewServeMux()
401 mux.HandleFunc("/", nil)
402 }
403
404 var serveMuxTests2 = []struct {
405 method string
406 host string
407 url string
408 code int
409 redirOk bool
410 }{
411 {"GET", "google.com", "/", 404, false},
412 {"GET", "example.com", "/test/?example.com/test/", 200, false},
413 {"GET", "example.com", "test/?example.com/test/", 200, true},
414 }
415
416
417
418 func TestServeMuxHandlerRedirects(t *testing.T) {
419 setParallel(t)
420 mux := NewServeMux()
421 for _, e := range serveMuxRegister {
422 mux.Handle(e.pattern, e.h)
423 }
424
425 for _, tt := range serveMuxTests2 {
426 tries := 1
427 turl := tt.url
428 for {
429 u, e := url.Parse(turl)
430 if e != nil {
431 t.Fatal(e)
432 }
433 r := &Request{
434 Method: tt.method,
435 Host: tt.host,
436 URL: u,
437 }
438 h, _ := mux.Handler(r)
439 rr := httptest.NewRecorder()
440 h.ServeHTTP(rr, r)
441 if rr.Code != 301 {
442 if rr.Code != tt.code {
443 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
444 }
445 break
446 }
447 if !tt.redirOk {
448 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
449 break
450 }
451 turl = rr.HeaderMap.Get("Location")
452 tries--
453 }
454 if tries < 0 {
455 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
456 }
457 }
458 }
459
460
461 func TestMuxRedirectLeadingSlashes(t *testing.T) {
462 setParallel(t)
463 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
464 for _, path := range paths {
465 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
466 if err != nil {
467 t.Errorf("%s", err)
468 }
469 mux := NewServeMux()
470 resp := httptest.NewRecorder()
471
472 mux.ServeHTTP(resp, req)
473
474 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
475 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
476 return
477 }
478
479 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
480 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
481 return
482 }
483 }
484 }
485
486
487
488
489
490 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
491 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
492 }
493 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
494 writeBackQuery := func(w ResponseWriter, r *Request) {
495 fmt.Fprintf(w, "%s", r.URL.RawQuery)
496 }
497
498 mux := NewServeMux()
499 mux.HandleFunc("/testOne", writeBackQuery)
500 mux.HandleFunc("/testTwo/", writeBackQuery)
501 mux.HandleFunc("/testThree", writeBackQuery)
502 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
503 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
504 })
505
506 ts := newClientServerTest(t, mode, mux).ts
507
508 tests := [...]struct {
509 path string
510 method string
511 want string
512 statusOk bool
513 }{
514 0: {"/testOne?this=that", "GET", "this=that", true},
515 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
516 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
517 3: {"/testTwo?", "GET", "", true},
518 4: {"/testThree?foo", "GET", "foo", true},
519 5: {"/testThree/?foo", "GET", "foo:bar", true},
520 6: {"/testThree?foo", "CONNECT", "foo", true},
521 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
522
523
524 8: {"/testOne/foo/..?foo", "GET", "foo", true},
525 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
526 }
527
528 for i, tt := range tests {
529 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
530 res, err := ts.Client().Do(req)
531 if err != nil {
532 continue
533 }
534 slurp, _ := io.ReadAll(res.Body)
535 res.Body.Close()
536 if !tt.statusOk {
537 if got, want := res.StatusCode, 404; got != want {
538 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
539 }
540 }
541 if got, want := string(slurp), tt.want; got != want {
542 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
543 }
544 }
545 }
546
547 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
548 setParallel(t)
549
550 mux := NewServeMux()
551 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
552 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
553 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
554 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
555 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
556 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
557
558 tests := []struct {
559 method string
560 url string
561 code int
562 loc string
563 want string
564 }{
565 {"GET", "http://example.com/", 404, "", ""},
566 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
567 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
568 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
569 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
570 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
571 {"CONNECT", "http://example.com/", 404, "", ""},
572 {"CONNECT", "http://example.com:3000/", 404, "", ""},
573 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
574 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
575 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
576 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
577 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
578 }
579
580 for i, tt := range tests {
581 req, _ := NewRequest(tt.method, tt.url, nil)
582 w := httptest.NewRecorder()
583 mux.ServeHTTP(w, req)
584
585 if got, want := w.Code, tt.code; got != want {
586 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
587 }
588
589 if tt.code == 301 {
590 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
591 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
592 }
593 } else {
594 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
595 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
596 }
597 }
598 }
599 }
600
601
602
603
604 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
605 mux := NewServeMux()
606 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
607 fmt.Fprintln(w, "ok")
608 })
609 w := httptest.NewRecorder()
610 req, _ := NewRequest("GET", "/", nil)
611 mux.ServeHTTP(w, req)
612 if g, w := w.Code, 404; g != w {
613 t.Errorf("got %d, want %d", g, w)
614 }
615 }
616
617
618
619
620 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
621 mux := NewServeMux()
622 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
623 fmt.Fprintln(w, "ok")
624 })
625 w := httptest.NewRecorder()
626 req, _ := NewRequest("GET", "/", nil)
627 mux.ServeHTTP(w, req)
628 if g, w := w.Code, 404; g != w {
629 t.Errorf("got %d, want %d", g, w)
630 }
631 }
632
633 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
634 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
635 mux := NewServeMux()
636 newClientServerTest(t, mode, mux)
637 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
638 }
639
640 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
641 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
642 func benchmarkServeMux(b *testing.B, runHandler bool) {
643 type test struct {
644 path string
645 code int
646 req *Request
647 }
648
649
650 var tests []test
651 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
652 for _, e := range endpoints {
653 for i := 200; i < 230; i++ {
654 p := fmt.Sprintf("/%s/%d/", e, i)
655 tests = append(tests, test{
656 path: p,
657 code: i,
658 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
659 })
660 }
661 }
662 mux := NewServeMux()
663 for _, tt := range tests {
664 mux.Handle(tt.path, serve(tt.code))
665 }
666
667 rw := httptest.NewRecorder()
668 b.ReportAllocs()
669 b.ResetTimer()
670 for i := 0; i < b.N; i++ {
671 for _, tt := range tests {
672 *rw = httptest.ResponseRecorder{}
673 h, pattern := mux.Handler(tt.req)
674 if runHandler {
675 h.ServeHTTP(rw, tt.req)
676 if pattern != tt.path || rw.Code != tt.code {
677 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
678 }
679 }
680 }
681 }
682 }
683
684 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
685 func testServerTimeouts(t *testing.T, mode testMode) {
686 runTimeSensitiveTest(t, []time.Duration{
687 10 * time.Millisecond,
688 50 * time.Millisecond,
689 100 * time.Millisecond,
690 500 * time.Millisecond,
691 1 * time.Second,
692 }, func(t *testing.T, timeout time.Duration) error {
693 return testServerTimeoutsWithTimeout(t, timeout, mode)
694 })
695 }
696
697 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
698 var reqNum atomic.Int32
699 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
700 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
701 }), func(ts *httptest.Server) {
702 ts.Config.ReadTimeout = timeout
703 ts.Config.WriteTimeout = timeout
704 })
705 defer cst.close()
706 ts := cst.ts
707
708
709 c := ts.Client()
710 r, err := c.Get(ts.URL)
711 if err != nil {
712 return fmt.Errorf("http Get #1: %v", err)
713 }
714 got, err := io.ReadAll(r.Body)
715 expected := "req=1"
716 if string(got) != expected || err != nil {
717 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
718 string(got), err, expected)
719 }
720
721
722 t1 := time.Now()
723 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
724 if err != nil {
725 return fmt.Errorf("Dial: %v", err)
726 }
727 buf := make([]byte, 1)
728 n, err := conn.Read(buf)
729 conn.Close()
730 latency := time.Since(t1)
731 if n != 0 || err != io.EOF {
732 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
733 }
734 minLatency := timeout / 5 * 4
735 if latency < minLatency {
736 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
737 }
738
739
740
741
742 r, err = c.Get(ts.URL)
743 if err != nil {
744 return fmt.Errorf("http Get #2: %v", err)
745 }
746 got, err = io.ReadAll(r.Body)
747 r.Body.Close()
748 expected = "req=2"
749 if string(got) != expected || err != nil {
750 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
751 }
752
753 if !testing.Short() {
754 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
755 if err != nil {
756 return fmt.Errorf("long Dial: %v", err)
757 }
758 defer conn.Close()
759 go io.Copy(io.Discard, conn)
760 for i := 0; i < 5; i++ {
761 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
762 if err != nil {
763 return fmt.Errorf("on write %d: %v", i, err)
764 }
765 time.Sleep(timeout / 2)
766 }
767 }
768 return nil
769 }
770
771 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
772 func testServerReadTimeout(t *testing.T, mode testMode) {
773 respBody := "response body"
774 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
775 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
776 _, err := io.Copy(io.Discard, req.Body)
777 if !errors.Is(err, os.ErrDeadlineExceeded) {
778 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
779 }
780 res.Write([]byte(respBody))
781 }), func(ts *httptest.Server) {
782 ts.Config.ReadHeaderTimeout = -1
783 ts.Config.ReadTimeout = timeout
784 t.Logf("Server.Config.ReadTimeout = %v", timeout)
785 })
786
787 var retries atomic.Int32
788 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
789 if retries.Add(1) != 1 {
790 return nil, errors.New("too many retries")
791 }
792 return nil, nil
793 }
794
795 pr, pw := io.Pipe()
796 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
797 if err != nil {
798 t.Logf("Get error, retrying: %v", err)
799 cst.close()
800 continue
801 }
802 defer res.Body.Close()
803 got, err := io.ReadAll(res.Body)
804 if string(got) != respBody || err != nil {
805 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
806 }
807 pw.Close()
808 break
809 }
810 }
811
812 func TestServerNoReadTimeout(t *testing.T) { run(t, testServerNoReadTimeout) }
813 func testServerNoReadTimeout(t *testing.T, mode testMode) {
814 reqBody := "Hello, Gophers!"
815 resBody := "Hi, Gophers!"
816 for _, timeout := range []time.Duration{0, -1} {
817 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
818 ctl := NewResponseController(res)
819 ctl.EnableFullDuplex()
820 res.WriteHeader(StatusOK)
821
822
823 if err := ctl.Flush(); err != nil {
824 t.Errorf("server flush response: %v", err)
825 return
826 }
827 got, err := io.ReadAll(req.Body)
828 if string(got) != reqBody || err != nil {
829 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
830 }
831 res.Write([]byte(resBody))
832 }), func(ts *httptest.Server) {
833 ts.Config.ReadTimeout = timeout
834 t.Logf("Server.Config.ReadTimeout = %d", timeout)
835 })
836
837 pr, pw := io.Pipe()
838 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
839 if err != nil {
840 t.Fatal(err)
841 }
842 defer res.Body.Close()
843
844
845 time.Sleep(10 * time.Millisecond)
846 pw.Write([]byte(reqBody))
847 pw.Close()
848
849 got, err := io.ReadAll(res.Body)
850 if string(got) != resBody || err != nil {
851 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
852 }
853 }
854 }
855
856 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
857 func testServerWriteTimeout(t *testing.T, mode testMode) {
858 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
859 errc := make(chan error, 2)
860 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
861 errc <- nil
862 _, err := io.Copy(res, neverEnding('a'))
863 errc <- err
864 }), func(ts *httptest.Server) {
865 ts.Config.WriteTimeout = timeout
866 t.Logf("Server.Config.WriteTimeout = %v", timeout)
867 })
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886 var retries atomic.Int32
887 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
888 if retries.Add(1) != 1 {
889 return nil, errors.New("too many retries")
890 }
891 return nil, nil
892 }
893
894 res, err := cst.c.Get(cst.ts.URL)
895 if err != nil {
896
897 t.Logf("Get error, retrying: %v", err)
898 cst.close()
899 continue
900 }
901 defer res.Body.Close()
902 _, err = io.Copy(io.Discard, res.Body)
903 if err == nil {
904 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
905 }
906 select {
907 case <-errc:
908 err = <-errc
909 if !errors.Is(err, os.ErrDeadlineExceeded) {
910 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
911 }
912 return
913 default:
914
915 t.Logf("handler didn't run, retrying")
916 cst.close()
917 }
918 }
919 }
920
921 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
922 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
923 for _, timeout := range []time.Duration{0, -1} {
924 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
925 _, err := io.Copy(res, neverEnding('a'))
926 t.Logf("server write response: %v", err)
927 }), func(ts *httptest.Server) {
928 ts.Config.WriteTimeout = timeout
929 t.Logf("Server.Config.WriteTimeout = %d", timeout)
930 })
931
932 res, err := cst.c.Get(cst.ts.URL)
933 if err != nil {
934 t.Fatal(err)
935 }
936 defer res.Body.Close()
937 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
938 if n != 1<<20 || err != nil {
939 t.Errorf("client read response body: %d, %v", n, err)
940 }
941
942
943 res.Body.Close()
944 cst.ts.Config.Shutdown(context.Background())
945 }
946 }
947
948
949 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
950 run(t, testWriteDeadlineExtendedOnNewRequest)
951 }
952 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
953 if testing.Short() {
954 t.Skip("skipping in short mode")
955 }
956 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
957 func(ts *httptest.Server) {
958 ts.Config.WriteTimeout = 250 * time.Millisecond
959 },
960 ).ts
961
962 c := ts.Client()
963
964 for i := 1; i <= 3; i++ {
965 req, err := NewRequest("GET", ts.URL, nil)
966 if err != nil {
967 t.Fatal(err)
968 }
969
970 r, err := c.Do(req)
971 if err != nil {
972 t.Fatalf("http2 Get #%d: %v", i, err)
973 }
974 r.Body.Close()
975 time.Sleep(ts.Config.WriteTimeout / 2)
976 }
977 }
978
979
980
981 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
982 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
983 for i, timeout := range tries {
984 err := testFunc(timeout)
985 if err == nil {
986 return
987 }
988 t.Logf("failed at %v: %v", timeout, err)
989 if i != len(tries)-1 {
990 t.Logf("retrying at %v ...", tries[i+1])
991 }
992 }
993 t.Fatal("all attempts failed")
994 }
995
996
997 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
998 if testing.Short() {
999 t.Skip("skipping in short mode")
1000 }
1001 setParallel(t)
1002 run(t, func(t *testing.T, mode testMode) {
1003 tryTimeouts(t, func(timeout time.Duration) error {
1004 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1005 })
1006 })
1007 }
1008
1009 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1010 firstRequest := make(chan bool, 1)
1011 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1012 select {
1013 case firstRequest <- true:
1014
1015 default:
1016
1017 time.Sleep(timeout)
1018 }
1019 }), func(ts *httptest.Server) {
1020 ts.Config.WriteTimeout = timeout / 2
1021 })
1022 defer cst.close()
1023 ts := cst.ts
1024
1025 c := ts.Client()
1026
1027 req, err := NewRequest("GET", ts.URL, nil)
1028 if err != nil {
1029 return fmt.Errorf("NewRequest: %v", err)
1030 }
1031 r, err := c.Do(req)
1032 if err != nil {
1033 return fmt.Errorf("Get #1: %v", err)
1034 }
1035 r.Body.Close()
1036
1037 req, err = NewRequest("GET", ts.URL, nil)
1038 if err != nil {
1039 return fmt.Errorf("NewRequest: %v", err)
1040 }
1041 r, err = c.Do(req)
1042 if err == nil {
1043 r.Body.Close()
1044 return fmt.Errorf("Get #2 expected error, got nil")
1045 }
1046 if mode == http2Mode {
1047 expected := "stream ID 3; INTERNAL_ERROR"
1048 if !strings.Contains(err.Error(), expected) {
1049 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1050 }
1051 }
1052 return nil
1053 }
1054
1055
1056 func TestNoWriteDeadline(t *testing.T) {
1057 if testing.Short() {
1058 t.Skip("skipping in short mode")
1059 }
1060 setParallel(t)
1061 defer afterTest(t)
1062 run(t, func(t *testing.T, mode testMode) {
1063 tryTimeouts(t, func(timeout time.Duration) error {
1064 return testNoWriteDeadline(t, mode, timeout)
1065 })
1066 })
1067 }
1068
1069 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1070 firstRequest := make(chan bool, 1)
1071 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1072 select {
1073 case firstRequest <- true:
1074
1075 default:
1076
1077 time.Sleep(timeout)
1078 }
1079 }))
1080 defer cst.close()
1081 ts := cst.ts
1082
1083 c := ts.Client()
1084
1085 for i := 0; i < 2; i++ {
1086 req, err := NewRequest("GET", ts.URL, nil)
1087 if err != nil {
1088 return fmt.Errorf("NewRequest: %v", err)
1089 }
1090 r, err := c.Do(req)
1091 if err != nil {
1092 return fmt.Errorf("Get #%d: %v", i, err)
1093 }
1094 r.Body.Close()
1095 }
1096 return nil
1097 }
1098
1099
1100
1101
1102 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1103 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1104 var (
1105 mu sync.RWMutex
1106 conn net.Conn
1107 )
1108 var afterTimeoutErrc = make(chan error, 1)
1109 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1110 buf := make([]byte, 512<<10)
1111 _, err := w.Write(buf)
1112 if err != nil {
1113 t.Errorf("handler Write error: %v", err)
1114 return
1115 }
1116 mu.RLock()
1117 defer mu.RUnlock()
1118 if conn == nil {
1119 t.Error("no established connection found")
1120 return
1121 }
1122 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1123 _, err = w.Write(buf)
1124 afterTimeoutErrc <- err
1125 }), func(ts *httptest.Server) {
1126 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1127 }).ts
1128
1129 c := ts.Client()
1130
1131 err := func() error {
1132 res, err := c.Get(ts.URL)
1133 if err != nil {
1134 return err
1135 }
1136 _, err = io.Copy(io.Discard, res.Body)
1137 res.Body.Close()
1138 return err
1139 }()
1140 if err == nil {
1141 t.Errorf("expected an error copying body from Get request")
1142 }
1143
1144 if err := <-afterTimeoutErrc; err == nil {
1145 t.Error("expected write error after timeout")
1146 }
1147 }
1148
1149
1150 type trackLastConnListener struct {
1151 net.Listener
1152
1153 mu *sync.RWMutex
1154 last *net.Conn
1155 }
1156
1157 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1158 c, err = l.Listener.Accept()
1159 if err == nil {
1160 l.mu.Lock()
1161 *l.last = c
1162 l.mu.Unlock()
1163 }
1164 return
1165 }
1166
1167
1168 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1169 func testIdentityResponse(t *testing.T, mode testMode) {
1170 if mode == http2Mode {
1171 t.Skip("https://go.dev/issue/56019")
1172 }
1173
1174 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1175 rw.Header().Set("Content-Length", "3")
1176 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1177 switch {
1178 case req.FormValue("overwrite") == "1":
1179 _, err := rw.Write([]byte("foo TOO LONG"))
1180 if err != ErrContentLength {
1181 t.Errorf("expected ErrContentLength; got %v", err)
1182 }
1183 case req.FormValue("underwrite") == "1":
1184 rw.Header().Set("Content-Length", "500")
1185 rw.Write([]byte("too short"))
1186 default:
1187 rw.Write([]byte("foo"))
1188 }
1189 })
1190
1191 ts := newClientServerTest(t, mode, handler).ts
1192 c := ts.Client()
1193
1194
1195
1196
1197
1198 for _, te := range []string{"", "identity"} {
1199 url := ts.URL + "/?te=" + te
1200 res, err := c.Get(url)
1201 if err != nil {
1202 t.Fatalf("error with Get of %s: %v", url, err)
1203 }
1204 if cl, expected := res.ContentLength, int64(3); cl != expected {
1205 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1206 }
1207 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1208 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1209 }
1210 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1211 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1212 url, expected, tl, res.TransferEncoding)
1213 }
1214 res.Body.Close()
1215 }
1216
1217
1218 url := ts.URL + "/?overwrite=1"
1219 res, err := c.Get(url)
1220 if err != nil {
1221 t.Fatalf("error with Get of %s: %v", url, err)
1222 }
1223 res.Body.Close()
1224
1225 if mode != http1Mode {
1226 return
1227 }
1228
1229
1230
1231 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1232 if err != nil {
1233 t.Fatalf("error dialing: %v", err)
1234 }
1235 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1236 if err != nil {
1237 t.Fatalf("error writing: %v", err)
1238 }
1239
1240
1241 got, _ := io.ReadAll(conn)
1242 expectedSuffix := "\r\n\r\ntoo short"
1243 if !strings.HasSuffix(string(got), expectedSuffix) {
1244 t.Errorf("Expected output to end with %q; got response body %q",
1245 expectedSuffix, string(got))
1246 }
1247 }
1248
1249 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1250 setParallel(t)
1251 s := newClientServerTest(t, http1Mode, h).ts
1252
1253 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1254 if err != nil {
1255 t.Fatal("dial error:", err)
1256 }
1257 defer conn.Close()
1258
1259 _, err = fmt.Fprint(conn, req)
1260 if err != nil {
1261 t.Fatal("print error:", err)
1262 }
1263
1264 r := bufio.NewReader(conn)
1265 res, err := ReadResponse(r, &Request{Method: "GET"})
1266 if err != nil {
1267 t.Fatal("ReadResponse error:", err)
1268 }
1269
1270 _, err = io.ReadAll(r)
1271 if err != nil {
1272 t.Fatal("read error:", err)
1273 }
1274
1275 if !res.Close {
1276 t.Errorf("Response.Close = false; want true")
1277 }
1278 }
1279
1280 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1281 setParallel(t)
1282 ts := newClientServerTest(t, http1Mode, handler).ts
1283 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1284 if err != nil {
1285 t.Fatal(err)
1286 }
1287 defer conn.Close()
1288 br := bufio.NewReader(conn)
1289 for i := 0; i < 2; i++ {
1290 if _, err := io.WriteString(conn, req); err != nil {
1291 t.Fatal(err)
1292 }
1293 res, err := ReadResponse(br, nil)
1294 if err != nil {
1295 t.Fatalf("res %d: %v", i+1, err)
1296 }
1297 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1298 t.Fatalf("res %d body copy: %v", i+1, err)
1299 }
1300 res.Body.Close()
1301 }
1302 }
1303
1304
1305 func TestServeHTTP10Close(t *testing.T) {
1306 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1307 ServeFile(w, r, "testdata/file")
1308 }))
1309 }
1310
1311
1312 func TestClientCanClose(t *testing.T) {
1313 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1314
1315 }))
1316 }
1317
1318
1319
1320 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1321 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1322 w.Header().Set("Connection", "close")
1323 }))
1324 }
1325
1326 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1327 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1328 w.Header().Set("Connection", "close")
1329 }))
1330 }
1331
1332 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1333 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1334
1335
1336 }))
1337 }
1338
1339 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1340 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1341
1342
1343 func TestHTTP10KeepAlive204Response(t *testing.T) {
1344 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1345 }
1346
1347 func TestHTTP11KeepAlive204Response(t *testing.T) {
1348 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1349 }
1350
1351 func TestHTTP10KeepAlive304Response(t *testing.T) {
1352 testTCPConnectionStaysOpen(t,
1353 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1354 HandlerFunc(send304))
1355 }
1356
1357
1358 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1359 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1360 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1361 w.(Flusher).Flush()
1362 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1363 }))
1364 type data struct {
1365 Addr string
1366 }
1367 var addrs [2]data
1368 for i := range addrs {
1369 res, err := cst.c.Get(cst.ts.URL)
1370 if err != nil {
1371 t.Fatal(err)
1372 }
1373 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1374 t.Fatal(err)
1375 }
1376 if addrs[i].Addr == "" {
1377 t.Fatal("no address")
1378 }
1379 res.Body.Close()
1380 }
1381 if addrs[0] != addrs[1] {
1382 t.Fatalf("connection not reused")
1383 }
1384 }
1385
1386 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1387 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1388 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1389 fmt.Fprintf(w, "%s", r.RemoteAddr)
1390 }))
1391
1392 res, err := cst.c.Get(cst.ts.URL)
1393 if err != nil {
1394 t.Fatalf("Get error: %v", err)
1395 }
1396 body, err := io.ReadAll(res.Body)
1397 if err != nil {
1398 t.Fatalf("ReadAll error: %v", err)
1399 }
1400 ip := string(body)
1401 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1402 t.Fatalf("Expected local addr; got %q", ip)
1403 }
1404 }
1405
1406 type blockingRemoteAddrListener struct {
1407 net.Listener
1408 conns chan<- net.Conn
1409 }
1410
1411 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1412 c, err := l.Listener.Accept()
1413 if err != nil {
1414 return nil, err
1415 }
1416 brac := &blockingRemoteAddrConn{
1417 Conn: c,
1418 addrs: make(chan net.Addr, 1),
1419 }
1420 l.conns <- brac
1421 return brac, nil
1422 }
1423
1424 type blockingRemoteAddrConn struct {
1425 net.Conn
1426 addrs chan net.Addr
1427 }
1428
1429 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1430 return <-c.addrs
1431 }
1432
1433
1434 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1435 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1436 }
1437 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1438 conns := make(chan net.Conn)
1439 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1440 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1441 }), func(ts *httptest.Server) {
1442 ts.Listener = &blockingRemoteAddrListener{
1443 Listener: ts.Listener,
1444 conns: conns,
1445 }
1446 }).ts
1447
1448 c := ts.Client()
1449
1450 c.Transport.(*Transport).DisableKeepAlives = true
1451
1452 fetch := func(num int, response chan<- string) {
1453 resp, err := c.Get(ts.URL)
1454 if err != nil {
1455 t.Errorf("Request %d: %v", num, err)
1456 response <- ""
1457 return
1458 }
1459 defer resp.Body.Close()
1460 body, err := io.ReadAll(resp.Body)
1461 if err != nil {
1462 t.Errorf("Request %d: %v", num, err)
1463 response <- ""
1464 return
1465 }
1466 response <- string(body)
1467 }
1468
1469
1470 response1c := make(chan string, 1)
1471 go fetch(1, response1c)
1472
1473
1474 conn1 := <-conns
1475
1476
1477 response2c := make(chan string, 1)
1478 go fetch(2, response2c)
1479 conn2 := <-conns
1480
1481
1482 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1483 IP: net.ParseIP("12.12.12.12"), Port: 12}
1484
1485
1486 response2 := <-response2c
1487 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1488 t.Fatalf("response 2 addr = %q; want %q", g, e)
1489 }
1490
1491
1492 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1493 IP: net.ParseIP("21.21.21.21"), Port: 21}
1494
1495
1496 response1 := <-response1c
1497 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1498 t.Fatalf("response 1 addr = %q; want %q", g, e)
1499 }
1500 }
1501
1502
1503
1504 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1505 func testHeadResponses(t *testing.T, mode testMode) {
1506 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1507 _, err := w.Write([]byte("<html>"))
1508 if err != nil {
1509 t.Errorf("ResponseWriter.Write: %v", err)
1510 }
1511
1512
1513 _, err = io.Copy(w, struct{ io.Reader }{strings.NewReader("789a")})
1514 if err != nil {
1515 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1516 }
1517 }))
1518 res, err := cst.c.Head(cst.ts.URL)
1519 if err != nil {
1520 t.Error(err)
1521 }
1522 if len(res.TransferEncoding) > 0 {
1523 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1524 }
1525 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1526 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1527 }
1528 if v := res.ContentLength; v != 10 {
1529 t.Errorf("Content-Length: %d; want 10", v)
1530 }
1531 body, err := io.ReadAll(res.Body)
1532 if err != nil {
1533 t.Error(err)
1534 }
1535 if len(body) > 0 {
1536 t.Errorf("got unexpected body %q", string(body))
1537 }
1538 }
1539
1540
1541
1542 func TestHeadReaderFrom(t *testing.T) { run(t, testHeadReaderFrom, []testMode{http1Mode}) }
1543 func testHeadReaderFrom(t *testing.T, mode testMode) {
1544
1545 wantBody := strings.Repeat("a", 4096)
1546 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1547 w.(io.ReaderFrom).ReadFrom(strings.NewReader(wantBody))
1548 }))
1549 res, err := cst.c.Head(cst.ts.URL)
1550 if err != nil {
1551 t.Fatal(err)
1552 }
1553 res.Body.Close()
1554 res, err = cst.c.Get(cst.ts.URL)
1555 if err != nil {
1556 t.Fatal(err)
1557 }
1558 gotBody, err := io.ReadAll(res.Body)
1559 res.Body.Close()
1560 if err != nil {
1561 t.Fatal(err)
1562 }
1563 if string(gotBody) != wantBody {
1564 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), len(wantBody))
1565 }
1566 }
1567
1568 func TestTLSHandshakeTimeout(t *testing.T) {
1569 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1570 }
1571 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1572 errLog := new(strings.Builder)
1573 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1574 func(ts *httptest.Server) {
1575 ts.Config.ReadTimeout = 250 * time.Millisecond
1576 ts.Config.ErrorLog = log.New(errLog, "", 0)
1577 },
1578 )
1579 ts := cst.ts
1580
1581 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1582 if err != nil {
1583 t.Fatalf("Dial: %v", err)
1584 }
1585 var buf [1]byte
1586 n, err := conn.Read(buf[:])
1587 if err == nil || n != 0 {
1588 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1589 }
1590 conn.Close()
1591
1592 cst.close()
1593 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1594 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1595 }
1596 }
1597
1598 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1599 func testTLSServer(t *testing.T, mode testMode) {
1600 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1601 if r.TLS != nil {
1602 w.Header().Set("X-TLS-Set", "true")
1603 if r.TLS.HandshakeComplete {
1604 w.Header().Set("X-TLS-HandshakeComplete", "true")
1605 }
1606 }
1607 }), func(ts *httptest.Server) {
1608 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1609 }).ts
1610
1611
1612
1613
1614
1615
1616 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1617 if err != nil {
1618 t.Fatalf("Dial: %v", err)
1619 }
1620 defer idleConn.Close()
1621
1622 if !strings.HasPrefix(ts.URL, "https://") {
1623 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1624 return
1625 }
1626 client := ts.Client()
1627 res, err := client.Get(ts.URL)
1628 if err != nil {
1629 t.Error(err)
1630 return
1631 }
1632 if res == nil {
1633 t.Errorf("got nil Response")
1634 return
1635 }
1636 defer res.Body.Close()
1637 if res.Header.Get("X-TLS-Set") != "true" {
1638 t.Errorf("expected X-TLS-Set response header")
1639 return
1640 }
1641 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1642 t.Errorf("expected X-TLS-HandshakeComplete header")
1643 }
1644 }
1645
1646 func TestServeTLS(t *testing.T) {
1647 CondSkipHTTP2(t)
1648
1649 defer afterTest(t)
1650 defer SetTestHookServerServe(nil)
1651
1652 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1653 if err != nil {
1654 t.Fatal(err)
1655 }
1656 tlsConf := &tls.Config{
1657 Certificates: []tls.Certificate{cert},
1658 }
1659
1660 ln := newLocalListener(t)
1661 defer ln.Close()
1662 addr := ln.Addr().String()
1663
1664 serving := make(chan bool, 1)
1665 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1666 serving <- true
1667 })
1668 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1669 s := &Server{
1670 Addr: addr,
1671 TLSConfig: tlsConf,
1672 Handler: handler,
1673 }
1674 errc := make(chan error, 1)
1675 go func() { errc <- s.ServeTLS(ln, "", "") }()
1676 select {
1677 case err := <-errc:
1678 t.Fatalf("ServeTLS: %v", err)
1679 case <-serving:
1680 }
1681
1682 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1683 InsecureSkipVerify: true,
1684 NextProtos: []string{"h2", "http/1.1"},
1685 })
1686 if err != nil {
1687 t.Fatal(err)
1688 }
1689 defer c.Close()
1690 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1691 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1692 }
1693 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1694 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1695 }
1696 }
1697
1698
1699 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1700 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1701 }
1702 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1703 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1704 t.Error("unexpected HTTPS request")
1705 }), func(ts *httptest.Server) {
1706 var errBuf bytes.Buffer
1707 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1708 }).ts
1709 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1710 if err != nil {
1711 t.Fatal(err)
1712 }
1713 defer conn.Close()
1714 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1715 slurp, err := io.ReadAll(conn)
1716 if err != nil {
1717 t.Fatal(err)
1718 }
1719 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1720 if !strings.HasPrefix(string(slurp), wantPrefix) {
1721 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1722 }
1723 }
1724
1725
1726 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1727 testAutomaticHTTP2_Serve(t, nil, true)
1728 }
1729
1730 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1731 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1732 }
1733
1734 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1735 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1736 }
1737
1738 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1739 setParallel(t)
1740 defer afterTest(t)
1741 ln := newLocalListener(t)
1742 ln.Close()
1743 var s Server
1744 s.TLSConfig = tlsConf
1745 if err := s.Serve(ln); err == nil {
1746 t.Fatal("expected an error")
1747 }
1748 gotH2 := s.TLSNextProto["h2"] != nil
1749 if gotH2 != wantH2 {
1750 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1751 }
1752 }
1753
1754 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1755 setParallel(t)
1756 defer afterTest(t)
1757 ln := newLocalListener(t)
1758 ln.Close()
1759 var s Server
1760
1761
1762 s.TLSConfig = &tls.Config{
1763 NextProtos: []string{"h2"},
1764 }
1765 if err := s.Serve(ln); err == nil {
1766 t.Fatal("expected an error")
1767 }
1768 on := s.TLSNextProto["h2"] != nil
1769 if !on {
1770 t.Errorf("http2 wasn't automatically enabled")
1771 }
1772 }
1773
1774 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1775 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1776 if err != nil {
1777 t.Fatal(err)
1778 }
1779 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1780 Certificates: []tls.Certificate{cert},
1781 })
1782 }
1783
1784 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1785 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1786 if err != nil {
1787 t.Fatal(err)
1788 }
1789 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1790 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1791 return &cert, nil
1792 },
1793 })
1794 }
1795
1796 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1797 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1798 if err != nil {
1799 t.Fatal(err)
1800 }
1801 conf := &tls.Config{
1802
1803
1804 NextProtos: []string{"h2"},
1805 Certificates: []tls.Certificate{cert},
1806 }
1807 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1808 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1809 return conf, nil
1810 },
1811 })
1812 }
1813
1814 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1815 CondSkipHTTP2(t)
1816
1817 defer afterTest(t)
1818 defer SetTestHookServerServe(nil)
1819 var ok bool
1820 var s *Server
1821 const maxTries = 5
1822 var ln net.Listener
1823 Try:
1824 for try := 0; try < maxTries; try++ {
1825 ln = newLocalListener(t)
1826 addr := ln.Addr().String()
1827 ln.Close()
1828 t.Logf("Got %v", addr)
1829 lnc := make(chan net.Listener, 1)
1830 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1831 lnc <- ln
1832 })
1833 s = &Server{
1834 Addr: addr,
1835 TLSConfig: tlsConf,
1836 }
1837 errc := make(chan error, 1)
1838 go func() { errc <- s.ListenAndServeTLS("", "") }()
1839 select {
1840 case err := <-errc:
1841 t.Logf("On try #%v: %v", try+1, err)
1842 continue
1843 case ln = <-lnc:
1844 ok = true
1845 t.Logf("Listening on %v", ln.Addr().String())
1846 break Try
1847 }
1848 }
1849 if !ok {
1850 t.Fatalf("Failed to start up after %d tries", maxTries)
1851 }
1852 defer ln.Close()
1853 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1854 InsecureSkipVerify: true,
1855 NextProtos: []string{"h2", "http/1.1"},
1856 })
1857 if err != nil {
1858 t.Fatal(err)
1859 }
1860 defer c.Close()
1861 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1862 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1863 }
1864 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1865 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1866 }
1867 }
1868
1869 type serverExpectTest struct {
1870 contentLength int
1871 chunked bool
1872 expectation string
1873 readBody bool
1874 expectedResponse string
1875 }
1876
1877 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1878 return serverExpectTest{
1879 contentLength: contentLength,
1880 expectation: expectation,
1881 readBody: readBody,
1882 expectedResponse: expectedResponse,
1883 }
1884 }
1885
1886 var serverExpectTests = []serverExpectTest{
1887
1888 expectTest(100, "100-continue", true, "100 Continue"),
1889 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1890
1891
1892 expectTest(100, "", true, "200 OK"),
1893
1894
1895
1896 expectTest(100, "100-continue", false, "401 Unauthorized"),
1897
1898 expectTest(100, "", false, "401 Unauthorized"),
1899
1900
1901 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1902
1903
1904 expectTest(0, "100-continue", true, "200 OK"),
1905
1906 expectTest(0, "100-continue", false, "401 Unauthorized"),
1907
1908 {
1909 expectation: "100-continue",
1910 readBody: true,
1911 chunked: true,
1912 expectedResponse: "100 Continue",
1913 },
1914 }
1915
1916
1917
1918 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
1919 func testServerExpect(t *testing.T, mode testMode) {
1920 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1921
1922
1923
1924 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1925 io.ReadAll(r.Body)
1926 w.Write([]byte("Hi"))
1927 } else {
1928 w.WriteHeader(StatusUnauthorized)
1929 }
1930 })).ts
1931
1932 runTest := func(test serverExpectTest) {
1933 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1934 if err != nil {
1935 t.Fatalf("Dial: %v", err)
1936 }
1937 defer conn.Close()
1938
1939
1940
1941 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
1942
1943 wg := sync.WaitGroup{}
1944 wg.Add(1)
1945 defer wg.Wait()
1946
1947 go func() {
1948 defer wg.Done()
1949
1950 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
1951 if test.chunked {
1952 contentLen = "Transfer-Encoding: chunked"
1953 }
1954 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
1955 "Connection: close\r\n"+
1956 "%s\r\n"+
1957 "Expect: %s\r\nHost: foo\r\n\r\n",
1958 test.readBody, contentLen, test.expectation)
1959 if err != nil {
1960 t.Errorf("On test %#v, error writing request headers: %v", test, err)
1961 return
1962 }
1963 if writeBody {
1964 var targ io.WriteCloser = struct {
1965 io.Writer
1966 io.Closer
1967 }{
1968 conn,
1969 io.NopCloser(nil),
1970 }
1971 if test.chunked {
1972 targ = httputil.NewChunkedWriter(conn)
1973 }
1974 body := strings.Repeat("A", test.contentLength)
1975 _, err = fmt.Fprint(targ, body)
1976 if err == nil {
1977 err = targ.Close()
1978 }
1979 if err != nil {
1980 if !test.readBody {
1981
1982
1983 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
1984 return
1985 }
1986 t.Errorf("On test %#v, error writing request body: %v", test, err)
1987 }
1988 }
1989 }()
1990 bufr := bufio.NewReader(conn)
1991 line, err := bufr.ReadString('\n')
1992 if err != nil {
1993 if writeBody && !test.readBody {
1994
1995
1996
1997
1998
1999 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
2000 return
2001 }
2002 t.Fatalf("On test %#v, ReadString: %v", test, err)
2003 }
2004 if !strings.Contains(line, test.expectedResponse) {
2005 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
2006 }
2007 }
2008
2009 for _, test := range serverExpectTests {
2010 runTest(test)
2011 }
2012 }
2013
2014
2015
2016 func TestServerUnreadRequestBodyLittle(t *testing.T) {
2017 setParallel(t)
2018 defer afterTest(t)
2019 conn := new(testConn)
2020 body := strings.Repeat("x", 100<<10)
2021 conn.readBuf.Write([]byte(fmt.Sprintf(
2022 "POST / HTTP/1.1\r\n"+
2023 "Host: test\r\n"+
2024 "Content-Length: %d\r\n"+
2025 "\r\n", len(body))))
2026 conn.readBuf.Write([]byte(body))
2027
2028 done := make(chan bool)
2029
2030 readBufLen := func() int {
2031 conn.readMu.Lock()
2032 defer conn.readMu.Unlock()
2033 return conn.readBuf.Len()
2034 }
2035
2036 ls := &oneConnListener{conn}
2037 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2038 defer close(done)
2039 if bufLen := readBufLen(); bufLen < len(body)/2 {
2040 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2041 }
2042 rw.WriteHeader(200)
2043 rw.(Flusher).Flush()
2044 if g, e := readBufLen(), 0; g != e {
2045 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2046 }
2047 if c := rw.Header().Get("Connection"); c != "" {
2048 t.Errorf(`Connection header = %q; want ""`, c)
2049 }
2050 }))
2051 <-done
2052 }
2053
2054
2055
2056
2057 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2058 setParallel(t)
2059 if testing.Short() && testenv.Builder() == "" {
2060 t.Log("skipping in short mode")
2061 }
2062 conn := new(testConn)
2063 body := strings.Repeat("x", 1<<20)
2064 conn.readBuf.Write([]byte(fmt.Sprintf(
2065 "POST / HTTP/1.1\r\n"+
2066 "Host: test\r\n"+
2067 "Content-Length: %d\r\n"+
2068 "\r\n", len(body))))
2069 conn.readBuf.Write([]byte(body))
2070 conn.closec = make(chan bool, 1)
2071
2072 ls := &oneConnListener{conn}
2073 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2074 if conn.readBuf.Len() < len(body)/2 {
2075 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2076 }
2077 rw.WriteHeader(200)
2078 rw.(Flusher).Flush()
2079 if conn.readBuf.Len() < len(body)/2 {
2080 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2081 }
2082 }))
2083 <-conn.closec
2084
2085 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2086 t.Errorf("Expected a Connection: close header; got response: %s", res)
2087 }
2088 }
2089
2090 type handlerBodyCloseTest struct {
2091 bodySize int
2092 bodyChunked bool
2093 reqConnClose bool
2094
2095 wantEOFSearch bool
2096 wantNextReq bool
2097 }
2098
2099 func (t handlerBodyCloseTest) connectionHeader() string {
2100 if t.reqConnClose {
2101 return "Connection: close\r\n"
2102 }
2103 return ""
2104 }
2105
2106 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2107
2108
2109 0: {
2110 bodySize: 20 << 10,
2111 bodyChunked: false,
2112 reqConnClose: false,
2113 wantEOFSearch: true,
2114 wantNextReq: true,
2115 },
2116
2117
2118
2119 1: {
2120 bodySize: 20 << 10,
2121 bodyChunked: true,
2122 reqConnClose: false,
2123 wantEOFSearch: true,
2124 wantNextReq: true,
2125 },
2126
2127
2128
2129
2130 2: {
2131 bodySize: 20 << 10,
2132 bodyChunked: false,
2133 reqConnClose: true,
2134 wantEOFSearch: false,
2135 wantNextReq: false,
2136 },
2137
2138
2139
2140
2141
2142
2143 3: {
2144 bodySize: 20 << 10,
2145 bodyChunked: true,
2146 reqConnClose: true,
2147 wantEOFSearch: true,
2148 wantNextReq: false,
2149 },
2150
2151
2152 4: {
2153 bodySize: 1 << 20,
2154 bodyChunked: false,
2155 reqConnClose: false,
2156 wantEOFSearch: false,
2157 wantNextReq: false,
2158 },
2159
2160
2161 5: {
2162 bodySize: 1 << 20,
2163 bodyChunked: true,
2164 reqConnClose: false,
2165 wantEOFSearch: true,
2166 wantNextReq: false,
2167 },
2168
2169
2170
2171
2172 6: {
2173 bodySize: 1 << 20,
2174 bodyChunked: true,
2175 reqConnClose: true,
2176 wantEOFSearch: true,
2177 wantNextReq: false,
2178 },
2179
2180
2181
2182 7: {
2183 bodySize: 1 << 20,
2184 bodyChunked: false,
2185 reqConnClose: true,
2186 wantEOFSearch: false,
2187 wantNextReq: false,
2188 },
2189 }
2190
2191 func TestHandlerBodyClose(t *testing.T) {
2192 setParallel(t)
2193 if testing.Short() && testenv.Builder() == "" {
2194 t.Skip("skipping in -short mode")
2195 }
2196 for i, tt := range handlerBodyCloseTests {
2197 testHandlerBodyClose(t, i, tt)
2198 }
2199 }
2200
2201 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2202 conn := new(testConn)
2203 body := strings.Repeat("x", tt.bodySize)
2204 if tt.bodyChunked {
2205 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2206 "Host: test\r\n" +
2207 tt.connectionHeader() +
2208 "Transfer-Encoding: chunked\r\n" +
2209 "\r\n")
2210 cw := internal.NewChunkedWriter(&conn.readBuf)
2211 io.WriteString(cw, body)
2212 cw.Close()
2213 conn.readBuf.WriteString("\r\n")
2214 } else {
2215 conn.readBuf.Write([]byte(fmt.Sprintf(
2216 "POST / HTTP/1.1\r\n"+
2217 "Host: test\r\n"+
2218 tt.connectionHeader()+
2219 "Content-Length: %d\r\n"+
2220 "\r\n", len(body))))
2221 conn.readBuf.Write([]byte(body))
2222 }
2223 if !tt.reqConnClose {
2224 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2225 }
2226 conn.closec = make(chan bool, 1)
2227
2228 readBufLen := func() int {
2229 conn.readMu.Lock()
2230 defer conn.readMu.Unlock()
2231 return conn.readBuf.Len()
2232 }
2233
2234 ls := &oneConnListener{conn}
2235 var numReqs int
2236 var size0, size1 int
2237 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2238 numReqs++
2239 if numReqs == 1 {
2240 size0 = readBufLen()
2241 req.Body.Close()
2242 size1 = readBufLen()
2243 }
2244 }))
2245 <-conn.closec
2246 if numReqs < 1 || numReqs > 2 {
2247 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2248 }
2249 didSearch := size0 != size1
2250 if didSearch != tt.wantEOFSearch {
2251 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2252 }
2253 if tt.wantNextReq && numReqs != 2 {
2254 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2255 }
2256 }
2257
2258
2259
2260 type testHandlerBodyConsumer struct {
2261 name string
2262 f func(io.ReadCloser)
2263 }
2264
2265 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2266 {"nil", func(io.ReadCloser) {}},
2267 {"close", func(r io.ReadCloser) { r.Close() }},
2268 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2269 }
2270
2271 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2272 setParallel(t)
2273 defer afterTest(t)
2274 for _, handler := range testHandlerBodyConsumers {
2275 conn := new(testConn)
2276 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2277 "Host: test\r\n" +
2278 "Transfer-Encoding: chunked\r\n" +
2279 "\r\n" +
2280 "hax\r\n" +
2281 "GET /secret HTTP/1.1\r\n" +
2282 "Host: test\r\n" +
2283 "\r\n")
2284
2285 conn.closec = make(chan bool, 1)
2286 ls := &oneConnListener{conn}
2287 var numReqs int
2288 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2289 numReqs++
2290 if strings.Contains(req.URL.Path, "secret") {
2291 t.Error("Request for /secret encountered, should not have happened.")
2292 }
2293 handler.f(req.Body)
2294 }))
2295 <-conn.closec
2296 if numReqs != 1 {
2297 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2298 }
2299 }
2300 }
2301
2302 func TestInvalidTrailerClosesConnection(t *testing.T) {
2303 setParallel(t)
2304 defer afterTest(t)
2305 for _, handler := range testHandlerBodyConsumers {
2306 conn := new(testConn)
2307 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2308 "Host: test\r\n" +
2309 "Trailer: hack\r\n" +
2310 "Transfer-Encoding: chunked\r\n" +
2311 "\r\n" +
2312 "3\r\n" +
2313 "hax\r\n" +
2314 "0\r\n" +
2315 "I'm not a valid trailer\r\n" +
2316 "GET /secret HTTP/1.1\r\n" +
2317 "Host: test\r\n" +
2318 "\r\n")
2319
2320 conn.closec = make(chan bool, 1)
2321 ln := &oneConnListener{conn}
2322 var numReqs int
2323 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2324 numReqs++
2325 if strings.Contains(req.URL.Path, "secret") {
2326 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2327 }
2328 handler.f(req.Body)
2329 }))
2330 <-conn.closec
2331 if numReqs != 1 {
2332 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2333 }
2334 }
2335 }
2336
2337
2338
2339
2340 type slowTestConn struct {
2341
2342 script []any
2343 closec chan bool
2344
2345 mu sync.Mutex
2346 rd, wd time.Time
2347 noopConn
2348 }
2349
2350 func (c *slowTestConn) SetDeadline(t time.Time) error {
2351 c.SetReadDeadline(t)
2352 c.SetWriteDeadline(t)
2353 return nil
2354 }
2355
2356 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2357 c.mu.Lock()
2358 defer c.mu.Unlock()
2359 c.rd = t
2360 return nil
2361 }
2362
2363 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2364 c.mu.Lock()
2365 defer c.mu.Unlock()
2366 c.wd = t
2367 return nil
2368 }
2369
2370 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2371 c.mu.Lock()
2372 defer c.mu.Unlock()
2373 restart:
2374 if !c.rd.IsZero() && time.Now().After(c.rd) {
2375 return 0, syscall.ETIMEDOUT
2376 }
2377 if len(c.script) == 0 {
2378 return 0, io.EOF
2379 }
2380
2381 switch cue := c.script[0].(type) {
2382 case time.Duration:
2383 if !c.rd.IsZero() {
2384
2385
2386 if remaining := time.Until(c.rd); remaining < cue {
2387 c.script[0] = cue - remaining
2388 time.Sleep(remaining)
2389 return 0, syscall.ETIMEDOUT
2390 }
2391 }
2392 c.script = c.script[1:]
2393 time.Sleep(cue)
2394 goto restart
2395
2396 case string:
2397 n = copy(b, cue)
2398
2399 if len(cue) > n {
2400 c.script[0] = cue[n:]
2401 } else {
2402 c.script = c.script[1:]
2403 }
2404
2405 default:
2406 panic("unknown cue in slowTestConn script")
2407 }
2408
2409 return
2410 }
2411
2412 func (c *slowTestConn) Close() error {
2413 select {
2414 case c.closec <- true:
2415 default:
2416 }
2417 return nil
2418 }
2419
2420 func (c *slowTestConn) Write(b []byte) (int, error) {
2421 if !c.wd.IsZero() && time.Now().After(c.wd) {
2422 return 0, syscall.ETIMEDOUT
2423 }
2424 return len(b), nil
2425 }
2426
2427 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2428 if testing.Short() {
2429 t.Skip("skipping in -short mode")
2430 }
2431 defer afterTest(t)
2432 for _, handler := range testHandlerBodyConsumers {
2433 conn := &slowTestConn{
2434 script: []any{
2435 "POST /public HTTP/1.1\r\n" +
2436 "Host: test\r\n" +
2437 "Content-Length: 10000\r\n" +
2438 "\r\n",
2439 "foo bar baz",
2440 600 * time.Millisecond,
2441 "GET /secret HTTP/1.1\r\n" +
2442 "Host: test\r\n" +
2443 "\r\n",
2444 },
2445 closec: make(chan bool, 1),
2446 }
2447 ls := &oneConnListener{conn}
2448
2449 var numReqs int
2450 s := Server{
2451 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2452 numReqs++
2453 if strings.Contains(req.URL.Path, "secret") {
2454 t.Error("Request for /secret encountered, should not have happened.")
2455 }
2456 handler.f(req.Body)
2457 }),
2458 ReadTimeout: 400 * time.Millisecond,
2459 }
2460 go s.Serve(ls)
2461 <-conn.closec
2462
2463 if numReqs != 1 {
2464 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2465 }
2466 }
2467 }
2468
2469
2470 type cancelableTimeoutContext struct {
2471 context.Context
2472 }
2473
2474 func (c cancelableTimeoutContext) Err() error {
2475 if c.Context.Err() != nil {
2476 return context.DeadlineExceeded
2477 }
2478 return nil
2479 }
2480
2481 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2482 func testTimeoutHandler(t *testing.T, mode testMode) {
2483 sendHi := make(chan bool, 1)
2484 writeErrors := make(chan error, 1)
2485 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2486 <-sendHi
2487 _, werr := w.Write([]byte("hi"))
2488 writeErrors <- werr
2489 })
2490 ctx, cancel := context.WithCancel(context.Background())
2491 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2492 cst := newClientServerTest(t, mode, h)
2493
2494
2495 sendHi <- true
2496 res, err := cst.c.Get(cst.ts.URL)
2497 if err != nil {
2498 t.Error(err)
2499 }
2500 if g, e := res.StatusCode, StatusOK; g != e {
2501 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2502 }
2503 body, _ := io.ReadAll(res.Body)
2504 if g, e := string(body), "hi"; g != e {
2505 t.Errorf("got body %q; expected %q", g, e)
2506 }
2507 if g := <-writeErrors; g != nil {
2508 t.Errorf("got unexpected Write error on first request: %v", g)
2509 }
2510
2511
2512 cancel()
2513
2514 res, err = cst.c.Get(cst.ts.URL)
2515 if err != nil {
2516 t.Error(err)
2517 }
2518 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2519 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2520 }
2521 body, _ = io.ReadAll(res.Body)
2522 if !strings.Contains(string(body), "<title>Timeout</title>") {
2523 t.Errorf("expected timeout body; got %q", string(body))
2524 }
2525 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2526 t.Errorf("response content-type = %q; want %q", g, w)
2527 }
2528
2529
2530
2531 sendHi <- true
2532 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2533 t.Errorf("expected Write error of %v; got %v", e, g)
2534 }
2535 }
2536
2537
2538 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2539 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2540 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2541 ms, _ := strconv.Atoi(r.URL.Path[1:])
2542 if ms == 0 {
2543 ms = 1
2544 }
2545 for i := 0; i < ms; i++ {
2546 w.Write([]byte("hi"))
2547 time.Sleep(time.Millisecond)
2548 }
2549 })
2550
2551 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2552
2553 c := ts.Client()
2554
2555 var wg sync.WaitGroup
2556 gate := make(chan bool, 10)
2557 n := 50
2558 if testing.Short() {
2559 n = 10
2560 gate = make(chan bool, 3)
2561 }
2562 for i := 0; i < n; i++ {
2563 gate <- true
2564 wg.Add(1)
2565 go func() {
2566 defer wg.Done()
2567 defer func() { <-gate }()
2568 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2569 if err == nil {
2570 io.Copy(io.Discard, res.Body)
2571 res.Body.Close()
2572 }
2573 }()
2574 }
2575 wg.Wait()
2576 }
2577
2578
2579
2580 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2581 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2582 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2583 w.WriteHeader(204)
2584 })
2585
2586 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2587
2588 var wg sync.WaitGroup
2589 gate := make(chan bool, 50)
2590 n := 500
2591 if testing.Short() {
2592 n = 10
2593 }
2594
2595 c := ts.Client()
2596 for i := 0; i < n; i++ {
2597 gate <- true
2598 wg.Add(1)
2599 go func() {
2600 defer wg.Done()
2601 defer func() { <-gate }()
2602 res, err := c.Get(ts.URL)
2603 if err != nil {
2604
2605
2606 t.Log(err)
2607 return
2608 }
2609 defer res.Body.Close()
2610 io.Copy(io.Discard, res.Body)
2611 }()
2612 }
2613 wg.Wait()
2614 }
2615
2616
2617 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2618 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2619 sendHi := make(chan bool, 1)
2620 writeErrors := make(chan error, 1)
2621 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2622 w.Header().Set("Content-Type", "text/plain")
2623 <-sendHi
2624 _, werr := w.Write([]byte("hi"))
2625 writeErrors <- werr
2626 })
2627 ctx, cancel := context.WithCancel(context.Background())
2628 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2629 cst := newClientServerTest(t, mode, h)
2630
2631
2632 sendHi <- true
2633 res, err := cst.c.Get(cst.ts.URL)
2634 if err != nil {
2635 t.Error(err)
2636 }
2637 if g, e := res.StatusCode, StatusOK; g != e {
2638 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2639 }
2640 body, _ := io.ReadAll(res.Body)
2641 if g, e := string(body), "hi"; g != e {
2642 t.Errorf("got body %q; expected %q", g, e)
2643 }
2644 if g := <-writeErrors; g != nil {
2645 t.Errorf("got unexpected Write error on first request: %v", g)
2646 }
2647
2648
2649 cancel()
2650
2651 res, err = cst.c.Get(cst.ts.URL)
2652 if err != nil {
2653 t.Error(err)
2654 }
2655 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2656 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2657 }
2658 body, _ = io.ReadAll(res.Body)
2659 if !strings.Contains(string(body), "<title>Timeout</title>") {
2660 t.Errorf("expected timeout body; got %q", string(body))
2661 }
2662
2663
2664
2665 sendHi <- true
2666 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2667 t.Errorf("expected Write error of %v; got %v", e, g)
2668 }
2669 }
2670
2671
2672 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2673 run(t, testTimeoutHandlerStartTimerWhenServing)
2674 }
2675 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2676 if testing.Short() {
2677 t.Skip("skipping sleeping test in -short mode")
2678 }
2679 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2680 w.WriteHeader(StatusNoContent)
2681 }
2682 timeout := 300 * time.Millisecond
2683 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2684 defer ts.Close()
2685
2686 c := ts.Client()
2687
2688
2689
2690
2691 time.Sleep(2 * timeout)
2692 res, err := c.Get(ts.URL)
2693 if err != nil {
2694 t.Fatal(err)
2695 }
2696 defer res.Body.Close()
2697 if res.StatusCode != StatusNoContent {
2698 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2699 }
2700 }
2701
2702 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2703 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2704 writeErrors := make(chan error, 1)
2705 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2706 w.Header().Set("Content-Type", "text/plain")
2707 var err error
2708
2709
2710
2711 for i := 0; i < 100; i++ {
2712 _, err = w.Write([]byte("a"))
2713 if err != nil {
2714 break
2715 }
2716 time.Sleep(1 * time.Millisecond)
2717 }
2718 writeErrors <- err
2719 })
2720 ctx, cancel := context.WithCancel(context.Background())
2721 cancel()
2722 h := NewTestTimeoutHandler(sayHi, ctx)
2723 cst := newClientServerTest(t, mode, h)
2724 defer cst.close()
2725
2726 res, err := cst.c.Get(cst.ts.URL)
2727 if err != nil {
2728 t.Error(err)
2729 }
2730 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2731 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2732 }
2733 body, _ := io.ReadAll(res.Body)
2734 if g, e := string(body), ""; g != e {
2735 t.Errorf("got body %q; expected %q", g, e)
2736 }
2737 if g, e := <-writeErrors, context.Canceled; g != e {
2738 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2739 }
2740 }
2741
2742
2743 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2744 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2745 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2746
2747 }
2748 timeout := 300 * time.Millisecond
2749 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2750
2751 c := ts.Client()
2752
2753 res, err := c.Get(ts.URL)
2754 if err != nil {
2755 t.Fatal(err)
2756 }
2757 defer res.Body.Close()
2758 if res.StatusCode != StatusOK {
2759 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2760 }
2761 }
2762
2763
2764 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2765 wrapper := func(h Handler) Handler {
2766 return TimeoutHandler(h, time.Second, "")
2767 }
2768 run(t, func(t *testing.T, mode testMode) {
2769 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2770 }, testNotParallel)
2771 }
2772
2773 func TestRedirectBadPath(t *testing.T) {
2774
2775
2776 rr := httptest.NewRecorder()
2777 req := &Request{
2778 Method: "GET",
2779 URL: &url.URL{
2780 Scheme: "http",
2781 Path: "not-empty-but-no-leading-slash",
2782 },
2783 }
2784 Redirect(rr, req, "", 304)
2785 if rr.Code != 304 {
2786 t.Errorf("Code = %d; want 304", rr.Code)
2787 }
2788 }
2789
2790
2791 func TestRedirect(t *testing.T) {
2792 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2793
2794 var tests = []struct {
2795 in string
2796 want string
2797 }{
2798
2799 {"http://foobar.com/baz", "http://foobar.com/baz"},
2800
2801 {"https://foobar.com/baz", "https://foobar.com/baz"},
2802
2803 {"test://foobar.com/baz", "test://foobar.com/baz"},
2804
2805 {"//foobar.com/baz", "//foobar.com/baz"},
2806
2807 {"/foobar.com/baz", "/foobar.com/baz"},
2808
2809 {"foobar.com/baz", "/qux/foobar.com/baz"},
2810
2811 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2812
2813 {"///foobar.com/baz", "/foobar.com/baz"},
2814
2815
2816 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2817 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2818 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2819
2820 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2821 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2822 }
2823
2824 for _, tt := range tests {
2825 rec := httptest.NewRecorder()
2826 Redirect(rec, req, tt.in, 302)
2827 if got, want := rec.Code, 302; got != want {
2828 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2829 }
2830 if got := rec.Header().Get("Location"); got != tt.want {
2831 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2832 }
2833 }
2834 }
2835
2836
2837
2838 func TestRedirectContentTypeAndBody(t *testing.T) {
2839 type ctHeader struct {
2840 Values []string
2841 }
2842
2843 var tests = []struct {
2844 method string
2845 ct *ctHeader
2846 wantCT string
2847 wantBody string
2848 }{
2849 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2850 {MethodHead, nil, "text/html; charset=utf-8", ""},
2851 {MethodPost, nil, "", ""},
2852 {MethodDelete, nil, "", ""},
2853 {"foo", nil, "", ""},
2854 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2855 {MethodGet, &ctHeader{[]string{}}, "", ""},
2856 {MethodGet, &ctHeader{nil}, "", ""},
2857 }
2858 for _, tt := range tests {
2859 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2860 rec := httptest.NewRecorder()
2861 if tt.ct != nil {
2862 rec.Header()["Content-Type"] = tt.ct.Values
2863 }
2864 Redirect(rec, req, "/foo", 302)
2865 if got, want := rec.Code, 302; got != want {
2866 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2867 }
2868 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2869 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2870 }
2871 resp := rec.Result()
2872 body, err := io.ReadAll(resp.Body)
2873 if err != nil {
2874 t.Fatal(err)
2875 }
2876 if got, want := string(body), tt.wantBody; got != want {
2877 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2878 }
2879 }
2880 }
2881
2882
2883
2884
2885
2886
2887
2888 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
2889
2890 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
2891 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
2892 all, err := io.ReadAll(r.Body)
2893 if err != nil {
2894 t.Fatalf("handler ReadAll: %v", err)
2895 }
2896 if len(all) != 0 {
2897 t.Errorf("handler got %d bytes; expected 0", len(all))
2898 }
2899 rw.Header().Set("Content-Length", "0")
2900 }))
2901
2902 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2903 if err != nil {
2904 t.Fatal(err)
2905 }
2906 req.ContentLength = 0
2907
2908 var resp [5]*Response
2909 for i := range resp {
2910 resp[i], err = cst.c.Do(req)
2911 if err != nil {
2912 t.Fatalf("client post #%d: %v", i, err)
2913 }
2914 }
2915
2916 for i := range resp {
2917 all, err := io.ReadAll(resp[i].Body)
2918 if err != nil {
2919 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2920 }
2921 if len(all) != 0 {
2922 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2923 }
2924 }
2925 }
2926
2927 func TestHandlerPanicNil(t *testing.T) {
2928 run(t, func(t *testing.T, mode testMode) {
2929 testHandlerPanic(t, false, mode, nil, nil)
2930 }, testNotParallel)
2931 }
2932
2933 func TestHandlerPanic(t *testing.T) {
2934 run(t, func(t *testing.T, mode testMode) {
2935 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
2936 }, testNotParallel)
2937 }
2938
2939 func TestHandlerPanicWithHijack(t *testing.T) {
2940
2941 run(t, func(t *testing.T, mode testMode) {
2942 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
2943 }, []testMode{http1Mode})
2944 }
2945
2946 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
2947
2948
2949
2950
2951
2952
2953
2954
2955 pr, pw := io.Pipe()
2956 defer pw.Close()
2957
2958 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
2959 if withHijack {
2960 rwc, _, err := w.(Hijacker).Hijack()
2961 if err != nil {
2962 t.Logf("unexpected error: %v", err)
2963 }
2964 defer rwc.Close()
2965 }
2966 panic(panicValue)
2967 })
2968 if wrapper != nil {
2969 handler = wrapper(handler)
2970 }
2971 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
2972 ts.Config.ErrorLog = log.New(pw, "", 0)
2973 })
2974
2975
2976 done := make(chan bool, 1)
2977 go func() {
2978 buf := make([]byte, 4<<10)
2979 _, err := pr.Read(buf)
2980 pr.Close()
2981 if err != nil && err != io.EOF {
2982 t.Error(err)
2983 }
2984 done <- true
2985 }()
2986
2987 _, err := cst.c.Get(cst.ts.URL)
2988 if err == nil {
2989 t.Logf("expected an error")
2990 }
2991
2992 if panicValue == nil {
2993 return
2994 }
2995
2996 <-done
2997 }
2998
2999 type terrorWriter struct{ t *testing.T }
3000
3001 func (w terrorWriter) Write(p []byte) (int, error) {
3002 w.t.Errorf("%s", p)
3003 return len(p), nil
3004 }
3005
3006
3007
3008 func TestServerWriteHijackZeroBytes(t *testing.T) {
3009 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
3010 }
3011 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
3012 done := make(chan struct{})
3013 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3014 defer close(done)
3015 w.(Flusher).Flush()
3016 conn, _, err := w.(Hijacker).Hijack()
3017 if err != nil {
3018 t.Errorf("Hijack: %v", err)
3019 return
3020 }
3021 defer conn.Close()
3022 _, err = w.Write(nil)
3023 if err != ErrHijacked {
3024 t.Errorf("Write error = %v; want ErrHijacked", err)
3025 }
3026 }), func(ts *httptest.Server) {
3027 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
3028 }).ts
3029
3030 c := ts.Client()
3031 res, err := c.Get(ts.URL)
3032 if err != nil {
3033 t.Fatal(err)
3034 }
3035 res.Body.Close()
3036 <-done
3037 }
3038
3039 func TestServerNoDate(t *testing.T) {
3040 run(t, func(t *testing.T, mode testMode) {
3041 testServerNoHeader(t, mode, "Date")
3042 })
3043 }
3044
3045 func TestServerContentType(t *testing.T) {
3046 run(t, func(t *testing.T, mode testMode) {
3047 testServerNoHeader(t, mode, "Content-Type")
3048 })
3049 }
3050
3051 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3052 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3053 w.Header()[header] = nil
3054 io.WriteString(w, "<html>foo</html>")
3055 }))
3056 res, err := cst.c.Get(cst.ts.URL)
3057 if err != nil {
3058 t.Fatal(err)
3059 }
3060 res.Body.Close()
3061 if got, ok := res.Header[header]; ok {
3062 t.Fatalf("Expected no %s header; got %q", header, got)
3063 }
3064 }
3065
3066 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3067 func testStripPrefix(t *testing.T, mode testMode) {
3068 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3069 w.Header().Set("X-Path", r.URL.Path)
3070 w.Header().Set("X-RawPath", r.URL.RawPath)
3071 })
3072 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3073
3074 c := ts.Client()
3075
3076 cases := []struct {
3077 reqPath string
3078 path string
3079 rawPath string
3080 }{
3081 {"/foo/bar/qux", "/qux", ""},
3082 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3083 {"/foo%2Fbar/qux", "", ""},
3084 {"/bar", "", ""},
3085 }
3086 for _, tc := range cases {
3087 t.Run(tc.reqPath, func(t *testing.T) {
3088 res, err := c.Get(ts.URL + tc.reqPath)
3089 if err != nil {
3090 t.Fatal(err)
3091 }
3092 res.Body.Close()
3093 if tc.path == "" {
3094 if res.StatusCode != StatusNotFound {
3095 t.Errorf("got %q, want 404 Not Found", res.Status)
3096 }
3097 return
3098 }
3099 if res.StatusCode != StatusOK {
3100 t.Fatalf("got %q, want 200 OK", res.Status)
3101 }
3102 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3103 t.Errorf("got Path %q, want %q", g, w)
3104 }
3105 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3106 t.Errorf("got RawPath %q, want %q", g, w)
3107 }
3108 })
3109 }
3110 }
3111
3112
3113 func TestStripPrefixNotModifyRequest(t *testing.T) {
3114 h := StripPrefix("/foo", NotFoundHandler())
3115 req := httptest.NewRequest("GET", "/foo/bar", nil)
3116 h.ServeHTTP(httptest.NewRecorder(), req)
3117 if req.URL.Path != "/foo/bar" {
3118 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3119 }
3120 }
3121
3122 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
3123 func testRequestLimit(t *testing.T, mode testMode) {
3124 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3125 t.Fatalf("didn't expect to get request in Handler")
3126 }), optQuietLog)
3127 req, _ := NewRequest("GET", cst.ts.URL, nil)
3128 var bytesPerHeader = len("header12345: val12345\r\n")
3129 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3130 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3131 }
3132 res, err := cst.c.Do(req)
3133 if res != nil {
3134 defer res.Body.Close()
3135 }
3136 if mode == http2Mode {
3137
3138
3139
3140
3141 if err == nil && res.StatusCode != 431 {
3142 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3143 }
3144 } else {
3145
3146
3147
3148
3149 if err != nil {
3150 t.Fatalf("Do: %v", err)
3151 }
3152 if res.StatusCode != 431 {
3153 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3154 }
3155 }
3156 }
3157
3158 type neverEnding byte
3159
3160 func (b neverEnding) Read(p []byte) (n int, err error) {
3161 for i := range p {
3162 p[i] = byte(b)
3163 }
3164 return len(p), nil
3165 }
3166
3167 type bodyLimitReader struct {
3168 mu sync.Mutex
3169 count int
3170 limit int
3171 closed chan struct{}
3172 }
3173
3174 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3175 r.mu.Lock()
3176 defer r.mu.Unlock()
3177 select {
3178 case <-r.closed:
3179 return 0, errors.New("closed")
3180 default:
3181 }
3182 if r.count > r.limit {
3183 return 0, errors.New("at limit")
3184 }
3185 r.count += len(p)
3186 for i := range p {
3187 p[i] = 'a'
3188 }
3189 return len(p), nil
3190 }
3191
3192 func (r *bodyLimitReader) Close() error {
3193 r.mu.Lock()
3194 defer r.mu.Unlock()
3195 close(r.closed)
3196 return nil
3197 }
3198
3199 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3200 func testRequestBodyLimit(t *testing.T, mode testMode) {
3201 const limit = 1 << 20
3202 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3203 r.Body = MaxBytesReader(w, r.Body, limit)
3204 n, err := io.Copy(io.Discard, r.Body)
3205 if err == nil {
3206 t.Errorf("expected error from io.Copy")
3207 }
3208 if n != limit {
3209 t.Errorf("io.Copy = %d, want %d", n, limit)
3210 }
3211 mbErr, ok := err.(*MaxBytesError)
3212 if !ok {
3213 t.Errorf("expected MaxBytesError, got %T", err)
3214 }
3215 if mbErr.Limit != limit {
3216 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3217 }
3218 }))
3219
3220 body := &bodyLimitReader{
3221 closed: make(chan struct{}),
3222 limit: limit * 200,
3223 }
3224 req, _ := NewRequest("POST", cst.ts.URL, body)
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235 resp, err := cst.c.Do(req)
3236 if err == nil {
3237 resp.Body.Close()
3238 }
3239
3240
3241 <-body.closed
3242
3243 if body.count > limit*100 {
3244 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3245 limit, body.count)
3246 }
3247 }
3248
3249
3250
3251 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3252 func testClientWriteShutdown(t *testing.T, mode testMode) {
3253 if runtime.GOOS == "plan9" {
3254 t.Skip("skipping test; see https://golang.org/issue/17906")
3255 }
3256 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3257 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3258 if err != nil {
3259 t.Fatalf("Dial: %v", err)
3260 }
3261 err = conn.(*net.TCPConn).CloseWrite()
3262 if err != nil {
3263 t.Fatalf("CloseWrite: %v", err)
3264 }
3265
3266 bs, err := io.ReadAll(conn)
3267 if err != nil {
3268 t.Errorf("ReadAll: %v", err)
3269 }
3270 got := string(bs)
3271 if got != "" {
3272 t.Errorf("read %q from server; want nothing", got)
3273 }
3274 }
3275
3276
3277
3278 func TestServerBufferedChunking(t *testing.T) {
3279 conn := new(testConn)
3280 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3281 conn.closec = make(chan bool, 1)
3282 ls := &oneConnListener{conn}
3283 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3284 rw.(Flusher).Flush()
3285 rw.Write([]byte{'x'})
3286 rw.Write([]byte{'y'})
3287 rw.Write([]byte{'z'})
3288 }))
3289 <-conn.closec
3290 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3291 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3292 conn.writeBuf.Bytes())
3293 }
3294 }
3295
3296
3297
3298
3299
3300 func TestServerGracefulClose(t *testing.T) {
3301
3302 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3303 }
3304 func testServerGracefulClose(t *testing.T, mode testMode) {
3305 runTimeSensitiveTest(t, []time.Duration{
3306 1 * time.Millisecond,
3307 5 * time.Millisecond,
3308 10 * time.Millisecond,
3309 50 * time.Millisecond,
3310 100 * time.Millisecond,
3311 500 * time.Millisecond,
3312 time.Second,
3313 5 * time.Second,
3314 }, func(t *testing.T, timeout time.Duration) error {
3315 SetRSTAvoidanceDelay(t, timeout)
3316 t.Logf("set RST avoidance delay to %v", timeout)
3317
3318 const bodySize = 5 << 20
3319 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3320 for i := 0; i < bodySize; i++ {
3321 req = append(req, 'x')
3322 }
3323
3324 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3325 Error(w, "bye", StatusUnauthorized)
3326 }))
3327
3328
3329 defer cst.close()
3330 ts := cst.ts
3331
3332 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3333 if err != nil {
3334 return err
3335 }
3336 writeErr := make(chan error)
3337 go func() {
3338 _, err := conn.Write(req)
3339 writeErr <- err
3340 }()
3341 defer func() {
3342 conn.Close()
3343
3344
3345
3346 <-writeErr
3347 }()
3348
3349 br := bufio.NewReader(conn)
3350 lineNum := 0
3351 for {
3352 line, err := br.ReadString('\n')
3353 if err == io.EOF {
3354 break
3355 }
3356 if err != nil {
3357 return fmt.Errorf("ReadLine: %v", err)
3358 }
3359 lineNum++
3360 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3361 t.Errorf("Response line = %q; want a 401", line)
3362 }
3363 }
3364 return nil
3365 })
3366 }
3367
3368 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3369 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3370 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3371 if r.Method != "get" {
3372 t.Errorf(`Got method %q; want "get"`, r.Method)
3373 }
3374 }))
3375 defer cst.close()
3376 req, _ := NewRequest("get", cst.ts.URL, nil)
3377 res, err := cst.c.Do(req)
3378 if err != nil {
3379 t.Error(err)
3380 return
3381 }
3382
3383 res.Body.Close()
3384 }
3385
3386
3387
3388
3389
3390 func TestContentLengthZero(t *testing.T) {
3391 run(t, testContentLengthZero, []testMode{http1Mode})
3392 }
3393 func testContentLengthZero(t *testing.T, mode testMode) {
3394 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3395
3396 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3397 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3398 if err != nil {
3399 t.Fatalf("error dialing: %v", err)
3400 }
3401 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3402 if err != nil {
3403 t.Fatalf("error writing: %v", err)
3404 }
3405 req, _ := NewRequest("GET", "/", nil)
3406 res, err := ReadResponse(bufio.NewReader(conn), req)
3407 if err != nil {
3408 t.Fatalf("error reading response: %v", err)
3409 }
3410 if te := res.TransferEncoding; len(te) > 0 {
3411 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3412 }
3413 if cl := res.ContentLength; cl != 0 {
3414 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3415 }
3416 conn.Close()
3417 }
3418 }
3419
3420 func TestCloseNotifier(t *testing.T) {
3421 run(t, testCloseNotifier, []testMode{http1Mode})
3422 }
3423 func testCloseNotifier(t *testing.T, mode testMode) {
3424 gotReq := make(chan bool, 1)
3425 sawClose := make(chan bool, 1)
3426 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3427 gotReq <- true
3428 cc := rw.(CloseNotifier).CloseNotify()
3429 <-cc
3430 sawClose <- true
3431 })).ts
3432 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3433 if err != nil {
3434 t.Fatalf("error dialing: %v", err)
3435 }
3436 diec := make(chan bool)
3437 go func() {
3438 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3439 if err != nil {
3440 t.Error(err)
3441 return
3442 }
3443 <-diec
3444 conn.Close()
3445 }()
3446 For:
3447 for {
3448 select {
3449 case <-gotReq:
3450 diec <- true
3451 case <-sawClose:
3452 break For
3453 }
3454 }
3455 ts.Close()
3456 }
3457
3458
3459
3460
3461
3462 func TestCloseNotifierPipelined(t *testing.T) {
3463 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3464 }
3465 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3466 gotReq := make(chan bool, 2)
3467 sawClose := make(chan bool, 2)
3468 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3469 gotReq <- true
3470 cc := rw.(CloseNotifier).CloseNotify()
3471 select {
3472 case <-cc:
3473 t.Error("unexpected CloseNotify")
3474 case <-time.After(100 * time.Millisecond):
3475 }
3476 sawClose <- true
3477 })).ts
3478 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3479 if err != nil {
3480 t.Fatalf("error dialing: %v", err)
3481 }
3482 diec := make(chan bool, 1)
3483 defer close(diec)
3484 go func() {
3485 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3486 _, err = io.WriteString(conn, req+req)
3487 if err != nil {
3488 t.Error(err)
3489 return
3490 }
3491 <-diec
3492 conn.Close()
3493 }()
3494 reqs := 0
3495 closes := 0
3496 for {
3497 select {
3498 case <-gotReq:
3499 reqs++
3500 if reqs > 2 {
3501 t.Fatal("too many requests")
3502 }
3503 case <-sawClose:
3504 closes++
3505 if closes > 1 {
3506 return
3507 }
3508 }
3509 }
3510 }
3511
3512 func TestCloseNotifierChanLeak(t *testing.T) {
3513 defer afterTest(t)
3514 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3515 for i := 0; i < 20; i++ {
3516 var output bytes.Buffer
3517 conn := &rwTestConn{
3518 Reader: bytes.NewReader(req),
3519 Writer: &output,
3520 closec: make(chan bool, 1),
3521 }
3522 ln := &oneConnListener{conn: conn}
3523 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3524
3525
3526
3527 _ = rw.(CloseNotifier).CloseNotify()
3528 })
3529 go Serve(ln, handler)
3530 <-conn.closec
3531 }
3532 }
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543 func TestHijackAfterCloseNotifier(t *testing.T) {
3544 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3545 }
3546 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3547 script := make(chan string, 2)
3548 script <- "closenotify"
3549 script <- "hijack"
3550 close(script)
3551 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3552 plan := <-script
3553 switch plan {
3554 default:
3555 panic("bogus plan; too many requests")
3556 case "closenotify":
3557 w.(CloseNotifier).CloseNotify()
3558 w.Header().Set("X-Addr", r.RemoteAddr)
3559 case "hijack":
3560 c, _, err := w.(Hijacker).Hijack()
3561 if err != nil {
3562 t.Errorf("Hijack in Handler: %v", err)
3563 return
3564 }
3565 if _, ok := c.(*net.TCPConn); !ok {
3566
3567
3568 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3569 }
3570 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3571 c.Close()
3572 return
3573 }
3574 })).ts
3575 res1, err := ts.Client().Get(ts.URL)
3576 if err != nil {
3577 log.Fatal(err)
3578 }
3579 res2, err := ts.Client().Get(ts.URL)
3580 if err != nil {
3581 log.Fatal(err)
3582 }
3583 addr1 := res1.Header.Get("X-Addr")
3584 addr2 := res2.Header.Get("X-Addr")
3585 if addr1 == "" || addr1 != addr2 {
3586 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3587 }
3588 }
3589
3590 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3591 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3592 }
3593 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3594 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3595 bodyOkay := make(chan bool, 1)
3596 gotCloseNotify := make(chan bool, 1)
3597 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3598 defer close(bodyOkay)
3599
3600 reqBody := r.Body
3601 r.Body = nil
3602
3603 gone := w.(CloseNotifier).CloseNotify()
3604 slurp, err := io.ReadAll(reqBody)
3605 if err != nil {
3606 t.Errorf("Body read: %v", err)
3607 return
3608 }
3609 if len(slurp) != len(requestBody) {
3610 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3611 return
3612 }
3613 if !bytes.Equal(slurp, requestBody) {
3614 t.Error("Backend read wrong request body.")
3615 return
3616 }
3617 bodyOkay <- true
3618 <-gone
3619 gotCloseNotify <- true
3620 })).ts
3621
3622 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3623 if err != nil {
3624 t.Fatal(err)
3625 }
3626 defer conn.Close()
3627
3628 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3629 len(requestBody), requestBody)
3630 if !<-bodyOkay {
3631
3632 return
3633 }
3634 conn.Close()
3635 <-gotCloseNotify
3636 }
3637
3638 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3639 func testOptions(t *testing.T, mode testMode) {
3640 uric := make(chan string, 2)
3641 mux := NewServeMux()
3642 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3643 uric <- r.RequestURI
3644 })
3645 ts := newClientServerTest(t, mode, mux).ts
3646
3647 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3648 if err != nil {
3649 t.Fatal(err)
3650 }
3651 defer conn.Close()
3652
3653
3654 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3655 if err != nil {
3656 t.Fatal(err)
3657 }
3658 br := bufio.NewReader(conn)
3659 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3660 if err != nil {
3661 t.Fatal(err)
3662 }
3663 if res.StatusCode != 200 {
3664 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3665 }
3666
3667
3668 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3669 if err != nil {
3670 t.Fatal(err)
3671 }
3672 res, err = ReadResponse(br, &Request{Method: "GET"})
3673 if err != nil {
3674 t.Fatal(err)
3675 }
3676 if res.StatusCode != 400 {
3677 t.Errorf("Got non-400 response to GET *: %#v", res)
3678 }
3679
3680 res, err = Get(ts.URL + "/second")
3681 if err != nil {
3682 t.Fatal(err)
3683 }
3684 res.Body.Close()
3685 if got := <-uric; got != "/second" {
3686 t.Errorf("Handler saw request for %q; want /second", got)
3687 }
3688 }
3689
3690 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3691 func testOptionsHandler(t *testing.T, mode testMode) {
3692 rc := make(chan *Request, 1)
3693
3694 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3695 rc <- r
3696 }), func(ts *httptest.Server) {
3697 ts.Config.DisableGeneralOptionsHandler = true
3698 }).ts
3699
3700 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3701 if err != nil {
3702 t.Fatal(err)
3703 }
3704 defer conn.Close()
3705
3706 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3707 if err != nil {
3708 t.Fatal(err)
3709 }
3710
3711 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3712 t.Errorf("Expected OPTIONS * request, got %v", got)
3713 }
3714 }
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725 func TestHeaderToWire(t *testing.T) {
3726 tests := []struct {
3727 name string
3728 handler func(ResponseWriter, *Request)
3729 check func(got, logs string) error
3730 }{
3731 {
3732 name: "write without Header",
3733 handler: func(rw ResponseWriter, r *Request) {
3734 rw.Write([]byte("hello world"))
3735 },
3736 check: func(got, logs string) error {
3737 if !strings.Contains(got, "Content-Length:") {
3738 return errors.New("no content-length")
3739 }
3740 if !strings.Contains(got, "Content-Type: text/plain") {
3741 return errors.New("no content-type")
3742 }
3743 return nil
3744 },
3745 },
3746 {
3747 name: "Header mutation before write",
3748 handler: func(rw ResponseWriter, r *Request) {
3749 h := rw.Header()
3750 h.Set("Content-Type", "some/type")
3751 rw.Write([]byte("hello world"))
3752 h.Set("Too-Late", "bogus")
3753 },
3754 check: func(got, logs string) error {
3755 if !strings.Contains(got, "Content-Length:") {
3756 return errors.New("no content-length")
3757 }
3758 if !strings.Contains(got, "Content-Type: some/type") {
3759 return errors.New("wrong content-type")
3760 }
3761 if strings.Contains(got, "Too-Late") {
3762 return errors.New("don't want too-late header")
3763 }
3764 return nil
3765 },
3766 },
3767 {
3768 name: "write then useless Header mutation",
3769 handler: func(rw ResponseWriter, r *Request) {
3770 rw.Write([]byte("hello world"))
3771 rw.Header().Set("Too-Late", "Write already wrote headers")
3772 },
3773 check: func(got, logs string) error {
3774 if strings.Contains(got, "Too-Late") {
3775 return errors.New("header appeared from after WriteHeader")
3776 }
3777 return nil
3778 },
3779 },
3780 {
3781 name: "flush then write",
3782 handler: func(rw ResponseWriter, r *Request) {
3783 rw.(Flusher).Flush()
3784 rw.Write([]byte("post-flush"))
3785 rw.Header().Set("Too-Late", "Write already wrote headers")
3786 },
3787 check: func(got, logs string) error {
3788 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3789 return errors.New("not chunked")
3790 }
3791 if strings.Contains(got, "Too-Late") {
3792 return errors.New("header appeared from after WriteHeader")
3793 }
3794 return nil
3795 },
3796 },
3797 {
3798 name: "header then flush",
3799 handler: func(rw ResponseWriter, r *Request) {
3800 rw.Header().Set("Content-Type", "some/type")
3801 rw.(Flusher).Flush()
3802 rw.Write([]byte("post-flush"))
3803 rw.Header().Set("Too-Late", "Write already wrote headers")
3804 },
3805 check: func(got, logs string) error {
3806 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3807 return errors.New("not chunked")
3808 }
3809 if strings.Contains(got, "Too-Late") {
3810 return errors.New("header appeared from after WriteHeader")
3811 }
3812 if !strings.Contains(got, "Content-Type: some/type") {
3813 return errors.New("wrong content-type")
3814 }
3815 return nil
3816 },
3817 },
3818 {
3819 name: "sniff-on-first-write content-type",
3820 handler: func(rw ResponseWriter, r *Request) {
3821 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3822 rw.Header().Set("Content-Type", "x/wrong")
3823 },
3824 check: func(got, logs string) error {
3825 if !strings.Contains(got, "Content-Type: text/html") {
3826 return errors.New("wrong content-type; want html")
3827 }
3828 return nil
3829 },
3830 },
3831 {
3832 name: "explicit content-type wins",
3833 handler: func(rw ResponseWriter, r *Request) {
3834 rw.Header().Set("Content-Type", "some/type")
3835 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3836 },
3837 check: func(got, logs string) error {
3838 if !strings.Contains(got, "Content-Type: some/type") {
3839 return errors.New("wrong content-type; want html")
3840 }
3841 return nil
3842 },
3843 },
3844 {
3845 name: "empty handler",
3846 handler: func(rw ResponseWriter, r *Request) {
3847 },
3848 check: func(got, logs string) error {
3849 if !strings.Contains(got, "Content-Length: 0") {
3850 return errors.New("want 0 content-length")
3851 }
3852 return nil
3853 },
3854 },
3855 {
3856 name: "only Header, no write",
3857 handler: func(rw ResponseWriter, r *Request) {
3858 rw.Header().Set("Some-Header", "some-value")
3859 },
3860 check: func(got, logs string) error {
3861 if !strings.Contains(got, "Some-Header") {
3862 return errors.New("didn't get header")
3863 }
3864 return nil
3865 },
3866 },
3867 {
3868 name: "WriteHeader call",
3869 handler: func(rw ResponseWriter, r *Request) {
3870 rw.WriteHeader(404)
3871 rw.Header().Set("Too-Late", "some-value")
3872 },
3873 check: func(got, logs string) error {
3874 if !strings.Contains(got, "404") {
3875 return errors.New("wrong status")
3876 }
3877 if strings.Contains(got, "Too-Late") {
3878 return errors.New("shouldn't have seen Too-Late")
3879 }
3880 return nil
3881 },
3882 },
3883 }
3884 for _, tc := range tests {
3885 ht := newHandlerTest(HandlerFunc(tc.handler))
3886 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3887 logs := ht.logbuf.String()
3888 if err := tc.check(got, logs); err != nil {
3889 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3890 }
3891 }
3892 }
3893
3894 type errorListener struct {
3895 errs []error
3896 }
3897
3898 func (l *errorListener) Accept() (c net.Conn, err error) {
3899 if len(l.errs) == 0 {
3900 return nil, io.EOF
3901 }
3902 err = l.errs[0]
3903 l.errs = l.errs[1:]
3904 return
3905 }
3906
3907 func (l *errorListener) Close() error {
3908 return nil
3909 }
3910
3911 func (l *errorListener) Addr() net.Addr {
3912 return dummyAddr("test-address")
3913 }
3914
3915 func TestAcceptMaxFds(t *testing.T) {
3916 setParallel(t)
3917
3918 ln := &errorListener{[]error{
3919 &net.OpError{
3920 Op: "accept",
3921 Err: syscall.EMFILE,
3922 }}}
3923 server := &Server{
3924 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3925 ErrorLog: log.New(io.Discard, "", 0),
3926 }
3927 err := server.Serve(ln)
3928 if err != io.EOF {
3929 t.Errorf("got error %v, want EOF", err)
3930 }
3931 }
3932
3933 func TestWriteAfterHijack(t *testing.T) {
3934 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3935 var buf strings.Builder
3936 wrotec := make(chan bool, 1)
3937 conn := &rwTestConn{
3938 Reader: bytes.NewReader(req),
3939 Writer: &buf,
3940 closec: make(chan bool, 1),
3941 }
3942 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3943 conn, bufrw, err := rw.(Hijacker).Hijack()
3944 if err != nil {
3945 t.Error(err)
3946 return
3947 }
3948 go func() {
3949 bufrw.Write([]byte("[hijack-to-bufw]"))
3950 bufrw.Flush()
3951 conn.Write([]byte("[hijack-to-conn]"))
3952 conn.Close()
3953 wrotec <- true
3954 }()
3955 })
3956 ln := &oneConnListener{conn: conn}
3957 go Serve(ln, handler)
3958 <-conn.closec
3959 <-wrotec
3960 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
3961 t.Errorf("wrote %q; want %q", g, w)
3962 }
3963 }
3964
3965 func TestDoubleHijack(t *testing.T) {
3966 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3967 var buf bytes.Buffer
3968 conn := &rwTestConn{
3969 Reader: bytes.NewReader(req),
3970 Writer: &buf,
3971 closec: make(chan bool, 1),
3972 }
3973 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3974 conn, _, err := rw.(Hijacker).Hijack()
3975 if err != nil {
3976 t.Error(err)
3977 return
3978 }
3979 _, _, err = rw.(Hijacker).Hijack()
3980 if err == nil {
3981 t.Errorf("got err = nil; want err != nil")
3982 }
3983 conn.Close()
3984 })
3985 ln := &oneConnListener{conn: conn}
3986 go Serve(ln, handler)
3987 <-conn.closec
3988 }
3989
3990
3991
3992
3993
3994
3995
3996 func TestHTTP10ConnectionHeader(t *testing.T) {
3997 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
3998 }
3999 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
4000 mux := NewServeMux()
4001 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
4002 ts := newClientServerTest(t, mode, mux).ts
4003
4004
4005 tests := []struct {
4006 req string
4007 expect []string
4008 }{
4009 {
4010 req: "GET / HTTP/1.0\r\n\r\n",
4011 expect: nil,
4012 },
4013 {
4014 req: "OPTIONS * HTTP/1.0\r\n\r\n",
4015 expect: nil,
4016 },
4017 {
4018 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
4019 expect: []string{"keep-alive"},
4020 },
4021 }
4022
4023 for _, tt := range tests {
4024 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
4025 if err != nil {
4026 t.Fatal("dial err:", err)
4027 }
4028
4029 _, err = fmt.Fprint(conn, tt.req)
4030 if err != nil {
4031 t.Fatal("conn write err:", err)
4032 }
4033
4034 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4035 if err != nil {
4036 t.Fatal("ReadResponse err:", err)
4037 }
4038 conn.Close()
4039 resp.Body.Close()
4040
4041 got := resp.Header["Connection"]
4042 if !slices.Equal(got, tt.expect) {
4043 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4044 }
4045 }
4046 }
4047
4048
4049 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4050 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4051 pr, pw := io.Pipe()
4052 const size = 3 << 20
4053 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4054 rw.Header().Set("Content-Type", "text/plain")
4055 done := make(chan bool)
4056 go func() {
4057 io.Copy(rw, pr)
4058 close(done)
4059 }()
4060 time.Sleep(25 * time.Millisecond)
4061 n, err := io.Copy(io.Discard, req.Body)
4062 if err != nil {
4063 t.Errorf("handler Copy: %v", err)
4064 return
4065 }
4066 if n != size {
4067 t.Errorf("handler Copy = %d; want %d", n, size)
4068 }
4069 pw.Write([]byte("hi"))
4070 pw.Close()
4071 <-done
4072 }))
4073
4074 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4075 if err != nil {
4076 t.Fatal(err)
4077 }
4078 res, err := cst.c.Do(req)
4079 if err != nil {
4080 t.Fatal(err)
4081 }
4082 all, err := io.ReadAll(res.Body)
4083 if err != nil {
4084 t.Fatal(err)
4085 }
4086 res.Body.Close()
4087 if string(all) != "hi" {
4088 t.Errorf("Body = %q; want hi", all)
4089 }
4090 }
4091
4092
4093 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4094 for _, code := range []int{StatusNotModified, StatusNoContent} {
4095 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4096 if r.URL.Path == "/header" {
4097 w.Header().Set("Content-Length", "123")
4098 }
4099 w.WriteHeader(code)
4100 if r.URL.Path == "/more" {
4101 w.Write([]byte("stuff"))
4102 }
4103 }))
4104 for _, req := range []string{
4105 "GET / HTTP/1.0",
4106 "GET /header HTTP/1.0",
4107 "GET /more HTTP/1.0",
4108 "GET / HTTP/1.1\nHost: foo",
4109 "GET /header HTTP/1.1\nHost: foo",
4110 "GET /more HTTP/1.1\nHost: foo",
4111 } {
4112 got := ht.rawResponse(req)
4113 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4114 if !strings.Contains(got, wantStatus) {
4115 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4116 } else if strings.Contains(got, "Content-Length") {
4117 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4118 } else if strings.Contains(got, "stuff") {
4119 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4120 }
4121 }
4122 }
4123 }
4124
4125 func TestContentTypeOkayOn204(t *testing.T) {
4126 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4127 w.Header().Set("Content-Length", "123")
4128 w.Header().Set("Content-Type", "foo/bar")
4129 w.WriteHeader(204)
4130 }))
4131 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4132 if !strings.Contains(got, "Content-Type: foo/bar") {
4133 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4134 }
4135 if strings.Contains(got, "Content-Length: 123") {
4136 t.Errorf("Response = %q; don't want a Content-Length", got)
4137 }
4138 }
4139
4140
4141
4142
4143
4144
4145
4146 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4147 run(t, testTransportAndServerSharedBodyRace, testNotParallel)
4148 }
4149 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4150
4151
4152
4153
4154 runTimeSensitiveTest(t, []time.Duration{
4155 1 * time.Millisecond,
4156 5 * time.Millisecond,
4157 10 * time.Millisecond,
4158 50 * time.Millisecond,
4159 100 * time.Millisecond,
4160 500 * time.Millisecond,
4161 time.Second,
4162 5 * time.Second,
4163 }, func(t *testing.T, timeout time.Duration) error {
4164 SetRSTAvoidanceDelay(t, timeout)
4165 t.Logf("set RST avoidance delay to %v", timeout)
4166
4167 const bodySize = 1 << 20
4168
4169 var wg sync.WaitGroup
4170 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4171
4172
4173
4174
4175
4176
4177
4178
4179 wg.Add(1)
4180 defer wg.Done()
4181
4182 n, err := io.CopyN(rw, req.Body, bodySize)
4183 t.Logf("backend CopyN: %v, %v", n, err)
4184 <-req.Context().Done()
4185 }))
4186
4187
4188 defer func() {
4189 wg.Wait()
4190 backend.close()
4191 }()
4192
4193 var proxy *clientServerTest
4194 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4195 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4196 req2.ContentLength = bodySize
4197 cancel := make(chan struct{})
4198 req2.Cancel = cancel
4199
4200 bresp, err := proxy.c.Do(req2)
4201 if err != nil {
4202 t.Errorf("Proxy outbound request: %v", err)
4203 return
4204 }
4205 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4206 if err != nil {
4207 t.Errorf("Proxy copy error: %v", err)
4208 return
4209 }
4210 t.Cleanup(func() { bresp.Body.Close() })
4211
4212
4213
4214
4215
4216
4217 if mode == http2Mode {
4218 close(cancel)
4219 } else {
4220 proxy.c.Transport.(*Transport).CancelRequest(req2)
4221 }
4222 rw.Write([]byte("OK"))
4223 }))
4224 defer proxy.close()
4225
4226 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4227 res, err := proxy.c.Do(req)
4228 if err != nil {
4229 return fmt.Errorf("original request: %v", err)
4230 }
4231 res.Body.Close()
4232 return nil
4233 })
4234 }
4235
4236
4237
4238
4239 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4240 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4241 }
4242 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4243 if testing.Short() {
4244 t.Skip("skipping in -short mode")
4245 }
4246
4247 readErrCh := make(chan error, 1)
4248 errCh := make(chan error, 2)
4249
4250 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4251 go func(body io.Reader) {
4252 _, err := body.Read(make([]byte, 100))
4253 readErrCh <- err
4254 }(req.Body)
4255 time.Sleep(500 * time.Millisecond)
4256 })).ts
4257
4258 closeConn := make(chan bool)
4259 defer close(closeConn)
4260 go func() {
4261 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4262 if err != nil {
4263 errCh <- err
4264 return
4265 }
4266 defer conn.Close()
4267 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4268 if err != nil {
4269 errCh <- err
4270 return
4271 }
4272
4273
4274 <-closeConn
4275 }()
4276 select {
4277 case err := <-readErrCh:
4278 if err == nil {
4279 t.Error("Read was nil. Expected error.")
4280 }
4281 case err := <-errCh:
4282 t.Error(err)
4283 }
4284 }
4285
4286
4287 func TestResponseWriterWriteString(t *testing.T) {
4288 okc := make(chan bool, 1)
4289 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4290 _, ok := w.(io.StringWriter)
4291 okc <- ok
4292 }))
4293 ht.rawResponse("GET / HTTP/1.0")
4294 select {
4295 case ok := <-okc:
4296 if !ok {
4297 t.Error("ResponseWriter did not implement io.StringWriter")
4298 }
4299 default:
4300 t.Error("handler was never called")
4301 }
4302 }
4303
4304 func TestAppendTime(t *testing.T) {
4305 var b [len(TimeFormat)]byte
4306 t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60))
4307 res := ExportAppendTime(b[:0], t1)
4308 t2, err := ParseTime(string(res))
4309 if err != nil {
4310 t.Fatalf("Error parsing time: %s", err)
4311 }
4312 if !t1.Equal(t2) {
4313 t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res))
4314 }
4315 }
4316
4317 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4318 func testServerConnState(t *testing.T, mode testMode) {
4319 handler := map[string]func(w ResponseWriter, r *Request){
4320 "/": func(w ResponseWriter, r *Request) {
4321 fmt.Fprintf(w, "Hello.")
4322 },
4323 "/close": func(w ResponseWriter, r *Request) {
4324 w.Header().Set("Connection", "close")
4325 fmt.Fprintf(w, "Hello.")
4326 },
4327 "/hijack": func(w ResponseWriter, r *Request) {
4328 c, _, _ := w.(Hijacker).Hijack()
4329 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4330 c.Close()
4331 },
4332 "/hijack-panic": func(w ResponseWriter, r *Request) {
4333 c, _, _ := w.(Hijacker).Hijack()
4334 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4335 c.Close()
4336 panic("intentional panic")
4337 },
4338 }
4339
4340
4341 type stateLog struct {
4342 active net.Conn
4343 got []ConnState
4344 want []ConnState
4345 complete chan<- struct{}
4346 }
4347 activeLog := make(chan *stateLog, 1)
4348
4349
4350
4351
4352 wantLog := func(doRequests func(), want ...ConnState) {
4353 t.Helper()
4354 complete := make(chan struct{})
4355 activeLog <- &stateLog{want: want, complete: complete}
4356
4357 doRequests()
4358
4359 <-complete
4360 sl := <-activeLog
4361 if !slices.Equal(sl.got, sl.want) {
4362 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4363 }
4364
4365
4366
4367 }
4368
4369 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4370 handler[r.URL.Path](w, r)
4371 }), func(ts *httptest.Server) {
4372 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4373 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4374 if c == nil {
4375 t.Errorf("nil conn seen in state %s", state)
4376 return
4377 }
4378 sl := <-activeLog
4379 if sl.active == nil && state == StateNew {
4380 sl.active = c
4381 } else if sl.active != c {
4382 t.Errorf("unexpected conn in state %s", state)
4383 activeLog <- sl
4384 return
4385 }
4386 sl.got = append(sl.got, state)
4387 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !slices.Equal(sl.got, sl.want[:len(sl.got)])) {
4388 close(sl.complete)
4389 sl.complete = nil
4390 }
4391 activeLog <- sl
4392 }
4393 }).ts
4394 defer func() {
4395 activeLog <- &stateLog{}
4396 ts.Close()
4397 }()
4398
4399 c := ts.Client()
4400
4401 mustGet := func(url string, headers ...string) {
4402 t.Helper()
4403 req, err := NewRequest("GET", url, nil)
4404 if err != nil {
4405 t.Fatal(err)
4406 }
4407 for len(headers) > 0 {
4408 req.Header.Add(headers[0], headers[1])
4409 headers = headers[2:]
4410 }
4411 res, err := c.Do(req)
4412 if err != nil {
4413 t.Errorf("Error fetching %s: %v", url, err)
4414 return
4415 }
4416 _, err = io.ReadAll(res.Body)
4417 defer res.Body.Close()
4418 if err != nil {
4419 t.Errorf("Error reading %s: %v", url, err)
4420 }
4421 }
4422
4423 wantLog(func() {
4424 mustGet(ts.URL + "/")
4425 mustGet(ts.URL + "/close")
4426 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4427
4428 wantLog(func() {
4429 mustGet(ts.URL + "/")
4430 mustGet(ts.URL+"/", "Connection", "close")
4431 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4432
4433 wantLog(func() {
4434 mustGet(ts.URL + "/hijack")
4435 }, StateNew, StateActive, StateHijacked)
4436
4437 wantLog(func() {
4438 mustGet(ts.URL + "/hijack-panic")
4439 }, StateNew, StateActive, StateHijacked)
4440
4441 wantLog(func() {
4442 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4443 if err != nil {
4444 t.Fatal(err)
4445 }
4446 c.Close()
4447 }, StateNew, StateClosed)
4448
4449 wantLog(func() {
4450 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4451 if err != nil {
4452 t.Fatal(err)
4453 }
4454 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4455 t.Fatal(err)
4456 }
4457 c.Read(make([]byte, 1))
4458 c.Close()
4459 }, StateNew, StateActive, StateClosed)
4460
4461 wantLog(func() {
4462 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4463 if err != nil {
4464 t.Fatal(err)
4465 }
4466 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4467 t.Fatal(err)
4468 }
4469 res, err := ReadResponse(bufio.NewReader(c), nil)
4470 if err != nil {
4471 t.Fatal(err)
4472 }
4473 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4474 t.Fatal(err)
4475 }
4476 c.Close()
4477 }, StateNew, StateActive, StateIdle, StateClosed)
4478 }
4479
4480 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4481 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4482 }
4483 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4484 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4485 }), func(ts *httptest.Server) {
4486 ts.Config.SetKeepAlivesEnabled(false)
4487 }).ts
4488 res, err := ts.Client().Get(ts.URL)
4489 if err != nil {
4490 t.Fatal(err)
4491 }
4492 defer res.Body.Close()
4493 if !res.Close {
4494 t.Errorf("Body.Close == false; want true")
4495 }
4496 }
4497
4498
4499 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4500 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4501 var n int32
4502 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4503 atomic.AddInt32(&n, 1)
4504 }), optQuietLog)
4505 var wg sync.WaitGroup
4506 const reqs = 20
4507 for i := 0; i < reqs; i++ {
4508 wg.Add(1)
4509 go func() {
4510 defer wg.Done()
4511 res, err := cst.c.Get(cst.ts.URL)
4512 if err != nil {
4513
4514
4515 time.Sleep(10 * time.Millisecond)
4516 res, err = cst.c.Get(cst.ts.URL)
4517 if err != nil {
4518 t.Error(err)
4519 return
4520 }
4521 }
4522 defer res.Body.Close()
4523 _, err = io.Copy(io.Discard, res.Body)
4524 if err != nil {
4525 t.Error(err)
4526 return
4527 }
4528 }()
4529 }
4530 wg.Wait()
4531 if got := atomic.LoadInt32(&n); got != reqs {
4532 t.Errorf("handler ran %d times; want %d", got, reqs)
4533 }
4534 }
4535
4536 func TestServerConnStateNew(t *testing.T) {
4537 sawNew := false
4538 srv := &Server{
4539 ConnState: func(c net.Conn, state ConnState) {
4540 if state == StateNew {
4541 sawNew = true
4542 }
4543 },
4544 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4545 }
4546 srv.Serve(&oneConnListener{
4547 conn: &rwTestConn{
4548 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4549 Writer: io.Discard,
4550 },
4551 })
4552 if !sawNew {
4553 t.Error("StateNew not seen")
4554 }
4555 }
4556
4557 type closeWriteTestConn struct {
4558 rwTestConn
4559 didCloseWrite bool
4560 }
4561
4562 func (c *closeWriteTestConn) CloseWrite() error {
4563 c.didCloseWrite = true
4564 return nil
4565 }
4566
4567 func TestCloseWrite(t *testing.T) {
4568 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4569
4570 var srv Server
4571 var testConn closeWriteTestConn
4572 c := ExportServerNewConn(&srv, &testConn)
4573 ExportCloseWriteAndWait(c)
4574 if !testConn.didCloseWrite {
4575 t.Error("didn't see CloseWrite call")
4576 }
4577 }
4578
4579
4580
4581
4582
4583
4584
4585
4586 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4587 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4588 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4589 io.WriteString(w, "Hello, ")
4590 w.(Flusher).Flush()
4591 conn, buf, _ := w.(Hijacker).Hijack()
4592 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4593 if err := buf.Flush(); err != nil {
4594 t.Error(err)
4595 }
4596 if err := conn.Close(); err != nil {
4597 t.Error(err)
4598 }
4599 })).ts
4600 res, err := Get(ts.URL)
4601 if err != nil {
4602 t.Fatal(err)
4603 }
4604 defer res.Body.Close()
4605 all, err := io.ReadAll(res.Body)
4606 if err != nil {
4607 t.Fatal(err)
4608 }
4609 if want := "Hello, world!"; string(all) != want {
4610 t.Errorf("Got %q; want %q", all, want)
4611 }
4612 }
4613
4614
4615
4616
4617
4618
4619
4620 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4621 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4622 }
4623 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4624 if testing.Short() {
4625 t.Skip("skipping in -short mode")
4626 }
4627 const numReq = 3
4628 addrc := make(chan string, numReq)
4629 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4630 addrc <- r.RemoteAddr
4631 time.Sleep(500 * time.Millisecond)
4632 w.(Flusher).Flush()
4633 }), func(ts *httptest.Server) {
4634 ts.Config.WriteTimeout = 250 * time.Millisecond
4635 }).ts
4636
4637 errc := make(chan error, numReq)
4638 go func() {
4639 defer close(errc)
4640 for i := 0; i < numReq; i++ {
4641 res, err := Get(ts.URL)
4642 if res != nil {
4643 res.Body.Close()
4644 }
4645 errc <- err
4646 }
4647 }()
4648
4649 addrSeen := map[string]bool{}
4650 numOkay := 0
4651 for {
4652 select {
4653 case v := <-addrc:
4654 addrSeen[v] = true
4655 case err, ok := <-errc:
4656 if !ok {
4657 if len(addrSeen) != numReq {
4658 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4659 }
4660 if numOkay != 0 {
4661 t.Errorf("got %d successful client requests; want 0", numOkay)
4662 }
4663 return
4664 }
4665 if err == nil {
4666 numOkay++
4667 }
4668 }
4669 }
4670 }
4671
4672
4673
4674 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4675 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4676 }
4677 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4678 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4679 w.Header().Set("Transfer-Encoding", "foo")
4680 io.WriteString(w, "<html>")
4681 })).ts
4682 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4683 if err != nil {
4684 t.Fatalf("Dial: %v", err)
4685 }
4686 defer c.Close()
4687 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4688 t.Fatal(err)
4689 }
4690 bs := bufio.NewScanner(c)
4691 var got strings.Builder
4692 for bs.Scan() {
4693 if strings.TrimSpace(bs.Text()) == "" {
4694 break
4695 }
4696 got.WriteString(bs.Text())
4697 got.WriteByte('\n')
4698 }
4699 if err := bs.Err(); err != nil {
4700 t.Fatal(err)
4701 }
4702 if strings.Contains(got.String(), "Content-Length") {
4703 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4704 }
4705 if strings.Contains(got.String(), "Content-Type") {
4706 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4707 }
4708 }
4709
4710
4711
4712 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4713 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4714 "\r\n\r\n" +
4715 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4716 var buf bytes.Buffer
4717 conn := &rwTestConn{
4718 Reader: bytes.NewReader(req),
4719 Writer: &buf,
4720 closec: make(chan bool, 1),
4721 }
4722 ln := &oneConnListener{conn: conn}
4723 numReq := 0
4724 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4725 numReq++
4726 }))
4727 <-conn.closec
4728 if numReq != 2 {
4729 t.Errorf("num requests = %d; want 2", numReq)
4730 t.Logf("Res: %s", buf.Bytes())
4731 }
4732 }
4733
4734 func TestIssue13893_Expect100(t *testing.T) {
4735
4736 req := reqBytes(`PUT /readbody HTTP/1.1
4737 User-Agent: PycURL/7.22.0
4738 Host: 127.0.0.1:9000
4739 Accept: */*
4740 Expect: 100-continue
4741 Content-Length: 10
4742
4743 HelloWorld
4744
4745 `)
4746 var buf bytes.Buffer
4747 conn := &rwTestConn{
4748 Reader: bytes.NewReader(req),
4749 Writer: &buf,
4750 closec: make(chan bool, 1),
4751 }
4752 ln := &oneConnListener{conn: conn}
4753 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4754 if _, ok := r.Header["Expect"]; !ok {
4755 t.Error("Expect header should not be filtered out")
4756 }
4757 }))
4758 <-conn.closec
4759 }
4760
4761 func TestIssue11549_Expect100(t *testing.T) {
4762 req := reqBytes(`PUT /readbody HTTP/1.1
4763 User-Agent: PycURL/7.22.0
4764 Host: 127.0.0.1:9000
4765 Accept: */*
4766 Expect: 100-continue
4767 Content-Length: 10
4768
4769 HelloWorldPUT /noreadbody HTTP/1.1
4770 User-Agent: PycURL/7.22.0
4771 Host: 127.0.0.1:9000
4772 Accept: */*
4773 Expect: 100-continue
4774 Content-Length: 10
4775
4776 GET /should-be-ignored HTTP/1.1
4777 Host: foo
4778
4779 `)
4780 var buf strings.Builder
4781 conn := &rwTestConn{
4782 Reader: bytes.NewReader(req),
4783 Writer: &buf,
4784 closec: make(chan bool, 1),
4785 }
4786 ln := &oneConnListener{conn: conn}
4787 numReq := 0
4788 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4789 numReq++
4790 if r.URL.Path == "/readbody" {
4791 io.ReadAll(r.Body)
4792 }
4793 io.WriteString(w, "Hello world!")
4794 }))
4795 <-conn.closec
4796 if numReq != 2 {
4797 t.Errorf("num requests = %d; want 2", numReq)
4798 }
4799 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4800 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4801 }
4802 }
4803
4804
4805
4806 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4807 setParallel(t)
4808 conn := newTestConn()
4809 conn.readBuf.WriteString(
4810 "POST / HTTP/1.1\r\n" +
4811 "Host: test\r\n" +
4812 "Content-Length: 9999999999\r\n" +
4813 "\r\n" + strings.Repeat("a", 1<<20))
4814
4815 ls := &oneConnListener{conn}
4816 var inHandlerLen int
4817 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4818 inHandlerLen = conn.readBuf.Len()
4819 rw.WriteHeader(404)
4820 }))
4821 <-conn.closec
4822 afterHandlerLen := conn.readBuf.Len()
4823
4824 if afterHandlerLen != inHandlerLen {
4825 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4826 }
4827 }
4828
4829 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4830 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4831 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4832 r.Body = nil
4833 fmt.Fprintf(w, "%v", r.RemoteAddr)
4834 }))
4835 get := func() string {
4836 res, err := cst.c.Get(cst.ts.URL)
4837 if err != nil {
4838 t.Fatal(err)
4839 }
4840 defer res.Body.Close()
4841 slurp, err := io.ReadAll(res.Body)
4842 if err != nil {
4843 t.Fatal(err)
4844 }
4845 return string(slurp)
4846 }
4847 a, b := get(), get()
4848 if a != b {
4849 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4850 }
4851 }
4852
4853
4854
4855 func TestServerValidatesHostHeader(t *testing.T) {
4856 tests := []struct {
4857 proto string
4858 host string
4859 want int
4860 }{
4861 {"HTTP/0.9", "", 505},
4862
4863 {"HTTP/1.1", "", 400},
4864 {"HTTP/1.1", "Host: \r\n", 200},
4865 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4866 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4867 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4868 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4869 {"HTTP/1.1", "Host: ::1\r\n", 200},
4870 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4871 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4872 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4873 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4874 {"HTTP/1.1", "Host: \x06\r\n", 400},
4875 {"HTTP/1.1", "Host: \xff\r\n", 400},
4876 {"HTTP/1.1", "Host: {\r\n", 400},
4877 {"HTTP/1.1", "Host: }\r\n", 400},
4878 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4879
4880
4881
4882 {"HTTP/1.0", "", 200},
4883 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4884 {"HTTP/1.0", "Host: \xff\r\n", 400},
4885
4886
4887 {"PRI * HTTP/2.0", "", 200},
4888
4889
4890 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4891
4892
4893 {"PRI / HTTP/2.0", "", 505},
4894 {"GET / HTTP/2.0", "", 505},
4895 {"GET / HTTP/3.0", "", 505},
4896 }
4897 for _, tt := range tests {
4898 conn := newTestConn()
4899 methodTarget := "GET / "
4900 if !strings.HasPrefix(tt.proto, "HTTP/") {
4901 methodTarget = ""
4902 }
4903 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4904
4905 ln := &oneConnListener{conn}
4906 srv := Server{
4907 ErrorLog: quietLog,
4908 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4909 }
4910 go srv.Serve(ln)
4911 <-conn.closec
4912 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4913 if err != nil {
4914 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4915 continue
4916 }
4917 if res.StatusCode != tt.want {
4918 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4919 }
4920 }
4921 }
4922
4923 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4924 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
4925 }
4926 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
4927 const upgradeResponse = "upgrade here"
4928 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4929 conn, br, err := w.(Hijacker).Hijack()
4930 if err != nil {
4931 t.Error(err)
4932 return
4933 }
4934 defer conn.Close()
4935 if r.Method != "PRI" || r.RequestURI != "*" {
4936 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4937 return
4938 }
4939 if !r.Close {
4940 t.Errorf("Request.Close = true; want false")
4941 }
4942 const want = "SM\r\n\r\n"
4943 buf := make([]byte, len(want))
4944 n, err := io.ReadFull(br, buf)
4945 if err != nil || string(buf[:n]) != want {
4946 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4947 return
4948 }
4949 io.WriteString(conn, upgradeResponse)
4950 })).ts
4951
4952 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4953 if err != nil {
4954 t.Fatalf("Dial: %v", err)
4955 }
4956 defer c.Close()
4957 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
4958 slurp, err := io.ReadAll(c)
4959 if err != nil {
4960 t.Fatal(err)
4961 }
4962 if string(slurp) != upgradeResponse {
4963 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
4964 }
4965 }
4966
4967
4968
4969 func TestServerValidatesHeaders(t *testing.T) {
4970 setParallel(t)
4971 tests := []struct {
4972 header string
4973 want int
4974 }{
4975 {"", 200},
4976 {"Foo: bar\r\n", 200},
4977 {"X-Foo: bar\r\n", 200},
4978 {"Foo: a space\r\n", 200},
4979
4980 {"A space: foo\r\n", 400},
4981 {"foo\xffbar: foo\r\n", 400},
4982 {"foo\x00bar: foo\r\n", 400},
4983 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
4984
4985
4986 {"Foo : bar\r\n", 400},
4987 {"Foo\t: bar\r\n", 400},
4988
4989
4990
4991 {": empty key\r\n", 400},
4992
4993
4994
4995
4996 {"Content-Length: notdigits\r\n", 400},
4997 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
4998
4999 {"foo: foo foo\r\n", 200},
5000 {"foo: foo\tfoo\r\n", 200},
5001 {"foo: foo\x00foo\r\n", 400},
5002 {"foo: foo\x7ffoo\r\n", 400},
5003 {"foo: foo\xfffoo\r\n", 200},
5004 }
5005 for _, tt := range tests {
5006 conn := newTestConn()
5007 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
5008
5009 ln := &oneConnListener{conn}
5010 srv := Server{
5011 ErrorLog: quietLog,
5012 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5013 }
5014 go srv.Serve(ln)
5015 <-conn.closec
5016 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5017 if err != nil {
5018 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
5019 continue
5020 }
5021 if res.StatusCode != tt.want {
5022 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
5023 }
5024 }
5025 }
5026
5027 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
5028 run(t, testServerRequestContextCancel_ServeHTTPDone)
5029 }
5030 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5031 ctxc := make(chan context.Context, 1)
5032 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5033 ctx := r.Context()
5034 select {
5035 case <-ctx.Done():
5036 t.Error("should not be Done in ServeHTTP")
5037 default:
5038 }
5039 ctxc <- ctx
5040 }))
5041 res, err := cst.c.Get(cst.ts.URL)
5042 if err != nil {
5043 t.Fatal(err)
5044 }
5045 res.Body.Close()
5046 ctx := <-ctxc
5047 select {
5048 case <-ctx.Done():
5049 default:
5050 t.Error("context should be done after ServeHTTP completes")
5051 }
5052 }
5053
5054
5055
5056
5057
5058 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5059 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5060 }
5061 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5062 inHandler := make(chan struct{})
5063 handlerDone := make(chan struct{})
5064 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5065 close(inHandler)
5066 <-r.Context().Done()
5067 close(handlerDone)
5068 })).ts
5069 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5070 if err != nil {
5071 t.Fatal(err)
5072 }
5073 defer c.Close()
5074 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5075 <-inHandler
5076 c.Close()
5077 <-handlerDone
5078 }
5079
5080 func TestServerContext_ServerContextKey(t *testing.T) {
5081 run(t, testServerContext_ServerContextKey)
5082 }
5083 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5084 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5085 ctx := r.Context()
5086 got := ctx.Value(ServerContextKey)
5087 if _, ok := got.(*Server); !ok {
5088 t.Errorf("context value = %T; want *http.Server", got)
5089 }
5090 }))
5091 res, err := cst.c.Get(cst.ts.URL)
5092 if err != nil {
5093 t.Fatal(err)
5094 }
5095 res.Body.Close()
5096 }
5097
5098 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5099 run(t, testServerContext_LocalAddrContextKey)
5100 }
5101 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5102 ch := make(chan any, 1)
5103 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5104 ch <- r.Context().Value(LocalAddrContextKey)
5105 }))
5106 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5107 t.Fatal(err)
5108 }
5109
5110 host := cst.ts.Listener.Addr().String()
5111 got := <-ch
5112 if addr, ok := got.(net.Addr); !ok {
5113 t.Errorf("local addr value = %T; want net.Addr", got)
5114 } else if fmt.Sprint(addr) != host {
5115 t.Errorf("local addr = %v; want %v", addr, host)
5116 }
5117 }
5118
5119
5120 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5121 setParallel(t)
5122 defer afterTest(t)
5123 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5124 w.Header().Set("Transfer-Encoding", "chunked")
5125 w.Write([]byte("hello"))
5126 }))
5127 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5128 const hdr = "Transfer-Encoding: chunked"
5129 if n := strings.Count(resp, hdr); n != 1 {
5130 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5131 }
5132 }
5133
5134
5135 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5136 setParallel(t)
5137 defer afterTest(t)
5138 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5139 w.Header().Set("Transfer-Encoding", "gzip")
5140 gz := gzip.NewWriter(w)
5141 gz.Write([]byte("hello"))
5142 gz.Close()
5143 }))
5144 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5145 for _, v := range []string{"gzip", "chunked"} {
5146 hdr := "Transfer-Encoding: " + v
5147 if n := strings.Count(resp, hdr); n != 1 {
5148 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5149 }
5150 }
5151 }
5152
5153 func BenchmarkClientServer(b *testing.B) {
5154 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5155 }
5156 func benchmarkClientServer(b *testing.B, mode testMode) {
5157 b.ReportAllocs()
5158 b.StopTimer()
5159 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5160 fmt.Fprintf(rw, "Hello world.\n")
5161 })).ts
5162 b.StartTimer()
5163
5164 c := ts.Client()
5165 for i := 0; i < b.N; i++ {
5166 res, err := c.Get(ts.URL)
5167 if err != nil {
5168 b.Fatal("Get:", err)
5169 }
5170 all, err := io.ReadAll(res.Body)
5171 res.Body.Close()
5172 if err != nil {
5173 b.Fatal("ReadAll:", err)
5174 }
5175 body := string(all)
5176 if body != "Hello world.\n" {
5177 b.Fatal("Got body:", body)
5178 }
5179 }
5180
5181 b.StopTimer()
5182 }
5183
5184 func BenchmarkClientServerParallel(b *testing.B) {
5185 for _, parallelism := range []int{4, 64} {
5186 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5187 run(b, func(b *testing.B, mode testMode) {
5188 benchmarkClientServerParallel(b, parallelism, mode)
5189 }, []testMode{http1Mode, https1Mode, http2Mode})
5190 })
5191 }
5192 }
5193
5194 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5195 b.ReportAllocs()
5196 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5197 fmt.Fprintf(rw, "Hello world.\n")
5198 })).ts
5199 b.ResetTimer()
5200 b.SetParallelism(parallelism)
5201 b.RunParallel(func(pb *testing.PB) {
5202 c := ts.Client()
5203 for pb.Next() {
5204 res, err := c.Get(ts.URL)
5205 if err != nil {
5206 b.Logf("Get: %v", err)
5207 continue
5208 }
5209 all, err := io.ReadAll(res.Body)
5210 res.Body.Close()
5211 if err != nil {
5212 b.Logf("ReadAll: %v", err)
5213 continue
5214 }
5215 body := string(all)
5216 if body != "Hello world.\n" {
5217 panic("Got body: " + body)
5218 }
5219 }
5220 })
5221 }
5222
5223
5224
5225
5226
5227
5228
5229
5230
5231
5232 func BenchmarkServer(b *testing.B) {
5233 b.ReportAllocs()
5234
5235 if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
5236 n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
5237 if err != nil {
5238 panic(err)
5239 }
5240 for i := 0; i < n; i++ {
5241 res, err := Get(url)
5242 if err != nil {
5243 log.Panicf("Get: %v", err)
5244 }
5245 all, err := io.ReadAll(res.Body)
5246 res.Body.Close()
5247 if err != nil {
5248 log.Panicf("ReadAll: %v", err)
5249 }
5250 body := string(all)
5251 if body != "Hello world.\n" {
5252 log.Panicf("Got body: %q", body)
5253 }
5254 }
5255 os.Exit(0)
5256 return
5257 }
5258
5259 var res = []byte("Hello world.\n")
5260 b.StopTimer()
5261 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5262 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5263 rw.Write(res)
5264 }))
5265 defer ts.Close()
5266 b.StartTimer()
5267
5268 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5269 cmd.Env = append([]string{
5270 fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
5271 fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
5272 }, os.Environ()...)
5273 out, err := cmd.CombinedOutput()
5274 if err != nil {
5275 b.Errorf("Test failure: %v, with output: %s", err, out)
5276 }
5277 }
5278
5279
5280 func getNoBody(urlStr string) (*Response, error) {
5281 res, err := Get(urlStr)
5282 if err != nil {
5283 return nil, err
5284 }
5285 res.Body.Close()
5286 return res, nil
5287 }
5288
5289
5290
5291 func BenchmarkClient(b *testing.B) {
5292 b.ReportAllocs()
5293 b.StopTimer()
5294 defer afterTest(b)
5295
5296 var data = []byte("Hello world.\n")
5297 if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
5298
5299 port := os.Getenv("TEST_BENCH_SERVER_PORT")
5300 if port == "" {
5301 port = "0"
5302 }
5303 ln, err := net.Listen("tcp", "localhost:"+port)
5304 if err != nil {
5305 fmt.Fprintln(os.Stderr, err.Error())
5306 os.Exit(1)
5307 }
5308 fmt.Println(ln.Addr().String())
5309 HandleFunc("/", func(w ResponseWriter, r *Request) {
5310 r.ParseForm()
5311 if r.Form.Get("stop") != "" {
5312 os.Exit(0)
5313 }
5314 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5315 w.Write(data)
5316 })
5317 var srv Server
5318 log.Fatal(srv.Serve(ln))
5319 }
5320
5321
5322 ctx, cancel := context.WithCancel(context.Background())
5323 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkClient$")
5324 cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes")
5325 cmd.Stderr = os.Stderr
5326 stdout, err := cmd.StdoutPipe()
5327 if err != nil {
5328 b.Fatal(err)
5329 }
5330 if err := cmd.Start(); err != nil {
5331 b.Fatalf("subprocess failed to start: %v", err)
5332 }
5333
5334 done := make(chan error, 1)
5335 go func() {
5336 done <- cmd.Wait()
5337 close(done)
5338 }()
5339 defer func() {
5340 cancel()
5341 <-done
5342 }()
5343
5344
5345
5346 bs := bufio.NewScanner(stdout)
5347 if !bs.Scan() {
5348 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5349 }
5350 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5351 if _, err := getNoBody(url); err != nil {
5352 b.Fatalf("initial probe of child process failed: %v", err)
5353 }
5354
5355
5356 b.StartTimer()
5357 for i := 0; i < b.N; i++ {
5358 res, err := Get(url)
5359 if err != nil {
5360 b.Fatalf("Get: %v", err)
5361 }
5362 body, err := io.ReadAll(res.Body)
5363 res.Body.Close()
5364 if err != nil {
5365 b.Fatalf("ReadAll: %v", err)
5366 }
5367 if !bytes.Equal(body, data) {
5368 b.Fatalf("Got body: %q", body)
5369 }
5370 }
5371 b.StopTimer()
5372
5373
5374 getNoBody(url + "?stop=yes")
5375 if err := <-done; err != nil {
5376 b.Fatalf("subprocess failed: %v", err)
5377 }
5378 }
5379
5380 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5381 b.ReportAllocs()
5382 req := reqBytes(`GET / HTTP/1.0
5383 Host: golang.org
5384 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5385 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5386 Accept-Encoding: gzip,deflate,sdch
5387 Accept-Language: en-US,en;q=0.8
5388 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5389 `)
5390 res := []byte("Hello world!\n")
5391
5392 conn := newTestConn()
5393 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5394 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5395 rw.Write(res)
5396 })
5397 ln := new(oneConnListener)
5398 for i := 0; i < b.N; i++ {
5399 conn.readBuf.Reset()
5400 conn.writeBuf.Reset()
5401 conn.readBuf.Write(req)
5402 ln.conn = conn
5403 Serve(ln, handler)
5404 <-conn.closec
5405 }
5406 }
5407
5408
5409 type repeatReader struct {
5410 content []byte
5411 count int
5412 off int
5413 }
5414
5415 func (r *repeatReader) Read(p []byte) (n int, err error) {
5416 if r.count <= 0 {
5417 return 0, io.EOF
5418 }
5419 n = copy(p, r.content[r.off:])
5420 r.off += n
5421 if r.off == len(r.content) {
5422 r.count--
5423 r.off = 0
5424 }
5425 return
5426 }
5427
5428 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5429 b.ReportAllocs()
5430
5431 req := reqBytes(`GET / HTTP/1.1
5432 Host: golang.org
5433 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5434 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5435 Accept-Encoding: gzip,deflate,sdch
5436 Accept-Language: en-US,en;q=0.8
5437 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5438 `)
5439 res := []byte("Hello world!\n")
5440
5441 conn := &rwTestConn{
5442 Reader: &repeatReader{content: req, count: b.N},
5443 Writer: io.Discard,
5444 closec: make(chan bool, 1),
5445 }
5446 handled := 0
5447 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5448 handled++
5449 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5450 rw.Write(res)
5451 })
5452 ln := &oneConnListener{conn: conn}
5453 go Serve(ln, handler)
5454 <-conn.closec
5455 if b.N != handled {
5456 b.Errorf("b.N=%d but handled %d", b.N, handled)
5457 }
5458 }
5459
5460
5461
5462 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5463 b.ReportAllocs()
5464
5465 req := reqBytes(`GET / HTTP/1.1
5466 Host: golang.org
5467 `)
5468 res := []byte("Hello world!\n")
5469
5470 conn := &rwTestConn{
5471 Reader: &repeatReader{content: req, count: b.N},
5472 Writer: io.Discard,
5473 closec: make(chan bool, 1),
5474 }
5475 handled := 0
5476 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5477 handled++
5478 rw.Write(res)
5479 })
5480 ln := &oneConnListener{conn: conn}
5481 go Serve(ln, handler)
5482 <-conn.closec
5483 if b.N != handled {
5484 b.Errorf("b.N=%d but handled %d", b.N, handled)
5485 }
5486 }
5487
5488 const someResponse = "<html>some response</html>"
5489
5490
5491 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5492
5493
5494 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5495 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5496 w.Header().Set("Content-Type", "text/html")
5497 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5498 w.Write(response)
5499 }))
5500 }
5501
5502
5503 func BenchmarkServerHandlerNoLen(b *testing.B) {
5504 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5505 w.Header().Set("Content-Type", "text/html")
5506 w.Write(response)
5507 }))
5508 }
5509
5510
5511 func BenchmarkServerHandlerNoType(b *testing.B) {
5512 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5513 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5514 w.Write(response)
5515 }))
5516 }
5517
5518
5519 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5520 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5521 w.Write(response)
5522 }))
5523 }
5524
5525 func benchmarkHandler(b *testing.B, h Handler) {
5526 b.ReportAllocs()
5527 req := reqBytes(`GET / HTTP/1.1
5528 Host: golang.org
5529 `)
5530 conn := &rwTestConn{
5531 Reader: &repeatReader{content: req, count: b.N},
5532 Writer: io.Discard,
5533 closec: make(chan bool, 1),
5534 }
5535 handled := 0
5536 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5537 handled++
5538 h.ServeHTTP(rw, r)
5539 })
5540 ln := &oneConnListener{conn: conn}
5541 go Serve(ln, handler)
5542 <-conn.closec
5543 if b.N != handled {
5544 b.Errorf("b.N=%d but handled %d", b.N, handled)
5545 }
5546 }
5547
5548 func BenchmarkServerHijack(b *testing.B) {
5549 b.ReportAllocs()
5550 req := reqBytes(`GET / HTTP/1.1
5551 Host: golang.org
5552 `)
5553 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5554 conn, _, err := w.(Hijacker).Hijack()
5555 if err != nil {
5556 panic(err)
5557 }
5558 conn.Close()
5559 })
5560 conn := &rwTestConn{
5561 Writer: io.Discard,
5562 closec: make(chan bool, 1),
5563 }
5564 ln := &oneConnListener{conn: conn}
5565 for i := 0; i < b.N; i++ {
5566 conn.Reader = bytes.NewReader(req)
5567 ln.conn = conn
5568 Serve(ln, h)
5569 <-conn.closec
5570 }
5571 }
5572
5573 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5574 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5575 b.ReportAllocs()
5576 b.StopTimer()
5577 sawClose := make(chan bool)
5578 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5579 <-rw.(CloseNotifier).CloseNotify()
5580 sawClose <- true
5581 })).ts
5582 b.StartTimer()
5583 for i := 0; i < b.N; i++ {
5584 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5585 if err != nil {
5586 b.Fatalf("error dialing: %v", err)
5587 }
5588 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5589 if err != nil {
5590 b.Fatal(err)
5591 }
5592 conn.Close()
5593 <-sawClose
5594 }
5595 b.StopTimer()
5596 }
5597
5598
5599 func TestConcurrentServerServe(t *testing.T) {
5600 setParallel(t)
5601 for i := 0; i < 100; i++ {
5602 ln1 := &oneConnListener{conn: nil}
5603 ln2 := &oneConnListener{conn: nil}
5604 srv := Server{}
5605 go func() { srv.Serve(ln1) }()
5606 go func() { srv.Serve(ln2) }()
5607 }
5608 }
5609
5610 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5611 func testServerIdleTimeout(t *testing.T, mode testMode) {
5612 if testing.Short() {
5613 t.Skip("skipping in short mode")
5614 }
5615 runTimeSensitiveTest(t, []time.Duration{
5616 10 * time.Millisecond,
5617 100 * time.Millisecond,
5618 1 * time.Second,
5619 10 * time.Second,
5620 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5621 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5622 io.Copy(io.Discard, r.Body)
5623 io.WriteString(w, r.RemoteAddr)
5624 }), func(ts *httptest.Server) {
5625 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5626 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5627 })
5628 defer cst.close()
5629 ts := cst.ts
5630 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5631 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5632 c := ts.Client()
5633
5634 get := func() (string, error) {
5635 res, err := c.Get(ts.URL)
5636 if err != nil {
5637 return "", err
5638 }
5639 defer res.Body.Close()
5640 slurp, err := io.ReadAll(res.Body)
5641 if err != nil {
5642
5643
5644
5645 t.Fatal(err)
5646 }
5647 return string(slurp), nil
5648 }
5649
5650 a1, err := get()
5651 if err != nil {
5652 return err
5653 }
5654 a2, err := get()
5655 if err != nil {
5656 return err
5657 }
5658 if a1 != a2 {
5659 return fmt.Errorf("did requests on different connections")
5660 }
5661 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5662 a3, err := get()
5663 if err != nil {
5664 return err
5665 }
5666 if a2 == a3 {
5667 return fmt.Errorf("request three unexpectedly on same connection")
5668 }
5669
5670
5671 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5672 if err != nil {
5673 return err
5674 }
5675 defer conn.Close()
5676 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5677 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5678 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5679 return fmt.Errorf("copy byte succeeded; want err")
5680 }
5681
5682 return nil
5683 })
5684 }
5685
5686 func get(t *testing.T, c *Client, url string) string {
5687 res, err := c.Get(url)
5688 if err != nil {
5689 t.Fatal(err)
5690 }
5691 defer res.Body.Close()
5692 slurp, err := io.ReadAll(res.Body)
5693 if err != nil {
5694 t.Fatal(err)
5695 }
5696 return string(slurp)
5697 }
5698
5699
5700
5701 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5702 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5703 }
5704 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5705 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5706 io.WriteString(w, r.RemoteAddr)
5707 })).ts
5708
5709 c := ts.Client()
5710 tr := c.Transport.(*Transport)
5711
5712 get := func() string { return get(t, c, ts.URL) }
5713
5714 a1, a2 := get(), get()
5715 if a1 == a2 {
5716 t.Logf("made two requests from a single conn %q (as expected)", a1)
5717 } else {
5718 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5719 }
5720
5721
5722
5723
5724
5725 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5726 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5727 }
5728
5729
5730 ts.Config.SetKeepAlivesEnabled(false)
5731
5732 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5733 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5734 if d > 0 {
5735 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5736 }
5737 return false
5738 }
5739 return true
5740 })
5741
5742
5743
5744
5745 }
5746
5747 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5748 func testServerShutdown(t *testing.T, mode testMode) {
5749 var cst *clientServerTest
5750
5751 var once sync.Once
5752 statesRes := make(chan map[ConnState]int, 1)
5753 shutdownRes := make(chan error, 1)
5754 gotOnShutdown := make(chan struct{})
5755 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5756 first := false
5757 once.Do(func() {
5758 statesRes <- cst.ts.Config.ExportAllConnsByState()
5759 go func() {
5760 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5761 }()
5762 first = true
5763 })
5764
5765 if first {
5766
5767
5768
5769 <-gotOnShutdown
5770
5771
5772 for !t.Failed() {
5773 res, err := cst.c.Get(cst.ts.URL)
5774 if err != nil {
5775 break
5776 }
5777 out, _ := io.ReadAll(res.Body)
5778 res.Body.Close()
5779 if mode == http2Mode {
5780 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5781 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5782 continue
5783 }
5784 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5785 }
5786 }
5787
5788 io.WriteString(w, r.RemoteAddr)
5789 })
5790
5791 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5792 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5793 })
5794
5795 out := get(t, cst.c, cst.ts.URL)
5796 t.Logf("%v: %q", cst.ts.URL, out)
5797
5798 if err := <-shutdownRes; err != nil {
5799 t.Fatalf("Shutdown: %v", err)
5800 }
5801 <-gotOnShutdown
5802
5803 if states := <-statesRes; states[StateActive] != 1 {
5804 t.Errorf("connection in wrong state, %v", states)
5805 }
5806 }
5807
5808 func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) }
5809 func testServerShutdownStateNew(t *testing.T, mode testMode) {
5810 if testing.Short() {
5811 t.Skip("test takes 5-6 seconds; skipping in short mode")
5812 }
5813
5814 var connAccepted sync.WaitGroup
5815 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5816
5817 }), func(ts *httptest.Server) {
5818 ts.Config.ConnState = func(conn net.Conn, state ConnState) {
5819 if state == StateNew {
5820 connAccepted.Done()
5821 }
5822 }
5823 }).ts
5824
5825
5826 connAccepted.Add(1)
5827 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5828 if err != nil {
5829 t.Fatal(err)
5830 }
5831 defer c.Close()
5832
5833
5834
5835
5836
5837 connAccepted.Wait()
5838
5839 shutdownRes := make(chan error, 1)
5840 go func() {
5841 shutdownRes <- ts.Config.Shutdown(context.Background())
5842 }()
5843 readRes := make(chan error, 1)
5844 go func() {
5845 _, err := c.Read([]byte{0})
5846 readRes <- err
5847 }()
5848
5849
5850
5851
5852 const expectTimeout = 5 * time.Second
5853
5854 t0 := time.Now()
5855 select {
5856 case got := <-shutdownRes:
5857 d := time.Since(t0)
5858 if got != nil {
5859 t.Fatalf("shutdown error after %v: %v", d, err)
5860 }
5861 if d < expectTimeout/2 {
5862 t.Errorf("shutdown too soon after %v", d)
5863 }
5864 case <-time.After(expectTimeout * 3 / 2):
5865 t.Fatalf("timeout waiting for shutdown")
5866 }
5867
5868
5869
5870 if err := <-readRes; err == nil {
5871 t.Error("expected error from Read")
5872 }
5873 }
5874
5875
5876 func TestServerCloseDeadlock(t *testing.T) {
5877 var s Server
5878 s.Close()
5879 s.Close()
5880 }
5881
5882
5883
5884 func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
5885 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
5886 if mode == http2Mode {
5887 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5888 defer restore()
5889 }
5890
5891 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
5892 defer cst.close()
5893 srv := cst.ts.Config
5894 srv.SetKeepAlivesEnabled(false)
5895 for try := 0; try < 2; try++ {
5896 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5897 if !srv.ExportAllConnsIdle() {
5898 if d > 0 {
5899 t.Logf("test server still has active conns after %v", d)
5900 }
5901 return false
5902 }
5903 return true
5904 })
5905 conns := 0
5906 var info httptrace.GotConnInfo
5907 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5908 GotConn: func(v httptrace.GotConnInfo) {
5909 conns++
5910 info = v
5911 },
5912 })
5913 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
5914 if err != nil {
5915 t.Fatal(err)
5916 }
5917 res, err := cst.c.Do(req)
5918 if err != nil {
5919 t.Fatal(err)
5920 }
5921 res.Body.Close()
5922 if conns != 1 {
5923 t.Fatalf("request %v: got %v conns, want 1", try, conns)
5924 }
5925 if info.Reused || info.WasIdle {
5926 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
5927 }
5928 }
5929 }
5930
5931
5932
5933
5934 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
5935 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
5936 runTimeSensitiveTest(t, []time.Duration{
5937 10 * time.Millisecond,
5938 50 * time.Millisecond,
5939 250 * time.Millisecond,
5940 time.Second,
5941 2 * time.Second,
5942 }, func(t *testing.T, timeout time.Duration) error {
5943 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5944 select {
5945 case <-time.After(2 * timeout):
5946 fmt.Fprint(w, "ok")
5947 case <-r.Context().Done():
5948 fmt.Fprint(w, r.Context().Err())
5949 }
5950 }), func(ts *httptest.Server) {
5951 ts.Config.ReadTimeout = timeout
5952 t.Logf("Server.Config.ReadTimeout = %v", timeout)
5953 })
5954 defer cst.close()
5955 ts := cst.ts
5956
5957 var retries atomic.Int32
5958 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
5959 if retries.Add(1) != 1 {
5960 return nil, errors.New("too many retries")
5961 }
5962 return nil, nil
5963 }
5964
5965 c := ts.Client()
5966
5967 res, err := c.Get(ts.URL)
5968 if err != nil {
5969 return fmt.Errorf("Get: %v", err)
5970 }
5971 slurp, err := io.ReadAll(res.Body)
5972 res.Body.Close()
5973 if err != nil {
5974 return fmt.Errorf("Body ReadAll: %v", err)
5975 }
5976 if string(slurp) != "ok" {
5977 return fmt.Errorf("got: %q, want ok", slurp)
5978 }
5979 return nil
5980 })
5981 }
5982
5983
5984
5985
5986 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
5987 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
5988 }
5989 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
5990 runTimeSensitiveTest(t, []time.Duration{
5991 10 * time.Millisecond,
5992 50 * time.Millisecond,
5993 250 * time.Millisecond,
5994 time.Second,
5995 2 * time.Second,
5996 }, func(t *testing.T, timeout time.Duration) error {
5997 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
5998 ts.Config.ReadHeaderTimeout = timeout
5999 ts.Config.IdleTimeout = 0
6000 })
6001 defer cst.close()
6002 ts := cst.ts
6003
6004
6005
6006 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6007 if err != nil {
6008 t.Fatalf("dial failed: %v", err)
6009 }
6010 br := bufio.NewReader(conn)
6011 defer conn.Close()
6012
6013 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6014 return fmt.Errorf("writing first request failed: %v", err)
6015 }
6016
6017 if _, err := ReadResponse(br, nil); err != nil {
6018 return fmt.Errorf("first response (before timeout) failed: %v", err)
6019 }
6020
6021
6022
6023 time.Sleep(timeout * 3 / 2)
6024
6025 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6026 return fmt.Errorf("writing second request failed: %v", err)
6027 }
6028
6029 if _, err := ReadResponse(br, nil); err != nil {
6030 return fmt.Errorf("second response (after timeout) failed: %v", err)
6031 }
6032
6033 return nil
6034 })
6035 }
6036
6037
6038
6039 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6040 for i, d := range durations {
6041 err := test(t, d)
6042 if err == nil {
6043 return
6044 }
6045 if i == len(durations)-1 || t.Failed() {
6046 t.Fatalf("failed with duration %v: %v", d, err)
6047 }
6048 t.Logf("retrying after error with duration %v: %v", d, err)
6049 }
6050 }
6051
6052
6053
6054 func TestServerDuplicateBackgroundRead(t *testing.T) {
6055 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6056 }
6057 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6058 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6059 testenv.SkipFlaky(t, 24826)
6060 }
6061
6062 goroutines := 5
6063 requests := 2000
6064 if testing.Short() {
6065 goroutines = 3
6066 requests = 100
6067 }
6068
6069 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6070
6071 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6072
6073 var wg sync.WaitGroup
6074 for i := 0; i < goroutines; i++ {
6075 wg.Add(1)
6076 go func() {
6077 defer wg.Done()
6078 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6079 if err != nil {
6080 t.Error(err)
6081 return
6082 }
6083 defer cn.Close()
6084
6085 wg.Add(1)
6086 go func() {
6087 defer wg.Done()
6088 io.Copy(io.Discard, cn)
6089 }()
6090
6091 for j := 0; j < requests; j++ {
6092 if t.Failed() {
6093 return
6094 }
6095 _, err := cn.Write(reqBytes)
6096 if err != nil {
6097 t.Error(err)
6098 return
6099 }
6100 }
6101 }()
6102 }
6103 wg.Wait()
6104 }
6105
6106
6107
6108
6109
6110
6111 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6112 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6113 }
6114 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6115 if runtime.GOOS == "plan9" {
6116 t.Skip("skipping test; see https://golang.org/issue/18657")
6117 }
6118 done := make(chan struct{})
6119 inHandler := make(chan bool, 1)
6120 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6121 defer close(done)
6122
6123
6124 inHandler <- true
6125
6126 conn, buf, err := w.(Hijacker).Hijack()
6127 if err != nil {
6128 t.Error(err)
6129 return
6130 }
6131 defer conn.Close()
6132
6133 peek, err := buf.Reader.Peek(3)
6134 if string(peek) != "foo" || err != nil {
6135 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6136 }
6137
6138 select {
6139 case <-r.Context().Done():
6140 t.Error("context unexpectedly canceled")
6141 default:
6142 }
6143 })).ts
6144
6145 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6146 if err != nil {
6147 t.Fatal(err)
6148 }
6149 defer cn.Close()
6150 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6151 t.Fatal(err)
6152 }
6153 <-inHandler
6154 if _, err := cn.Write([]byte("foo")); err != nil {
6155 t.Fatal(err)
6156 }
6157
6158 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6159 t.Fatal(err)
6160 }
6161 <-done
6162 }
6163
6164
6165
6166
6167 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6168 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6169 }
6170 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6171 if runtime.GOOS == "plan9" {
6172 t.Skip("skipping test; see https://golang.org/issue/18657")
6173 }
6174 done := make(chan struct{})
6175 const size = 8 << 10
6176 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6177 defer close(done)
6178
6179 conn, buf, err := w.(Hijacker).Hijack()
6180 if err != nil {
6181 t.Error(err)
6182 return
6183 }
6184 defer conn.Close()
6185 slurp, err := io.ReadAll(buf.Reader)
6186 if err != nil {
6187 t.Errorf("Copy: %v", err)
6188 }
6189 allX := true
6190 for _, v := range slurp {
6191 if v != 'x' {
6192 allX = false
6193 }
6194 }
6195 if len(slurp) != size {
6196 t.Errorf("read %d; want %d", len(slurp), size)
6197 } else if !allX {
6198 t.Errorf("read %q; want %d 'x'", slurp, size)
6199 }
6200 })).ts
6201
6202 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6203 if err != nil {
6204 t.Fatal(err)
6205 }
6206 defer cn.Close()
6207 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6208 strings.Repeat("x", size)); err != nil {
6209 t.Fatal(err)
6210 }
6211 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6212 t.Fatal(err)
6213 }
6214
6215 <-done
6216 }
6217
6218
6219 func TestServerValidatesMethod(t *testing.T) {
6220 tests := []struct {
6221 method string
6222 want int
6223 }{
6224 {"GET", 200},
6225 {"GE(T", 400},
6226 }
6227 for _, tt := range tests {
6228 conn := newTestConn()
6229 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6230
6231 ln := &oneConnListener{conn}
6232 go Serve(ln, serve(200))
6233 <-conn.closec
6234 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6235 if err != nil {
6236 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6237 continue
6238 }
6239 if res.StatusCode != tt.want {
6240 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6241 }
6242 }
6243 }
6244
6245
6246 type eofListenerNotComparable []int
6247
6248 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6249 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6250 func (eofListenerNotComparable) Close() error { return nil }
6251
6252
6253 func TestServerListenNotComparableListener(t *testing.T) {
6254 var s Server
6255 s.Serve(make(eofListenerNotComparable, 1))
6256 }
6257
6258
6259 type countCloseListener struct {
6260 net.Listener
6261 closes int32
6262 }
6263
6264 func (p *countCloseListener) Close() error {
6265 var err error
6266 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6267 err = p.Listener.Close()
6268 }
6269 return err
6270 }
6271
6272
6273 func TestServerCloseListenerOnce(t *testing.T) {
6274 setParallel(t)
6275 defer afterTest(t)
6276
6277 ln := newLocalListener(t)
6278 defer ln.Close()
6279
6280 cl := &countCloseListener{Listener: ln}
6281 server := &Server{}
6282 sdone := make(chan bool, 1)
6283
6284 go func() {
6285 server.Serve(cl)
6286 sdone <- true
6287 }()
6288 time.Sleep(10 * time.Millisecond)
6289 server.Shutdown(context.Background())
6290 ln.Close()
6291 <-sdone
6292
6293 nclose := atomic.LoadInt32(&cl.closes)
6294 if nclose != 1 {
6295 t.Errorf("Close calls = %v; want 1", nclose)
6296 }
6297 }
6298
6299
6300 func TestServerShutdownThenServe(t *testing.T) {
6301 var srv Server
6302 cl := &countCloseListener{Listener: nil}
6303 srv.Shutdown(context.Background())
6304 got := srv.Serve(cl)
6305 if got != ErrServerClosed {
6306 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6307 }
6308 nclose := atomic.LoadInt32(&cl.closes)
6309 if nclose != 1 {
6310 t.Errorf("Close calls = %v; want 1", nclose)
6311 }
6312 }
6313
6314
6315 func TestStripPortFromHost(t *testing.T) {
6316 mux := NewServeMux()
6317
6318 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6319 fmt.Fprintf(w, "OK")
6320 })
6321 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6322 fmt.Fprintf(w, "uh-oh!")
6323 })
6324
6325 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6326 rw := httptest.NewRecorder()
6327
6328 mux.ServeHTTP(rw, req)
6329
6330 response := rw.Body.String()
6331 if response != "OK" {
6332 t.Errorf("Response gotten was %q", response)
6333 }
6334 }
6335
6336 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6337 func testServerContexts(t *testing.T, mode testMode) {
6338 type baseKey struct{}
6339 type connKey struct{}
6340 ch := make(chan context.Context, 1)
6341 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6342 ch <- r.Context()
6343 }), func(ts *httptest.Server) {
6344 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6345 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6346 t.Errorf("unexpected onceClose listener type %T", ln)
6347 }
6348 return context.WithValue(context.Background(), baseKey{}, "base")
6349 }
6350 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6351 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6352 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6353 }
6354 return context.WithValue(ctx, connKey{}, "conn")
6355 }
6356 }).ts
6357 res, err := ts.Client().Get(ts.URL)
6358 if err != nil {
6359 t.Fatal(err)
6360 }
6361 res.Body.Close()
6362 ctx := <-ch
6363 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6364 t.Errorf("base context key = %#v; want %q", got, want)
6365 }
6366 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6367 t.Errorf("conn context key = %#v; want %q", got, want)
6368 }
6369 }
6370
6371
6372 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6373 run(t, testConnContextNotModifyingAllContexts)
6374 }
6375 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6376 type connKey struct{}
6377 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6378 rw.Header().Set("Connection", "close")
6379 }), func(ts *httptest.Server) {
6380 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6381 if got := ctx.Value(connKey{}); got != nil {
6382 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6383 }
6384 return context.WithValue(ctx, connKey{}, "conn")
6385 }
6386 }).ts
6387
6388 var res *Response
6389 var err error
6390
6391 res, err = ts.Client().Get(ts.URL)
6392 if err != nil {
6393 t.Fatal(err)
6394 }
6395 res.Body.Close()
6396
6397 res, err = ts.Client().Get(ts.URL)
6398 if err != nil {
6399 t.Fatal(err)
6400 }
6401 res.Body.Close()
6402 }
6403
6404
6405
6406 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6407 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6408 }
6409 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6410 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6411 w.Write([]byte("Hello, World!"))
6412 })).ts
6413
6414 serverURL, err := url.Parse(cst.URL)
6415 if err != nil {
6416 t.Fatalf("Failed to parse server URL: %v", err)
6417 }
6418
6419 unsupportedTEs := []string{
6420 "fugazi",
6421 "foo-bar",
6422 "unknown",
6423 `" chunked"`,
6424 }
6425
6426 for _, badTE := range unsupportedTEs {
6427 http1ReqBody := fmt.Sprintf(""+
6428 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6429 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6430
6431 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6432 if err != nil {
6433 t.Errorf("%q. unexpected error: %v", badTE, err)
6434 continue
6435 }
6436
6437 wantBody := fmt.Sprintf("" +
6438 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6439 "Connection: close\r\n\r\nUnsupported transfer encoding")
6440
6441 if string(gotBody) != wantBody {
6442 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6443 }
6444 }
6445 }
6446
6447
6448 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6449 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6450 type setting struct {
6451 name string
6452 body []byte
6453
6454
6455
6456
6457 contentEncoding any
6458 wantContentType string
6459 }
6460
6461 settings := []*setting{
6462 {
6463 name: "gzip content-encoding, gzipped",
6464 contentEncoding: "application/gzip",
6465 wantContentType: "",
6466 body: func() []byte {
6467 buf := new(bytes.Buffer)
6468 gzw := gzip.NewWriter(buf)
6469 gzw.Write([]byte("doctype html><p>Hello</p>"))
6470 gzw.Close()
6471 return buf.Bytes()
6472 }(),
6473 },
6474 {
6475 name: "zlib content-encoding, zlibbed",
6476 contentEncoding: "application/zlib",
6477 wantContentType: "",
6478 body: func() []byte {
6479 buf := new(bytes.Buffer)
6480 zw := zlib.NewWriter(buf)
6481 zw.Write([]byte("doctype html><p>Hello</p>"))
6482 zw.Close()
6483 return buf.Bytes()
6484 }(),
6485 },
6486 {
6487 name: "no content-encoding",
6488 wantContentType: "application/x-gzip",
6489 body: func() []byte {
6490 buf := new(bytes.Buffer)
6491 gzw := gzip.NewWriter(buf)
6492 gzw.Write([]byte("doctype html><p>Hello</p>"))
6493 gzw.Close()
6494 return buf.Bytes()
6495 }(),
6496 },
6497 {
6498 name: "phony content-encoding",
6499 contentEncoding: "foo/bar",
6500 body: []byte("doctype html><p>Hello</p>"),
6501 },
6502 {
6503 name: "empty but set content-encoding",
6504 contentEncoding: "",
6505 wantContentType: "audio/mpeg",
6506 body: []byte("ID3"),
6507 },
6508 }
6509
6510 for _, tt := range settings {
6511 t.Run(tt.name, func(t *testing.T) {
6512 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6513 if tt.contentEncoding != nil {
6514 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6515 }
6516 rw.Write(tt.body)
6517 }))
6518
6519 res, err := cst.c.Get(cst.ts.URL)
6520 if err != nil {
6521 t.Fatalf("Failed to fetch URL: %v", err)
6522 }
6523 defer res.Body.Close()
6524
6525 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6526 if w != nil {
6527 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6528 } else if g != "" {
6529 t.Errorf("Unexpected Content-Encoding %q", g)
6530 }
6531 }
6532
6533 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6534 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6535 }
6536 })
6537 }
6538 }
6539
6540
6541
6542 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6543 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6544 }
6545 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6546 if testing.Short() {
6547 t.Skip("skipping in short mode")
6548 }
6549
6550 pc, curFile, _, _ := runtime.Caller(0)
6551 curFileBaseName := filepath.Base(curFile)
6552 testFuncName := runtime.FuncForPC(pc).Name()
6553
6554 timeoutMsg := "timed out here!"
6555
6556 tests := []struct {
6557 name string
6558 mustTimeout bool
6559 wantResp string
6560 }{
6561 {
6562 name: "return before timeout",
6563 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6564 },
6565 {
6566 name: "return after timeout",
6567 mustTimeout: true,
6568 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6569 len(timeoutMsg), timeoutMsg),
6570 },
6571 }
6572
6573 for _, tt := range tests {
6574 tt := tt
6575 t.Run(tt.name, func(t *testing.T) {
6576 exitHandler := make(chan bool, 1)
6577 defer close(exitHandler)
6578 lastLine := make(chan int, 1)
6579
6580 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6581 w.WriteHeader(404)
6582 w.WriteHeader(404)
6583 w.WriteHeader(404)
6584 w.WriteHeader(404)
6585 _, _, line, _ := runtime.Caller(0)
6586 lastLine <- line
6587 <-exitHandler
6588 })
6589
6590 if !tt.mustTimeout {
6591 exitHandler <- true
6592 }
6593
6594 logBuf := new(strings.Builder)
6595 srvLog := log.New(logBuf, "", 0)
6596
6597 dur := 20 * time.Millisecond
6598 if !tt.mustTimeout {
6599
6600 dur = 10 * time.Second
6601 }
6602 th := TimeoutHandler(sh, dur, timeoutMsg)
6603 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6604 defer cst.close()
6605
6606 res, err := cst.c.Get(cst.ts.URL)
6607 if err != nil {
6608 t.Fatalf("Unexpected error: %v", err)
6609 }
6610
6611
6612
6613 res.Header.Del("Date")
6614 res.Header.Del("Content-Type")
6615
6616
6617 blob, _ := httputil.DumpResponse(res, true)
6618 if g, w := string(blob), tt.wantResp; g != w {
6619 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6620 }
6621
6622
6623
6624 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6625 if g, w := len(logEntries), 3; g != w {
6626 blob, _ := json.MarshalIndent(logEntries, "", " ")
6627 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6628 }
6629
6630 lastSpuriousLine := <-lastLine
6631 firstSpuriousLine := lastSpuriousLine - 3
6632
6633
6634 for i, logEntry := range logEntries {
6635 wantLine := firstSpuriousLine + i
6636 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6637 testFuncName, curFileBaseName, wantLine)
6638 re := regexp.MustCompile(pat)
6639 if !re.MatchString(logEntry) {
6640 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6641 }
6642 }
6643 })
6644 }
6645 }
6646
6647
6648
6649
6650 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6651 conn, err := net.Dial("tcp", host)
6652 if err != nil {
6653 return nil, err
6654 }
6655 defer conn.Close()
6656
6657 if _, err := conn.Write(http1ReqBody); err != nil {
6658 return nil, err
6659 }
6660 return io.ReadAll(conn)
6661 }
6662
6663 func BenchmarkResponseStatusLine(b *testing.B) {
6664 b.ReportAllocs()
6665 b.RunParallel(func(pb *testing.PB) {
6666 bw := bufio.NewWriter(io.Discard)
6667 var buf3 [3]byte
6668 for pb.Next() {
6669 Export_writeStatusLine(bw, true, 200, buf3[:])
6670 }
6671 })
6672 }
6673
6674 func TestDisableKeepAliveUpgrade(t *testing.T) {
6675 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6676 }
6677 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6678 if testing.Short() {
6679 t.Skip("skipping in short mode")
6680 }
6681
6682 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6683 w.Header().Set("Connection", "Upgrade")
6684 w.Header().Set("Upgrade", "someProto")
6685 w.WriteHeader(StatusSwitchingProtocols)
6686 c, buf, err := w.(Hijacker).Hijack()
6687 if err != nil {
6688 return
6689 }
6690 defer c.Close()
6691
6692
6693
6694 io.Copy(c, buf)
6695 }), func(ts *httptest.Server) {
6696 ts.Config.SetKeepAlivesEnabled(false)
6697 }).ts
6698
6699 cl := s.Client()
6700 cl.Transport.(*Transport).DisableKeepAlives = true
6701
6702 resp, err := cl.Get(s.URL)
6703 if err != nil {
6704 t.Fatalf("failed to perform request: %v", err)
6705 }
6706 defer resp.Body.Close()
6707
6708 if resp.StatusCode != StatusSwitchingProtocols {
6709 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6710 }
6711
6712 rwc, ok := resp.Body.(io.ReadWriteCloser)
6713 if !ok {
6714 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6715 }
6716
6717 _, err = rwc.Write([]byte("hello"))
6718 if err != nil {
6719 t.Fatalf("failed to write to body: %v", err)
6720 }
6721
6722 b := make([]byte, 5)
6723 _, err = io.ReadFull(rwc, b)
6724 if err != nil {
6725 t.Fatalf("failed to read from body: %v", err)
6726 }
6727
6728 if string(b) != "hello" {
6729 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6730 }
6731 }
6732
6733 type tlogWriter struct{ t *testing.T }
6734
6735 func (w tlogWriter) Write(p []byte) (int, error) {
6736 w.t.Log(string(p))
6737 return len(p), nil
6738 }
6739
6740 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6741 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6742 }
6743 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6744 const wantBody = "want"
6745 const wantUpgrade = "someProto"
6746 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6747 w.Header().Set("Connection", "Upgrade")
6748 w.Header().Set("Upgrade", wantUpgrade)
6749 w.WriteHeader(StatusSwitchingProtocols)
6750 NewResponseController(w).Flush()
6751
6752
6753 w.WriteHeader(200)
6754 if _, err := w.Write([]byte("x")); err == nil {
6755 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6756 }
6757
6758 c, _, err := NewResponseController(w).Hijack()
6759 if err != nil {
6760 t.Errorf("Hijack: %v", err)
6761 return
6762 }
6763 defer c.Close()
6764 if _, err := c.Write([]byte(wantBody)); err != nil {
6765 t.Errorf("Write to hijacked body: %v", err)
6766 }
6767 }), func(ts *httptest.Server) {
6768
6769 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6770 }).ts
6771
6772 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6773 if err != nil {
6774 t.Fatalf("net.Dial: %v", err)
6775 }
6776 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6777 if err != nil {
6778 t.Fatalf("conn.Write: %v", err)
6779 }
6780 defer conn.Close()
6781
6782 r := bufio.NewReader(conn)
6783 res, err := ReadResponse(r, &Request{Method: "GET"})
6784 if err != nil {
6785 t.Fatal("ReadResponse error:", err)
6786 }
6787 if res.StatusCode != StatusSwitchingProtocols {
6788 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
6789 }
6790 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
6791 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
6792 }
6793 body, err := io.ReadAll(r)
6794 if err != nil {
6795 t.Error(err)
6796 }
6797 if string(body) != wantBody {
6798 t.Errorf("Response body = %q, want %q", string(body), wantBody)
6799 }
6800 }
6801
6802 func TestMuxRedirectRelative(t *testing.T) {
6803 setParallel(t)
6804 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6805 if err != nil {
6806 t.Errorf("%s", err)
6807 }
6808 mux := NewServeMux()
6809 resp := httptest.NewRecorder()
6810 mux.ServeHTTP(resp, req)
6811 if got, want := resp.Header().Get("Location"), "/"; got != want {
6812 t.Errorf("Location header expected %q; got %q", want, got)
6813 }
6814 if got, want := resp.Code, StatusMovedPermanently; got != want {
6815 t.Errorf("Expected response code %d; got %d", want, got)
6816 }
6817 }
6818
6819
6820 func TestQuerySemicolon(t *testing.T) {
6821 t.Cleanup(func() { afterTest(t) })
6822
6823 tests := []struct {
6824 query string
6825 xNoSemicolons string
6826 xWithSemicolons string
6827 expectParseFormErr bool
6828 }{
6829 {"?a=1;x=bad&x=good", "good", "bad", true},
6830 {"?a=1;b=bad&x=good", "good", "good", true},
6831 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6832 {"?a=1;x=good;x=bad", "", "good", true},
6833 }
6834
6835 run(t, func(t *testing.T, mode testMode) {
6836 for _, tt := range tests {
6837 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6838 allowSemicolons := false
6839 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
6840 })
6841 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6842 allowSemicolons, expectParseFormErr := true, false
6843 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
6844 })
6845 }
6846 })
6847 }
6848
6849 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
6850 writeBackX := func(w ResponseWriter, r *Request) {
6851 x := r.URL.Query().Get("x")
6852 if expectParseFormErr {
6853 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6854 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6855 }
6856 } else {
6857 if err := r.ParseForm(); err != nil {
6858 t.Errorf("expected no error from ParseForm, got %v", err)
6859 }
6860 }
6861 if got := r.FormValue("x"); x != got {
6862 t.Errorf("got %q from FormValue, want %q", got, x)
6863 }
6864 fmt.Fprintf(w, "%s", x)
6865 }
6866
6867 h := Handler(HandlerFunc(writeBackX))
6868 if allowSemicolons {
6869 h = AllowQuerySemicolons(h)
6870 }
6871
6872 logBuf := &strings.Builder{}
6873 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
6874 ts.Config.ErrorLog = log.New(logBuf, "", 0)
6875 }).ts
6876
6877 req, _ := NewRequest("GET", ts.URL+query, nil)
6878 res, err := ts.Client().Do(req)
6879 if err != nil {
6880 t.Fatal(err)
6881 }
6882 slurp, _ := io.ReadAll(res.Body)
6883 res.Body.Close()
6884 if got, want := res.StatusCode, 200; got != want {
6885 t.Errorf("Status = %d; want = %d", got, want)
6886 }
6887 if got, want := string(slurp), wantX; got != want {
6888 t.Errorf("Body = %q; want = %q", got, want)
6889 }
6890 }
6891
6892 func TestMaxBytesHandler(t *testing.T) {
6893
6894 defer afterTest(t)
6895
6896 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
6897 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
6898 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
6899 func(t *testing.T) {
6900 run(t, func(t *testing.T, mode testMode) {
6901 testMaxBytesHandler(t, mode, maxSize, requestSize)
6902 }, testNotParallel)
6903 })
6904 }
6905 }
6906 }
6907
6908 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
6909 runTimeSensitiveTest(t, []time.Duration{
6910 1 * time.Millisecond,
6911 5 * time.Millisecond,
6912 10 * time.Millisecond,
6913 50 * time.Millisecond,
6914 100 * time.Millisecond,
6915 500 * time.Millisecond,
6916 time.Second,
6917 5 * time.Second,
6918 }, func(t *testing.T, timeout time.Duration) error {
6919 SetRSTAvoidanceDelay(t, timeout)
6920 t.Logf("set RST avoidance delay to %v", timeout)
6921
6922 var (
6923 handlerN int64
6924 handlerErr error
6925 )
6926 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
6927 var buf bytes.Buffer
6928 handlerN, handlerErr = io.Copy(&buf, r.Body)
6929 io.Copy(w, &buf)
6930 })
6931
6932 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
6933
6934
6935 defer cst.close()
6936 ts := cst.ts
6937 c := ts.Client()
6938
6939 body := strings.Repeat("a", int(requestSize))
6940 var wg sync.WaitGroup
6941 defer wg.Wait()
6942 getBody := func() (io.ReadCloser, error) {
6943 wg.Add(1)
6944 body := &wgReadCloser{
6945 Reader: strings.NewReader(body),
6946 wg: &wg,
6947 }
6948 return body, nil
6949 }
6950 reqBody, _ := getBody()
6951 req, err := NewRequest("POST", ts.URL, reqBody)
6952 if err != nil {
6953 reqBody.Close()
6954 t.Fatal(err)
6955 }
6956 req.ContentLength = int64(len(body))
6957 req.GetBody = getBody
6958 req.Header.Set("Content-Type", "text/plain")
6959
6960 var buf strings.Builder
6961 res, err := c.Do(req)
6962 if err != nil {
6963 return fmt.Errorf("unexpected connection error: %v", err)
6964 } else {
6965 _, err = io.Copy(&buf, res.Body)
6966 res.Body.Close()
6967 if err != nil {
6968 return fmt.Errorf("unexpected read error: %v", err)
6969 }
6970 }
6971
6972
6973
6974
6975 if handlerN > maxSize {
6976 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
6977 }
6978 if requestSize > maxSize && handlerErr == nil {
6979 t.Error("expected error on handler side; got nil")
6980 }
6981 if requestSize <= maxSize {
6982 if handlerErr != nil {
6983 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
6984 }
6985 if handlerN != requestSize {
6986 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
6987 }
6988 }
6989 if buf.Len() != int(handlerN) {
6990 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
6991 }
6992
6993 return nil
6994 })
6995 }
6996
6997 func TestEarlyHints(t *testing.T) {
6998 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
6999 h := w.Header()
7000 h.Add("Link", "</style.css>; rel=preload; as=style")
7001 h.Add("Link", "</script.js>; rel=preload; as=script")
7002 w.WriteHeader(StatusEarlyHints)
7003
7004 h.Add("Link", "</foo.js>; rel=preload; as=script")
7005 w.WriteHeader(StatusEarlyHints)
7006
7007 w.Write([]byte("stuff"))
7008 }))
7009
7010 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7011 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
7012 if !strings.Contains(got, expected) {
7013 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7014 }
7015 }
7016 func TestProcessing(t *testing.T) {
7017 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7018 w.WriteHeader(StatusProcessing)
7019 w.Write([]byte("stuff"))
7020 }))
7021
7022 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7023 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
7024 if !strings.Contains(got, expected) {
7025 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7026 }
7027 }
7028
7029 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
7030 func testParseFormCleanup(t *testing.T, mode testMode) {
7031 if mode == http2Mode {
7032 t.Skip("https://go.dev/issue/20253")
7033 }
7034
7035 const maxMemory = 1024
7036 const key = "file"
7037
7038 if runtime.GOOS == "windows" {
7039
7040 t.Skip("https://go.dev/issue/25965")
7041 }
7042
7043 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7044 r.ParseMultipartForm(maxMemory)
7045 f, _, err := r.FormFile(key)
7046 if err != nil {
7047 t.Errorf("r.FormFile(%q) = %v", key, err)
7048 return
7049 }
7050 of, ok := f.(*os.File)
7051 if !ok {
7052 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7053 return
7054 }
7055 w.Write([]byte(of.Name()))
7056 }))
7057
7058 fBuf := new(bytes.Buffer)
7059 mw := multipart.NewWriter(fBuf)
7060 mf, err := mw.CreateFormFile(key, "myfile.txt")
7061 if err != nil {
7062 t.Fatal(err)
7063 }
7064 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7065 t.Fatal(err)
7066 }
7067 if err := mw.Close(); err != nil {
7068 t.Fatal(err)
7069 }
7070 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7071 if err != nil {
7072 t.Fatal(err)
7073 }
7074 req.Header.Set("Content-Type", mw.FormDataContentType())
7075 res, err := cst.c.Do(req)
7076 if err != nil {
7077 t.Fatal(err)
7078 }
7079 defer res.Body.Close()
7080 fname, err := io.ReadAll(res.Body)
7081 if err != nil {
7082 t.Fatal(err)
7083 }
7084 cst.close()
7085 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7086 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7087 }
7088 }
7089
7090 func TestHeadBody(t *testing.T) {
7091 const identityMode = false
7092 const chunkedMode = true
7093 run(t, func(t *testing.T, mode testMode) {
7094 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7095 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7096 })
7097 }
7098
7099 func TestGetBody(t *testing.T) {
7100 const identityMode = false
7101 const chunkedMode = true
7102 run(t, func(t *testing.T, mode testMode) {
7103 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7104 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7105 })
7106 }
7107
7108 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7109 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7110 b, err := io.ReadAll(r.Body)
7111 if err != nil {
7112 t.Errorf("server reading body: %v", err)
7113 return
7114 }
7115 w.Header().Set("X-Request-Body", string(b))
7116 w.Header().Set("Content-Length", "0")
7117 }))
7118 defer cst.close()
7119 for _, reqBody := range []string{
7120 "",
7121 "",
7122 "request_body",
7123 "",
7124 } {
7125 var bodyReader io.Reader
7126 if reqBody != "" {
7127 bodyReader = strings.NewReader(reqBody)
7128 if chunked {
7129 bodyReader = bufio.NewReader(bodyReader)
7130 }
7131 }
7132 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7133 if err != nil {
7134 t.Fatal(err)
7135 }
7136 res, err := cst.c.Do(req)
7137 if err != nil {
7138 t.Fatal(err)
7139 }
7140 res.Body.Close()
7141 if got, want := res.StatusCode, 200; got != want {
7142 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7143 }
7144 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7145 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7146 }
7147 }
7148 }
7149
7150
7151
7152 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7153 func testDisableContentLength(t *testing.T, mode testMode) {
7154 if mode == http2Mode {
7155 t.Skip("skipping until h2_bundle.go is updated; see https://go-review.googlesource.com/c/net/+/471535")
7156 }
7157
7158 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7159 w.Header()["Content-Length"] = nil
7160 fmt.Fprintf(w, "OK")
7161 }))
7162
7163 res, err := noCL.c.Get(noCL.ts.URL)
7164 if err != nil {
7165 t.Fatal(err)
7166 }
7167 if got, haveCL := res.Header["Content-Length"]; haveCL {
7168 t.Errorf("Unexpected Content-Length: %q", got)
7169 }
7170 if err := res.Body.Close(); err != nil {
7171 t.Fatal(err)
7172 }
7173
7174 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7175 fmt.Fprintf(w, "OK")
7176 }))
7177
7178 res, err = withCL.c.Get(withCL.ts.URL)
7179 if err != nil {
7180 t.Fatal(err)
7181 }
7182 if got := res.Header.Get("Content-Length"); got != "2" {
7183 t.Errorf("Content-Length: %q; want 2", got)
7184 }
7185 if err := res.Body.Close(); err != nil {
7186 t.Fatal(err)
7187 }
7188 }
7189
7190 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7191 func testErrorContentLength(t *testing.T, mode testMode) {
7192 const errorBody = "an error occurred"
7193 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7194 w.Header().Set("Content-Length", "1000")
7195 Error(w, errorBody, 400)
7196 }))
7197 res, err := cst.c.Get(cst.ts.URL)
7198 if err != nil {
7199 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7200 }
7201 defer res.Body.Close()
7202 body, err := io.ReadAll(res.Body)
7203 if err != nil {
7204 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7205 }
7206 if string(body) != errorBody+"\n" {
7207 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7208 }
7209 }
7210
7211 func TestError(t *testing.T) {
7212 w := httptest.NewRecorder()
7213 w.Header().Set("Content-Length", "1")
7214 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7215 w.Header().Set("Other", "foo")
7216 Error(w, "oops", 432)
7217
7218 h := w.Header()
7219 for _, hdr := range []string{"Content-Length"} {
7220 if v, ok := h[hdr]; ok {
7221 t.Errorf("%s: %q, want not present", hdr, v)
7222 }
7223 }
7224 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7225 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7226 }
7227 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7228 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7229 }
7230 }
7231
7232 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7233 run(t, testServerReadAfterWriteHeader100Continue)
7234 }
7235 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7236 t.Skip("https://go.dev/issue/67555")
7237 body := []byte("body")
7238 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7239 w.WriteHeader(200)
7240 NewResponseController(w).Flush()
7241 io.ReadAll(r.Body)
7242 w.Write(body)
7243 }), func(tr *Transport) {
7244 tr.ExpectContinueTimeout = 24 * time.Hour
7245 })
7246
7247 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7248 req.Header.Set("Expect", "100-continue")
7249 res, err := cst.c.Do(req)
7250 if err != nil {
7251 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7252 }
7253 defer res.Body.Close()
7254 got, err := io.ReadAll(res.Body)
7255 if err != nil {
7256 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7257 }
7258 if !bytes.Equal(got, body) {
7259 t.Fatalf("response body = %q, want %q", got, body)
7260 }
7261 }
7262
7263 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7264 run(t, testServerReadAfterHandlerDone100Continue)
7265 }
7266 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7267 t.Skip("https://go.dev/issue/67555")
7268 readyc := make(chan struct{})
7269 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7270 go func() {
7271 <-readyc
7272 io.ReadAll(r.Body)
7273 <-readyc
7274 }()
7275 }), func(tr *Transport) {
7276 tr.ExpectContinueTimeout = 24 * time.Hour
7277 })
7278
7279 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7280 req.Header.Set("Expect", "100-continue")
7281 res, err := cst.c.Do(req)
7282 if err != nil {
7283 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7284 }
7285 res.Body.Close()
7286 readyc <- struct{}{}
7287 readyc <- struct{}{}
7288 }
7289
7290 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7291 run(t, testServerReadAfterHandlerAbort100Continue)
7292 }
7293 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7294 t.Skip("https://go.dev/issue/67555")
7295 readyc := make(chan struct{})
7296 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7297 go func() {
7298 <-readyc
7299 io.ReadAll(r.Body)
7300 <-readyc
7301 }()
7302 panic(ErrAbortHandler)
7303 }), func(tr *Transport) {
7304 tr.ExpectContinueTimeout = 24 * time.Hour
7305 })
7306
7307 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7308 req.Header.Set("Expect", "100-continue")
7309 res, err := cst.c.Do(req)
7310 if err == nil {
7311 res.Body.Close()
7312 }
7313 readyc <- struct{}{}
7314 readyc <- struct{}{}
7315 }
7316
View as plain text