Source file src/net/http/httptest/recorder.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"net/textproto"
    13  	"strconv"
    14  	"strings"
    15  
    16  	"golang.org/x/net/http/httpguts"
    17  )
    18  
    19  // ResponseRecorder is an implementation of [http.ResponseWriter] that
    20  // records its mutations for later inspection in tests.
    21  type ResponseRecorder struct {
    22  	// Code is the HTTP response code set by WriteHeader.
    23  	//
    24  	// Note that if a Handler never calls WriteHeader or Write,
    25  	// this might end up being 0, rather than the implicit
    26  	// http.StatusOK. To get the implicit value, use the Result
    27  	// method.
    28  	Code int
    29  
    30  	// HeaderMap contains the headers explicitly set by the Handler.
    31  	// It is an internal detail.
    32  	//
    33  	// Deprecated: HeaderMap exists for historical compatibility
    34  	// and should not be used. To access the headers returned by a handler,
    35  	// use the Response.Header map as returned by the Result method.
    36  	HeaderMap http.Header
    37  
    38  	// Body is the buffer to which the Handler's Write calls are sent.
    39  	// If nil, the Writes are silently discarded.
    40  	Body *bytes.Buffer
    41  
    42  	// Flushed is whether the Handler called Flush.
    43  	Flushed bool
    44  
    45  	result      *http.Response // cache of Result's return value
    46  	snapHeader  http.Header    // snapshot of HeaderMap at first Write
    47  	wroteHeader bool
    48  }
    49  
    50  // NewRecorder returns an initialized [ResponseRecorder].
    51  func NewRecorder() *ResponseRecorder {
    52  	return &ResponseRecorder{
    53  		HeaderMap: make(http.Header),
    54  		Body:      new(bytes.Buffer),
    55  		Code:      200,
    56  	}
    57  }
    58  
    59  // DefaultRemoteAddr is the default remote address to return in RemoteAddr if
    60  // an explicit DefaultRemoteAddr isn't set on [ResponseRecorder].
    61  const DefaultRemoteAddr = "1.2.3.4"
    62  
    63  // Header implements [http.ResponseWriter]. It returns the response
    64  // headers to mutate within a handler. To test the headers that were
    65  // written after a handler completes, use the [ResponseRecorder.Result] method and see
    66  // the returned Response value's Header.
    67  func (rw *ResponseRecorder) Header() http.Header {
    68  	m := rw.HeaderMap
    69  	if m == nil {
    70  		m = make(http.Header)
    71  		rw.HeaderMap = m
    72  	}
    73  	return m
    74  }
    75  
    76  // writeHeader writes a header if it was not written yet and
    77  // detects Content-Type if needed.
    78  //
    79  // bytes or str are the beginning of the response body.
    80  // We pass both to avoid unnecessarily generate garbage
    81  // in rw.WriteString which was created for performance reasons.
    82  // Non-nil bytes win.
    83  func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
    84  	if rw.wroteHeader {
    85  		return
    86  	}
    87  	if len(str) > 512 {
    88  		str = str[:512]
    89  	}
    90  
    91  	m := rw.Header()
    92  
    93  	_, hasType := m["Content-Type"]
    94  	hasTE := m.Get("Transfer-Encoding") != ""
    95  	if !hasType && !hasTE {
    96  		if b == nil {
    97  			b = []byte(str)
    98  		}
    99  		m.Set("Content-Type", http.DetectContentType(b))
   100  	}
   101  
   102  	rw.WriteHeader(200)
   103  }
   104  
   105  // Write implements http.ResponseWriter. The data in buf is written to
   106  // rw.Body, if not nil.
   107  func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
   108  	rw.writeHeader(buf, "")
   109  	if rw.Body != nil {
   110  		rw.Body.Write(buf)
   111  	}
   112  	return len(buf), nil
   113  }
   114  
   115  // WriteString implements [io.StringWriter]. The data in str is written
   116  // to rw.Body, if not nil.
   117  func (rw *ResponseRecorder) WriteString(str string) (int, error) {
   118  	rw.writeHeader(nil, str)
   119  	if rw.Body != nil {
   120  		rw.Body.WriteString(str)
   121  	}
   122  	return len(str), nil
   123  }
   124  
   125  func checkWriteHeaderCode(code int) {
   126  	// Issue 22880: require valid WriteHeader status codes.
   127  	// For now we only enforce that it's three digits.
   128  	// In the future we might block things over 599 (600 and above aren't defined
   129  	// at https://httpwg.org/specs/rfc7231.html#status.codes)
   130  	// and we might block under 200 (once we have more mature 1xx support).
   131  	// But for now any three digits.
   132  	//
   133  	// We used to send "HTTP/1.1 000 0" on the wire in responses but there's
   134  	// no equivalent bogus thing we can realistically send in HTTP/2,
   135  	// so we'll consistently panic instead and help people find their bugs
   136  	// early. (We can't return an error from WriteHeader even if we wanted to.)
   137  	if code < 100 || code > 999 {
   138  		panic(fmt.Sprintf("invalid WriteHeader code %v", code))
   139  	}
   140  }
   141  
   142  // WriteHeader implements [http.ResponseWriter].
   143  func (rw *ResponseRecorder) WriteHeader(code int) {
   144  	if rw.wroteHeader {
   145  		return
   146  	}
   147  
   148  	checkWriteHeaderCode(code)
   149  	rw.Code = code
   150  	rw.wroteHeader = true
   151  	if rw.HeaderMap == nil {
   152  		rw.HeaderMap = make(http.Header)
   153  	}
   154  	rw.snapHeader = rw.HeaderMap.Clone()
   155  }
   156  
   157  // Flush implements [http.Flusher]. To test whether Flush was
   158  // called, see rw.Flushed.
   159  func (rw *ResponseRecorder) Flush() {
   160  	if !rw.wroteHeader {
   161  		rw.WriteHeader(200)
   162  	}
   163  	rw.Flushed = true
   164  }
   165  
   166  // Result returns the response generated by the handler.
   167  //
   168  // The returned Response will have at least its StatusCode,
   169  // Header, Body, and optionally Trailer populated.
   170  // More fields may be populated in the future, so callers should
   171  // not DeepEqual the result in tests.
   172  //
   173  // The Response.Header is a snapshot of the headers at the time of the
   174  // first write call, or at the time of this call, if the handler never
   175  // did a write.
   176  //
   177  // The Response.Body is guaranteed to be non-nil and Body.Read call is
   178  // guaranteed to not return any error other than [io.EOF].
   179  //
   180  // Result must only be called after the handler has finished running.
   181  func (rw *ResponseRecorder) Result() *http.Response {
   182  	if rw.result != nil {
   183  		return rw.result
   184  	}
   185  	if rw.snapHeader == nil {
   186  		rw.snapHeader = rw.HeaderMap.Clone()
   187  	}
   188  	res := &http.Response{
   189  		Proto:      "HTTP/1.1",
   190  		ProtoMajor: 1,
   191  		ProtoMinor: 1,
   192  		StatusCode: rw.Code,
   193  		Header:     rw.snapHeader,
   194  	}
   195  	rw.result = res
   196  	if res.StatusCode == 0 {
   197  		res.StatusCode = 200
   198  	}
   199  	res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
   200  	if rw.Body != nil {
   201  		res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
   202  	} else {
   203  		res.Body = http.NoBody
   204  	}
   205  	res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
   206  
   207  	if trailers, ok := rw.snapHeader["Trailer"]; ok {
   208  		res.Trailer = make(http.Header, len(trailers))
   209  		for _, k := range trailers {
   210  			for _, k := range strings.Split(k, ",") {
   211  				k = http.CanonicalHeaderKey(textproto.TrimString(k))
   212  				if !httpguts.ValidTrailerHeader(k) {
   213  					// Ignore since forbidden by RFC 7230, section 4.1.2.
   214  					continue
   215  				}
   216  				vv, ok := rw.HeaderMap[k]
   217  				if !ok {
   218  					continue
   219  				}
   220  				vv2 := make([]string, len(vv))
   221  				copy(vv2, vv)
   222  				res.Trailer[k] = vv2
   223  			}
   224  		}
   225  	}
   226  	for k, vv := range rw.HeaderMap {
   227  		if !strings.HasPrefix(k, http.TrailerPrefix) {
   228  			continue
   229  		}
   230  		if res.Trailer == nil {
   231  			res.Trailer = make(http.Header)
   232  		}
   233  		for _, v := range vv {
   234  			res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
   235  		}
   236  	}
   237  	return res
   238  }
   239  
   240  // parseContentLength trims whitespace from s and returns -1 if no value
   241  // is set, or the value if it's >= 0.
   242  //
   243  // This a modified version of same function found in net/http/transfer.go. This
   244  // one just ignores an invalid header.
   245  func parseContentLength(cl string) int64 {
   246  	cl = textproto.TrimString(cl)
   247  	if cl == "" {
   248  		return -1
   249  	}
   250  	n, err := strconv.ParseUint(cl, 10, 63)
   251  	if err != nil {
   252  		return -1
   253  	}
   254  	return int64(n)
   255  }
   256  

View as plain text