1
2
3
4
5 package tls
6
7 import (
8 "context"
9 "errors"
10 "fmt"
11 )
12
13
14
15 type QUICEncryptionLevel int
16
17 const (
18 QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
19 QUICEncryptionLevelEarly
20 QUICEncryptionLevelHandshake
21 QUICEncryptionLevelApplication
22 )
23
24 func (l QUICEncryptionLevel) String() string {
25 switch l {
26 case QUICEncryptionLevelInitial:
27 return "Initial"
28 case QUICEncryptionLevelEarly:
29 return "Early"
30 case QUICEncryptionLevelHandshake:
31 return "Handshake"
32 case QUICEncryptionLevelApplication:
33 return "Application"
34 default:
35 return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
36 }
37 }
38
39
40
41
42
43 type QUICConn struct {
44 conn *Conn
45
46 sessionTicketSent bool
47 }
48
49
50 type QUICConfig struct {
51 TLSConfig *Config
52
53
54
55
56
57
58 EnableSessionEvents bool
59 }
60
61
62 type QUICEventKind int
63
64 const (
65
66 QUICNoEvent QUICEventKind = iota
67
68
69
70
71
72
73
74 QUICSetReadSecret
75 QUICSetWriteSecret
76
77
78
79 QUICWriteData
80
81
82
83 QUICTransportParameters
84
85
86
87
88
89
90
91
92 QUICTransportParametersRequired
93
94
95
96
97
98 QUICRejectedEarlyData
99
100
101 QUICHandshakeDone
102
103
104
105
106
107
108
109
110
111 QUICResumeSession
112
113
114
115
116
117
118
119 QUICStoreSession
120 )
121
122
123
124
125
126 type QUICEvent struct {
127 Kind QUICEventKind
128
129
130 Level QUICEncryptionLevel
131
132
133
134 Data []byte
135
136
137 Suite uint16
138
139
140 SessionState *SessionState
141 }
142
143 type quicState struct {
144 events []QUICEvent
145 nextEvent int
146
147
148
149
150
151 eventArr [8]QUICEvent
152
153 started bool
154 signalc chan struct{}
155 blockedc chan struct{}
156 cancelc <-chan struct{}
157 cancel context.CancelFunc
158
159 waitingForDrain bool
160
161
162
163
164 readbuf []byte
165
166 transportParams []byte
167
168 enableSessionEvents bool
169 }
170
171
172
173
174
175 func QUICClient(config *QUICConfig) *QUICConn {
176 return newQUICConn(Client(nil, config.TLSConfig), config)
177 }
178
179
180
181
182
183 func QUICServer(config *QUICConfig) *QUICConn {
184 return newQUICConn(Server(nil, config.TLSConfig), config)
185 }
186
187 func newQUICConn(conn *Conn, config *QUICConfig) *QUICConn {
188 conn.quic = &quicState{
189 signalc: make(chan struct{}),
190 blockedc: make(chan struct{}),
191 enableSessionEvents: config.EnableSessionEvents,
192 }
193 conn.quic.events = conn.quic.eventArr[:0]
194 return &QUICConn{
195 conn: conn,
196 }
197 }
198
199
200
201
202
203 func (q *QUICConn) Start(ctx context.Context) error {
204 if q.conn.quic.started {
205 return quicError(errors.New("tls: Start called more than once"))
206 }
207 q.conn.quic.started = true
208 if q.conn.config.MinVersion < VersionTLS13 {
209 return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.13"))
210 }
211 go q.conn.HandshakeContext(ctx)
212 if _, ok := <-q.conn.quic.blockedc; !ok {
213 return q.conn.handshakeErr
214 }
215 return nil
216 }
217
218
219
220 func (q *QUICConn) NextEvent() QUICEvent {
221 qs := q.conn.quic
222 if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
223
224
225 qs.events[last].Data[0] = 0
226 }
227 if qs.nextEvent >= len(qs.events) && qs.waitingForDrain {
228 qs.waitingForDrain = false
229 <-qs.signalc
230 <-qs.blockedc
231 }
232 if qs.nextEvent >= len(qs.events) {
233 qs.events = qs.events[:0]
234 qs.nextEvent = 0
235 return QUICEvent{Kind: QUICNoEvent}
236 }
237 e := qs.events[qs.nextEvent]
238 qs.events[qs.nextEvent] = QUICEvent{}
239 qs.nextEvent++
240 return e
241 }
242
243
244 func (q *QUICConn) Close() error {
245 if q.conn.quic.cancel == nil {
246 return nil
247 }
248 q.conn.quic.cancel()
249 for range q.conn.quic.blockedc {
250
251 }
252 return q.conn.handshakeErr
253 }
254
255
256
257 func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
258 c := q.conn
259 if c.in.level != level {
260 return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
261 }
262 c.quic.readbuf = data
263 <-c.quic.signalc
264 _, ok := <-c.quic.blockedc
265 if ok {
266
267 return nil
268 }
269
270 c.handshakeMutex.Lock()
271 defer c.handshakeMutex.Unlock()
272 c.hand.Write(c.quic.readbuf)
273 c.quic.readbuf = nil
274 for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
275 b := q.conn.hand.Bytes()
276 n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
277 if n > maxHandshake {
278 q.conn.handshakeErr = fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)
279 break
280 }
281 if len(b) < 4+n {
282 return nil
283 }
284 if err := q.conn.handlePostHandshakeMessage(); err != nil {
285 q.conn.handshakeErr = err
286 }
287 }
288 if q.conn.handshakeErr != nil {
289 return quicError(q.conn.handshakeErr)
290 }
291 return nil
292 }
293
294 type QUICSessionTicketOptions struct {
295
296 EarlyData bool
297 Extra [][]byte
298 }
299
300
301
302
303 func (q *QUICConn) SendSessionTicket(opts QUICSessionTicketOptions) error {
304 c := q.conn
305 if !c.isHandshakeComplete.Load() {
306 return quicError(errors.New("tls: SendSessionTicket called before handshake completed"))
307 }
308 if c.isClient {
309 return quicError(errors.New("tls: SendSessionTicket called on the client"))
310 }
311 if q.sessionTicketSent {
312 return quicError(errors.New("tls: SendSessionTicket called multiple times"))
313 }
314 q.sessionTicketSent = true
315 return quicError(c.sendSessionTicket(opts.EarlyData, opts.Extra))
316 }
317
318
319
320
321
322 func (q *QUICConn) StoreSession(session *SessionState) error {
323 c := q.conn
324 if !c.isClient {
325 return quicError(errors.New("tls: StoreSessionTicket called on the server"))
326 }
327 cacheKey := c.clientSessionCacheKey()
328 if cacheKey == "" {
329 return nil
330 }
331 cs := &ClientSessionState{session: session}
332 c.config.ClientSessionCache.Put(cacheKey, cs)
333 return nil
334 }
335
336
337 func (q *QUICConn) ConnectionState() ConnectionState {
338 return q.conn.ConnectionState()
339 }
340
341
342
343
344
345 func (q *QUICConn) SetTransportParameters(params []byte) {
346 if params == nil {
347 params = []byte{}
348 }
349 q.conn.quic.transportParams = params
350 if q.conn.quic.started {
351 <-q.conn.quic.signalc
352 <-q.conn.quic.blockedc
353 }
354 }
355
356
357
358 func quicError(err error) error {
359 if err == nil {
360 return nil
361 }
362 var ae AlertError
363 if errors.As(err, &ae) {
364 return err
365 }
366 var a alert
367 if !errors.As(err, &a) {
368 a = alertInternalError
369 }
370
371
372 return fmt.Errorf("%w%.0w", err, AlertError(a))
373 }
374
375 func (c *Conn) quicReadHandshakeBytes(n int) error {
376 for c.hand.Len() < n {
377 if err := c.quicWaitForSignal(); err != nil {
378 return err
379 }
380 }
381 return nil
382 }
383
384 func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
385 c.quic.events = append(c.quic.events, QUICEvent{
386 Kind: QUICSetReadSecret,
387 Level: level,
388 Suite: suite,
389 Data: secret,
390 })
391 }
392
393 func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
394 c.quic.events = append(c.quic.events, QUICEvent{
395 Kind: QUICSetWriteSecret,
396 Level: level,
397 Suite: suite,
398 Data: secret,
399 })
400 }
401
402 func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
403 var last *QUICEvent
404 if len(c.quic.events) > 0 {
405 last = &c.quic.events[len(c.quic.events)-1]
406 }
407 if last == nil || last.Kind != QUICWriteData || last.Level != level {
408 c.quic.events = append(c.quic.events, QUICEvent{
409 Kind: QUICWriteData,
410 Level: level,
411 })
412 last = &c.quic.events[len(c.quic.events)-1]
413 }
414 last.Data = append(last.Data, data...)
415 }
416
417 func (c *Conn) quicResumeSession(session *SessionState) error {
418 c.quic.events = append(c.quic.events, QUICEvent{
419 Kind: QUICResumeSession,
420 SessionState: session,
421 })
422 c.quic.waitingForDrain = true
423 for c.quic.waitingForDrain {
424 if err := c.quicWaitForSignal(); err != nil {
425 return err
426 }
427 }
428 return nil
429 }
430
431 func (c *Conn) quicStoreSession(session *SessionState) {
432 c.quic.events = append(c.quic.events, QUICEvent{
433 Kind: QUICStoreSession,
434 SessionState: session,
435 })
436 }
437
438 func (c *Conn) quicSetTransportParameters(params []byte) {
439 c.quic.events = append(c.quic.events, QUICEvent{
440 Kind: QUICTransportParameters,
441 Data: params,
442 })
443 }
444
445 func (c *Conn) quicGetTransportParameters() ([]byte, error) {
446 if c.quic.transportParams == nil {
447 c.quic.events = append(c.quic.events, QUICEvent{
448 Kind: QUICTransportParametersRequired,
449 })
450 }
451 for c.quic.transportParams == nil {
452 if err := c.quicWaitForSignal(); err != nil {
453 return nil, err
454 }
455 }
456 return c.quic.transportParams, nil
457 }
458
459 func (c *Conn) quicHandshakeComplete() {
460 c.quic.events = append(c.quic.events, QUICEvent{
461 Kind: QUICHandshakeDone,
462 })
463 }
464
465 func (c *Conn) quicRejectedEarlyData() {
466 c.quic.events = append(c.quic.events, QUICEvent{
467 Kind: QUICRejectedEarlyData,
468 })
469 }
470
471
472
473
474
475
476 func (c *Conn) quicWaitForSignal() error {
477
478
479 c.handshakeMutex.Unlock()
480 defer c.handshakeMutex.Lock()
481
482
483
484 select {
485 case c.quic.blockedc <- struct{}{}:
486 case <-c.quic.cancelc:
487 return c.sendAlertLocked(alertCloseNotify)
488 }
489
490
491
492 select {
493 case c.quic.signalc <- struct{}{}:
494 c.hand.Write(c.quic.readbuf)
495 c.quic.readbuf = nil
496 case <-c.quic.cancelc:
497 return c.sendAlertLocked(alertCloseNotify)
498 }
499 return nil
500 }
501
View as plain text