// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package httptest import ( "bytes" "fmt" "io" "net/http" "net/textproto" "strconv" "strings" "golang.org/x/net/http/httpguts" ) // ResponseRecorder is an implementation of [http.ResponseWriter] that // records its mutations for later inspection in tests. type ResponseRecorder struct { // Code is the HTTP response code set by WriteHeader. // // Note that if a Handler never calls WriteHeader or Write, // this might end up being 0, rather than the implicit // http.StatusOK. To get the implicit value, use the Result // method. Code int // HeaderMap contains the headers explicitly set by the Handler. // It is an internal detail. // // Deprecated: HeaderMap exists for historical compatibility // and should not be used. To access the headers returned by a handler, // use the Response.Header map as returned by the Result method. HeaderMap http.Header // Body is the buffer to which the Handler's Write calls are sent. // If nil, the Writes are silently discarded. Body *bytes.Buffer // Flushed is whether the Handler called Flush. Flushed bool result *http.Response // cache of Result's return value snapHeader http.Header // snapshot of HeaderMap at first Write wroteHeader bool } // NewRecorder returns an initialized [ResponseRecorder]. func NewRecorder() *ResponseRecorder { return &ResponseRecorder{ HeaderMap: make(http.Header), Body: new(bytes.Buffer), Code: 200, } } // DefaultRemoteAddr is the default remote address to return in RemoteAddr if // an explicit DefaultRemoteAddr isn't set on [ResponseRecorder]. const DefaultRemoteAddr = "1.2.3.4" // Header implements [http.ResponseWriter]. It returns the response // headers to mutate within a handler. To test the headers that were // written after a handler completes, use the [ResponseRecorder.Result] method and see // the returned Response value's Header. func (rw *ResponseRecorder) Header() http.Header { m := rw.HeaderMap if m == nil { m = make(http.Header) rw.HeaderMap = m } return m } // writeHeader writes a header if it was not written yet and // detects Content-Type if needed. // // bytes or str are the beginning of the response body. // We pass both to avoid unnecessarily generate garbage // in rw.WriteString which was created for performance reasons. // Non-nil bytes win. func (rw *ResponseRecorder) writeHeader(b []byte, str string) { if rw.wroteHeader { return } if len(str) > 512 { str = str[:512] } m := rw.Header() _, hasType := m["Content-Type"] hasTE := m.Get("Transfer-Encoding") != "" if !hasType && !hasTE { if b == nil { b = []byte(str) } m.Set("Content-Type", http.DetectContentType(b)) } rw.WriteHeader(200) } // Write implements http.ResponseWriter. The data in buf is written to // rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { rw.writeHeader(buf, "") if rw.Body != nil { rw.Body.Write(buf) } return len(buf), nil } // WriteString implements [io.StringWriter]. The data in str is written // to rw.Body, if not nil. func (rw *ResponseRecorder) WriteString(str string) (int, error) { rw.writeHeader(nil, str) if rw.Body != nil { rw.Body.WriteString(str) } return len(str), nil } func checkWriteHeaderCode(code int) { // Issue 22880: require valid WriteHeader status codes. // For now we only enforce that it's three digits. // In the future we might block things over 599 (600 and above aren't defined // at https://httpwg.org/specs/rfc7231.html#status.codes) // and we might block under 200 (once we have more mature 1xx support). // But for now any three digits. // // We used to send "HTTP/1.1 000 0" on the wire in responses but there's // no equivalent bogus thing we can realistically send in HTTP/2, // so we'll consistently panic instead and help people find their bugs // early. (We can't return an error from WriteHeader even if we wanted to.) if code < 100 || code > 999 { panic(fmt.Sprintf("invalid WriteHeader code %v", code)) } } // WriteHeader implements [http.ResponseWriter]. func (rw *ResponseRecorder) WriteHeader(code int) { if rw.wroteHeader { return } checkWriteHeaderCode(code) rw.Code = code rw.wroteHeader = true if rw.HeaderMap == nil { rw.HeaderMap = make(http.Header) } rw.snapHeader = rw.HeaderMap.Clone() } // Flush implements [http.Flusher]. To test whether Flush was // called, see rw.Flushed. func (rw *ResponseRecorder) Flush() { if !rw.wroteHeader { rw.WriteHeader(200) } rw.Flushed = true } // Result returns the response generated by the handler. // // The returned Response will have at least its StatusCode, // Header, Body, and optionally Trailer populated. // More fields may be populated in the future, so callers should // not DeepEqual the result in tests. // // The Response.Header is a snapshot of the headers at the time of the // first write call, or at the time of this call, if the handler never // did a write. // // The Response.Body is guaranteed to be non-nil and Body.Read call is // guaranteed to not return any error other than [io.EOF]. // // Result must only be called after the handler has finished running. func (rw *ResponseRecorder) Result() *http.Response { if rw.result != nil { return rw.result } if rw.snapHeader == nil { rw.snapHeader = rw.HeaderMap.Clone() } res := &http.Response{ Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, StatusCode: rw.Code, Header: rw.snapHeader, } rw.result = res if res.StatusCode == 0 { res.StatusCode = 200 } res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) if rw.Body != nil { res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes())) } else { res.Body = http.NoBody } res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) if trailers, ok := rw.snapHeader["Trailer"]; ok { res.Trailer = make(http.Header, len(trailers)) for _, k := range trailers { for _, k := range strings.Split(k, ",") { k = http.CanonicalHeaderKey(textproto.TrimString(k)) if !httpguts.ValidTrailerHeader(k) { // Ignore since forbidden by RFC 7230, section 4.1.2. continue } vv, ok := rw.HeaderMap[k] if !ok { continue } vv2 := make([]string, len(vv)) copy(vv2, vv) res.Trailer[k] = vv2 } } } for k, vv := range rw.HeaderMap { if !strings.HasPrefix(k, http.TrailerPrefix) { continue } if res.Trailer == nil { res.Trailer = make(http.Header) } for _, v := range vv { res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v) } } return res } // parseContentLength trims whitespace from s and returns -1 if no value // is set, or the value if it's >= 0. // // This a modified version of same function found in net/http/transfer.go. This // one just ignores an invalid header. func parseContentLength(cl string) int64 { cl = textproto.TrimString(cl) if cl == "" { return -1 } n, err := strconv.ParseUint(cl, 10, 63) if err != nil { return -1 } return int64(n) }