Source file src/internal/nettest/conn.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package nettest
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"io"
    11  	"math"
    12  	"net"
    13  	"net/netip"
    14  	"os"
    15  	"time"
    16  )
    17  
    18  // Conn is an in-memory test implementation of net.Conn.
    19  type Conn struct {
    20  	// Conns come in pairs.
    21  	// Writes to one Conn are read by its peer, and vice-versa.
    22  	//
    23  	// A connHalf handles one direction of data flow.
    24  	// A Conn consists of read and write halves.
    25  	// A Conn's peer has the same halves, only swapped.
    26  	//
    27  	// A Conn reads from r and writes to w.
    28  	r, w *connHalf
    29  
    30  	// peer is the other endpoint.
    31  	peer *Conn
    32  }
    33  
    34  // NewConnPair returns a pair of connected Conns.
    35  func NewConnPair() (*Conn, *Conn) {
    36  	return newConnPair(
    37  		net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:10000")),
    38  		net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:10001")),
    39  	)
    40  }
    41  
    42  func newConnPair(addr1, addr2 net.Addr) (*Conn, *Conn) {
    43  	h1 := newConnHalf(addr1)
    44  	h2 := newConnHalf(addr2)
    45  	c1 := &Conn{r: h1, w: h2}
    46  	c2 := &Conn{r: h2, w: h1}
    47  	c1.peer = c2
    48  	c2.peer = c1
    49  	c1.SetReadBufferSize(-1)
    50  	c2.SetReadBufferSize(-1)
    51  	return c1, c2
    52  }
    53  
    54  // Peer returns the other end of the connection.
    55  func (c *Conn) Peer() *Conn {
    56  	return c.peer
    57  }
    58  
    59  // Read reads data from the connection.
    60  func (c *Conn) Read(b []byte) (n int, err error) {
    61  	n, err = c.r.read(b)
    62  	if err != nil && err != io.EOF {
    63  		err = &net.OpError{
    64  			Op:     "read",
    65  			Net:    "tcp",
    66  			Source: c.RemoteAddr(),
    67  			Addr:   c.LocalAddr(),
    68  			Err:    err,
    69  		}
    70  	}
    71  	return n, err
    72  }
    73  
    74  // CanRead reports whether Read can proceed without blocking.
    75  func (c *Conn) CanRead() bool {
    76  	return c.r.canRead()
    77  }
    78  
    79  // Write writes data to the connection.
    80  func (c *Conn) Write(b []byte) (n int, err error) {
    81  	n, err = c.w.write(b)
    82  	if err != nil {
    83  		err = &net.OpError{
    84  			Op:     "write",
    85  			Net:    "tcp",
    86  			Source: c.LocalAddr(),
    87  			Addr:   c.RemoteAddr(),
    88  			Err:    err,
    89  		}
    90  	}
    91  	return n, err
    92  }
    93  
    94  // IsClosed reports whether the connection has been closed.
    95  // A connection is closed if [CloseRead] and [CloseWrite] are both called,
    96  // or if [Close] is called.
    97  //
    98  // To identify when the other side of the Conn has been closed,
    99  // use Conn.Peer().IsClosed().
   100  func (c *Conn) IsClosed() bool {
   101  	c.r.lock()
   102  	readClosed := c.r.readClosed
   103  	c.r.unlock()
   104  	c.w.lock()
   105  	writeClosed := c.w.writeClosed
   106  	c.w.unlock()
   107  	return readClosed && writeClosed
   108  }
   109  
   110  var errClosedByPeer = errors.New("connection closed by peer")
   111  
   112  // CloseRead shuts down the reading side of the connection.
   113  func (c *Conn) CloseRead() error {
   114  	c.r.lock()
   115  	defer c.r.unlock()
   116  	c.r.buf.Reset() // discard unread data
   117  	c.r.readClosed = true
   118  	return nil
   119  }
   120  
   121  // CloseWrite shuts down the writing side of the connection.
   122  func (c *Conn) CloseWrite() error {
   123  	c.w.lock()
   124  	defer c.w.unlock()
   125  	c.w.writeClosed = true
   126  	return nil
   127  }
   128  
   129  // Close closes the connection.
   130  func (c *Conn) Close() error {
   131  	c.r.lock()
   132  	readClosed := c.r.readClosed
   133  	c.r.buf.Reset() // discard unread data
   134  	c.r.readClosed = true
   135  	err := c.r.closeErr
   136  	c.r.unlock()
   137  
   138  	c.w.lock()
   139  	writeClosed := c.w.writeClosed
   140  	c.w.writeClosed = true
   141  	c.w.unlock()
   142  
   143  	if readClosed && writeClosed {
   144  		err = net.ErrClosed
   145  	}
   146  	if err != nil {
   147  		err = &net.OpError{
   148  			Op:   "close",
   149  			Net:  "tcp",
   150  			Addr: c.LocalAddr(),
   151  			Err:  err,
   152  		}
   153  	}
   154  	return err
   155  }
   156  
   157  // SetCloseError sets the error returned by Close.
   158  // Close still closes the connection.
   159  // A nil error restores the usual behavior.
   160  func (c *Conn) SetCloseError(err error) {
   161  	c.r.lock()
   162  	c.r.closeErr = err
   163  	c.r.unlock()
   164  }
   165  
   166  // LocalAddr returns the (fake) local network address.
   167  func (c *Conn) LocalAddr() net.Addr {
   168  	c.r.lock()
   169  	defer c.r.unlock()
   170  	return c.r.addr
   171  }
   172  
   173  // SetLocalAddr sets the local address.
   174  //
   175  // To set the remote address, set the local address of Conn's peer.
   176  func (c *Conn) SetLocalAddr(addr net.Addr) {
   177  	c.r.lock()
   178  	defer c.r.unlock()
   179  	c.r.addr = addr
   180  }
   181  
   182  // LocalAddr returns the (fake) remote network address.
   183  func (c *Conn) RemoteAddr() net.Addr {
   184  	c.r.lock()
   185  	defer c.r.unlock()
   186  	return c.w.addr
   187  }
   188  
   189  // SetDeadline sets the read and write deadlines for the connection.
   190  func (c *Conn) SetDeadline(t time.Time) error {
   191  	c.SetReadDeadline(t)
   192  	c.SetWriteDeadline(t)
   193  	return nil
   194  }
   195  
   196  // SetReadDeadline sets the read deadline for the connection.
   197  func (c *Conn) SetReadDeadline(t time.Time) error {
   198  	c.r.readDeadline.setDeadline(c.r, t)
   199  	return nil
   200  }
   201  
   202  // SetWriteDeadline sets the write deadline for the connection.
   203  func (c *Conn) SetWriteDeadline(t time.Time) error {
   204  	c.w.writeDeadline.setDeadline(c.w, t)
   205  	return nil
   206  }
   207  
   208  // SetReadBufferSize sets the connection's read buffer.
   209  // Writes to the other end of the connection will block so long as the buffer is full.
   210  // Setting the size to 0 blocks all writes until the size is increased.
   211  func (c *Conn) SetReadBufferSize(size int) {
   212  	if size < 0 {
   213  		size = math.MaxInt
   214  	}
   215  	c.r.setBufferSize(size)
   216  }
   217  
   218  // SetReadError causes any currently blocked and future Read calls to return
   219  // a net.OpError wrapping err. It does not affect the other half of the connection.
   220  // Reads will return any buffered data before returning the error,
   221  // including data written after the error is set and io.EOF after the other end is closed.
   222  // A nil error restores the usual behavior.
   223  func (c *Conn) SetReadError(err error) {
   224  	c.r.lock()
   225  	defer c.r.unlock()
   226  	c.r.readErr = err
   227  }
   228  
   229  // SetWriteError causes any currently blocked and future Write calls to return
   230  // a net.OpError wrapping err. It does not affect the other half of the connection.
   231  // Writes will not write data to the connection buffer while an error is set.
   232  // A nil error restores the usual behavior.
   233  func (c *Conn) SetWriteError(err error) {
   234  	c.w.lock()
   235  	defer c.w.unlock()
   236  	c.w.writeErr = err
   237  }
   238  
   239  // connHalf is one direction data flow in a Conn.
   240  // The connHalf contains a buffer.
   241  // Writes to the connHalf push to the buffer and reads pull from it.
   242  type connHalf struct {
   243  	addr net.Addr
   244  
   245  	// A half can be readable and/or writable.
   246  	//
   247  	// These four channels act as a lock,
   248  	// and allow waiting for readability/writability.
   249  	// When the half is unlocked, exactly one channel contains a value.
   250  	// When the half is locked, all channels are empty.
   251  	lockr  chan struct{} // readable
   252  	lockw  chan struct{} // writable
   253  	lockrw chan struct{} // readable and writable
   254  	lockc  chan struct{} // neither readable nor writable
   255  
   256  	// Read and write timeouts.
   257  	readDeadline, writeDeadline connDeadline
   258  
   259  	bufMax int // maximum buffer size
   260  	buf    bytes.Buffer
   261  
   262  	readClosed, writeClosed bool
   263  	readErr, writeErr       error // errors returned by reads/writes
   264  	closeErr                error // error returned by closing the conn reading from this half
   265  }
   266  
   267  func newConnHalf(addr net.Addr) *connHalf {
   268  	h := &connHalf{
   269  		addr:   addr,
   270  		lockw:  make(chan struct{}, 1),
   271  		lockr:  make(chan struct{}, 1),
   272  		lockrw: make(chan struct{}, 1),
   273  		lockc:  make(chan struct{}, 1),
   274  		bufMax: math.MaxInt, // unlimited
   275  	}
   276  	h.unlock()
   277  	return h
   278  }
   279  
   280  // lock locks h.
   281  func (h *connHalf) lock() {
   282  	select {
   283  	case <-h.lockw: // writable
   284  	case <-h.lockr: // readable
   285  	case <-h.lockrw: // readable and writable
   286  	case <-h.lockc: // neither readable nor writable
   287  	}
   288  }
   289  
   290  // unlock unlocks h.
   291  func (h *connHalf) unlock() {
   292  	canRead := h.canReadLocked()
   293  	canWrite := h.canWriteLocked()
   294  	switch {
   295  	case canRead && canWrite:
   296  		h.lockrw <- struct{}{} // readable and writable
   297  	case canRead:
   298  		h.lockr <- struct{}{} // readable
   299  	case canWrite:
   300  		h.lockw <- struct{}{} // writable
   301  	default:
   302  		h.lockc <- struct{}{} // neither readable nor writable
   303  	}
   304  }
   305  
   306  func (h *connHalf) canRead() bool {
   307  	h.lock()
   308  	defer h.unlock()
   309  	return h.canReadLocked()
   310  }
   311  
   312  func (h *connHalf) canReadLocked() bool {
   313  	return h.readErr != nil || h.readDeadline.expired || h.buf.Len() > 0 || h.readClosed || h.writeClosed
   314  }
   315  
   316  func (h *connHalf) canWriteLocked() bool {
   317  	return h.writeErr != nil || h.writeDeadline.expired || h.bufMax > h.buf.Len() || h.readClosed || h.writeClosed
   318  }
   319  
   320  // waitAndLockForRead waits until h is readable and locks it.
   321  func (h *connHalf) waitAndLockForRead() {
   322  	select {
   323  	case <-h.lockr:
   324  		// readable
   325  	case <-h.lockrw:
   326  		// readable and writable
   327  	}
   328  }
   329  
   330  // waitAndLockForWrite waits until h is writable and locks it.
   331  func (h *connHalf) waitAndLockForWrite() {
   332  	select {
   333  	case <-h.lockw:
   334  		// writable
   335  	case <-h.lockrw:
   336  		// readable and writable
   337  	}
   338  }
   339  
   340  func (h *connHalf) read(b []byte) (n int, err error) {
   341  	h.waitAndLockForRead()
   342  	defer h.unlock()
   343  	if h.readClosed {
   344  		return 0, net.ErrClosed
   345  	}
   346  	if h.readDeadline.expired {
   347  		return 0, os.ErrDeadlineExceeded
   348  	}
   349  	if h.buf.Len() > 0 {
   350  		return h.buf.Read(b)
   351  	}
   352  	if h.writeClosed {
   353  		return 0, io.EOF
   354  	}
   355  	return 0, h.readErr
   356  }
   357  
   358  func (h *connHalf) setBufferSize(size int) {
   359  	h.lock()
   360  	defer h.unlock()
   361  	h.bufMax = size
   362  }
   363  
   364  func (h *connHalf) write(b []byte) (n int, err error) {
   365  	for n < len(b) {
   366  		nn, err := h.writePartial(b[n:])
   367  		n += nn
   368  		if err != nil {
   369  			return n, err
   370  		}
   371  	}
   372  	return n, nil
   373  }
   374  
   375  func (h *connHalf) writePartial(b []byte) (n int, err error) {
   376  	h.waitAndLockForWrite()
   377  	defer h.unlock()
   378  	if h.writeClosed {
   379  		return 0, net.ErrClosed
   380  	}
   381  	if h.writeDeadline.expired {
   382  		return 0, os.ErrDeadlineExceeded
   383  	}
   384  	if h.readClosed {
   385  		return 0, errClosedByPeer
   386  	}
   387  	if h.writeErr != nil {
   388  		return 0, h.writeErr
   389  	}
   390  	writeMax := h.bufMax - h.buf.Len()
   391  	if writeMax < len(b) {
   392  		b = b[:writeMax]
   393  	}
   394  	return h.buf.Write(b)
   395  }
   396  
   397  type connDeadline struct {
   398  	timer   *time.Timer
   399  	expired bool
   400  }
   401  
   402  type locker interface {
   403  	lock()
   404  	unlock()
   405  }
   406  
   407  func (d *connDeadline) setDeadline(mu locker, t time.Time) {
   408  	mu.lock()
   409  	defer mu.unlock()
   410  	if d.timer != nil {
   411  		d.timer.Stop()
   412  		d.timer = nil
   413  	}
   414  	if t.IsZero() {
   415  		// No deadline.
   416  		d.expired = false
   417  		return
   418  	}
   419  	expiry := time.Until(t)
   420  	if expiry <= 0 {
   421  		// Deadline has already passed.
   422  		d.expired = true
   423  		return
   424  	}
   425  	// Deadline is in the future.
   426  	d.expired = false
   427  	var timer *time.Timer
   428  	timer = time.AfterFunc(expiry, func() {
   429  		mu.lock()
   430  		defer mu.unlock()
   431  		if d.timer == timer {
   432  			d.timer = nil
   433  			d.expired = true
   434  		}
   435  	})
   436  	d.timer = timer
   437  }
   438  

View as plain text