1
2
3
4
5 package http3
6
7 import (
8 "context"
9 "io"
10 "sync"
11
12 "golang.org/x/net/quic"
13 )
14
15 type streamHandler interface {
16 handleControlStream(*stream) error
17 handlePushStream(*stream) error
18 handleEncoderStream(*stream) error
19 handleDecoderStream(*stream) error
20 handleRequestStream(*stream) error
21 abort(error)
22 }
23
24 type genericConn struct {
25 mu sync.Mutex
26
27
28
29
30 streamsCreated uint8
31 }
32
33 func (c *genericConn) acceptStreams(qconn *quic.Conn, h streamHandler) {
34 for {
35
36
37 st, err := qconn.AcceptStream(context.Background())
38 if err != nil {
39 return
40 }
41 if st.IsReadOnly() {
42 go c.handleUnidirectionalStream(newStream(st), h)
43 } else {
44 go c.handleRequestStream(newStream(st), h)
45 }
46 }
47 }
48
49 func (c *genericConn) handleUnidirectionalStream(st *stream, h streamHandler) {
50
51 v, err := st.readVarint()
52 if err != nil {
53 h.abort(&connectionError{
54 code: errH3StreamCreationError,
55 message: "error reading unidirectional stream header",
56 })
57 return
58 }
59 stype := streamType(v)
60 if err := c.checkStreamCreation(stype); err != nil {
61 h.abort(err)
62 return
63 }
64 switch stype {
65 case streamTypeControl:
66 err = h.handleControlStream(st)
67 case streamTypePush:
68 err = h.handlePushStream(st)
69 case streamTypeEncoder:
70 err = h.handleEncoderStream(st)
71 case streamTypeDecoder:
72 err = h.handleDecoderStream(st)
73 default:
74
75
76
77
78
79
80
81
82 err = nil
83 }
84 if err == io.EOF {
85 err = &connectionError{
86 code: errH3ClosedCriticalStream,
87 message: streamType(stype).String() + " stream closed",
88 }
89 }
90 c.handleStreamError(st, h, err)
91 }
92
93 func (c *genericConn) handleRequestStream(st *stream, h streamHandler) {
94 c.handleStreamError(st, h, h.handleRequestStream(st))
95 }
96
97 func (c *genericConn) handleStreamError(st *stream, h streamHandler, err error) {
98 switch err := err.(type) {
99 case *connectionError:
100 h.abort(err)
101 case nil:
102 st.stream.CloseRead()
103 st.stream.CloseWrite()
104 case *streamError:
105 st.stream.CloseRead()
106 st.stream.Reset(uint64(err.code))
107 default:
108 st.stream.CloseRead()
109 st.stream.Reset(uint64(errH3InternalError))
110 }
111 }
112
113 func (c *genericConn) checkStreamCreation(stype streamType) error {
114 switch stype {
115 case streamTypeControl, streamTypeEncoder, streamTypeDecoder:
116
117 default:
118 return nil
119 }
120 c.mu.Lock()
121 defer c.mu.Unlock()
122 bit := uint8(1) << stype
123 if c.streamsCreated&bit != 0 {
124 return &connectionError{
125 code: errH3StreamCreationError,
126 message: "multiple " + stype.String() + " streams created",
127 }
128 }
129 c.streamsCreated |= bit
130 return nil
131 }
132
View as plain text