1
2
3
4
5 package nettest
6
7 import (
8 "context"
9 "errors"
10 "internal/gate"
11 "net"
12 "os"
13 "slices"
14 "sync"
15 "time"
16 )
17
18
19 type PacketNet struct {
20 mu sync.Mutex
21 conns map[netAddr]*PacketConn
22 }
23
24 type netAddr struct {
25 network string
26 addr string
27 }
28
29
30 func NewPacketNet() *PacketNet {
31 return &PacketNet{
32 conns: make(map[netAddr]*PacketConn),
33 }
34 }
35
36
37
38 func (n *PacketNet) NewConn(a net.Addr) (*PacketConn, error) {
39 n.mu.Lock()
40 defer n.mu.Unlock()
41 addrKey := netAddr{a.Network(), a.String()}
42 if _, ok := n.conns[addrKey]; ok {
43 return nil, &net.OpError{
44 Op: "listen",
45 Net: "udp",
46 Addr: a,
47 Err: errors.New("address is in use"),
48 }
49 }
50 p := &PacketConn{
51 gate: gate.New(false),
52 addr: a,
53 net: n,
54 }
55 n.conns[addrKey] = p
56 return p, nil
57 }
58
59 type PacketConn struct {
60 gate gate.Gate
61 queue queue[*packet]
62 closed bool
63 readErr error
64 writeErr error
65 closeErr error
66 readDeadline connDeadline
67
68 net *PacketNet
69 addr net.Addr
70 }
71
72 type packet struct {
73 b []byte
74 src net.Addr
75 }
76
77
78 func (p *PacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
79 p.gate.WaitAndLock(context.Background())
80 defer p.unlock()
81
82 switch {
83 case p.closed:
84 err = net.ErrClosed
85 case p.readDeadline.expired:
86 err = os.ErrDeadlineExceeded
87 case p.queue.len() == 0 && p.readErr != nil:
88 err = p.readErr
89 }
90 if err != nil {
91 return 0, nil, &net.OpError{
92 Op: "read",
93 Net: "udp",
94 Addr: p.addr,
95 Err: err,
96 }
97 }
98 pkt := p.queue.pop()
99 n = copy(b, pkt.b)
100 return n, pkt.src, nil
101 }
102
103
104
105
106
107
108
109 func (p *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
110 p.gate.Lock()
111 switch {
112 case p.closed:
113 err = net.ErrClosed
114 case p.writeErr != nil:
115 err = p.writeErr
116 }
117 p.unlock()
118 if err != nil {
119 return 0, &net.OpError{
120 Op: "write",
121 Net: "udp",
122 Source: p.addr,
123 Addr: addr,
124 Err: err,
125 }
126 }
127
128 p.net.mu.Lock()
129 dst := p.net.conns[netAddr{addr.Network(), addr.String()}]
130 p.net.mu.Unlock()
131
132 if dst == nil {
133
134
135 return len(b), nil
136 }
137
138 dst.lock()
139 if !dst.closed {
140 dst.queue.push(&packet{b: slices.Clone(b), src: p.addr})
141 }
142 dst.unlock()
143 return len(b), nil
144 }
145
146
147 func (p *PacketConn) Close() error {
148 p.net.mu.Lock()
149 delete(p.net.conns, netAddr{p.addr.Network(), p.addr.String()})
150 p.net.mu.Unlock()
151
152 p.lock()
153 defer p.unlock()
154 err := p.closeErr
155 p.closed = true
156 p.readErr = net.ErrClosed
157 p.writeErr = net.ErrClosed
158 p.closeErr = net.ErrClosed
159 if err != nil {
160 return &net.OpError{
161 Op: "close",
162 Net: "udp",
163 Addr: p.addr,
164 Err: err,
165 }
166 }
167 return err
168 }
169
170
171 func (p *PacketConn) LocalAddr() net.Addr {
172 p.lock()
173 defer p.unlock()
174 return p.addr
175 }
176
177
178
179 func (p *PacketConn) SetDeadline(t time.Time) error {
180 return p.SetReadDeadline(t)
181 }
182
183
184 func (p *PacketConn) SetReadDeadline(t time.Time) error {
185 p.readDeadline.setDeadline(p, t)
186 return nil
187 }
188
189
190
191 func (p *PacketConn) SetWriteDeadline(t time.Time) error {
192 return nil
193 }
194
195
196
197
198
199
200 func (c *PacketConn) SetReadError(err error) {
201 c.lock()
202 defer c.unlock()
203 c.readErr = err
204 }
205
206
207
208
209
210 func (c *PacketConn) SetWriteError(err error) {
211 c.lock()
212 defer c.unlock()
213 c.writeErr = err
214 }
215
216
217
218
219 func (c *PacketConn) SetCloseError(err error) {
220 c.lock()
221 defer c.unlock()
222 c.closeErr = err
223 }
224
225
226
227 func (p *PacketConn) CanRead() bool {
228 p.lock()
229 defer p.unlock()
230 return p.canReadLocked()
231 }
232
233 func (p *PacketConn) canReadLocked() bool {
234 return p.queue.len() > 0 || p.readDeadline.expired || p.closed || p.readErr != nil
235 }
236
237
238 func (p *PacketConn) IsClosed() bool {
239 p.lock()
240 defer p.unlock()
241 return p.closed
242 }
243
244 func (p *PacketConn) lock() {
245 p.gate.Lock()
246 }
247
248 func (p *PacketConn) unlock() {
249 p.gate.Unlock(p.canReadLocked())
250 }
251
View as plain text