1
2
3
4
5
6
7 package quic
8
9 import (
10 "encoding/binary"
11 "net"
12 "net/netip"
13 "sync"
14 "syscall"
15 "unsafe"
16 )
17
18
19
20 type netUDPConn struct {
21 c *net.UDPConn
22 localAddr netip.AddrPort
23 }
24
25 func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) {
26 a, _ := uc.LocalAddr().(*net.UDPAddr)
27 localAddr := a.AddrPort()
28 if localAddr.Addr().IsUnspecified() {
29
30
31
32
33 localAddr = netip.AddrPortFrom(netip.Addr{}, localAddr.Port())
34 }
35
36 sc, err := uc.SyscallConn()
37 if err != nil {
38 return nil, err
39 }
40 sc.Control(func(fd uintptr) {
41
42
43
44
45
46 syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, ip_recvtos, 1)
47 syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_RECVTCLASS, 1)
48 if !localAddr.IsValid() {
49 syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1)
50 syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, ipv6_recvpktinfo, 1)
51 }
52 })
53
54 return &netUDPConn{
55 c: uc,
56 localAddr: localAddr,
57 }, nil
58 }
59
60 func (c *netUDPConn) Close() error { return c.c.Close() }
61
62 func (c *netUDPConn) LocalAddr() netip.AddrPort {
63 a, _ := c.c.LocalAddr().(*net.UDPAddr)
64 return a.AddrPort()
65 }
66
67 func (c *netUDPConn) Read(f func(*datagram)) {
68
69
70 const (
71 inPktinfoSize = 12
72 in6PktinfoSize = 20
73 ipTOSSize = 4
74 ipv6TclassSize = 4
75 )
76 control := make([]byte, 0+
77 syscall.CmsgSpace(inPktinfoSize)+
78 syscall.CmsgSpace(in6PktinfoSize)+
79 syscall.CmsgSpace(ipTOSSize)+
80 syscall.CmsgSpace(ipv6TclassSize))
81
82 for {
83 d := newDatagram()
84 n, controlLen, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(d.b, control)
85 if err != nil {
86 return
87 }
88 if n == 0 {
89 continue
90 }
91 d.localAddr = c.localAddr
92 d.peerAddr = unmapAddrPort(peerAddr)
93 d.b = d.b[:n]
94 parseControl(d, control[:controlLen])
95 f(d)
96 }
97 }
98
99 var cmsgPool = sync.Pool{
100 New: func() any {
101 return new([]byte)
102 },
103 }
104
105 func (c *netUDPConn) Write(dgram datagram) error {
106 controlp := cmsgPool.Get().(*[]byte)
107 control := *controlp
108 defer func() {
109 *controlp = control[:0]
110 cmsgPool.Put(controlp)
111 }()
112
113 localIP := dgram.localAddr.Addr()
114 if localIP.IsValid() {
115 if localIP.Is4() {
116 control = appendCmsgIPSourceAddrV4(control, localIP)
117 } else {
118 control = appendCmsgIPSourceAddrV6(control, localIP)
119 }
120 }
121 if dgram.ecn != ecnNotECT {
122 if dgram.peerAddr.Addr().Is4() {
123 control = appendCmsgECNv4(control, dgram.ecn)
124 } else {
125 control = appendCmsgECNv6(control, dgram.ecn)
126 }
127 }
128
129 _, _, err := c.c.WriteMsgUDPAddrPort(dgram.b, control, dgram.peerAddr)
130 return err
131 }
132
133 func parseControl(d *datagram, control []byte) {
134 msgs, err := syscall.ParseSocketControlMessage(control)
135 if err != nil {
136 return
137 }
138 for _, m := range msgs {
139 switch m.Header.Level {
140 case syscall.IPPROTO_IP:
141 switch m.Header.Type {
142 case syscall.IP_TOS, ip_recvtos:
143
144
145 if ecn, ok := parseIPTOS(m.Data); ok {
146 d.ecn = ecn
147 }
148 case syscall.IP_PKTINFO:
149 if a, ok := parseInPktinfo(m.Data); ok {
150 d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
151 }
152 }
153 case syscall.IPPROTO_IPV6:
154 switch m.Header.Type {
155 case syscall.IPV6_TCLASS:
156
157
158 if ecn, ok := parseIPv6TCLASS(m.Data); ok {
159 d.ecn = ecn
160 }
161 case ipv6_pktinfo:
162 if a, ok := parseIn6Pktinfo(m.Data); ok {
163 d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
164 }
165 }
166 }
167 }
168 }
169
170
171
172 func parseIPv6TCLASS(b []byte) (ecnBits, bool) {
173 if len(b) != 4 {
174 return 0, false
175 }
176 return ecnBits(binary.NativeEndian.Uint32(b) & ecnMask), true
177 }
178
179 func appendCmsgECNv6(b []byte, ecn ecnBits) []byte {
180 b, data := appendCmsg(b, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, 4)
181 binary.NativeEndian.PutUint32(data, uint32(ecn))
182 return b
183 }
184
185
186
187
188
189
190
191
192 func parseInPktinfo(b []byte) (dst netip.Addr, ok bool) {
193 if len(b) != 12 {
194 return netip.Addr{}, false
195 }
196 return netip.AddrFrom4([4]byte(b[8:][:4])), true
197 }
198
199
200
201 func appendCmsgIPSourceAddrV4(b []byte, src netip.Addr) []byte {
202
203
204
205
206
207 b, data := appendCmsg(b, syscall.IPPROTO_IP, syscall.IP_PKTINFO, 12)
208 ip := src.As4()
209 copy(data[4:], ip[:])
210 return b
211 }
212
213
214
215
216
217
218
219 func parseIn6Pktinfo(b []byte) (netip.Addr, bool) {
220 if len(b) != 20 {
221 return netip.Addr{}, false
222 }
223 return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true
224 }
225
226
227
228 func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte {
229 b, data := appendCmsg(b, syscall.IPPROTO_IPV6, ipv6_pktinfo, 20)
230 ip := src.As16()
231 copy(data[0:], ip[:])
232 return b
233 }
234
235
236
237 func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) {
238 off := len(b)
239 b = append(b, make([]byte, syscall.CmsgSpace(size))...)
240 h := (*syscall.Cmsghdr)(unsafe.Pointer(&b[off]))
241 h.Level = level
242 h.Type = typ
243 h.SetLen(syscall.CmsgLen(size))
244 return b, b[off+syscall.CmsgSpace(0):][:size]
245 }
246
View as plain text