1
2
3
4
5 package httptest
6
7 import (
8 "bytes"
9 "fmt"
10 "io"
11 "net/http"
12 "net/textproto"
13 "strconv"
14 "strings"
15
16 "golang.org/x/net/http/httpguts"
17 )
18
19
20
21 type ResponseRecorder struct {
22
23
24
25
26
27
28 Code int
29
30
31
32
33
34
35
36 HeaderMap http.Header
37
38
39
40 Body *bytes.Buffer
41
42
43 Flushed bool
44
45 result *http.Response
46 snapHeader http.Header
47 wroteHeader bool
48 }
49
50
51 func NewRecorder() *ResponseRecorder {
52 return &ResponseRecorder{
53 HeaderMap: make(http.Header),
54 Body: new(bytes.Buffer),
55 Code: 200,
56 }
57 }
58
59
60
61 const DefaultRemoteAddr = "1.2.3.4"
62
63
64
65
66
67 func (rw *ResponseRecorder) Header() http.Header {
68 m := rw.HeaderMap
69 if m == nil {
70 m = make(http.Header)
71 rw.HeaderMap = m
72 }
73 return m
74 }
75
76
77
78
79
80
81
82
83 func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
84 if rw.wroteHeader {
85 return
86 }
87 if len(str) > 512 {
88 str = str[:512]
89 }
90
91 m := rw.Header()
92
93 _, hasType := m["Content-Type"]
94 hasTE := m.Get("Transfer-Encoding") != ""
95 if !hasType && !hasTE {
96 if b == nil {
97 b = []byte(str)
98 }
99 m.Set("Content-Type", http.DetectContentType(b))
100 }
101
102 rw.WriteHeader(200)
103 }
104
105
106
107 func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
108 code := rw.Code
109 if !bodyAllowedForStatus(code) {
110 return 0, http.ErrBodyNotAllowed
111 }
112 rw.writeHeader(buf, "")
113 if rw.Body != nil {
114 rw.Body.Write(buf)
115 }
116 return len(buf), nil
117 }
118
119
120
121 func (rw *ResponseRecorder) WriteString(str string) (int, error) {
122 code := rw.Code
123 if !bodyAllowedForStatus(code) {
124 return 0, http.ErrBodyNotAllowed
125 }
126 rw.writeHeader(nil, str)
127 if rw.Body != nil {
128 rw.Body.WriteString(str)
129 }
130 return len(str), nil
131 }
132
133
134
135 func bodyAllowedForStatus(status int) bool {
136 switch {
137 case status >= 100 && status <= 199:
138 return false
139 case status == 204:
140 return false
141 case status == 304:
142 return false
143 }
144 return true
145 }
146
147 func checkWriteHeaderCode(code int) {
148
149
150
151
152
153
154
155
156
157
158
159 if code < 100 || code > 999 {
160 panic(fmt.Sprintf("invalid WriteHeader code %v", code))
161 }
162 }
163
164
165 func (rw *ResponseRecorder) WriteHeader(code int) {
166 if rw.wroteHeader {
167 return
168 }
169
170 checkWriteHeaderCode(code)
171 rw.Code = code
172 rw.wroteHeader = true
173 if rw.HeaderMap == nil {
174 rw.HeaderMap = make(http.Header)
175 }
176 rw.snapHeader = rw.HeaderMap.Clone()
177 }
178
179
180
181 func (rw *ResponseRecorder) Flush() {
182 if !rw.wroteHeader {
183 rw.WriteHeader(200)
184 }
185 rw.Flushed = true
186 }
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203 func (rw *ResponseRecorder) Result() *http.Response {
204 if rw.result != nil {
205 return rw.result
206 }
207 if rw.snapHeader == nil {
208 rw.snapHeader = rw.HeaderMap.Clone()
209 }
210 res := &http.Response{
211 Proto: "HTTP/1.1",
212 ProtoMajor: 1,
213 ProtoMinor: 1,
214 StatusCode: rw.Code,
215 Header: rw.snapHeader,
216 }
217 rw.result = res
218 if res.StatusCode == 0 {
219 res.StatusCode = 200
220 }
221 res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
222 if rw.Body != nil {
223 res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
224 } else {
225 res.Body = http.NoBody
226 }
227 res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
228
229 if trailers, ok := rw.snapHeader["Trailer"]; ok {
230 res.Trailer = make(http.Header, len(trailers))
231 for _, k := range trailers {
232 for k := range strings.SplitSeq(k, ",") {
233 k = http.CanonicalHeaderKey(textproto.TrimString(k))
234 if !httpguts.ValidTrailerHeader(k) {
235
236 continue
237 }
238 vv, ok := rw.HeaderMap[k]
239 if !ok {
240 continue
241 }
242 vv2 := make([]string, len(vv))
243 copy(vv2, vv)
244 res.Trailer[k] = vv2
245 }
246 }
247 }
248 for k, vv := range rw.HeaderMap {
249 if !strings.HasPrefix(k, http.TrailerPrefix) {
250 continue
251 }
252 if res.Trailer == nil {
253 res.Trailer = make(http.Header)
254 }
255 for _, v := range vv {
256 res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
257 }
258 }
259 return res
260 }
261
262
263
264
265
266
267 func parseContentLength(cl string) int64 {
268 cl = textproto.TrimString(cl)
269 if cl == "" {
270 return -1
271 }
272 n, err := strconv.ParseUint(cl, 10, 63)
273 if err != nil {
274 return -1
275 }
276 return int64(n)
277 }
278
View as plain text