1
2
3
4
5 package http3
6
7 import (
8 "context"
9 "io"
10
11 "golang.org/x/net/quic"
12 )
13
14
15 type stream struct {
16 stream *quic.Stream
17
18
19
20
21
22
23 lim int64
24 }
25
26
27
28
29
30
31
32 func newConnStream(ctx context.Context, qconn *quic.Conn, stype streamType) (*stream, error) {
33 var qs *quic.Stream
34 var err error
35 if stype == streamTypeRequest {
36
37 qs, err = qconn.NewStream(ctx)
38 } else {
39
40 qs, err = qconn.NewSendOnlyStream(ctx)
41 }
42 if err != nil {
43 return nil, err
44 }
45 st := &stream{
46 stream: qs,
47 lim: -1,
48 }
49 if stype != streamTypeRequest {
50
51 st.writeVarint(int64(stype))
52 }
53 return st, err
54 }
55
56 func newStream(qs *quic.Stream) *stream {
57 return &stream{
58 stream: qs,
59 lim: -1,
60 }
61 }
62
63
64
65
66
67 func (st *stream) readFrameHeader() (ftype frameType, err error) {
68 if st.lim >= 0 {
69
70 return 0, errH3FrameError
71 }
72 ftype, err = readVarint[frameType](st)
73 if err != nil {
74 return 0, err
75 }
76 size, err := st.readVarint()
77 if err != nil {
78 return 0, err
79 }
80 st.lim = size
81 return ftype, nil
82 }
83
84
85
86 func (st *stream) endFrame() error {
87 if st.lim != 0 {
88 return &connectionError{
89 code: errH3FrameError,
90 message: "invalid HTTP/3 frame",
91 }
92 }
93 st.lim = -1
94 return nil
95 }
96
97
98 func (st *stream) readFrameData() ([]byte, error) {
99 if st.lim < 0 {
100 return nil, errH3FrameError
101 }
102
103 b := make([]byte, st.lim)
104 _, err := io.ReadFull(st, b)
105 if err != nil {
106 return nil, err
107 }
108 return b, nil
109 }
110
111
112 func (st *stream) ReadByte() (b byte, err error) {
113 if err := st.recordBytesRead(1); err != nil {
114 return 0, err
115 }
116 b, err = st.stream.ReadByte()
117 if err != nil {
118 if err == io.EOF && st.lim < 0 {
119 return 0, io.EOF
120 }
121 return 0, errH3FrameError
122 }
123 return b, nil
124 }
125
126
127 func (st *stream) Read(b []byte) (int, error) {
128 n, err := st.stream.Read(b)
129 if e2 := st.recordBytesRead(n); e2 != nil {
130 return 0, e2
131 }
132 if err == io.EOF {
133 if st.lim == 0 {
134
135 return n, nil
136 } else if st.lim > 0 {
137
138 return 0, errH3FrameError
139 } else {
140
141 return n, io.EOF
142 }
143 }
144 if err != nil {
145 return 0, errH3FrameError
146 }
147 return n, nil
148 }
149
150
151
152
153
154
155 func (st *stream) discardUnknownFrame(ftype frameType) error {
156 switch ftype {
157 case frameTypeData,
158 frameTypeHeaders,
159 frameTypeCancelPush,
160 frameTypeSettings,
161 frameTypePushPromise,
162 frameTypeGoaway,
163 frameTypeMaxPushID:
164 return &connectionError{
165 code: errH3FrameUnexpected,
166 message: "unexpected " + ftype.String() + " frame",
167 }
168 }
169 return st.discardFrame()
170 }
171
172
173 func (st *stream) discardFrame() error {
174
175 for range st.lim {
176 _, err := st.stream.ReadByte()
177 if err != nil {
178 return &streamError{errH3FrameError, err.Error()}
179 }
180 }
181 st.lim = -1
182 return nil
183 }
184
185
186 func (st *stream) Write(b []byte) (int, error) { return st.stream.Write(b) }
187
188
189 func (st *stream) Flush() error { return st.stream.Flush() }
190
191
192 func (st *stream) readVarint() (v int64, err error) {
193 b, err := st.stream.ReadByte()
194 if err != nil {
195 return 0, err
196 }
197 v = int64(b & 0x3f)
198 n := 1 << (b >> 6)
199 for i := 1; i < n; i++ {
200 b, err := st.stream.ReadByte()
201 if err != nil {
202 return 0, errH3FrameError
203 }
204 v = (v << 8) | int64(b)
205 }
206 if err := st.recordBytesRead(n); err != nil {
207 return 0, err
208 }
209 return v, nil
210 }
211
212
213 func readVarint[T ~int64 | ~uint64](st *stream) (T, error) {
214 v, err := st.readVarint()
215 return T(v), err
216 }
217
218
219 func (st *stream) writeVarint(v int64) {
220 switch {
221 case v <= (1<<6)-1:
222 st.stream.WriteByte(byte(v))
223 case v <= (1<<14)-1:
224 st.stream.WriteByte((1 << 6) | byte(v>>8))
225 st.stream.WriteByte(byte(v))
226 case v <= (1<<30)-1:
227 st.stream.WriteByte((2 << 6) | byte(v>>24))
228 st.stream.WriteByte(byte(v >> 16))
229 st.stream.WriteByte(byte(v >> 8))
230 st.stream.WriteByte(byte(v))
231 case v <= (1<<62)-1:
232 st.stream.WriteByte((3 << 6) | byte(v>>56))
233 st.stream.WriteByte(byte(v >> 48))
234 st.stream.WriteByte(byte(v >> 40))
235 st.stream.WriteByte(byte(v >> 32))
236 st.stream.WriteByte(byte(v >> 24))
237 st.stream.WriteByte(byte(v >> 16))
238 st.stream.WriteByte(byte(v >> 8))
239 st.stream.WriteByte(byte(v))
240 default:
241 panic("varint too large")
242 }
243 }
244
245
246
247 func (st *stream) recordBytesRead(n int) error {
248 if st.lim < 0 {
249 return nil
250 }
251 st.lim -= int64(n)
252 if st.lim < 0 {
253 st.stream = nil
254 return &connectionError{
255 code: errH3FrameError,
256 message: "invalid HTTP/3 frame",
257 }
258 }
259 return nil
260 }
261
View as plain text