1
2
3
4
5 package httptest
6
7 import (
8 "bufio"
9 "io"
10 "net"
11 "net/http"
12 "sync"
13 "testing"
14 )
15
16 type newServerFunc func(http.Handler) *Server
17
18 var newServers = map[string]newServerFunc{
19 "NewServer": NewServer,
20 "NewTLSServer": NewTLSServer,
21
22
23
24 "NewServerManual": func(h http.Handler) *Server {
25 ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
26 ts.Start()
27 return ts
28 },
29 "NewTLSServerManual": func(h http.Handler) *Server {
30 ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
31 ts.StartTLS()
32 return ts
33 },
34 }
35
36 func TestServer(t *testing.T) {
37 for _, name := range []string{"NewServer", "NewServerManual"} {
38 t.Run(name, func(t *testing.T) {
39 newServer := newServers[name]
40 t.Run("Server", func(t *testing.T) { testServer(t, newServer) })
41 t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) })
42 t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) })
43 t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) })
44 t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) })
45 })
46 }
47 for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} {
48 t.Run(name, func(t *testing.T) {
49 newServer := newServers[name]
50 t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) })
51 t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) })
52 })
53 }
54 }
55
56 func testServer(t *testing.T, newServer newServerFunc) {
57 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58 w.Write([]byte("hello"))
59 }))
60 defer ts.Close()
61 res, err := http.Get(ts.URL)
62 if err != nil {
63 t.Fatal(err)
64 }
65 got, err := io.ReadAll(res.Body)
66 res.Body.Close()
67 if err != nil {
68 t.Fatal(err)
69 }
70 if string(got) != "hello" {
71 t.Errorf("got %q, want hello", string(got))
72 }
73 }
74
75
76 func testGetAfterClose(t *testing.T, newServer newServerFunc) {
77 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
78 w.Write([]byte("hello"))
79 }))
80
81 res, err := http.Get(ts.URL)
82 if err != nil {
83 t.Fatal(err)
84 }
85 got, err := io.ReadAll(res.Body)
86 res.Body.Close()
87 if err != nil {
88 t.Fatal(err)
89 }
90 if string(got) != "hello" {
91 t.Fatalf("got %q, want hello", string(got))
92 }
93
94 ts.Close()
95
96 res, err = http.Get(ts.URL)
97 if err == nil {
98 body, _ := io.ReadAll(res.Body)
99 t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body)
100 }
101 }
102
103 func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
104 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
105 w.Write([]byte("hello"))
106 }))
107 dial := func() net.Conn {
108 c, err := net.Dial("tcp", ts.Listener.Addr().String())
109 if err != nil {
110 t.Fatal(err)
111 }
112 return c
113 }
114
115
116 cnew := dial()
117 defer cnew.Close()
118
119
120 cidle := dial()
121 defer cidle.Close()
122 cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
123 _, err := http.ReadResponse(bufio.NewReader(cidle), nil)
124 if err != nil {
125 t.Fatal(err)
126 }
127
128 ts.Close()
129 }
130
131
132 func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
133 var s *Server
134 s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
135 s.CloseClientConnections()
136 }))
137 defer s.Close()
138 res, err := http.Get(s.URL)
139 if err == nil {
140 res.Body.Close()
141 t.Fatalf("Unexpected response: %#v", res)
142 }
143 }
144
145
146
147 func testServerClient(t *testing.T, newTLSServer newServerFunc) {
148 ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
149 w.Write([]byte("hello"))
150 }))
151 defer ts.Close()
152 client := ts.Client()
153 res, err := client.Get(ts.URL)
154 if err != nil {
155 t.Fatal(err)
156 }
157 got, err := io.ReadAll(res.Body)
158 res.Body.Close()
159 if err != nil {
160 t.Fatal(err)
161 }
162 if string(got) != "hello" {
163 t.Errorf("got %q, want hello", string(got))
164 }
165 }
166
167
168
169 func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
170 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
171 }))
172 defer ts.Close()
173 client := ts.Client()
174 if _, ok := client.Transport.(*http.Transport); !ok {
175 t.Errorf("got %T, want *http.Transport", client.Transport)
176 }
177 }
178
179
180
181 func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) {
182 ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
183 }))
184 defer ts.Close()
185 client := ts.Client()
186 if _, ok := client.Transport.(*http.Transport); !ok {
187 t.Errorf("got %T, want *http.Transport", client.Transport)
188 }
189 }
190
191 type onlyCloseListener struct {
192 net.Listener
193 }
194
195 func (onlyCloseListener) Close() error { return nil }
196
197
198
199 func TestServerZeroValueClose(t *testing.T) {
200 ts := &Server{
201 Listener: onlyCloseListener{},
202 Config: &http.Server{},
203 }
204
205 ts.Close()
206 }
207
208
209
210 func TestCloseHijackedConnection(t *testing.T) {
211 hijacked := make(chan net.Conn)
212 ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
213 defer close(hijacked)
214 hj, ok := w.(http.Hijacker)
215 if !ok {
216 t.Fatal("failed to hijack")
217 }
218 c, _, err := hj.Hijack()
219 if err != nil {
220 t.Fatal(err)
221 }
222 hijacked <- c
223 }))
224
225 var wg sync.WaitGroup
226 wg.Add(1)
227 go func() {
228 defer wg.Done()
229 req, err := http.NewRequest("GET", ts.URL, nil)
230 if err != nil {
231 t.Log(err)
232 }
233
234 var c http.Client
235 resp, err := c.Do(req)
236 if err != nil {
237 t.Log(err)
238 return
239 }
240 resp.Body.Close()
241 }()
242
243 wg.Add(1)
244 conn := <-hijacked
245 go func(conn net.Conn) {
246 defer wg.Done()
247
248
249 conn.Close()
250 ts.Config.ConnState(conn, http.StateClosed)
251 }(conn)
252
253 wg.Add(1)
254 go func() {
255 defer wg.Done()
256 ts.Close()
257 }()
258 wg.Wait()
259 }
260
261 func TestTLSServerWithHTTP2(t *testing.T) {
262 modes := []struct {
263 name string
264 wantProto string
265 }{
266 {"http1", "HTTP/1.1"},
267 {"http2", "HTTP/2.0"},
268 }
269
270 for _, tt := range modes {
271 t.Run(tt.name, func(t *testing.T) {
272 cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
273 w.Header().Set("X-Proto", r.Proto)
274 }))
275
276 switch tt.name {
277 case "http2":
278 cst.EnableHTTP2 = true
279 cst.StartTLS()
280 default:
281 cst.Start()
282 }
283
284 defer cst.Close()
285
286 res, err := cst.Client().Get(cst.URL)
287 if err != nil {
288 t.Fatalf("Failed to make request: %v", err)
289 }
290 if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
291 t.Fatalf("X-Proto header mismatch:\n\tgot: %q\n\twant: %q", g, w)
292 }
293 })
294 }
295 }
296
View as plain text