Source file
src/net/http/transport_internal_test.go
1
2
3
4
5
6
7 package http
8
9 import (
10 "bytes"
11 "context"
12 "crypto/tls"
13 "errors"
14 "io"
15 "net"
16 "net/http/internal/testcert"
17 "strings"
18 "testing"
19 )
20
21
22 func TestTransportPersistConnReadLoopEOF(t *testing.T) {
23 ln := newLocalListener(t)
24 defer ln.Close()
25
26 connc := make(chan net.Conn, 1)
27 go func() {
28 defer close(connc)
29 c, err := ln.Accept()
30 if err != nil {
31 t.Error(err)
32 return
33 }
34 connc <- c
35 }()
36
37 tr := new(Transport)
38 req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
39 req = req.WithT(t)
40 ctx, cancel := context.WithCancelCause(context.Background())
41 treq := &transportRequest{Request: req, ctx: ctx, cancel: cancel}
42 cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
43 pc, err := tr.getConn(treq, cm)
44 if err != nil {
45 t.Fatal(err)
46 }
47 defer pc.close(errors.New("test over"))
48
49 conn := <-connc
50 if conn == nil {
51
52 return
53 }
54 conn.Close()
55
56 _, err = pc.roundTrip(treq)
57 if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle {
58 t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle, transportReadFromServerError, or nothingWrittenError", err, err)
59 }
60
61 <-pc.closech
62 err = pc.closed
63 if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle {
64 t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError, or nothingWrittenError", err, err)
65 }
66 }
67
68 func isNothingWrittenError(err error) bool {
69 _, ok := err.(nothingWrittenError)
70 return ok
71 }
72
73 func isTransportReadFromServerError(err error) bool {
74 _, ok := err.(transportReadFromServerError)
75 return ok
76 }
77
78 func newLocalListener(t *testing.T) net.Listener {
79 ln, err := net.Listen("tcp", "127.0.0.1:0")
80 if err != nil {
81 ln, err = net.Listen("tcp6", "[::1]:0")
82 }
83 if err != nil {
84 t.Fatal(err)
85 }
86 return ln
87 }
88
89 func dummyRequest(method string) *Request {
90 req, err := NewRequest(method, "http://fake.tld/", nil)
91 if err != nil {
92 panic(err)
93 }
94 return req
95 }
96 func dummyRequestWithBody(method string) *Request {
97 req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo"))
98 if err != nil {
99 panic(err)
100 }
101 return req
102 }
103
104 func dummyRequestWithBodyNoGetBody(method string) *Request {
105 req := dummyRequestWithBody(method)
106 req.GetBody = nil
107 return req
108 }
109
110
111 type issue22091Error struct{}
112
113 func (issue22091Error) IsHTTP2NoCachedConnError() {}
114 func (issue22091Error) Error() string { return "issue22091Error" }
115
116 func TestTransportShouldRetryRequest(t *testing.T) {
117 tests := []struct {
118 pc *persistConn
119 req *Request
120
121 err error
122 want bool
123 }{
124 0: {
125 pc: &persistConn{reused: false},
126 req: dummyRequest("POST"),
127 err: nothingWrittenError{},
128 want: false,
129 },
130 1: {
131 pc: &persistConn{reused: true},
132 req: dummyRequest("POST"),
133 err: nothingWrittenError{},
134 want: true,
135 },
136 2: {
137 pc: &persistConn{reused: true},
138 req: dummyRequest("POST"),
139 err: http2ErrNoCachedConn,
140 want: true,
141 },
142 3: {
143 pc: nil,
144 req: nil,
145 err: issue22091Error{},
146 want: true,
147 },
148 4: {
149 pc: &persistConn{reused: true},
150 req: dummyRequest("POST"),
151 err: errMissingHost,
152 want: false,
153 },
154 5: {
155 pc: &persistConn{reused: true},
156 req: dummyRequest("POST"),
157 err: transportReadFromServerError{},
158 want: false,
159 },
160 6: {
161 pc: &persistConn{reused: true},
162 req: dummyRequest("GET"),
163 err: transportReadFromServerError{},
164 want: true,
165 },
166 7: {
167 pc: &persistConn{reused: true},
168 req: dummyRequest("GET"),
169 err: errServerClosedIdle,
170 want: true,
171 },
172 8: {
173 pc: &persistConn{reused: true},
174 req: dummyRequestWithBody("POST"),
175 err: nothingWrittenError{},
176 want: true,
177 },
178 9: {
179 pc: &persistConn{reused: true},
180 req: dummyRequestWithBodyNoGetBody("POST"),
181 err: nothingWrittenError{},
182 want: false,
183 },
184 }
185 for i, tt := range tests {
186 got := tt.pc.shouldRetryRequest(tt.req, tt.err)
187 if got != tt.want {
188 t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want)
189 }
190 }
191 }
192
193 type roundTripFunc func(r *Request) (*Response, error)
194
195 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
196 return f(r)
197 }
198
199
200 func TestTransportBodyAltRewind(t *testing.T) {
201 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
202 if err != nil {
203 t.Fatal(err)
204 }
205 ln := newLocalListener(t)
206 defer ln.Close()
207
208 go func() {
209 tln := tls.NewListener(ln, &tls.Config{
210 NextProtos: []string{"foo"},
211 Certificates: []tls.Certificate{cert},
212 })
213 for i := 0; i < 2; i++ {
214 sc, err := tln.Accept()
215 if err != nil {
216 t.Error(err)
217 return
218 }
219 if err := sc.(*tls.Conn).Handshake(); err != nil {
220 t.Error(err)
221 return
222 }
223 sc.Close()
224 }
225 }()
226
227 addr := ln.Addr().String()
228 req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
229 roundTripped := false
230 tr := &Transport{
231 DisableKeepAlives: true,
232 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
233 "foo": func(authority string, c *tls.Conn) RoundTripper {
234 return roundTripFunc(func(r *Request) (*Response, error) {
235 n, _ := io.Copy(io.Discard, r.Body)
236 if n == 0 {
237 t.Error("body length is zero")
238 }
239 if roundTripped {
240 return &Response{
241 Body: NoBody,
242 StatusCode: 200,
243 }, nil
244 }
245 roundTripped = true
246 return nil, http2noCachedConnError{}
247 })
248 },
249 },
250 DialTLS: func(_, _ string) (net.Conn, error) {
251 tc, err := tls.Dial("tcp", addr, &tls.Config{
252 InsecureSkipVerify: true,
253 NextProtos: []string{"foo"},
254 })
255 if err != nil {
256 return nil, err
257 }
258 if err := tc.Handshake(); err != nil {
259 return nil, err
260 }
261 return tc, nil
262 },
263 }
264 c := &Client{Transport: tr}
265 _, err = c.Do(req)
266 if err != nil {
267 t.Error(err)
268 }
269 }
270
View as plain text