1
2
3
4
5 package nettest
6
7 import (
8 "bytes"
9 "errors"
10 "io"
11 "math"
12 "net"
13 "net/netip"
14 "os"
15 "time"
16 )
17
18
19 type Conn struct {
20
21
22
23
24
25
26
27
28 r, w *connHalf
29
30
31 peer *Conn
32 }
33
34
35 func NewConnPair() (*Conn, *Conn) {
36 return newConnPair(
37 net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:10000")),
38 net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:10001")),
39 )
40 }
41
42 func newConnPair(addr1, addr2 net.Addr) (*Conn, *Conn) {
43 h1 := newConnHalf(addr1)
44 h2 := newConnHalf(addr2)
45 c1 := &Conn{r: h1, w: h2}
46 c2 := &Conn{r: h2, w: h1}
47 c1.peer = c2
48 c2.peer = c1
49 c1.SetReadBufferSize(-1)
50 c2.SetReadBufferSize(-1)
51 return c1, c2
52 }
53
54
55 func (c *Conn) Peer() *Conn {
56 return c.peer
57 }
58
59
60 func (c *Conn) Read(b []byte) (n int, err error) {
61 n, err = c.r.read(b)
62 if err != nil && err != io.EOF {
63 err = &net.OpError{
64 Op: "read",
65 Net: "tcp",
66 Source: c.RemoteAddr(),
67 Addr: c.LocalAddr(),
68 Err: err,
69 }
70 }
71 return n, err
72 }
73
74
75 func (c *Conn) CanRead() bool {
76 return c.r.canRead()
77 }
78
79
80 func (c *Conn) Write(b []byte) (n int, err error) {
81 n, err = c.w.write(b)
82 if err != nil {
83 err = &net.OpError{
84 Op: "write",
85 Net: "tcp",
86 Source: c.LocalAddr(),
87 Addr: c.RemoteAddr(),
88 Err: err,
89 }
90 }
91 return n, err
92 }
93
94
95
96
97
98
99
100 func (c *Conn) IsClosed() bool {
101 c.r.lock()
102 readClosed := c.r.readClosed
103 c.r.unlock()
104 c.w.lock()
105 writeClosed := c.w.writeClosed
106 c.w.unlock()
107 return readClosed && writeClosed
108 }
109
110 var errClosedByPeer = errors.New("connection closed by peer")
111
112
113 func (c *Conn) CloseRead() error {
114 c.r.lock()
115 defer c.r.unlock()
116 c.r.buf.Reset()
117 c.r.readClosed = true
118 return nil
119 }
120
121
122 func (c *Conn) CloseWrite() error {
123 c.w.lock()
124 defer c.w.unlock()
125 c.w.writeClosed = true
126 return nil
127 }
128
129
130 func (c *Conn) Close() error {
131 c.r.lock()
132 readClosed := c.r.readClosed
133 c.r.buf.Reset()
134 c.r.readClosed = true
135 err := c.r.closeErr
136 c.r.unlock()
137
138 c.w.lock()
139 writeClosed := c.w.writeClosed
140 c.w.writeClosed = true
141 c.w.unlock()
142
143 if readClosed && writeClosed {
144 err = net.ErrClosed
145 }
146 if err != nil {
147 err = &net.OpError{
148 Op: "close",
149 Net: "tcp",
150 Addr: c.LocalAddr(),
151 Err: err,
152 }
153 }
154 return err
155 }
156
157
158
159
160 func (c *Conn) SetCloseError(err error) {
161 c.r.lock()
162 c.r.closeErr = err
163 c.r.unlock()
164 }
165
166
167 func (c *Conn) LocalAddr() net.Addr {
168 c.r.lock()
169 defer c.r.unlock()
170 return c.r.addr
171 }
172
173
174
175
176 func (c *Conn) SetLocalAddr(addr net.Addr) {
177 c.r.lock()
178 defer c.r.unlock()
179 c.r.addr = addr
180 }
181
182
183 func (c *Conn) RemoteAddr() net.Addr {
184 c.r.lock()
185 defer c.r.unlock()
186 return c.w.addr
187 }
188
189
190 func (c *Conn) SetDeadline(t time.Time) error {
191 c.SetReadDeadline(t)
192 c.SetWriteDeadline(t)
193 return nil
194 }
195
196
197 func (c *Conn) SetReadDeadline(t time.Time) error {
198 c.r.readDeadline.setDeadline(c.r, t)
199 return nil
200 }
201
202
203 func (c *Conn) SetWriteDeadline(t time.Time) error {
204 c.w.writeDeadline.setDeadline(c.w, t)
205 return nil
206 }
207
208
209
210
211 func (c *Conn) SetReadBufferSize(size int) {
212 if size < 0 {
213 size = math.MaxInt
214 }
215 c.r.setBufferSize(size)
216 }
217
218
219
220
221
222
223 func (c *Conn) SetReadError(err error) {
224 c.r.lock()
225 defer c.r.unlock()
226 c.r.readErr = err
227 }
228
229
230
231
232
233 func (c *Conn) SetWriteError(err error) {
234 c.w.lock()
235 defer c.w.unlock()
236 c.w.writeErr = err
237 }
238
239
240
241
242 type connHalf struct {
243 addr net.Addr
244
245
246
247
248
249
250
251 lockr chan struct{}
252 lockw chan struct{}
253 lockrw chan struct{}
254 lockc chan struct{}
255
256
257 readDeadline, writeDeadline connDeadline
258
259 bufMax int
260 buf bytes.Buffer
261
262 readClosed, writeClosed bool
263 readErr, writeErr error
264 closeErr error
265 }
266
267 func newConnHalf(addr net.Addr) *connHalf {
268 h := &connHalf{
269 addr: addr,
270 lockw: make(chan struct{}, 1),
271 lockr: make(chan struct{}, 1),
272 lockrw: make(chan struct{}, 1),
273 lockc: make(chan struct{}, 1),
274 bufMax: math.MaxInt,
275 }
276 h.unlock()
277 return h
278 }
279
280
281 func (h *connHalf) lock() {
282 select {
283 case <-h.lockw:
284 case <-h.lockr:
285 case <-h.lockrw:
286 case <-h.lockc:
287 }
288 }
289
290
291 func (h *connHalf) unlock() {
292 canRead := h.canReadLocked()
293 canWrite := h.canWriteLocked()
294 switch {
295 case canRead && canWrite:
296 h.lockrw <- struct{}{}
297 case canRead:
298 h.lockr <- struct{}{}
299 case canWrite:
300 h.lockw <- struct{}{}
301 default:
302 h.lockc <- struct{}{}
303 }
304 }
305
306 func (h *connHalf) canRead() bool {
307 h.lock()
308 defer h.unlock()
309 return h.canReadLocked()
310 }
311
312 func (h *connHalf) canReadLocked() bool {
313 return h.readErr != nil || h.readDeadline.expired || h.buf.Len() > 0 || h.readClosed || h.writeClosed
314 }
315
316 func (h *connHalf) canWriteLocked() bool {
317 return h.writeErr != nil || h.writeDeadline.expired || h.bufMax > h.buf.Len() || h.readClosed || h.writeClosed
318 }
319
320
321 func (h *connHalf) waitAndLockForRead() {
322 select {
323 case <-h.lockr:
324
325 case <-h.lockrw:
326
327 }
328 }
329
330
331 func (h *connHalf) waitAndLockForWrite() {
332 select {
333 case <-h.lockw:
334
335 case <-h.lockrw:
336
337 }
338 }
339
340 func (h *connHalf) read(b []byte) (n int, err error) {
341 h.waitAndLockForRead()
342 defer h.unlock()
343 if h.readClosed {
344 return 0, net.ErrClosed
345 }
346 if h.readDeadline.expired {
347 return 0, os.ErrDeadlineExceeded
348 }
349 if h.buf.Len() > 0 {
350 return h.buf.Read(b)
351 }
352 if h.writeClosed {
353 return 0, io.EOF
354 }
355 return 0, h.readErr
356 }
357
358 func (h *connHalf) setBufferSize(size int) {
359 h.lock()
360 defer h.unlock()
361 h.bufMax = size
362 }
363
364 func (h *connHalf) write(b []byte) (n int, err error) {
365 for n < len(b) {
366 nn, err := h.writePartial(b[n:])
367 n += nn
368 if err != nil {
369 return n, err
370 }
371 }
372 return n, nil
373 }
374
375 func (h *connHalf) writePartial(b []byte) (n int, err error) {
376 h.waitAndLockForWrite()
377 defer h.unlock()
378 if h.writeClosed {
379 return 0, net.ErrClosed
380 }
381 if h.writeDeadline.expired {
382 return 0, os.ErrDeadlineExceeded
383 }
384 if h.readClosed {
385 return 0, errClosedByPeer
386 }
387 if h.writeErr != nil {
388 return 0, h.writeErr
389 }
390 writeMax := h.bufMax - h.buf.Len()
391 if writeMax < len(b) {
392 b = b[:writeMax]
393 }
394 return h.buf.Write(b)
395 }
396
397 type connDeadline struct {
398 timer *time.Timer
399 expired bool
400 }
401
402 type locker interface {
403 lock()
404 unlock()
405 }
406
407 func (d *connDeadline) setDeadline(mu locker, t time.Time) {
408 mu.lock()
409 defer mu.unlock()
410 if d.timer != nil {
411 d.timer.Stop()
412 d.timer = nil
413 }
414 if t.IsZero() {
415
416 d.expired = false
417 return
418 }
419 expiry := time.Until(t)
420 if expiry <= 0 {
421
422 d.expired = true
423 return
424 }
425
426 d.expired = false
427 var timer *time.Timer
428 timer = time.AfterFunc(expiry, func() {
429 mu.lock()
430 defer mu.unlock()
431 if d.timer == timer {
432 d.timer = nil
433 d.expired = true
434 }
435 })
436 d.timer = timer
437 }
438
View as plain text