1
2
3
4
5 package httptest
6
7 import (
8 "bytes"
9 "fmt"
10 "io"
11 "net/http"
12 "net/textproto"
13 "strconv"
14 "strings"
15
16 "golang.org/x/net/http/httpguts"
17 )
18
19
20
21 type ResponseRecorder struct {
22
23
24
25
26
27
28 Code int
29
30
31
32
33
34
35
36 HeaderMap http.Header
37
38
39
40 Body *bytes.Buffer
41
42
43 Flushed bool
44
45 result *http.Response
46 snapHeader http.Header
47 wroteHeader bool
48 }
49
50
51 func NewRecorder() *ResponseRecorder {
52 return &ResponseRecorder{
53 HeaderMap: make(http.Header),
54 Body: new(bytes.Buffer),
55 Code: 200,
56 }
57 }
58
59
60
61 const DefaultRemoteAddr = "1.2.3.4"
62
63
64
65
66
67 func (rw *ResponseRecorder) Header() http.Header {
68 m := rw.HeaderMap
69 if m == nil {
70 m = make(http.Header)
71 rw.HeaderMap = m
72 }
73 return m
74 }
75
76
77
78
79
80
81
82
83 func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
84 if rw.wroteHeader {
85 return
86 }
87 if len(str) > 512 {
88 str = str[:512]
89 }
90
91 m := rw.Header()
92
93 _, hasType := m["Content-Type"]
94 hasTE := m.Get("Transfer-Encoding") != ""
95 if !hasType && !hasTE {
96 if b == nil {
97 b = []byte(str)
98 }
99 m.Set("Content-Type", http.DetectContentType(b))
100 }
101
102 rw.WriteHeader(200)
103 }
104
105
106
107 func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
108 rw.writeHeader(buf, "")
109 if rw.Body != nil {
110 rw.Body.Write(buf)
111 }
112 return len(buf), nil
113 }
114
115
116
117 func (rw *ResponseRecorder) WriteString(str string) (int, error) {
118 rw.writeHeader(nil, str)
119 if rw.Body != nil {
120 rw.Body.WriteString(str)
121 }
122 return len(str), nil
123 }
124
125 func checkWriteHeaderCode(code int) {
126
127
128
129
130
131
132
133
134
135
136
137 if code < 100 || code > 999 {
138 panic(fmt.Sprintf("invalid WriteHeader code %v", code))
139 }
140 }
141
142
143 func (rw *ResponseRecorder) WriteHeader(code int) {
144 if rw.wroteHeader {
145 return
146 }
147
148 checkWriteHeaderCode(code)
149 rw.Code = code
150 rw.wroteHeader = true
151 if rw.HeaderMap == nil {
152 rw.HeaderMap = make(http.Header)
153 }
154 rw.snapHeader = rw.HeaderMap.Clone()
155 }
156
157
158
159 func (rw *ResponseRecorder) Flush() {
160 if !rw.wroteHeader {
161 rw.WriteHeader(200)
162 }
163 rw.Flushed = true
164 }
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181 func (rw *ResponseRecorder) Result() *http.Response {
182 if rw.result != nil {
183 return rw.result
184 }
185 if rw.snapHeader == nil {
186 rw.snapHeader = rw.HeaderMap.Clone()
187 }
188 res := &http.Response{
189 Proto: "HTTP/1.1",
190 ProtoMajor: 1,
191 ProtoMinor: 1,
192 StatusCode: rw.Code,
193 Header: rw.snapHeader,
194 }
195 rw.result = res
196 if res.StatusCode == 0 {
197 res.StatusCode = 200
198 }
199 res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
200 if rw.Body != nil {
201 res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
202 } else {
203 res.Body = http.NoBody
204 }
205 res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
206
207 if trailers, ok := rw.snapHeader["Trailer"]; ok {
208 res.Trailer = make(http.Header, len(trailers))
209 for _, k := range trailers {
210 for _, k := range strings.Split(k, ",") {
211 k = http.CanonicalHeaderKey(textproto.TrimString(k))
212 if !httpguts.ValidTrailerHeader(k) {
213
214 continue
215 }
216 vv, ok := rw.HeaderMap[k]
217 if !ok {
218 continue
219 }
220 vv2 := make([]string, len(vv))
221 copy(vv2, vv)
222 res.Trailer[k] = vv2
223 }
224 }
225 }
226 for k, vv := range rw.HeaderMap {
227 if !strings.HasPrefix(k, http.TrailerPrefix) {
228 continue
229 }
230 if res.Trailer == nil {
231 res.Trailer = make(http.Header)
232 }
233 for _, v := range vv {
234 res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
235 }
236 }
237 return res
238 }
239
240
241
242
243
244
245 func parseContentLength(cl string) int64 {
246 cl = textproto.TrimString(cl)
247 if cl == "" {
248 return -1
249 }
250 n, err := strconv.ParseUint(cl, 10, 63)
251 if err != nil {
252 return -1
253 }
254 return int64(n)
255 }
256
View as plain text