1
2
3
4
5 package nettest_test
6
7 import (
8 "bytes"
9 "errors"
10 "internal/nettest"
11 "io"
12 "net"
13 "os"
14 "testing"
15 "testing/synctest"
16 "time"
17 )
18
19 func TestConnReadWrite(t *testing.T) {
20 synctest.Test(t, func(t *testing.T) {
21 cliConn, srvConn := nettest.NewConnPair()
22
23 cliData := []byte("hello")
24 srvData := []byte("HELLO")
25 if n, err := cliConn.Write(cliData); n != len(cliData) || err != nil {
26 t.Fatalf("cliConn.Write(%q) = %v, %v; want %v, nil", cliData, n, err, len(cliData))
27 }
28 if err := cliConn.CloseWrite(); err != nil {
29 t.Fatalf("cliConn.CloseWrite() = %v, want nil", err)
30 }
31 if n, err := srvConn.Write(srvData); n != len(srvData) || err != nil {
32 t.Fatalf("srvConn.Write(%q) = %v, %v; want %v, nil", srvData, n, err, len(srvData))
33 }
34 if err := srvConn.CloseWrite(); err != nil {
35 t.Fatalf("cliConn.CloseWrite() = %v, want nil", err)
36 }
37 gotCli, err := io.ReadAll(cliConn)
38 if !bytes.Equal(gotCli, srvData) || err != nil {
39 t.Fatalf("io.ReadAll(cliConn) = %q, %v; want %v, nil", gotCli, err, srvData)
40 }
41 gotSrv, err := io.ReadAll(srvConn)
42 if !bytes.Equal(gotSrv, cliData) || err != nil {
43 t.Fatalf("io.ReadAll(srvConn) = %q, %v; want %v, nil", gotSrv, err, cliData)
44 }
45 })
46 }
47
48 func TestConnZeroBuffer(t *testing.T) {
49
50
51
52 synctest.Test(t, func(t *testing.T) {
53 rconn, wconn := nettest.NewConnPair()
54 rconn.SetReadBufferSize(0)
55 var readDone, writeDone bool
56 go func() {
57 rconn.Read(make([]byte, 100))
58 readDone = true
59 }()
60 go func() {
61 wconn.Write([]byte("a"))
62 writeDone = true
63 }()
64 synctest.Wait()
65 if readDone || writeDone {
66 t.Errorf("before unblocking: readDone=%v, writeDone=%v; want false", readDone, writeDone)
67 }
68 wconn.Close()
69 synctest.Wait()
70 if !readDone || !writeDone {
71 t.Errorf("after unblocking: readDone=%v, writeDone=%v; want true", readDone, writeDone)
72 }
73 })
74 }
75
76 func TestConnPartialWrite(t *testing.T) {
77
78 synctest.Test(t, func(t *testing.T) {
79 const readSize = 5
80 data := []byte("0123456789")
81 rconn, wconn := nettest.NewConnPair()
82 rconn.SetReadBufferSize(1)
83 go func() {
84 got := make([]byte, readSize)
85 if n, err := io.ReadFull(rconn, got); n != readSize || err != nil {
86 t.Errorf("io.ReadFull() = %v, %v; want %v, nil", n, err, readSize)
87 }
88 if want := data[:readSize]; !bytes.Equal(got, want) {
89 t.Errorf("read %q, want %q", got, want)
90 }
91 rconn.Close()
92 }()
93 n, err := wconn.Write(data)
94 if n != readSize+1 || err == nil {
95 t.Errorf("Write() = %v, %v; want %v, error", n, err, readSize+1)
96 }
97 })
98 }
99
100 func TestConnReadDeadline(t *testing.T) {
101 for _, unblock := range []struct {
102 name string
103 f func(*nettest.Conn)
104 }{{
105 name: "Write",
106 f: func(c *nettest.Conn) {
107 c.Write([]byte("x"))
108 },
109 }, {
110 name: "Close",
111 f: func(c *nettest.Conn) {
112 c.Close()
113 },
114 }, {
115 name: "CloseWrite",
116 f: func(c *nettest.Conn) {
117 c.CloseWrite()
118 },
119 }} {
120 for _, setDeadline := range []struct {
121 name string
122 f func(*nettest.Conn, time.Time) error
123 }{{
124 name: "SetDeadline",
125 f: (*nettest.Conn).SetDeadline,
126 }, {
127 name: "SetReadDeadline",
128 f: (*nettest.Conn).SetReadDeadline,
129 }} {
130 t.Run(unblock.name+"/"+setDeadline.name, func(t *testing.T) {
131 testDeadline(t, func() deadlineTest {
132 rconn, wconn := nettest.NewConnPair()
133 return deadlineTest{
134 what: "Read()",
135 block: func() error {
136 _, err := rconn.Read(make([]byte, 1))
137 return err
138 },
139 unblock: func() {
140 unblock.f(wconn)
141 },
142 setDeadline: func(d time.Duration) {
143 setDeadline.f(rconn, time.Now().Add(d))
144 },
145 }
146 })
147 })
148 }
149 }
150 }
151
152 func TestConnWriteDeadline(t *testing.T) {
153 for _, unblock := range []struct {
154 name string
155 f func(*nettest.Conn)
156 }{{
157 name: "Read",
158 f: func(c *nettest.Conn) {
159 io.Copy(io.Discard, c)
160 },
161 }, {
162 name: "Close",
163 f: func(c *nettest.Conn) {
164 c.Close()
165 },
166 }, {
167 name: "CloseRead",
168 f: func(c *nettest.Conn) {
169 c.CloseRead()
170 },
171 }} {
172 for _, setDeadline := range []struct {
173 name string
174 f func(*nettest.Conn, time.Time) error
175 }{{
176 name: "SetDeadline",
177 f: (*nettest.Conn).SetDeadline,
178 }, {
179 name: "SetWriteDeadline",
180 f: (*nettest.Conn).SetWriteDeadline,
181 }} {
182 t.Run(unblock.name+"/"+setDeadline.name, func(t *testing.T) {
183 testDeadline(t, func() deadlineTest {
184 rconn, wconn := nettest.NewConnPair()
185 rconn.SetReadBufferSize(1)
186 return deadlineTest{
187 what: "Write()",
188 block: func() error {
189 _, err := wconn.Write([]byte("1234"))
190 wconn.Close()
191 return err
192 },
193 unblock: func() {
194 go unblock.f(rconn)
195 },
196 setDeadline: func(d time.Duration) {
197 setDeadline.f(wconn, time.Now().Add(d))
198 },
199 }
200 })
201 })
202 }
203 }
204 }
205
206 func TestConnCanRead(t *testing.T) {
207 synctest.Test(t, func(t *testing.T) {
208 rconn, wconn := nettest.NewConnPair()
209 if got, want := rconn.CanRead(), false; got != want {
210 t.Fatalf("before writing data: rconn.CanRead() = %v, want %v", got, want)
211 }
212 wconn.Write([]byte("a"))
213 if got, want := rconn.CanRead(), true; got != want {
214 t.Fatalf("after writing data: rconn.CanRead() = %v, want %v", got, want)
215 }
216 rconn.Read(make([]byte, 1))
217 if got, want := rconn.CanRead(), false; got != want {
218 t.Fatalf("after reading data: rconn.CanRead() = %v, want %v", got, want)
219 }
220 wconn.Close()
221 if got, want := rconn.CanRead(), true; got != want {
222 t.Fatalf("after closing: rconn.CanRead() = %v, want %v", got, want)
223 }
224 })
225 }
226
227 func TestConnIsClosed(t *testing.T) {
228 for _, test := range []struct {
229 name string
230 f func() *nettest.Conn
231 want bool
232 }{{
233 name: "unclosed",
234 f: func() *nettest.Conn {
235 conn, _ := nettest.NewConnPair()
236 return conn
237 },
238 want: false,
239 }, {
240 name: "closed",
241 f: func() *nettest.Conn {
242 conn, _ := nettest.NewConnPair()
243 conn.Close()
244 return conn
245 },
246 want: true,
247 }, {
248 name: "read-closed",
249 f: func() *nettest.Conn {
250 conn, _ := nettest.NewConnPair()
251 conn.CloseRead()
252 return conn
253 },
254 want: false,
255 }, {
256 name: "write-closed",
257 f: func() *nettest.Conn {
258 conn, _ := nettest.NewConnPair()
259 conn.CloseWrite()
260 return conn
261 },
262 want: false,
263 }, {
264 name: "read-write-closed",
265 f: func() *nettest.Conn {
266 conn, _ := nettest.NewConnPair()
267 conn.CloseRead()
268 conn.CloseWrite()
269 return conn
270 },
271 want: true,
272 }} {
273 synctestSubtest(t, test.name, func(t *testing.T) {
274 conn := test.f()
275 if got, want := conn.IsClosed(), test.want; got != want {
276 t.Fatalf("conn.IsClosed() = %v, want %v", got, want)
277 }
278 if got, want := conn.Peer().IsClosed(), false; got != want {
279 t.Fatalf("conn.Peer().IsClosed() = %v, want %v", got, want)
280 }
281 })
282 }
283 }
284
285 var anyError = errors.New("any")
286
287 func isOpError(err, want error) bool {
288 oe, ok := err.(*net.OpError)
289 return ok && (oe.Err == want || want == anyError)
290 }
291
292 func wantConnReadBytes(t *testing.T, c *nettest.Conn, want []byte) {
293 t.Helper()
294 got := make([]byte, len(want))
295 n, err := io.ReadFull(c, got)
296 if n < len(want) || err != nil {
297 t.Fatalf("io.ReadFull = %v, %v; want %v, nil", n, err, len(want))
298 }
299
300 if !bytes.Equal(got, want) {
301 t.Fatalf("io.ReadFull read %q, want %q", got, want)
302 }
303 }
304
305 func wantConnReadErr(t *testing.T, c *nettest.Conn, want error) {
306 t.Helper()
307 n, err := c.Read(make([]byte, 1))
308 if want == io.EOF {
309 if n != 0 || err != io.EOF {
310 t.Fatalf("c.Read() = %v, %v; want 0, io.EOF", n, err)
311 }
312 } else {
313 if n != 0 || !isOpError(err, want) {
314 t.Fatalf("c.Read() = %v, %v; want 0, OpError{Err: %q}", n, err, want)
315 }
316 }
317 }
318
319 func wantConnReadBlocked(t *testing.T, c *nettest.Conn) {
320 done := false
321 go func() {
322 n, err := c.Read(make([]byte, 1))
323 if n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
324 t.Errorf("c.Read() = %v, %v; want 0, ErrDeadlineExceeded", n, err)
325 }
326 done = true
327 }()
328 synctest.Wait()
329 if done {
330 t.Fatalf("Read unexpectedly returned before setting deadline")
331 }
332 c.SetReadDeadline(time.Now().Add(-1 * time.Second))
333 synctest.Wait()
334 c.SetReadDeadline(time.Time{})
335 if !done {
336 t.Fatalf("Read unexpectedly did not return after setting deadline")
337 }
338 }
339
340 func TestConnSetReadError(t *testing.T) {
341 synctest.Test(t, func(t *testing.T) {
342 wantErr := errors.New("error")
343 rconn, wconn := nettest.NewConnPair()
344 rconn.SetReadError(wantErr)
345
346
347 wconn.Write([]byte("one"))
348 wantConnReadBytes(t, rconn, []byte("one"))
349 wantConnReadErr(t, rconn, wantErr)
350
351
352 wconn.Write([]byte("two"))
353 wantConnReadBytes(t, rconn, []byte("two"))
354 wantConnReadErr(t, rconn, wantErr)
355
356
357 rconn.SetReadError(nil)
358 wantConnReadBlocked(t, rconn)
359
360
361 rconn.SetReadError(wantErr)
362 wconn.Write([]byte("three"))
363 wconn.Close()
364 wantConnReadBytes(t, rconn, []byte("three"))
365 wantConnReadErr(t, rconn, io.EOF)
366
367
368 rconn.SetReadError(nil)
369 wantConnReadErr(t, rconn, io.EOF)
370 rconn.SetReadError(wantErr)
371 wantConnReadErr(t, rconn, io.EOF)
372
373
374 rconn.Close()
375 wantConnReadErr(t, rconn, net.ErrClosed)
376 })
377 }
378
379 func wantConnWriteBytes(t *testing.T, c *nettest.Conn, b []byte) {
380 t.Helper()
381 if n, err := c.Write(b); n != len(b) || err != nil {
382 t.Fatalf("c.Write() = %v, %v; want %v, nil", n, err, len(b))
383 }
384 }
385
386 func wantConnWriteErr(t *testing.T, c *nettest.Conn, want error) {
387 t.Helper()
388 n, err := c.Write(make([]byte, 1))
389 if n != 0 || !isOpError(err, want) {
390 t.Fatalf("c.Write() = %v, %v; want 0, OpError{Err: %q}", n, err, want)
391 }
392 }
393
394 func TestConnSetWriteError(t *testing.T) {
395 synctest.Test(t, func(t *testing.T) {
396 wantErr := errors.New("error")
397 rconn, wconn := nettest.NewConnPair()
398 wconn.SetWriteError(wantErr)
399
400
401 wantConnWriteErr(t, wconn, wantErr)
402 wantConnReadBlocked(t, rconn)
403
404
405 wconn.SetWriteError(nil)
406 wantConnWriteBytes(t, wconn, []byte("one"))
407
408
409 wconn.SetWriteError(wantErr)
410 wantConnWriteErr(t, wconn, wantErr)
411 wantConnReadBytes(t, rconn, []byte("one"))
412
413
414 wconn.Close()
415 wantConnReadErr(t, rconn, io.EOF)
416 })
417 }
418
419 func TestConnSetCloseError(t *testing.T) {
420 synctest.Test(t, func(t *testing.T) {
421 wantErr := errors.New("error")
422 rconn, wconn := nettest.NewConnPair()
423
424 wconn.SetCloseError(wantErr)
425 if _, err := wconn.Write([]byte("one")); err != nil {
426 t.Fatalf("wconn.Write = %v, want success", err)
427 }
428 if err := wconn.Close(); !isOpError(err, wantErr) {
429 t.Fatalf("wconn.Close = %v, want OpError{Err: %v}", err, wantErr)
430 }
431 if err := wconn.Close(); !isOpError(err, net.ErrClosed) {
432 t.Fatalf("wconn.Close = %v, want OpError{Err: net.ErrClosed}", err)
433 }
434 wantConnReadBytes(t, rconn, []byte("one"))
435 wantConnReadErr(t, rconn, io.EOF)
436 })
437 }
438
439 func TestConnCloseReadWriteError(t *testing.T) {
440 synctest.Test(t, func(t *testing.T) {
441 conn, _ := nettest.NewConnPair()
442 conn.SetCloseError(errors.New("error"))
443 if err := conn.CloseRead(); err != nil {
444 t.Fatalf("conn.CloseRead = %v, want nil", err)
445 }
446 if err := conn.CloseWrite(); err != nil {
447 t.Fatalf("conn.CloseRead = %v, want nil", err)
448 }
449 })
450 }
451
View as plain text