1
2
3
4
5
6
7 package http2
8
9 import (
10 "context"
11 "errors"
12 "net"
13 "slices"
14 "sync"
15 )
16
17
18 type ClientConnPool interface {
19
20
21
22
23
24
25 GetClientConn(req *ClientRequest, addr string) (*ClientConn, error)
26 MarkDead(*ClientConn)
27 }
28
29
30
31 type clientConnPoolIdleCloser interface {
32 ClientConnPool
33 closeIdleConnections()
34 }
35
36 var (
37 _ clientConnPoolIdleCloser = (*clientConnPool)(nil)
38 _ clientConnPoolIdleCloser = noDialClientConnPool{}
39 )
40
41
42 type clientConnPool struct {
43 t *Transport
44
45 mu sync.Mutex
46
47
48 conns map[string][]*ClientConn
49 dialing map[string]*dialCall
50 keys map[*ClientConn][]string
51 addConnCalls map[string]*addConnCall
52 }
53
54 func (p *clientConnPool) GetClientConn(req *ClientRequest, addr string) (*ClientConn, error) {
55 return p.getClientConn(req, addr, dialOnMiss)
56 }
57
58 const (
59 dialOnMiss = true
60 noDialOnMiss = false
61 )
62
63 func (p *clientConnPool) getClientConn(req *ClientRequest, addr string, dialOnMiss bool) (*ClientConn, error) {
64
65 if isConnectionCloseRequest(req) && dialOnMiss {
66
67 traceGetConn(req, addr)
68 const singleUse = true
69 cc, err := p.t.dialClientConn(req.Context, addr, singleUse)
70 if err != nil {
71 return nil, err
72 }
73 return cc, nil
74 }
75 for {
76 p.mu.Lock()
77 for _, cc := range p.conns[addr] {
78 if cc.ReserveNewRequest() {
79
80
81
82 if !cc.getConnCalled {
83 traceGetConn(req, addr)
84 }
85 cc.getConnCalled = false
86 p.mu.Unlock()
87 return cc, nil
88 }
89 }
90 if !dialOnMiss {
91 p.mu.Unlock()
92 return nil, ErrNoCachedConn
93 }
94 traceGetConn(req, addr)
95 call := p.getStartDialLocked(req.Context, addr)
96 p.mu.Unlock()
97 <-call.done
98 if shouldRetryDial(call, req) {
99 continue
100 }
101 cc, err := call.res, call.err
102 if err != nil {
103 return nil, err
104 }
105 if cc.ReserveNewRequest() {
106 return cc, nil
107 }
108 }
109 }
110
111
112 type dialCall struct {
113 _ incomparable
114 p *clientConnPool
115
116
117 ctx context.Context
118 done chan struct{}
119 res *ClientConn
120 err error
121 }
122
123
124 func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
125 if call, ok := p.dialing[addr]; ok {
126
127 return call
128 }
129 call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
130 if p.dialing == nil {
131 p.dialing = make(map[string]*dialCall)
132 }
133 p.dialing[addr] = call
134 go call.dial(call.ctx, addr)
135 return call
136 }
137
138
139 func (c *dialCall) dial(ctx context.Context, addr string) {
140 const singleUse = false
141 c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
142
143 c.p.mu.Lock()
144 delete(c.p.dialing, addr)
145 if c.err == nil {
146 c.p.addConnLocked(addr, c.res)
147 }
148 c.p.mu.Unlock()
149
150 close(c.done)
151 }
152
153
154
155
156
157
158
159
160
161 func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) {
162 p.mu.Lock()
163 for _, cc := range p.conns[key] {
164 if cc.CanTakeNewRequest() {
165 p.mu.Unlock()
166 return false, nil
167 }
168 }
169 call, dup := p.addConnCalls[key]
170 if !dup {
171 if p.addConnCalls == nil {
172 p.addConnCalls = make(map[string]*addConnCall)
173 }
174 call = &addConnCall{
175 p: p,
176 done: make(chan struct{}),
177 }
178 p.addConnCalls[key] = call
179 go call.run(t, key, c)
180 }
181 p.mu.Unlock()
182
183 <-call.done
184 if call.err != nil {
185 return false, call.err
186 }
187 return !dup, nil
188 }
189
190 type addConnCall struct {
191 _ incomparable
192 p *clientConnPool
193 done chan struct{}
194 err error
195 }
196
197 func (c *addConnCall) run(t *Transport, key string, nc net.Conn) {
198 cc, err := t.newClientConn(nc, t.disableKeepAlives(), nil)
199
200 p := c.p
201 p.mu.Lock()
202 if err != nil {
203 c.err = err
204 } else {
205 cc.getConnCalled = true
206 p.addConnLocked(key, cc)
207 }
208 delete(p.addConnCalls, key)
209 p.mu.Unlock()
210 close(c.done)
211 }
212
213
214 func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) {
215 if slices.Contains(p.conns[key], cc) {
216 return
217 }
218 if p.conns == nil {
219 p.conns = make(map[string][]*ClientConn)
220 }
221 if p.keys == nil {
222 p.keys = make(map[*ClientConn][]string)
223 }
224 p.conns[key] = append(p.conns[key], cc)
225 p.keys[cc] = append(p.keys[cc], key)
226 }
227
228 func (p *clientConnPool) MarkDead(cc *ClientConn) {
229 p.mu.Lock()
230 defer p.mu.Unlock()
231 for _, key := range p.keys[cc] {
232 vv, ok := p.conns[key]
233 if !ok {
234 continue
235 }
236 newList := filterOutClientConn(vv, cc)
237 if len(newList) > 0 {
238 p.conns[key] = newList
239 } else {
240 delete(p.conns, key)
241 }
242 }
243 delete(p.keys, cc)
244 }
245
246 func (p *clientConnPool) closeIdleConnections() {
247 p.mu.Lock()
248 defer p.mu.Unlock()
249
250
251
252
253
254
255 for _, vv := range p.conns {
256 for _, cc := range vv {
257 cc.closeIfIdle()
258 }
259 }
260 }
261
262 func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn {
263 out := in[:0]
264 for _, v := range in {
265 if v != exclude {
266 out = append(out, v)
267 }
268 }
269
270
271 if len(in) != len(out) {
272 in[len(in)-1] = nil
273 }
274 return out
275 }
276
277
278
279
280 type noDialClientConnPool struct{ *clientConnPool }
281
282 func (p noDialClientConnPool) GetClientConn(req *ClientRequest, addr string) (*ClientConn, error) {
283 return p.getClientConn(req, addr, noDialOnMiss)
284 }
285
286
287
288
289
290 func shouldRetryDial(call *dialCall, req *ClientRequest) bool {
291 if call.err == nil {
292
293 return false
294 }
295 if call.ctx == req.Context {
296
297
298
299 return false
300 }
301 if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
302
303
304 return false
305 }
306
307
308 return call.ctx.Err() != nil
309 }
310
View as plain text