1
2
3
4
5 package httpcommon
6
7 import (
8 "context"
9 "errors"
10 "fmt"
11 "net/http/httptrace"
12 "net/textproto"
13 "net/url"
14 "sort"
15 "strconv"
16 "strings"
17
18 "golang.org/x/net/http/httpguts"
19 "golang.org/x/net/http2/hpack"
20 )
21
22 var (
23 ErrRequestHeaderListSize = errors.New("request header list larger than peer's advertised limit")
24 )
25
26
27
28
29 type Request struct {
30 URL *url.URL
31 Method string
32 Host string
33 Header map[string][]string
34 Trailer map[string][]string
35 ActualContentLength int64
36 }
37
38
39 type EncodeHeadersParam struct {
40 Request Request
41
42
43
44 AddGzipHeader bool
45
46
47 PeerMaxHeaderListSize uint64
48
49
50
51 DefaultUserAgent string
52 }
53
54
55 type EncodeHeadersResult struct {
56 HasBody bool
57 HasTrailers bool
58 }
59
60
61
62
63
64 func EncodeHeaders(ctx context.Context, param EncodeHeadersParam, headerf func(name, value string)) (res EncodeHeadersResult, _ error) {
65 req := param.Request
66
67
68 if err := checkConnHeaders(req.Header); err != nil {
69 return res, err
70 }
71
72 if req.URL == nil {
73 return res, errors.New("Request.URL is nil")
74 }
75
76 host := req.Host
77 if host == "" {
78 host = req.URL.Host
79 }
80 host, err := httpguts.PunycodeHostPort(host)
81 if err != nil {
82 return res, err
83 }
84 if !httpguts.ValidHostHeader(host) {
85 return res, errors.New("invalid Host header")
86 }
87
88
89 isNormalConnect := false
90 var protocol string
91 if vv := req.Header[":protocol"]; len(vv) > 0 {
92 protocol = vv[0]
93 }
94 if req.Method == "CONNECT" && protocol == "" {
95 isNormalConnect = true
96 } else if protocol != "" && req.Method != "CONNECT" {
97 return res, errors.New("invalid :protocol header in non-CONNECT request")
98 }
99
100
101 var path string
102 if !isNormalConnect {
103 path = req.URL.RequestURI()
104 if !validPseudoPath(path) {
105 orig := path
106 path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
107 if !validPseudoPath(path) {
108 if req.URL.Opaque != "" {
109 return res, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
110 } else {
111 return res, fmt.Errorf("invalid request :path %q", orig)
112 }
113 }
114 }
115 }
116
117
118
119
120 if err := validateHeaders(req.Header); err != "" {
121 return res, fmt.Errorf("invalid HTTP header %s", err)
122 }
123 if err := validateHeaders(req.Trailer); err != "" {
124 return res, fmt.Errorf("invalid HTTP trailer %s", err)
125 }
126
127 trailers, err := commaSeparatedTrailers(req.Trailer)
128 if err != nil {
129 return res, err
130 }
131
132 enumerateHeaders := func(f func(name, value string)) {
133
134
135
136
137
138 f(":authority", host)
139 m := req.Method
140 if m == "" {
141 m = "GET"
142 }
143 f(":method", m)
144 if !isNormalConnect {
145 f(":path", path)
146 f(":scheme", req.URL.Scheme)
147 }
148 if protocol != "" {
149 f(":protocol", protocol)
150 }
151 if trailers != "" {
152 f("trailer", trailers)
153 }
154
155 var didUA bool
156 for k, vv := range req.Header {
157 if asciiEqualFold(k, "host") || asciiEqualFold(k, "content-length") {
158
159
160 continue
161 } else if asciiEqualFold(k, "connection") ||
162 asciiEqualFold(k, "proxy-connection") ||
163 asciiEqualFold(k, "transfer-encoding") ||
164 asciiEqualFold(k, "upgrade") ||
165 asciiEqualFold(k, "keep-alive") {
166
167
168
169
170 continue
171 } else if asciiEqualFold(k, "user-agent") {
172
173
174
175
176 didUA = true
177 if len(vv) < 1 {
178 continue
179 }
180 vv = vv[:1]
181 if vv[0] == "" {
182 continue
183 }
184 } else if asciiEqualFold(k, "cookie") {
185
186
187
188 for _, v := range vv {
189 for {
190 p := strings.IndexByte(v, ';')
191 if p < 0 {
192 break
193 }
194 f("cookie", v[:p])
195 p++
196
197 for p+1 <= len(v) && v[p] == ' ' {
198 p++
199 }
200 v = v[p:]
201 }
202 if len(v) > 0 {
203 f("cookie", v)
204 }
205 }
206 continue
207 } else if k == ":protocol" {
208
209 continue
210 }
211
212 for _, v := range vv {
213 f(k, v)
214 }
215 }
216 if shouldSendReqContentLength(req.Method, req.ActualContentLength) {
217 f("content-length", strconv.FormatInt(req.ActualContentLength, 10))
218 }
219 if param.AddGzipHeader {
220 f("accept-encoding", "gzip")
221 }
222 if !didUA {
223 f("user-agent", param.DefaultUserAgent)
224 }
225 }
226
227
228
229
230
231 if param.PeerMaxHeaderListSize > 0 {
232 hlSize := uint64(0)
233 enumerateHeaders(func(name, value string) {
234 hf := hpack.HeaderField{Name: name, Value: value}
235 hlSize += uint64(hf.Size())
236 })
237
238 if hlSize > param.PeerMaxHeaderListSize {
239 return res, ErrRequestHeaderListSize
240 }
241 }
242
243 trace := httptrace.ContextClientTrace(ctx)
244
245
246 enumerateHeaders(func(name, value string) {
247 name, ascii := LowerHeader(name)
248 if !ascii {
249
250
251 return
252 }
253
254 headerf(name, value)
255
256 if trace != nil && trace.WroteHeaderField != nil {
257 trace.WroteHeaderField(name, []string{value})
258 }
259 })
260
261 res.HasBody = req.ActualContentLength != 0
262 res.HasTrailers = trailers != ""
263 return res, nil
264 }
265
266
267
268 func IsRequestGzip(method string, header map[string][]string, disableCompression bool) bool {
269
270 if !disableCompression &&
271 len(header["Accept-Encoding"]) == 0 &&
272 len(header["Range"]) == 0 &&
273 method != "HEAD" {
274
275
276
277
278
279
280
281
282
283
284
285
286 return true
287 }
288 return false
289 }
290
291
292
293
294
295
296
297
298 func checkConnHeaders(h map[string][]string) error {
299 if vv := h["Upgrade"]; len(vv) > 0 && (vv[0] != "" && vv[0] != "chunked") {
300 return fmt.Errorf("invalid Upgrade request header: %q", vv)
301 }
302 if vv := h["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
303 return fmt.Errorf("invalid Transfer-Encoding request header: %q", vv)
304 }
305 if vv := h["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) {
306 return fmt.Errorf("invalid Connection request header: %q", vv)
307 }
308 return nil
309 }
310
311 func commaSeparatedTrailers(trailer map[string][]string) (string, error) {
312 keys := make([]string, 0, len(trailer))
313 for k := range trailer {
314 k = CanonicalHeader(k)
315 switch k {
316 case "Transfer-Encoding", "Trailer", "Content-Length":
317 return "", fmt.Errorf("invalid Trailer key %q", k)
318 }
319 keys = append(keys, k)
320 }
321 if len(keys) > 0 {
322 sort.Strings(keys)
323 return strings.Join(keys, ","), nil
324 }
325 return "", nil
326 }
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341 func validPseudoPath(v string) bool {
342 return (len(v) > 0 && v[0] == '/') || v == "*"
343 }
344
345 func validateHeaders(hdrs map[string][]string) string {
346 for k, vv := range hdrs {
347 if !httpguts.ValidHeaderFieldName(k) && k != ":protocol" {
348 return fmt.Sprintf("name %q", k)
349 }
350 for _, v := range vv {
351 if !httpguts.ValidHeaderFieldValue(v) {
352
353
354 return fmt.Sprintf("value for header %q", k)
355 }
356 }
357 }
358 return ""
359 }
360
361
362
363
364
365
366 func shouldSendReqContentLength(method string, contentLength int64) bool {
367 if contentLength > 0 {
368 return true
369 }
370 if contentLength < 0 {
371 return false
372 }
373
374
375 switch method {
376 case "POST", "PUT", "PATCH":
377 return true
378 default:
379 return false
380 }
381 }
382
383
384 type ServerRequestParam struct {
385 Method string
386 Scheme, Authority, Path string
387 Protocol string
388 Header map[string][]string
389 }
390
391
392 type ServerRequestResult struct {
393
394 URL *url.URL
395 RequestURI string
396 Trailer map[string][]string
397
398 NeedsContinue bool
399
400
401
402
403
404 InvalidReason string
405 }
406
407 func NewServerRequest(rp ServerRequestParam) ServerRequestResult {
408 needsContinue := httpguts.HeaderValuesContainsToken(rp.Header["Expect"], "100-continue")
409 if needsContinue {
410 delete(rp.Header, "Expect")
411 }
412
413 if cookies := rp.Header["Cookie"]; len(cookies) > 1 {
414 rp.Header["Cookie"] = []string{strings.Join(cookies, "; ")}
415 }
416
417
418 var trailer map[string][]string
419 for _, v := range rp.Header["Trailer"] {
420 for _, key := range strings.Split(v, ",") {
421 key = textproto.CanonicalMIMEHeaderKey(textproto.TrimString(key))
422 switch key {
423 case "Transfer-Encoding", "Trailer", "Content-Length":
424
425
426 default:
427 if trailer == nil {
428 trailer = make(map[string][]string)
429 }
430 trailer[key] = nil
431 }
432 }
433 }
434 delete(rp.Header, "Trailer")
435
436
437
438
439 if strings.IndexByte(rp.Authority, '@') != -1 && (rp.Scheme == "http" || rp.Scheme == "https") {
440 return ServerRequestResult{
441 InvalidReason: "userinfo_in_authority",
442 }
443 }
444
445 var url_ *url.URL
446 var requestURI string
447 if rp.Method == "CONNECT" && rp.Protocol == "" {
448 url_ = &url.URL{Host: rp.Authority}
449 requestURI = rp.Authority
450 } else {
451
452
453 if rp.Path == "" || (rp.Path[0] != '/' && rp.Path != "*") {
454 return ServerRequestResult{
455 InvalidReason: "bad_path",
456 }
457 }
458
459 var err error
460 url_, err = url.ParseRequestURI(rp.Path)
461 if err != nil {
462 return ServerRequestResult{
463 InvalidReason: "bad_path",
464 }
465 }
466 requestURI = rp.Path
467 }
468
469 return ServerRequestResult{
470 URL: url_,
471 NeedsContinue: needsContinue,
472 RequestURI: requestURI,
473 Trailer: trailer,
474 }
475 }
476
View as plain text