1
2
3
4
5 package httptest
6
7 import (
8 "fmt"
9 "io"
10 "net/http"
11 "testing"
12 )
13
14 func TestRecorder(t *testing.T) {
15 type checkFunc func(*ResponseRecorder) error
16 check := func(fns ...checkFunc) []checkFunc { return fns }
17
18 hasStatus := func(wantCode int) checkFunc {
19 return func(rec *ResponseRecorder) error {
20 if rec.Code != wantCode {
21 return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
22 }
23 return nil
24 }
25 }
26 hasResultStatus := func(want string) checkFunc {
27 return func(rec *ResponseRecorder) error {
28 if rec.Result().Status != want {
29 return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want)
30 }
31 return nil
32 }
33 }
34 hasResultStatusCode := func(wantCode int) checkFunc {
35 return func(rec *ResponseRecorder) error {
36 if rec.Result().StatusCode != wantCode {
37 return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
38 }
39 return nil
40 }
41 }
42 hasResultContents := func(want string) checkFunc {
43 return func(rec *ResponseRecorder) error {
44 contentBytes, err := io.ReadAll(rec.Result().Body)
45 if err != nil {
46 return err
47 }
48 contents := string(contentBytes)
49 if contents != want {
50 return fmt.Errorf("Result().Body = %s; want %s", contents, want)
51 }
52 return nil
53 }
54 }
55 hasContents := func(want string) checkFunc {
56 return func(rec *ResponseRecorder) error {
57 if rec.Body.String() != want {
58 return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
59 }
60 return nil
61 }
62 }
63 hasFlush := func(want bool) checkFunc {
64 return func(rec *ResponseRecorder) error {
65 if rec.Flushed != want {
66 return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
67 }
68 return nil
69 }
70 }
71 hasOldHeader := func(key, want string) checkFunc {
72 return func(rec *ResponseRecorder) error {
73 if got := rec.HeaderMap.Get(key); got != want {
74 return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
75 }
76 return nil
77 }
78 }
79 hasHeader := func(key, want string) checkFunc {
80 return func(rec *ResponseRecorder) error {
81 if got := rec.Result().Header.Get(key); got != want {
82 return fmt.Errorf("final header %s = %q; want %q", key, got, want)
83 }
84 return nil
85 }
86 }
87 hasNotHeaders := func(keys ...string) checkFunc {
88 return func(rec *ResponseRecorder) error {
89 for _, k := range keys {
90 v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
91 if ok {
92 return fmt.Errorf("unexpected header %s with value %q", k, v)
93 }
94 }
95 return nil
96 }
97 }
98 hasTrailer := func(key, want string) checkFunc {
99 return func(rec *ResponseRecorder) error {
100 if got := rec.Result().Trailer.Get(key); got != want {
101 return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
102 }
103 return nil
104 }
105 }
106 hasNotTrailers := func(keys ...string) checkFunc {
107 return func(rec *ResponseRecorder) error {
108 trailers := rec.Result().Trailer
109 for _, k := range keys {
110 _, ok := trailers[http.CanonicalHeaderKey(k)]
111 if ok {
112 return fmt.Errorf("unexpected trailer %s", k)
113 }
114 }
115 return nil
116 }
117 }
118 hasContentLength := func(length int64) checkFunc {
119 return func(rec *ResponseRecorder) error {
120 if got := rec.Result().ContentLength; got != length {
121 return fmt.Errorf("ContentLength = %d; want %d", got, length)
122 }
123 return nil
124 }
125 }
126
127 for _, tt := range [...]struct {
128 name string
129 h func(w http.ResponseWriter, r *http.Request)
130 checks []checkFunc
131 }{
132 {
133 "200 default",
134 func(w http.ResponseWriter, r *http.Request) {},
135 check(hasStatus(200), hasContents("")),
136 },
137 {
138 "first code only",
139 func(w http.ResponseWriter, r *http.Request) {
140 w.WriteHeader(201)
141 w.WriteHeader(202)
142 w.Write([]byte("hi"))
143 },
144 check(hasStatus(201), hasContents("hi")),
145 },
146 {
147 "write sends 200",
148 func(w http.ResponseWriter, r *http.Request) {
149 w.Write([]byte("hi first"))
150 w.WriteHeader(201)
151 w.WriteHeader(202)
152 },
153 check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
154 },
155 {
156 "write string",
157 func(w http.ResponseWriter, r *http.Request) {
158 io.WriteString(w, "hi first")
159 },
160 check(
161 hasStatus(200),
162 hasContents("hi first"),
163 hasFlush(false),
164 hasHeader("Content-Type", "text/plain; charset=utf-8"),
165 ),
166 },
167 {
168 "flush",
169 func(w http.ResponseWriter, r *http.Request) {
170 w.(http.Flusher).Flush()
171 w.WriteHeader(201)
172 },
173 check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
174 },
175 {
176 "Content-Type detection",
177 func(w http.ResponseWriter, r *http.Request) {
178 io.WriteString(w, "<html>")
179 },
180 check(hasHeader("Content-Type", "text/html; charset=utf-8")),
181 },
182 {
183 "no Content-Type detection with Transfer-Encoding",
184 func(w http.ResponseWriter, r *http.Request) {
185 w.Header().Set("Transfer-Encoding", "some encoding")
186 io.WriteString(w, "<html>")
187 },
188 check(hasHeader("Content-Type", "")),
189 },
190 {
191 "no Content-Type detection if set explicitly",
192 func(w http.ResponseWriter, r *http.Request) {
193 w.Header().Set("Content-Type", "some/type")
194 io.WriteString(w, "<html>")
195 },
196 check(hasHeader("Content-Type", "some/type")),
197 },
198 {
199 "Content-Type detection doesn't crash if HeaderMap is nil",
200 func(w http.ResponseWriter, r *http.Request) {
201
202
203
204 w.(*ResponseRecorder).HeaderMap = nil
205 io.WriteString(w, "<html>")
206 },
207 check(hasHeader("Content-Type", "text/html; charset=utf-8")),
208 },
209 {
210 "Header is not changed after write",
211 func(w http.ResponseWriter, r *http.Request) {
212 hdr := w.Header()
213 hdr.Set("Key", "correct")
214 w.WriteHeader(200)
215 hdr.Set("Key", "incorrect")
216 },
217 check(hasHeader("Key", "correct")),
218 },
219 {
220 "Trailer headers are correctly recorded",
221 func(w http.ResponseWriter, r *http.Request) {
222 w.Header().Set("Non-Trailer", "correct")
223 w.Header().Set("Trailer", "Trailer-A, Trailer-B")
224 w.Header().Add("Trailer", "Trailer-C")
225 io.WriteString(w, "<html>")
226 w.Header().Set("Non-Trailer", "incorrect")
227 w.Header().Set("Trailer-A", "valuea")
228 w.Header().Set("Trailer-C", "valuec")
229 w.Header().Set("Trailer-NotDeclared", "should be omitted")
230 w.Header().Set("Trailer:Trailer-D", "with prefix")
231 },
232 check(
233 hasStatus(200),
234 hasHeader("Content-Type", "text/html; charset=utf-8"),
235 hasHeader("Non-Trailer", "correct"),
236 hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
237 hasTrailer("Trailer-A", "valuea"),
238 hasTrailer("Trailer-C", "valuec"),
239 hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
240 hasTrailer("Trailer-D", "with prefix"),
241 ),
242 },
243 {
244 "Header set without any write",
245 func(w http.ResponseWriter, r *http.Request) {
246 w.Header().Set("X-Foo", "1")
247
248
249
250
251
252 w.(*ResponseRecorder).Code = 0
253 },
254 check(
255 hasOldHeader("X-Foo", "1"),
256 hasStatus(0),
257 hasHeader("X-Foo", "1"),
258 hasResultStatus("200 OK"),
259 hasResultStatusCode(200),
260 ),
261 },
262 {
263 "HeaderMap vs FinalHeaders",
264 func(w http.ResponseWriter, r *http.Request) {
265 h := w.Header()
266 h.Set("X-Foo", "1")
267 w.Write([]byte("hi"))
268 h.Set("X-Foo", "2")
269 h.Set("X-Bar", "2")
270 },
271 check(
272 hasOldHeader("X-Foo", "2"),
273 hasOldHeader("X-Bar", "2"),
274 hasHeader("X-Foo", "1"),
275 hasNotHeaders("X-Bar"),
276 ),
277 },
278 {
279 "setting Content-Length header",
280 func(w http.ResponseWriter, r *http.Request) {
281 body := "Some body"
282 contentLength := fmt.Sprintf("%d", len(body))
283 w.Header().Set("Content-Length", contentLength)
284 io.WriteString(w, body)
285 },
286 check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
287 },
288 {
289 "nil ResponseRecorder.Body",
290 func(w http.ResponseWriter, r *http.Request) {
291 w.(*ResponseRecorder).Body = nil
292 io.WriteString(w, "hi")
293 },
294 check(hasResultContents("")),
295
296 },
297 } {
298 t.Run(tt.name, func(t *testing.T) {
299 r, _ := http.NewRequest("GET", "http://foo.com/", nil)
300 h := http.HandlerFunc(tt.h)
301 rec := NewRecorder()
302 h.ServeHTTP(rec, r)
303 for _, check := range tt.checks {
304 if err := check(rec); err != nil {
305 t.Error(err)
306 }
307 }
308 })
309 }
310 }
311
312
313 func TestParseContentLength(t *testing.T) {
314 tests := []struct {
315 cl string
316 want int64
317 }{
318 {
319 cl: "3",
320 want: 3,
321 },
322 {
323 cl: "+3",
324 want: -1,
325 },
326 {
327 cl: "-3",
328 want: -1,
329 },
330 {
331
332 cl: "9223372036854775807",
333 want: 9223372036854775807,
334 },
335 {
336 cl: "9223372036854775808",
337 want: -1,
338 },
339 }
340
341 for _, tt := range tests {
342 if got := parseContentLength(tt.cl); got != tt.want {
343 t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want)
344 }
345 }
346 }
347
348
349
350 func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
351 badCodes := []int{
352 -100, 0, 99, 1000, 20000,
353 }
354 for _, badCode := range badCodes {
355 badCode := badCode
356 t.Run(fmt.Sprintf("Code=%d", badCode), func(t *testing.T) {
357 defer func() {
358 if r := recover(); r == nil {
359 t.Fatal("Expected a panic")
360 }
361 }()
362
363 handler := func(rw http.ResponseWriter, _ *http.Request) {
364 rw.WriteHeader(badCode)
365 }
366 r, _ := http.NewRequest("GET", "http://example.org/", nil)
367 rw := NewRecorder()
368 handler(rw, r)
369 })
370 }
371 }
372
View as plain text