1
2
3
4
5 package http3
6
7 import (
8 "errors"
9 "fmt"
10 "io"
11 "net"
12 "net/http"
13 "net/textproto"
14 "strings"
15 "sync"
16
17 "golang.org/x/net/http/httpguts"
18 )
19
20
21
22
23 func extractTrailerFromHeader(header, trailer http.Header) {
24 for _, names := range header["Trailer"] {
25 names = textproto.TrimString(names)
26 for name := range strings.SplitSeq(names, ",") {
27 name = textproto.CanonicalMIMEHeaderKey(textproto.TrimString(name))
28 if !httpguts.ValidTrailerHeader(name) {
29 continue
30 }
31 trailer[name] = nil
32 }
33 }
34 }
35
36
37
38 type bodyWriter struct {
39 st *stream
40 remain int64
41 flush bool
42 name string
43 trailer http.Header
44 enc *qpackEncoder
45 }
46
47 func (w *bodyWriter) write(ps ...[]byte) (n int, err error) {
48 var size int64
49 for _, p := range ps {
50 size += int64(len(p))
51 }
52
53
54 if size == 0 {
55 return 0, nil
56 }
57 if w.remain >= 0 && size > w.remain {
58 return 0, &streamError{
59 code: errH3InternalError,
60 message: w.name + " body longer than specified content length",
61 }
62 }
63 w.st.writeVarint(int64(frameTypeData))
64 w.st.writeVarint(size)
65 for _, p := range ps {
66 var n2 int
67 n2, err = w.st.Write(p)
68 n += n2
69 if w.remain >= 0 {
70 w.remain -= int64(n)
71 }
72 if err != nil {
73 break
74 }
75 }
76 if w.flush && err == nil {
77 err = w.st.Flush()
78 }
79 if err != nil {
80 err = fmt.Errorf("writing %v body: %w", w.name, err)
81 }
82 return n, err
83 }
84
85 func (w *bodyWriter) Write(p []byte) (n int, err error) {
86 return w.write(p)
87 }
88
89 func (w *bodyWriter) Close() error {
90 if w.remain > 0 {
91 return errors.New(w.name + " body shorter than specified content length")
92 }
93 if len(w.trailer) > 0 {
94 encTrailer := w.enc.encode(func(f func(itype indexType, name, value string)) {
95 for name, values := range w.trailer {
96 if !httpguts.ValidHeaderFieldName(name) {
97 continue
98 }
99 for _, val := range values {
100 if !httpguts.ValidHeaderFieldValue(val) {
101 continue
102 }
103 f(mayIndex, name, val)
104 }
105 }
106 })
107 w.st.writeVarint(int64(frameTypeHeaders))
108 w.st.writeVarint(int64(len(encTrailer)))
109 w.st.Write(encTrailer)
110 }
111 if w.st != nil && w.st.stream != nil {
112 w.st.stream.CloseWrite()
113 }
114 return nil
115 }
116
117
118 type bodyReader struct {
119 st *stream
120
121 mu sync.Mutex
122 remain int64
123 err error
124
125
126
127 send100Continue func()
128
129
130
131
132 trailer http.Header
133 }
134
135 func (r *bodyReader) Read(p []byte) (n int, err error) {
136
137
138
139 r.mu.Lock()
140 defer r.mu.Unlock()
141 if r.send100Continue != nil {
142 r.send100Continue()
143 r.send100Continue = nil
144 }
145 if r.err != nil {
146 return 0, r.err
147 }
148 defer func() {
149 if err != nil {
150 r.err = err
151 }
152 }()
153 if r.st.lim == 0 {
154
155 if err := r.st.endFrame(); err != nil {
156 return 0, err
157 }
158 }
159
160
161 for r.st.lim < 0 {
162 ftype, err := r.st.readFrameHeader()
163 if err == io.EOF && r.remain > 0 {
164 return 0, &streamError{
165 code: errH3MessageError,
166 message: "body shorter than content-length",
167 }
168 }
169 if err != nil {
170 return 0, err
171 }
172 switch ftype {
173 case frameTypeData:
174 if r.remain >= 0 && r.st.lim > r.remain {
175 return 0, &streamError{
176 code: errH3MessageError,
177 message: "body longer than content-length",
178 }
179 }
180
181 case frameTypeHeaders:
182
183 if r.remain > 0 {
184 return 0, &streamError{
185 code: errH3MessageError,
186 message: "body shorter than content-length",
187 }
188 }
189 var dec qpackDecoder
190 if err := dec.decode(r.st, func(_ indexType, name, value string) error {
191 name = textproto.CanonicalMIMEHeaderKey(textproto.TrimString(name))
192 if _, ok := r.trailer[name]; ok {
193 r.trailer.Add(name, value)
194 }
195 return nil
196 }); err != nil {
197 return 0, err
198 }
199 if err := r.st.discardFrame(); err != nil {
200 return 0, err
201 }
202 return 0, io.EOF
203 default:
204 if err := r.st.discardUnknownFrame(ftype); err != nil {
205 return 0, err
206 }
207 }
208 }
209
210
211
212 if int64(len(p)) > r.st.lim {
213 p = p[:r.st.lim]
214 }
215 n, err = r.st.Read(p)
216 if r.remain > 0 {
217 r.remain -= int64(n)
218 }
219 return n, err
220 }
221
222 func (r *bodyReader) Close() error {
223
224
225 r.st.stream.CloseRead()
226
227
228 r.err = net.ErrClosed
229 r.remain = 0
230 return nil
231 }
232
View as plain text