Source file src/internal/nettest/conn_test.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package nettest_test
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"internal/nettest"
    11  	"io"
    12  	"net"
    13  	"os"
    14  	"testing"
    15  	"testing/synctest"
    16  	"time"
    17  )
    18  
    19  func TestConnReadWrite(t *testing.T) {
    20  	synctest.Test(t, func(t *testing.T) {
    21  		cliConn, srvConn := nettest.NewConnPair()
    22  
    23  		cliData := []byte("hello")
    24  		srvData := []byte("HELLO")
    25  		if n, err := cliConn.Write(cliData); n != len(cliData) || err != nil {
    26  			t.Fatalf("cliConn.Write(%q) = %v, %v; want %v, nil", cliData, n, err, len(cliData))
    27  		}
    28  		if err := cliConn.CloseWrite(); err != nil {
    29  			t.Fatalf("cliConn.CloseWrite() = %v, want nil", err)
    30  		}
    31  		if n, err := srvConn.Write(srvData); n != len(srvData) || err != nil {
    32  			t.Fatalf("srvConn.Write(%q) = %v, %v; want %v, nil", srvData, n, err, len(srvData))
    33  		}
    34  		if err := srvConn.CloseWrite(); err != nil {
    35  			t.Fatalf("cliConn.CloseWrite() = %v, want nil", err)
    36  		}
    37  		gotCli, err := io.ReadAll(cliConn)
    38  		if !bytes.Equal(gotCli, srvData) || err != nil {
    39  			t.Fatalf("io.ReadAll(cliConn) = %q, %v; want %v, nil", gotCli, err, srvData)
    40  		}
    41  		gotSrv, err := io.ReadAll(srvConn)
    42  		if !bytes.Equal(gotSrv, cliData) || err != nil {
    43  			t.Fatalf("io.ReadAll(srvConn) = %q, %v; want %v, nil", gotSrv, err, cliData)
    44  		}
    45  	})
    46  }
    47  
    48  func TestConnZeroBuffer(t *testing.T) {
    49  	// Exercise the case where one side of the conn is blocked writing and the
    50  	// other side is blocked reading.
    51  	// This can only happen when the read buffer has been set to 0, blocking all writes.
    52  	synctest.Test(t, func(t *testing.T) {
    53  		rconn, wconn := nettest.NewConnPair()
    54  		rconn.SetReadBufferSize(0)
    55  		var readDone, writeDone bool
    56  		go func() {
    57  			rconn.Read(make([]byte, 100))
    58  			readDone = true
    59  		}()
    60  		go func() {
    61  			wconn.Write([]byte("a"))
    62  			writeDone = true
    63  		}()
    64  		synctest.Wait()
    65  		if readDone || writeDone {
    66  			t.Errorf("before unblocking: readDone=%v, writeDone=%v; want false", readDone, writeDone)
    67  		}
    68  		wconn.Close()
    69  		synctest.Wait()
    70  		if !readDone || !writeDone {
    71  			t.Errorf("after unblocking: readDone=%v, writeDone=%v; want true", readDone, writeDone)
    72  		}
    73  	})
    74  }
    75  
    76  func TestConnPartialWrite(t *testing.T) {
    77  	// A blocking write to a conn successfully writes some, but not all data.
    78  	synctest.Test(t, func(t *testing.T) {
    79  		const readSize = 5
    80  		data := []byte("0123456789")
    81  		rconn, wconn := nettest.NewConnPair()
    82  		rconn.SetReadBufferSize(1)
    83  		go func() {
    84  			got := make([]byte, readSize)
    85  			if n, err := io.ReadFull(rconn, got); n != readSize || err != nil {
    86  				t.Errorf("io.ReadFull() = %v, %v; want %v, nil", n, err, readSize)
    87  			}
    88  			if want := data[:readSize]; !bytes.Equal(got, want) {
    89  				t.Errorf("read %q, want %q", got, want)
    90  			}
    91  			rconn.Close()
    92  		}()
    93  		n, err := wconn.Write(data)
    94  		if n != readSize+1 || err == nil {
    95  			t.Errorf("Write() = %v, %v; want %v, error", n, err, readSize+1)
    96  		}
    97  	})
    98  }
    99  
   100  func TestConnReadDeadline(t *testing.T) {
   101  	for _, unblock := range []struct {
   102  		name string
   103  		f    func(*nettest.Conn)
   104  	}{{
   105  		name: "Write",
   106  		f: func(c *nettest.Conn) {
   107  			c.Write([]byte("x"))
   108  		},
   109  	}, {
   110  		name: "Close",
   111  		f: func(c *nettest.Conn) {
   112  			c.Close()
   113  		},
   114  	}, {
   115  		name: "CloseWrite",
   116  		f: func(c *nettest.Conn) {
   117  			c.CloseWrite()
   118  		},
   119  	}} {
   120  		for _, setDeadline := range []struct {
   121  			name string
   122  			f    func(*nettest.Conn, time.Time) error
   123  		}{{
   124  			name: "SetDeadline",
   125  			f:    (*nettest.Conn).SetDeadline,
   126  		}, {
   127  			name: "SetReadDeadline",
   128  			f:    (*nettest.Conn).SetReadDeadline,
   129  		}} {
   130  			t.Run(unblock.name+"/"+setDeadline.name, func(t *testing.T) {
   131  				testDeadline(t, func() deadlineTest {
   132  					rconn, wconn := nettest.NewConnPair()
   133  					return deadlineTest{
   134  						what: "Read()",
   135  						block: func() error {
   136  							_, err := rconn.Read(make([]byte, 1))
   137  							return err
   138  						},
   139  						unblock: func() {
   140  							unblock.f(wconn)
   141  						},
   142  						setDeadline: func(d time.Duration) {
   143  							setDeadline.f(rconn, time.Now().Add(d))
   144  						},
   145  					}
   146  				})
   147  			})
   148  		}
   149  	}
   150  }
   151  
   152  func TestConnWriteDeadline(t *testing.T) {
   153  	for _, unblock := range []struct {
   154  		name string
   155  		f    func(*nettest.Conn)
   156  	}{{
   157  		name: "Read",
   158  		f: func(c *nettest.Conn) {
   159  			io.Copy(io.Discard, c)
   160  		},
   161  	}, {
   162  		name: "Close",
   163  		f: func(c *nettest.Conn) {
   164  			c.Close()
   165  		},
   166  	}, {
   167  		name: "CloseRead",
   168  		f: func(c *nettest.Conn) {
   169  			c.CloseRead()
   170  		},
   171  	}} {
   172  		for _, setDeadline := range []struct {
   173  			name string
   174  			f    func(*nettest.Conn, time.Time) error
   175  		}{{
   176  			name: "SetDeadline",
   177  			f:    (*nettest.Conn).SetDeadline,
   178  		}, {
   179  			name: "SetWriteDeadline",
   180  			f:    (*nettest.Conn).SetWriteDeadline,
   181  		}} {
   182  			t.Run(unblock.name+"/"+setDeadline.name, func(t *testing.T) {
   183  				testDeadline(t, func() deadlineTest {
   184  					rconn, wconn := nettest.NewConnPair()
   185  					rconn.SetReadBufferSize(1)
   186  					return deadlineTest{
   187  						what: "Write()",
   188  						block: func() error {
   189  							_, err := wconn.Write([]byte("1234"))
   190  							wconn.Close()
   191  							return err
   192  						},
   193  						unblock: func() {
   194  							go unblock.f(rconn)
   195  						},
   196  						setDeadline: func(d time.Duration) {
   197  							setDeadline.f(wconn, time.Now().Add(d))
   198  						},
   199  					}
   200  				})
   201  			})
   202  		}
   203  	}
   204  }
   205  
   206  func TestConnCanRead(t *testing.T) {
   207  	synctest.Test(t, func(t *testing.T) {
   208  		rconn, wconn := nettest.NewConnPair()
   209  		if got, want := rconn.CanRead(), false; got != want {
   210  			t.Fatalf("before writing data: rconn.CanRead() = %v, want %v", got, want)
   211  		}
   212  		wconn.Write([]byte("a"))
   213  		if got, want := rconn.CanRead(), true; got != want {
   214  			t.Fatalf("after writing data: rconn.CanRead() = %v, want %v", got, want)
   215  		}
   216  		rconn.Read(make([]byte, 1))
   217  		if got, want := rconn.CanRead(), false; got != want {
   218  			t.Fatalf("after reading data: rconn.CanRead() = %v, want %v", got, want)
   219  		}
   220  		wconn.Close()
   221  		if got, want := rconn.CanRead(), true; got != want {
   222  			t.Fatalf("after closing: rconn.CanRead() = %v, want %v", got, want)
   223  		}
   224  	})
   225  }
   226  
   227  func TestConnIsClosed(t *testing.T) {
   228  	for _, test := range []struct {
   229  		name string
   230  		f    func() *nettest.Conn
   231  		want bool
   232  	}{{
   233  		name: "unclosed",
   234  		f: func() *nettest.Conn {
   235  			conn, _ := nettest.NewConnPair()
   236  			return conn
   237  		},
   238  		want: false,
   239  	}, {
   240  		name: "closed",
   241  		f: func() *nettest.Conn {
   242  			conn, _ := nettest.NewConnPair()
   243  			conn.Close()
   244  			return conn
   245  		},
   246  		want: true,
   247  	}, {
   248  		name: "read-closed",
   249  		f: func() *nettest.Conn {
   250  			conn, _ := nettest.NewConnPair()
   251  			conn.CloseRead()
   252  			return conn
   253  		},
   254  		want: false,
   255  	}, {
   256  		name: "write-closed",
   257  		f: func() *nettest.Conn {
   258  			conn, _ := nettest.NewConnPair()
   259  			conn.CloseWrite()
   260  			return conn
   261  		},
   262  		want: false,
   263  	}, {
   264  		name: "read-write-closed",
   265  		f: func() *nettest.Conn {
   266  			conn, _ := nettest.NewConnPair()
   267  			conn.CloseRead()
   268  			conn.CloseWrite()
   269  			return conn
   270  		},
   271  		want: true,
   272  	}} {
   273  		synctestSubtest(t, test.name, func(t *testing.T) {
   274  			conn := test.f()
   275  			if got, want := conn.IsClosed(), test.want; got != want {
   276  				t.Fatalf("conn.IsClosed() = %v, want %v", got, want)
   277  			}
   278  			if got, want := conn.Peer().IsClosed(), false; got != want {
   279  				t.Fatalf("conn.Peer().IsClosed() = %v, want %v", got, want)
   280  			}
   281  		})
   282  	}
   283  }
   284  
   285  var anyError = errors.New("any") // anyError is passed to isOpError to match any error
   286  
   287  func isOpError(err, want error) bool {
   288  	oe, ok := err.(*net.OpError)
   289  	return ok && (oe.Err == want || want == anyError)
   290  }
   291  
   292  func wantConnReadBytes(t *testing.T, c *nettest.Conn, want []byte) {
   293  	t.Helper()
   294  	got := make([]byte, len(want))
   295  	n, err := io.ReadFull(c, got)
   296  	if n < len(want) || err != nil {
   297  		t.Fatalf("io.ReadFull = %v, %v; want %v, nil", n, err, len(want))
   298  	}
   299  
   300  	if !bytes.Equal(got, want) {
   301  		t.Fatalf("io.ReadFull read %q, want %q", got, want)
   302  	}
   303  }
   304  
   305  func wantConnReadErr(t *testing.T, c *nettest.Conn, want error) {
   306  	t.Helper()
   307  	n, err := c.Read(make([]byte, 1))
   308  	if want == io.EOF {
   309  		if n != 0 || err != io.EOF {
   310  			t.Fatalf("c.Read() = %v, %v; want 0, io.EOF", n, err)
   311  		}
   312  	} else {
   313  		if n != 0 || !isOpError(err, want) {
   314  			t.Fatalf("c.Read() = %v, %v; want 0, OpError{Err: %q}", n, err, want)
   315  		}
   316  	}
   317  }
   318  
   319  func wantConnReadBlocked(t *testing.T, c *nettest.Conn) {
   320  	done := false
   321  	go func() {
   322  		n, err := c.Read(make([]byte, 1))
   323  		if n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
   324  			t.Errorf("c.Read() = %v, %v; want 0, ErrDeadlineExceeded", n, err)
   325  		}
   326  		done = true
   327  	}()
   328  	synctest.Wait()
   329  	if done {
   330  		t.Fatalf("Read unexpectedly returned before setting deadline")
   331  	}
   332  	c.SetReadDeadline(time.Now().Add(-1 * time.Second))
   333  	synctest.Wait()
   334  	c.SetReadDeadline(time.Time{})
   335  	if !done {
   336  		t.Fatalf("Read unexpectedly did not return after setting deadline")
   337  	}
   338  }
   339  
   340  func TestConnSetReadError(t *testing.T) {
   341  	synctest.Test(t, func(t *testing.T) {
   342  		wantErr := errors.New("error")
   343  		rconn, wconn := nettest.NewConnPair()
   344  		rconn.SetReadError(wantErr)
   345  
   346  		// Consume buffer before returning error.
   347  		wconn.Write([]byte("one"))
   348  		wantConnReadBytes(t, rconn, []byte("one"))
   349  		wantConnReadErr(t, rconn, wantErr)
   350  
   351  		// Write more to the buffer, suppressing error until buffer drains again.
   352  		wconn.Write([]byte("two"))
   353  		wantConnReadBytes(t, rconn, []byte("two"))
   354  		wantConnReadErr(t, rconn, wantErr)
   355  
   356  		// Error may be cleared.
   357  		rconn.SetReadError(nil)
   358  		wantConnReadBlocked(t, rconn)
   359  
   360  		// Close overrides read error.
   361  		rconn.SetReadError(wantErr)
   362  		wconn.Write([]byte("three"))
   363  		wconn.Close()
   364  		wantConnReadBytes(t, rconn, []byte("three"))
   365  		wantConnReadErr(t, rconn, io.EOF)
   366  
   367  		// Setting another read error does not override Close.
   368  		rconn.SetReadError(nil)
   369  		wantConnReadErr(t, rconn, io.EOF)
   370  		rconn.SetReadError(wantErr)
   371  		wantConnReadErr(t, rconn, io.EOF)
   372  
   373  		// ErrClosed takes precedence over read error.
   374  		rconn.Close()
   375  		wantConnReadErr(t, rconn, net.ErrClosed)
   376  	})
   377  }
   378  
   379  func wantConnWriteBytes(t *testing.T, c *nettest.Conn, b []byte) {
   380  	t.Helper()
   381  	if n, err := c.Write(b); n != len(b) || err != nil {
   382  		t.Fatalf("c.Write() = %v, %v; want %v, nil", n, err, len(b))
   383  	}
   384  }
   385  
   386  func wantConnWriteErr(t *testing.T, c *nettest.Conn, want error) {
   387  	t.Helper()
   388  	n, err := c.Write(make([]byte, 1))
   389  	if n != 0 || !isOpError(err, want) {
   390  		t.Fatalf("c.Write() = %v, %v; want 0, OpError{Err: %q}", n, err, want)
   391  	}
   392  }
   393  
   394  func TestConnSetWriteError(t *testing.T) {
   395  	synctest.Test(t, func(t *testing.T) {
   396  		wantErr := errors.New("error")
   397  		rconn, wconn := nettest.NewConnPair()
   398  		wconn.SetWriteError(wantErr)
   399  
   400  		// Error blocks writes.
   401  		wantConnWriteErr(t, wconn, wantErr)
   402  		wantConnReadBlocked(t, rconn)
   403  
   404  		// Error may be cleared.
   405  		wconn.SetWriteError(nil)
   406  		wantConnWriteBytes(t, wconn, []byte("one"))
   407  
   408  		// Restoring error does not prevent reading buffered data.
   409  		wconn.SetWriteError(wantErr)
   410  		wantConnWriteErr(t, wconn, wantErr)
   411  		wantConnReadBytes(t, rconn, []byte("one"))
   412  
   413  		// Error does not interfere with closing the conn.
   414  		wconn.Close()
   415  		wantConnReadErr(t, rconn, io.EOF)
   416  	})
   417  }
   418  
   419  func TestConnSetCloseError(t *testing.T) {
   420  	synctest.Test(t, func(t *testing.T) {
   421  		wantErr := errors.New("error")
   422  		rconn, wconn := nettest.NewConnPair()
   423  
   424  		wconn.SetCloseError(wantErr)
   425  		if _, err := wconn.Write([]byte("one")); err != nil {
   426  			t.Fatalf("wconn.Write = %v, want success", err)
   427  		}
   428  		if err := wconn.Close(); !isOpError(err, wantErr) {
   429  			t.Fatalf("wconn.Close = %v, want OpError{Err: %v}", err, wantErr)
   430  		}
   431  		if err := wconn.Close(); !isOpError(err, net.ErrClosed) {
   432  			t.Fatalf("wconn.Close = %v, want OpError{Err: net.ErrClosed}", err)
   433  		}
   434  		wantConnReadBytes(t, rconn, []byte("one"))
   435  		wantConnReadErr(t, rconn, io.EOF)
   436  	})
   437  }
   438  
   439  func TestConnCloseReadWriteError(t *testing.T) {
   440  	synctest.Test(t, func(t *testing.T) {
   441  		conn, _ := nettest.NewConnPair()
   442  		conn.SetCloseError(errors.New("error"))
   443  		if err := conn.CloseRead(); err != nil {
   444  			t.Fatalf("conn.CloseRead = %v, want nil", err)
   445  		}
   446  		if err := conn.CloseWrite(); err != nil {
   447  			t.Fatalf("conn.CloseRead = %v, want nil", err)
   448  		}
   449  	})
   450  }
   451  

View as plain text