Source file
src/net/http/transport_dial_test.go
1
2
3
4
5 package http_test
6
7 import (
8 "context"
9 "io"
10 "net"
11 "net/http"
12 "net/http/httptrace"
13 "testing"
14 )
15
16 func TestTransportPoolConnReusePriorConnection(t *testing.T) {
17 dt := newTransportDialTester(t, http1Mode)
18
19
20 rt1 := dt.roundTrip()
21 c1 := dt.wantDial()
22 c1.finish(nil)
23 rt1.wantDone(c1)
24 rt1.finish()
25
26
27 rt2 := dt.roundTrip()
28 rt2.wantDone(c1)
29 rt2.finish()
30 }
31
32 func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
33 dt := newTransportDialTester(t, http1Mode)
34
35
36 rt1 := dt.roundTrip()
37 c1 := dt.wantDial()
38 c1.finish(nil)
39 rt1.wantDone(c1)
40
41
42
43 rt2 := dt.roundTrip()
44 c2 := dt.wantDial()
45 c2.finish(nil)
46 rt2.wantDone(c2)
47 }
48
49 func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
50 dt := newTransportDialTester(t, http1Mode)
51
52
53 rt1 := dt.roundTrip()
54 c1 := dt.wantDial()
55 c1.finish(nil)
56 rt1.wantDone(c1)
57
58
59
60
61 rt2 := dt.roundTrip()
62 c2 := dt.wantDial()
63 rt1.finish()
64 rt2.wantDone(c1)
65
66
67
68
69
70 rt3 := dt.roundTrip()
71 c3 := dt.wantDial()
72 c2.finish(nil)
73 rt3.wantDone(c2)
74
75 c3.finish(nil)
76 }
77
78
79 type transportDialTester struct {
80 t *testing.T
81 cst *clientServerTest
82
83 dials chan *transportDialTesterConn
84
85 roundTripCount int
86 dialCount int
87 }
88
89
90 type transportDialTesterRoundTrip struct {
91 t *testing.T
92
93 roundTripID int
94 cancel context.CancelFunc
95 reqBody io.WriteCloser
96 finished bool
97
98 done chan struct{}
99 res *http.Response
100 err error
101 conn *transportDialTesterConn
102 }
103
104
105
106 type transportDialTesterConn struct {
107 t *testing.T
108
109 connID int
110 ready chan error
111
112 net.Conn
113 }
114
115 func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
116 t.Helper()
117 dt := &transportDialTester{
118 t: t,
119 dials: make(chan *transportDialTesterConn),
120 }
121 dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
122
123 http.NewResponseController(w).EnableFullDuplex()
124 w.WriteHeader(200)
125 http.NewResponseController(w).Flush()
126
127
128 io.ReadAll(r.Body)
129 }), func(tr *http.Transport) {
130 tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
131 c := &transportDialTesterConn{
132 t: t,
133 ready: make(chan error),
134 }
135
136
137 dt.dials <- c
138 if err := <-c.ready; err != nil {
139 return nil, err
140 }
141 nc, err := net.Dial(network, address)
142 if err != nil {
143 return nil, err
144 }
145
146
147 c.Conn = nc
148 return c, err
149 }
150 })
151 return dt
152 }
153
154
155
156 func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
157 dt.t.Helper()
158 ctx, cancel := context.WithCancel(context.Background())
159 pr, pw := io.Pipe()
160 rt := &transportDialTesterRoundTrip{
161 t: dt.t,
162 roundTripID: dt.roundTripCount,
163 done: make(chan struct{}),
164 reqBody: pw,
165 cancel: cancel,
166 }
167 dt.roundTripCount++
168 dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
169 dt.t.Cleanup(func() {
170 rt.cancel()
171 rt.finish()
172 })
173 go func() {
174 ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
175 GotConn: func(info httptrace.GotConnInfo) {
176 rt.conn = info.Conn.(*transportDialTesterConn)
177 },
178 })
179 req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr)
180 req.Header.Set("Content-Type", "text/plain")
181 rt.res, rt.err = dt.cst.tr.RoundTrip(req)
182 dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
183 close(rt.done)
184 }()
185 return rt
186 }
187
188
189 func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
190 rt.t.Helper()
191 <-rt.done
192 if rt.err != nil {
193 rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
194 }
195 if rt.conn != c {
196 rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
197 }
198 }
199
200
201
202 func (rt *transportDialTesterRoundTrip) finish() {
203 rt.t.Helper()
204
205 if rt.finished {
206 return
207 }
208 rt.finished = true
209
210 <-rt.done
211
212 if rt.err != nil {
213 return
214 }
215 rt.reqBody.Close()
216 io.ReadAll(rt.res.Body)
217 rt.res.Body.Close()
218 rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID)
219 }
220
221
222 func (dt *transportDialTester) wantDial() *transportDialTesterConn {
223 c := <-dt.dials
224 c.connID = dt.dialCount
225 dt.dialCount++
226 dt.t.Logf("Dial %v: started", c.connID)
227 return c
228 }
229
230
231 func (c *transportDialTesterConn) finish(err error) {
232 c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
233 c.ready <- err
234 close(c.ready)
235 }
236
View as plain text