// Copyright 2026 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package nettest import ( "bytes" "errors" "io" "math" "net" "net/netip" "os" "time" ) // Conn is an in-memory test implementation of net.Conn. type Conn struct { // Conns come in pairs. // Writes to one Conn are read by its peer, and vice-versa. // // A connHalf handles one direction of data flow. // A Conn consists of read and write halves. // A Conn's peer has the same halves, only swapped. // // A Conn reads from r and writes to w. r, w *connHalf // peer is the other endpoint. peer *Conn } // NewConnPair returns a pair of connected Conns. func NewConnPair() (*Conn, *Conn) { return newConnPair( net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:10000")), net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:10001")), ) } func newConnPair(addr1, addr2 net.Addr) (*Conn, *Conn) { h1 := newConnHalf(addr1) h2 := newConnHalf(addr2) c1 := &Conn{r: h1, w: h2} c2 := &Conn{r: h2, w: h1} c1.peer = c2 c2.peer = c1 c1.SetReadBufferSize(-1) c2.SetReadBufferSize(-1) return c1, c2 } // Peer returns the other end of the connection. func (c *Conn) Peer() *Conn { return c.peer } // Read reads data from the connection. func (c *Conn) Read(b []byte) (n int, err error) { n, err = c.r.read(b) if err != nil && err != io.EOF { err = &net.OpError{ Op: "read", Net: "tcp", Source: c.RemoteAddr(), Addr: c.LocalAddr(), Err: err, } } return n, err } // CanRead reports whether Read can proceed without blocking. func (c *Conn) CanRead() bool { return c.r.canRead() } // Write writes data to the connection. func (c *Conn) Write(b []byte) (n int, err error) { n, err = c.w.write(b) if err != nil { err = &net.OpError{ Op: "write", Net: "tcp", Source: c.LocalAddr(), Addr: c.RemoteAddr(), Err: err, } } return n, err } // IsClosed reports whether the connection has been closed. // A connection is closed if [CloseRead] and [CloseWrite] are both called, // or if [Close] is called. // // To identify when the other side of the Conn has been closed, // use Conn.Peer().IsClosed(). func (c *Conn) IsClosed() bool { c.r.lock() readClosed := c.r.readClosed c.r.unlock() c.w.lock() writeClosed := c.w.writeClosed c.w.unlock() return readClosed && writeClosed } var errClosedByPeer = errors.New("connection closed by peer") // CloseRead shuts down the reading side of the connection. func (c *Conn) CloseRead() error { c.r.lock() defer c.r.unlock() c.r.buf.Reset() // discard unread data c.r.readClosed = true return nil } // CloseWrite shuts down the writing side of the connection. func (c *Conn) CloseWrite() error { c.w.lock() defer c.w.unlock() c.w.writeClosed = true return nil } // Close closes the connection. func (c *Conn) Close() error { c.r.lock() readClosed := c.r.readClosed c.r.buf.Reset() // discard unread data c.r.readClosed = true err := c.r.closeErr c.r.unlock() c.w.lock() writeClosed := c.w.writeClosed c.w.writeClosed = true c.w.unlock() if readClosed && writeClosed { err = net.ErrClosed } if err != nil { err = &net.OpError{ Op: "close", Net: "tcp", Addr: c.LocalAddr(), Err: err, } } return err } // SetCloseError sets the error returned by Close. // Close still closes the connection. // A nil error restores the usual behavior. func (c *Conn) SetCloseError(err error) { c.r.lock() c.r.closeErr = err c.r.unlock() } // LocalAddr returns the (fake) local network address. func (c *Conn) LocalAddr() net.Addr { c.r.lock() defer c.r.unlock() return c.r.addr } // SetLocalAddr sets the local address. // // To set the remote address, set the local address of Conn's peer. func (c *Conn) SetLocalAddr(addr net.Addr) { c.r.lock() defer c.r.unlock() c.r.addr = addr } // LocalAddr returns the (fake) remote network address. func (c *Conn) RemoteAddr() net.Addr { c.r.lock() defer c.r.unlock() return c.w.addr } // SetDeadline sets the read and write deadlines for the connection. func (c *Conn) SetDeadline(t time.Time) error { c.SetReadDeadline(t) c.SetWriteDeadline(t) return nil } // SetReadDeadline sets the read deadline for the connection. func (c *Conn) SetReadDeadline(t time.Time) error { c.r.readDeadline.setDeadline(c.r, t) return nil } // SetWriteDeadline sets the write deadline for the connection. func (c *Conn) SetWriteDeadline(t time.Time) error { c.w.writeDeadline.setDeadline(c.w, t) return nil } // SetReadBufferSize sets the connection's read buffer. // Writes to the other end of the connection will block so long as the buffer is full. // Setting the size to 0 blocks all writes until the size is increased. func (c *Conn) SetReadBufferSize(size int) { if size < 0 { size = math.MaxInt } c.r.setBufferSize(size) } // SetReadError causes any currently blocked and future Read calls to return // a net.OpError wrapping err. It does not affect the other half of the connection. // Reads will return any buffered data before returning the error, // including data written after the error is set and io.EOF after the other end is closed. // A nil error restores the usual behavior. func (c *Conn) SetReadError(err error) { c.r.lock() defer c.r.unlock() c.r.readErr = err } // SetWriteError causes any currently blocked and future Write calls to return // a net.OpError wrapping err. It does not affect the other half of the connection. // Writes will not write data to the connection buffer while an error is set. // A nil error restores the usual behavior. func (c *Conn) SetWriteError(err error) { c.w.lock() defer c.w.unlock() c.w.writeErr = err } // connHalf is one direction data flow in a Conn. // The connHalf contains a buffer. // Writes to the connHalf push to the buffer and reads pull from it. type connHalf struct { addr net.Addr // A half can be readable and/or writable. // // These four channels act as a lock, // and allow waiting for readability/writability. // When the half is unlocked, exactly one channel contains a value. // When the half is locked, all channels are empty. lockr chan struct{} // readable lockw chan struct{} // writable lockrw chan struct{} // readable and writable lockc chan struct{} // neither readable nor writable // Read and write timeouts. readDeadline, writeDeadline connDeadline bufMax int // maximum buffer size buf bytes.Buffer readClosed, writeClosed bool readErr, writeErr error // errors returned by reads/writes closeErr error // error returned by closing the conn reading from this half } func newConnHalf(addr net.Addr) *connHalf { h := &connHalf{ addr: addr, lockw: make(chan struct{}, 1), lockr: make(chan struct{}, 1), lockrw: make(chan struct{}, 1), lockc: make(chan struct{}, 1), bufMax: math.MaxInt, // unlimited } h.unlock() return h } // lock locks h. func (h *connHalf) lock() { select { case <-h.lockw: // writable case <-h.lockr: // readable case <-h.lockrw: // readable and writable case <-h.lockc: // neither readable nor writable } } // unlock unlocks h. func (h *connHalf) unlock() { canRead := h.canReadLocked() canWrite := h.canWriteLocked() switch { case canRead && canWrite: h.lockrw <- struct{}{} // readable and writable case canRead: h.lockr <- struct{}{} // readable case canWrite: h.lockw <- struct{}{} // writable default: h.lockc <- struct{}{} // neither readable nor writable } } func (h *connHalf) canRead() bool { h.lock() defer h.unlock() return h.canReadLocked() } func (h *connHalf) canReadLocked() bool { return h.readErr != nil || h.readDeadline.expired || h.buf.Len() > 0 || h.readClosed || h.writeClosed } func (h *connHalf) canWriteLocked() bool { return h.writeErr != nil || h.writeDeadline.expired || h.bufMax > h.buf.Len() || h.readClosed || h.writeClosed } // waitAndLockForRead waits until h is readable and locks it. func (h *connHalf) waitAndLockForRead() { select { case <-h.lockr: // readable case <-h.lockrw: // readable and writable } } // waitAndLockForWrite waits until h is writable and locks it. func (h *connHalf) waitAndLockForWrite() { select { case <-h.lockw: // writable case <-h.lockrw: // readable and writable } } func (h *connHalf) read(b []byte) (n int, err error) { h.waitAndLockForRead() defer h.unlock() if h.readClosed { return 0, net.ErrClosed } if h.readDeadline.expired { return 0, os.ErrDeadlineExceeded } if h.buf.Len() > 0 { return h.buf.Read(b) } if h.writeClosed { return 0, io.EOF } return 0, h.readErr } func (h *connHalf) setBufferSize(size int) { h.lock() defer h.unlock() h.bufMax = size } func (h *connHalf) write(b []byte) (n int, err error) { for n < len(b) { nn, err := h.writePartial(b[n:]) n += nn if err != nil { return n, err } } return n, nil } func (h *connHalf) writePartial(b []byte) (n int, err error) { h.waitAndLockForWrite() defer h.unlock() if h.writeClosed { return 0, net.ErrClosed } if h.writeDeadline.expired { return 0, os.ErrDeadlineExceeded } if h.readClosed { return 0, errClosedByPeer } if h.writeErr != nil { return 0, h.writeErr } writeMax := h.bufMax - h.buf.Len() if writeMax < len(b) { b = b[:writeMax] } return h.buf.Write(b) } type connDeadline struct { timer *time.Timer expired bool } type locker interface { lock() unlock() } func (d *connDeadline) setDeadline(mu locker, t time.Time) { mu.lock() defer mu.unlock() if d.timer != nil { d.timer.Stop() d.timer = nil } if t.IsZero() { // No deadline. d.expired = false return } expiry := time.Until(t) if expiry <= 0 { // Deadline has already passed. d.expired = true return } // Deadline is in the future. d.expired = false var timer *time.Timer timer = time.AfterFunc(expiry, func() { mu.lock() defer mu.unlock() if d.timer == timer { d.timer = nil d.expired = true } }) d.timer = timer }