1
2
3
4
5 package http3
6
7 import (
8 "errors"
9 "io"
10 "net/http"
11 "net/http/httptrace"
12 "net/textproto"
13 "strconv"
14 "sync"
15
16 "golang.org/x/net/http/httpguts"
17 "golang.org/x/net/internal/httpcommon"
18 )
19
20 type roundTripState struct {
21 cc *clientConn
22 st *stream
23
24
25 onceCloseReqBody sync.Once
26 reqBody io.ReadCloser
27
28 reqBodyWriter bodyWriter
29
30
31 respBody io.ReadCloser
32
33 trace *httptrace.ClientTrace
34
35 errOnce sync.Once
36 err error
37 }
38
39
40
41 func (rt *roundTripState) abort(err error) error {
42 rt.errOnce.Do(func() {
43 rt.err = err
44 switch e := err.(type) {
45 case *connectionError:
46 rt.cc.abort(e)
47 case *streamError:
48 rt.st.stream.CloseRead()
49 rt.st.stream.Reset(uint64(e.code))
50 default:
51 rt.st.stream.CloseRead()
52 rt.st.stream.Reset(uint64(errH3NoError))
53 }
54 })
55 return rt.err
56 }
57
58
59 func (rt *roundTripState) closeReqBody() {
60 if rt.reqBody != nil {
61 rt.onceCloseReqBody.Do(func() {
62 rt.reqBody.Close()
63 })
64 }
65 }
66
67
68 func (rt *roundTripState) maybeCallGot1xxResponse(status int, h http.Header) error {
69 if rt.trace == nil || rt.trace.Got1xxResponse == nil {
70 return nil
71 }
72 return rt.trace.Got1xxResponse(status, textproto.MIMEHeader(h))
73 }
74
75 func (rt *roundTripState) maybeCallGot100Continue() {
76 if rt.trace == nil || rt.trace.Got100Continue == nil {
77 return
78 }
79 rt.trace.Got100Continue()
80 }
81
82 func (rt *roundTripState) maybeCallWait100Continue() {
83 if rt.trace == nil || rt.trace.Wait100Continue == nil {
84 return
85 }
86 rt.trace.Wait100Continue()
87 }
88
89
90 func (cc *clientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) {
91
92 st, err := newConnStream(req.Context(), cc.qconn, streamTypeRequest)
93 if err != nil {
94 return nil, err
95 }
96 rt := &roundTripState{
97 cc: cc,
98 st: st,
99 trace: httptrace.ContextClientTrace(req.Context()),
100 reqBody: req.Body,
101 }
102 if rt.reqBody == nil {
103 rt.reqBody = http.NoBody
104 }
105 defer func() {
106 if err != nil {
107 err = rt.abort(err)
108 }
109 }()
110
111
112 st.stream.SetReadContext(req.Context())
113 st.stream.SetWriteContext(req.Context())
114
115 headers := cc.enc.encode(func(yield func(itype indexType, name, value string)) {
116 _, err = httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{
117 Request: httpcommon.Request{
118 URL: req.URL,
119 Method: req.Method,
120 Host: req.Host,
121 Header: req.Header,
122 Trailer: req.Trailer,
123 ActualContentLength: actualContentLength(req),
124 },
125 AddGzipHeader: false,
126 PeerMaxHeaderListSize: 0,
127 DefaultUserAgent: "Go-http-client/3",
128 }, func(name, value string) {
129
130 yield(mayIndex, name, value)
131 })
132 })
133 if err != nil {
134 return nil, err
135 }
136
137
138 st.writeVarint(int64(frameTypeHeaders))
139 st.writeVarint(int64(len(headers)))
140 st.Write(headers)
141 if err := st.Flush(); err != nil {
142 return nil, err
143 }
144
145 var bodyAndTrailerWritten bool
146 is100ContinueReq := httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue")
147 if is100ContinueReq {
148 rt.maybeCallWait100Continue()
149 } else {
150 bodyAndTrailerWritten = true
151 go cc.writeBodyAndTrailer(rt, req)
152 }
153
154
155 for {
156 ftype, err := st.readFrameHeader()
157 if err != nil {
158 return nil, err
159 }
160 switch ftype {
161 case frameTypeHeaders:
162 statusCode, h, err := cc.handleHeaders(st)
163 if err != nil {
164 return nil, err
165 }
166
167
168 if isInfoStatus(statusCode) {
169 if err := rt.maybeCallGot1xxResponse(statusCode, h); err != nil {
170 return nil, err
171 }
172 switch statusCode {
173 case 100:
174 rt.maybeCallGot100Continue()
175 if is100ContinueReq && !bodyAndTrailerWritten {
176 bodyAndTrailerWritten = true
177 go cc.writeBodyAndTrailer(rt, req)
178 continue
179 }
180
181
182
183 default:
184 continue
185 }
186 }
187
188
189
190 contentLength, err := parseResponseContentLength(req.Method, statusCode, h)
191 if err != nil {
192 return nil, err
193 }
194
195 trailer := make(http.Header)
196 extractTrailerFromHeader(h, trailer)
197 delete(h, "Trailer")
198
199 if (contentLength != 0 && req.Method != http.MethodHead) || len(trailer) > 0 {
200 rt.respBody = &bodyReader{
201 st: st,
202 remain: contentLength,
203 trailer: trailer,
204 }
205 } else {
206 rt.respBody = http.NoBody
207 }
208 resp := &http.Response{
209 Proto: "HTTP/3.0",
210 ProtoMajor: 3,
211 Header: h,
212 StatusCode: statusCode,
213 Status: strconv.Itoa(statusCode) + " " + http.StatusText(statusCode),
214 ContentLength: contentLength,
215 Trailer: trailer,
216 Body: (*transportResponseBody)(rt),
217 }
218
219 return resp, nil
220 case frameTypePushPromise:
221 if err := cc.handlePushPromise(st); err != nil {
222 return nil, err
223 }
224 default:
225 if err := st.discardUnknownFrame(ftype); err != nil {
226 return nil, err
227 }
228 }
229 }
230 }
231
232
233
234 func actualContentLength(req *http.Request) int64 {
235 if req.Body == nil || req.Body == http.NoBody {
236 return 0
237 }
238 if req.ContentLength != 0 {
239 return req.ContentLength
240 }
241 return -1
242 }
243
244
245
246 func (cc *clientConn) writeBodyAndTrailer(rt *roundTripState, req *http.Request) {
247 defer rt.closeReqBody()
248
249 declaredTrailer := req.Trailer.Clone()
250
251 rt.reqBodyWriter.st = rt.st
252 rt.reqBodyWriter.remain = actualContentLength(req)
253 rt.reqBodyWriter.flush = true
254 rt.reqBodyWriter.name = "request"
255 rt.reqBodyWriter.trailer = req.Trailer
256 rt.reqBodyWriter.enc = &cc.enc
257
258 if _, err := io.Copy(&rt.reqBodyWriter, rt.reqBody); err != nil {
259 rt.abort(err)
260 }
261
262
263
264 for name := range req.Trailer {
265 if _, ok := declaredTrailer[name]; !ok {
266 delete(req.Trailer, name)
267 }
268 }
269 if err := rt.reqBodyWriter.Close(); err != nil {
270 rt.abort(err)
271 }
272 }
273
274
275 type transportResponseBody roundTripState
276
277
278 func (b *transportResponseBody) Read(p []byte) (n int, err error) {
279 return b.respBody.Read(p)
280 }
281
282 var errRespBodyClosed = errors.New("response body closed")
283
284
285
286 func (b *transportResponseBody) Close() error {
287 rt := (*roundTripState)(b)
288
289
290 rt.closeReqBody()
291
292
293 rt.st.stream.Reset(uint64(errH3NoError))
294
295 err := rt.respBody.Close()
296 if err == nil {
297 err = errRespBodyClosed
298 }
299 err = rt.abort(err)
300 if err == errRespBodyClosed {
301
302
303 return nil
304 }
305 return err
306 }
307
308 func parseResponseContentLength(method string, statusCode int, h http.Header) (int64, error) {
309 clens := h["Content-Length"]
310 if len(clens) == 0 {
311 return -1, nil
312 }
313
314
315
316 for _, v := range clens[1:] {
317 if clens[0] != v {
318 return -1, &streamError{errH3MessageError, "mismatching Content-Length headers"}
319 }
320 }
321
322
323
324
325
326
327 if (statusCode >= 100 && statusCode < 200) ||
328 statusCode == 204 ||
329 (method == "CONNECT" && statusCode >= 200 && statusCode < 300) {
330
331
332 return -1, nil
333 }
334
335 contentLen, err := strconv.ParseUint(clens[0], 10, 63)
336 if err != nil {
337 return -1, &streamError{errH3MessageError, "invalid Content-Length header"}
338 }
339 return int64(contentLen), nil
340 }
341
342 func (cc *clientConn) handleHeaders(st *stream) (statusCode int, h http.Header, err error) {
343 haveStatus := false
344 cookie := ""
345
346
347 err = cc.dec.decode(st, func(_ indexType, name, value string) error {
348 if !httpguts.ValidHeaderFieldValue(value) {
349 return &streamError{errH3MessageError, "invalid field value"}
350 }
351 switch {
352 case name == ":status":
353 if haveStatus {
354 return &streamError{errH3MessageError, "duplicate :status"}
355 }
356 haveStatus = true
357 statusCode, err = strconv.Atoi(value)
358 if err != nil {
359 return &streamError{errH3MessageError, "invalid :status"}
360 }
361 case name[0] == ':':
362
363
364
365
366 return &streamError{errH3MessageError, "undefined pseudo-header"}
367 case name == "cookie":
368
369
370
371
372 if cookie == "" {
373 cookie = value
374 } else {
375 cookie += "; " + value
376 }
377 default:
378 if !validWireHeaderFieldName(name) {
379 return &streamError{errH3MessageError, "invalid field name"}
380 }
381 if h == nil {
382 h = make(http.Header)
383 }
384
385
386
387 cname := httpcommon.CanonicalHeader(name)
388
389
390
391
392
393
394 h[cname] = append(h[cname], value)
395 }
396 return nil
397 })
398 if !haveStatus {
399
400
401 err = errH3MessageError
402 }
403 if cookie != "" {
404 if h == nil {
405 h = make(http.Header)
406 }
407 h["Cookie"] = []string{cookie}
408 }
409 if err := st.endFrame(); err != nil {
410 return 0, nil, err
411 }
412 return statusCode, h, err
413 }
414
415 func (cc *clientConn) handlePushPromise(st *stream) error {
416
417
418
419 return &connectionError{
420 code: errH3IDError,
421 message: "PUSH_PROMISE received when no MAX_PUSH_ID has been sent",
422 }
423 }
424
View as plain text