Source file src/net/net_fake.go

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Fake networking for js/wasm and wasip1/wasm.
     6  // It is intended to allow tests of other package to pass.
     7  
     8  //go:build js || wasip1
     9  
    10  package net
    11  
    12  import (
    13  	"context"
    14  	"errors"
    15  	"io"
    16  	"os"
    17  	"sync"
    18  	"sync/atomic"
    19  	"syscall"
    20  	"time"
    21  )
    22  
    23  var (
    24  	sockets         sync.Map // fakeSockAddr → *netFD
    25  	fakePorts       sync.Map // int (port #) → *netFD
    26  	nextPortCounter atomic.Int32
    27  )
    28  
    29  const defaultBuffer = 65535
    30  
    31  type fakeSockAddr struct {
    32  	family  int
    33  	address string
    34  }
    35  
    36  func fakeAddr(sa sockaddr) fakeSockAddr {
    37  	return fakeSockAddr{
    38  		family:  sa.family(),
    39  		address: sa.String(),
    40  	}
    41  }
    42  
    43  // socket returns a network file descriptor that is ready for
    44  // I/O using the fake network.
    45  func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
    46  	if raddr != nil && ctrlCtxFn != nil {
    47  		return nil, os.NewSyscallError("socket", syscall.ENOTSUP)
    48  	}
    49  	switch sotype {
    50  	case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET, syscall.SOCK_DGRAM:
    51  	default:
    52  		return nil, os.NewSyscallError("socket", syscall.ENOTSUP)
    53  	}
    54  
    55  	fd := &netFD{
    56  		family: family,
    57  		sotype: sotype,
    58  		net:    net,
    59  	}
    60  	fd.fakeNetFD = newFakeNetFD(fd)
    61  
    62  	if raddr == nil {
    63  		if err := fakeListen(fd, laddr); err != nil {
    64  			fd.Close()
    65  			return nil, err
    66  		}
    67  		return fd, nil
    68  	}
    69  
    70  	if err := fakeConnect(ctx, fd, laddr, raddr); err != nil {
    71  		fd.Close()
    72  		return nil, err
    73  	}
    74  	return fd, nil
    75  }
    76  
    77  func validateResolvedAddr(net string, family int, sa sockaddr) error {
    78  	validateIP := func(ip IP) error {
    79  		switch family {
    80  		case syscall.AF_INET:
    81  			if len(ip) != 4 {
    82  				return &AddrError{
    83  					Err:  "non-IPv4 address",
    84  					Addr: ip.String(),
    85  				}
    86  			}
    87  		case syscall.AF_INET6:
    88  			if len(ip) != 16 {
    89  				return &AddrError{
    90  					Err:  "non-IPv6 address",
    91  					Addr: ip.String(),
    92  				}
    93  			}
    94  		default:
    95  			panic("net: unexpected address family in validateResolvedAddr")
    96  		}
    97  		return nil
    98  	}
    99  
   100  	switch net {
   101  	case "tcp", "tcp4", "tcp6":
   102  		sa, ok := sa.(*TCPAddr)
   103  		if !ok {
   104  			return &AddrError{
   105  				Err:  "non-TCP address for " + net + " network",
   106  				Addr: sa.String(),
   107  			}
   108  		}
   109  		if err := validateIP(sa.IP); err != nil {
   110  			return err
   111  		}
   112  		if sa.Port <= 0 || sa.Port >= 1<<16 {
   113  			return &AddrError{
   114  				Err:  "port out of range",
   115  				Addr: sa.String(),
   116  			}
   117  		}
   118  		return nil
   119  
   120  	case "udp", "udp4", "udp6":
   121  		sa, ok := sa.(*UDPAddr)
   122  		if !ok {
   123  			return &AddrError{
   124  				Err:  "non-UDP address for " + net + " network",
   125  				Addr: sa.String(),
   126  			}
   127  		}
   128  		if err := validateIP(sa.IP); err != nil {
   129  			return err
   130  		}
   131  		if sa.Port <= 0 || sa.Port >= 1<<16 {
   132  			return &AddrError{
   133  				Err:  "port out of range",
   134  				Addr: sa.String(),
   135  			}
   136  		}
   137  		return nil
   138  
   139  	case "unix", "unixgram", "unixpacket":
   140  		sa, ok := sa.(*UnixAddr)
   141  		if !ok {
   142  			return &AddrError{
   143  				Err:  "non-Unix address for " + net + " network",
   144  				Addr: sa.String(),
   145  			}
   146  		}
   147  		if sa.Name != "" {
   148  			i := len(sa.Name) - 1
   149  			for i > 0 && !os.IsPathSeparator(sa.Name[i]) {
   150  				i--
   151  			}
   152  			for i > 0 && os.IsPathSeparator(sa.Name[i]) {
   153  				i--
   154  			}
   155  			if i <= 0 {
   156  				return &AddrError{
   157  					Err:  "unix socket name missing path component",
   158  					Addr: sa.Name,
   159  				}
   160  			}
   161  			if _, err := os.Stat(sa.Name[:i+1]); err != nil {
   162  				return &AddrError{
   163  					Err:  err.Error(),
   164  					Addr: sa.Name,
   165  				}
   166  			}
   167  		}
   168  		return nil
   169  
   170  	default:
   171  		return &AddrError{
   172  			Err:  syscall.EAFNOSUPPORT.Error(),
   173  			Addr: sa.String(),
   174  		}
   175  	}
   176  }
   177  
   178  func matchIPFamily(family int, addr sockaddr) sockaddr {
   179  	convertIP := func(ip IP) IP {
   180  		switch family {
   181  		case syscall.AF_INET:
   182  			return ip.To4()
   183  		case syscall.AF_INET6:
   184  			return ip.To16()
   185  		default:
   186  			return ip
   187  		}
   188  	}
   189  
   190  	switch addr := addr.(type) {
   191  	case *TCPAddr:
   192  		ip := convertIP(addr.IP)
   193  		if ip == nil || len(ip) == len(addr.IP) {
   194  			return addr
   195  		}
   196  		return &TCPAddr{IP: ip, Port: addr.Port, Zone: addr.Zone}
   197  	case *UDPAddr:
   198  		ip := convertIP(addr.IP)
   199  		if ip == nil || len(ip) == len(addr.IP) {
   200  			return addr
   201  		}
   202  		return &UDPAddr{IP: ip, Port: addr.Port, Zone: addr.Zone}
   203  	default:
   204  		return addr
   205  	}
   206  }
   207  
   208  type fakeNetFD struct {
   209  	fd           *netFD
   210  	assignedPort int // 0 if no port has been assigned for this socket
   211  
   212  	queue         *packetQueue // incoming packets
   213  	peer          *netFD       // connected peer (for outgoing packets); nil for listeners and PacketConns
   214  	readDeadline  atomic.Pointer[deadlineTimer]
   215  	writeDeadline atomic.Pointer[deadlineTimer]
   216  
   217  	fakeAddr fakeSockAddr // cached fakeSockAddr equivalent of fd.laddr
   218  
   219  	// The incoming channels hold incoming connections that have not yet been accepted.
   220  	// All of these channels are 1-buffered.
   221  	incoming      chan []*netFD // holds the queue when it has >0 but <SOMAXCONN pending connections; closed when the Listener is closed
   222  	incomingFull  chan []*netFD // holds the queue when it has SOMAXCONN pending connections
   223  	incomingEmpty chan bool     // holds true when the incoming queue is empty
   224  }
   225  
   226  func newFakeNetFD(fd *netFD) *fakeNetFD {
   227  	ffd := &fakeNetFD{fd: fd}
   228  	ffd.readDeadline.Store(newDeadlineTimer(noDeadline))
   229  	ffd.writeDeadline.Store(newDeadlineTimer(noDeadline))
   230  	return ffd
   231  }
   232  
   233  func (ffd *fakeNetFD) Read(p []byte) (n int, err error) {
   234  	n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, false, nil)
   235  	return n, err
   236  }
   237  
   238  func (ffd *fakeNetFD) Write(p []byte) (nn int, err error) {
   239  	peer := ffd.peer
   240  	if peer == nil {
   241  		if ffd.fd.raddr == nil {
   242  			return 0, os.NewSyscallError("write", syscall.ENOTCONN)
   243  		}
   244  		peeri, _ := sockets.Load(fakeAddr(ffd.fd.raddr.(sockaddr)))
   245  		if peeri == nil {
   246  			return 0, os.NewSyscallError("write", syscall.ECONNRESET)
   247  		}
   248  		peer = peeri.(*netFD)
   249  		if peer.queue == nil {
   250  			return 0, os.NewSyscallError("write", syscall.ECONNRESET)
   251  		}
   252  	}
   253  
   254  	if peer.fakeNetFD == nil {
   255  		return 0, os.NewSyscallError("write", syscall.EINVAL)
   256  	}
   257  	return peer.queue.write(ffd.writeDeadline.Load(), p, ffd.fd.laddr.(sockaddr))
   258  }
   259  
   260  func (ffd *fakeNetFD) Close() (err error) {
   261  	if ffd.fakeAddr != (fakeSockAddr{}) {
   262  		sockets.CompareAndDelete(ffd.fakeAddr, ffd.fd)
   263  	}
   264  
   265  	if ffd.queue != nil {
   266  		if closeErr := ffd.queue.closeRead(); err == nil {
   267  			err = closeErr
   268  		}
   269  	}
   270  	if ffd.peer != nil {
   271  		if closeErr := ffd.peer.queue.closeWrite(); err == nil {
   272  			err = closeErr
   273  		}
   274  	}
   275  	ffd.readDeadline.Load().Reset(noDeadline)
   276  	ffd.writeDeadline.Load().Reset(noDeadline)
   277  
   278  	if ffd.incoming != nil {
   279  		var (
   280  			incoming []*netFD
   281  			ok       bool
   282  		)
   283  		select {
   284  		case _, ok = <-ffd.incomingEmpty:
   285  		case incoming, ok = <-ffd.incoming:
   286  		case incoming, ok = <-ffd.incomingFull:
   287  		}
   288  		if ok {
   289  			// Sends on ffd.incoming require a receive first.
   290  			// Since we successfully received, no other goroutine may
   291  			// send on it at this point, and we may safely close it.
   292  			close(ffd.incoming)
   293  
   294  			for _, c := range incoming {
   295  				c.Close()
   296  			}
   297  		}
   298  	}
   299  
   300  	if ffd.assignedPort != 0 {
   301  		fakePorts.CompareAndDelete(ffd.assignedPort, ffd.fd)
   302  	}
   303  
   304  	return err
   305  }
   306  
   307  func (ffd *fakeNetFD) closeRead() error {
   308  	return ffd.queue.closeRead()
   309  }
   310  
   311  func (ffd *fakeNetFD) closeWrite() error {
   312  	if ffd.peer == nil {
   313  		return os.NewSyscallError("closeWrite", syscall.ENOTCONN)
   314  	}
   315  	return ffd.peer.queue.closeWrite()
   316  }
   317  
   318  func (ffd *fakeNetFD) accept(laddr Addr) (*netFD, error) {
   319  	if ffd.incoming == nil {
   320  		return nil, os.NewSyscallError("accept", syscall.EINVAL)
   321  	}
   322  
   323  	var (
   324  		incoming []*netFD
   325  		ok       bool
   326  	)
   327  	expired := ffd.readDeadline.Load().expired
   328  	select {
   329  	case <-expired:
   330  		return nil, os.ErrDeadlineExceeded
   331  	case incoming, ok = <-ffd.incoming:
   332  		if !ok {
   333  			return nil, ErrClosed
   334  		}
   335  		select {
   336  		case <-expired:
   337  			ffd.incoming <- incoming
   338  			return nil, os.ErrDeadlineExceeded
   339  		default:
   340  		}
   341  	case incoming, ok = <-ffd.incomingFull:
   342  		select {
   343  		case <-expired:
   344  			ffd.incomingFull <- incoming
   345  			return nil, os.ErrDeadlineExceeded
   346  		default:
   347  		}
   348  	}
   349  
   350  	peer := incoming[0]
   351  	incoming = incoming[1:]
   352  	if len(incoming) == 0 {
   353  		ffd.incomingEmpty <- true
   354  	} else {
   355  		ffd.incoming <- incoming
   356  	}
   357  	return peer, nil
   358  }
   359  
   360  func (ffd *fakeNetFD) SetDeadline(t time.Time) error {
   361  	err1 := ffd.SetReadDeadline(t)
   362  	err2 := ffd.SetWriteDeadline(t)
   363  	if err1 != nil {
   364  		return err1
   365  	}
   366  	return err2
   367  }
   368  
   369  func (ffd *fakeNetFD) SetReadDeadline(t time.Time) error {
   370  	dt := ffd.readDeadline.Load()
   371  	if !dt.Reset(t) {
   372  		ffd.readDeadline.Store(newDeadlineTimer(t))
   373  	}
   374  	return nil
   375  }
   376  
   377  func (ffd *fakeNetFD) SetWriteDeadline(t time.Time) error {
   378  	dt := ffd.writeDeadline.Load()
   379  	if !dt.Reset(t) {
   380  		ffd.writeDeadline.Store(newDeadlineTimer(t))
   381  	}
   382  	return nil
   383  }
   384  
   385  const maxPacketSize = 65535
   386  
   387  type packet struct {
   388  	buf       []byte
   389  	bufOffset int
   390  	next      *packet
   391  	from      sockaddr
   392  }
   393  
   394  func (p *packet) clear() {
   395  	p.buf = p.buf[:0]
   396  	p.bufOffset = 0
   397  	p.next = nil
   398  	p.from = nil
   399  }
   400  
   401  var packetPool = sync.Pool{
   402  	New: func() any { return new(packet) },
   403  }
   404  
   405  type packetQueueState struct {
   406  	head, tail      *packet // unqueued packets
   407  	nBytes          int     // number of bytes enqueued in the packet buffers starting from head
   408  	readBufferBytes int     // soft limit on nbytes; no more packets may be enqueued when the limit is exceeded
   409  	readClosed      bool    // true if the reader of the queue has stopped reading
   410  	writeClosed     bool    // true if the writer of the queue has stopped writing; the reader sees either io.EOF or syscall.ECONNRESET when they have read all buffered packets
   411  	noLinger        bool    // if true, the reader sees ECONNRESET instead of EOF
   412  }
   413  
   414  // A packetQueue is a set of 1-buffered channels implementing a FIFO queue
   415  // of packets.
   416  type packetQueue struct {
   417  	empty chan packetQueueState // contains configuration parameters when the queue is empty and not closed
   418  	ready chan packetQueueState // contains the packets when non-empty or closed
   419  	full  chan packetQueueState // contains the packets when buffer is full and not closed
   420  }
   421  
   422  func newPacketQueue(readBufferBytes int) *packetQueue {
   423  	pq := &packetQueue{
   424  		empty: make(chan packetQueueState, 1),
   425  		ready: make(chan packetQueueState, 1),
   426  		full:  make(chan packetQueueState, 1),
   427  	}
   428  	pq.put(packetQueueState{
   429  		readBufferBytes: readBufferBytes,
   430  	})
   431  	return pq
   432  }
   433  
   434  func (pq *packetQueue) get() packetQueueState {
   435  	var q packetQueueState
   436  	select {
   437  	case q = <-pq.empty:
   438  	case q = <-pq.ready:
   439  	case q = <-pq.full:
   440  	}
   441  	return q
   442  }
   443  
   444  func (pq *packetQueue) put(q packetQueueState) {
   445  	switch {
   446  	case q.readClosed || q.writeClosed:
   447  		pq.ready <- q
   448  	case q.nBytes >= q.readBufferBytes:
   449  		pq.full <- q
   450  	case q.head == nil:
   451  		if q.nBytes > 0 {
   452  			defer panic("net: put with nil packet list and nonzero nBytes")
   453  		}
   454  		pq.empty <- q
   455  	default:
   456  		pq.ready <- q
   457  	}
   458  }
   459  
   460  func (pq *packetQueue) closeRead() error {
   461  	q := pq.get()
   462  	q.readClosed = true
   463  	pq.put(q)
   464  	return nil
   465  }
   466  
   467  func (pq *packetQueue) closeWrite() error {
   468  	q := pq.get()
   469  	q.writeClosed = true
   470  	pq.put(q)
   471  	return nil
   472  }
   473  
   474  func (pq *packetQueue) setLinger(linger bool) error {
   475  	q := pq.get()
   476  	defer func() { pq.put(q) }()
   477  
   478  	if q.writeClosed {
   479  		return ErrClosed
   480  	}
   481  	q.noLinger = !linger
   482  	return nil
   483  }
   484  
   485  func (pq *packetQueue) write(dt *deadlineTimer, b []byte, from sockaddr) (n int, err error) {
   486  	for {
   487  		dn := len(b)
   488  		if dn > maxPacketSize {
   489  			dn = maxPacketSize
   490  		}
   491  
   492  		dn, err = pq.send(dt, b[:dn], from, true)
   493  		n += dn
   494  		if err != nil {
   495  			return n, err
   496  		}
   497  
   498  		b = b[dn:]
   499  		if len(b) == 0 {
   500  			return n, nil
   501  		}
   502  	}
   503  }
   504  
   505  func (pq *packetQueue) send(dt *deadlineTimer, b []byte, from sockaddr, block bool) (n int, err error) {
   506  	if from == nil {
   507  		return 0, os.NewSyscallError("send", syscall.EINVAL)
   508  	}
   509  	if len(b) > maxPacketSize {
   510  		return 0, os.NewSyscallError("send", syscall.EMSGSIZE)
   511  	}
   512  
   513  	var q packetQueueState
   514  	var full chan packetQueueState
   515  	if !block {
   516  		full = pq.full
   517  	}
   518  
   519  	select {
   520  	case <-dt.expired:
   521  		return 0, os.ErrDeadlineExceeded
   522  
   523  	case q = <-full:
   524  		pq.put(q)
   525  		return 0, os.NewSyscallError("send", syscall.ENOBUFS)
   526  
   527  	case q = <-pq.empty:
   528  	case q = <-pq.ready:
   529  	}
   530  	defer func() { pq.put(q) }()
   531  
   532  	// Don't allow a packet to be sent if the deadline has expired,
   533  	// even if the select above chose a different branch.
   534  	select {
   535  	case <-dt.expired:
   536  		return 0, os.ErrDeadlineExceeded
   537  	default:
   538  	}
   539  	if q.writeClosed {
   540  		return 0, ErrClosed
   541  	} else if q.readClosed && q.nBytes >= q.readBufferBytes {
   542  		return 0, os.NewSyscallError("send", syscall.ECONNRESET)
   543  	}
   544  
   545  	p := packetPool.Get().(*packet)
   546  	p.buf = append(p.buf[:0], b...)
   547  	p.from = from
   548  
   549  	if q.head == nil {
   550  		q.head = p
   551  	} else {
   552  		q.tail.next = p
   553  	}
   554  	q.tail = p
   555  	q.nBytes += len(p.buf)
   556  
   557  	return len(b), nil
   558  }
   559  
   560  func (pq *packetQueue) recvfrom(dt *deadlineTimer, b []byte, wholePacket bool, checkFrom func(sockaddr) error) (n int, from sockaddr, err error) {
   561  	var q packetQueueState
   562  	var empty chan packetQueueState
   563  	if len(b) == 0 {
   564  		// For consistency with the implementation on Unix platforms,
   565  		// allow a zero-length Read to proceed if the queue is empty.
   566  		// (Without this, TestZeroByteRead deadlocks.)
   567  		empty = pq.empty
   568  	}
   569  
   570  	select {
   571  	case <-dt.expired:
   572  		return 0, nil, os.ErrDeadlineExceeded
   573  	case q = <-empty:
   574  	case q = <-pq.ready:
   575  	case q = <-pq.full:
   576  	}
   577  	defer func() { pq.put(q) }()
   578  
   579  	if q.readClosed {
   580  		return 0, nil, ErrClosed
   581  	}
   582  
   583  	p := q.head
   584  	if p == nil {
   585  		switch {
   586  		case q.writeClosed:
   587  			if q.noLinger {
   588  				return 0, nil, os.NewSyscallError("recvfrom", syscall.ECONNRESET)
   589  			}
   590  			return 0, nil, io.EOF
   591  		case len(b) == 0:
   592  			return 0, nil, nil
   593  		default:
   594  			// This should be impossible: pq.full should only contain a non-empty list,
   595  			// pq.ready should either contain a non-empty list or indicate that the
   596  			// connection is closed, and we should only receive from pq.empty if
   597  			// len(b) == 0.
   598  			panic("net: nil packet list from non-closed packetQueue")
   599  		}
   600  	}
   601  
   602  	select {
   603  	case <-dt.expired:
   604  		return 0, nil, os.ErrDeadlineExceeded
   605  	default:
   606  	}
   607  
   608  	if checkFrom != nil {
   609  		if err := checkFrom(p.from); err != nil {
   610  			return 0, nil, err
   611  		}
   612  	}
   613  
   614  	n = copy(b, p.buf[p.bufOffset:])
   615  	from = p.from
   616  	if wholePacket || p.bufOffset+n == len(p.buf) {
   617  		q.head = p.next
   618  		q.nBytes -= len(p.buf)
   619  		p.clear()
   620  		packetPool.Put(p)
   621  	} else {
   622  		p.bufOffset += n
   623  	}
   624  
   625  	return n, from, nil
   626  }
   627  
   628  // setReadBuffer sets a soft limit on the number of bytes available to read
   629  // from the pipe.
   630  func (pq *packetQueue) setReadBuffer(bytes int) error {
   631  	if bytes <= 0 {
   632  		return os.NewSyscallError("setReadBuffer", syscall.EINVAL)
   633  	}
   634  	q := pq.get() // Use the queue as a lock.
   635  	q.readBufferBytes = bytes
   636  	pq.put(q)
   637  	return nil
   638  }
   639  
   640  type deadlineTimer struct {
   641  	timer   chan *time.Timer
   642  	expired chan struct{}
   643  }
   644  
   645  func newDeadlineTimer(deadline time.Time) *deadlineTimer {
   646  	dt := &deadlineTimer{
   647  		timer:   make(chan *time.Timer, 1),
   648  		expired: make(chan struct{}),
   649  	}
   650  	dt.timer <- nil
   651  	dt.Reset(deadline)
   652  	return dt
   653  }
   654  
   655  // Reset attempts to reset the timer.
   656  // If the timer has already expired, Reset returns false.
   657  func (dt *deadlineTimer) Reset(deadline time.Time) bool {
   658  	timer := <-dt.timer
   659  	defer func() { dt.timer <- timer }()
   660  
   661  	if deadline.Equal(noDeadline) {
   662  		if timer != nil && timer.Stop() {
   663  			timer = nil
   664  		}
   665  		return timer == nil
   666  	}
   667  
   668  	d := time.Until(deadline)
   669  	if d < 0 {
   670  		// Ensure that a deadline in the past takes effect immediately.
   671  		defer func() { <-dt.expired }()
   672  	}
   673  
   674  	if timer == nil {
   675  		timer = time.AfterFunc(d, func() { close(dt.expired) })
   676  		return true
   677  	}
   678  	if !timer.Stop() {
   679  		return false
   680  	}
   681  	timer.Reset(d)
   682  	return true
   683  }
   684  
   685  func sysSocket(family, sotype, proto int) (int, error) {
   686  	return 0, os.NewSyscallError("sysSocket", syscall.ENOSYS)
   687  }
   688  
   689  func fakeListen(fd *netFD, laddr sockaddr) (err error) {
   690  	wrapErr := func(err error) error {
   691  		if errno, ok := err.(syscall.Errno); ok {
   692  			err = os.NewSyscallError("listen", errno)
   693  		}
   694  		if errors.Is(err, syscall.EADDRINUSE) {
   695  			return err
   696  		}
   697  		if laddr != nil {
   698  			if _, ok := err.(*AddrError); !ok {
   699  				err = &AddrError{
   700  					Err:  err.Error(),
   701  					Addr: laddr.String(),
   702  				}
   703  			}
   704  		}
   705  		return err
   706  	}
   707  
   708  	ffd := newFakeNetFD(fd)
   709  	defer func() {
   710  		if fd.fakeNetFD != ffd {
   711  			// Failed to register listener; clean up.
   712  			ffd.Close()
   713  		}
   714  	}()
   715  
   716  	if err := ffd.assignFakeAddr(matchIPFamily(fd.family, laddr)); err != nil {
   717  		return wrapErr(err)
   718  	}
   719  
   720  	ffd.fakeAddr = fakeAddr(fd.laddr.(sockaddr))
   721  	switch fd.sotype {
   722  	case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
   723  		ffd.incoming = make(chan []*netFD, 1)
   724  		ffd.incomingFull = make(chan []*netFD, 1)
   725  		ffd.incomingEmpty = make(chan bool, 1)
   726  		ffd.incomingEmpty <- true
   727  	case syscall.SOCK_DGRAM:
   728  		ffd.queue = newPacketQueue(defaultBuffer)
   729  	default:
   730  		return wrapErr(syscall.EINVAL)
   731  	}
   732  
   733  	fd.fakeNetFD = ffd
   734  	if _, dup := sockets.LoadOrStore(ffd.fakeAddr, fd); dup {
   735  		fd.fakeNetFD = nil
   736  		return wrapErr(syscall.EADDRINUSE)
   737  	}
   738  
   739  	return nil
   740  }
   741  
   742  func fakeConnect(ctx context.Context, fd *netFD, laddr, raddr sockaddr) error {
   743  	wrapErr := func(err error) error {
   744  		if errno, ok := err.(syscall.Errno); ok {
   745  			err = os.NewSyscallError("connect", errno)
   746  		}
   747  		if errors.Is(err, syscall.EADDRINUSE) {
   748  			return err
   749  		}
   750  		if terr, ok := err.(interface{ Timeout() bool }); !ok || !terr.Timeout() {
   751  			// For consistency with the net implementation on other platforms,
   752  			// if we don't need to preserve the Timeout-ness of err we should
   753  			// wrap it in an AddrError. (Unfortunately we can't wrap errors
   754  			// that convey structured information, because AddrError reduces
   755  			// the wrapped Err to a flat string.)
   756  			if _, ok := err.(*AddrError); !ok {
   757  				err = &AddrError{
   758  					Err:  err.Error(),
   759  					Addr: raddr.String(),
   760  				}
   761  			}
   762  		}
   763  		return err
   764  	}
   765  
   766  	if fd.isConnected {
   767  		return wrapErr(syscall.EISCONN)
   768  	}
   769  	if ctx.Err() != nil {
   770  		return wrapErr(syscall.ETIMEDOUT)
   771  	}
   772  
   773  	fd.raddr = matchIPFamily(fd.family, raddr)
   774  	if err := validateResolvedAddr(fd.net, fd.family, fd.raddr.(sockaddr)); err != nil {
   775  		return wrapErr(err)
   776  	}
   777  
   778  	if err := fd.fakeNetFD.assignFakeAddr(laddr); err != nil {
   779  		return wrapErr(err)
   780  	}
   781  	fd.fakeNetFD.queue = newPacketQueue(defaultBuffer)
   782  
   783  	switch fd.sotype {
   784  	case syscall.SOCK_DGRAM:
   785  		if ua, ok := fd.laddr.(*UnixAddr); !ok || ua.Name != "" {
   786  			fd.fakeNetFD.fakeAddr = fakeAddr(fd.laddr.(sockaddr))
   787  			if _, dup := sockets.LoadOrStore(fd.fakeNetFD.fakeAddr, fd); dup {
   788  				return wrapErr(syscall.EADDRINUSE)
   789  			}
   790  		}
   791  		fd.isConnected = true
   792  		return nil
   793  
   794  	case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
   795  	default:
   796  		return wrapErr(syscall.EINVAL)
   797  	}
   798  
   799  	fa := fakeAddr(raddr)
   800  	lni, ok := sockets.Load(fa)
   801  	if !ok {
   802  		return wrapErr(syscall.ECONNREFUSED)
   803  	}
   804  	ln := lni.(*netFD)
   805  	if ln.sotype != fd.sotype {
   806  		return wrapErr(syscall.EPROTOTYPE)
   807  	}
   808  	if ln.incoming == nil {
   809  		return wrapErr(syscall.ECONNREFUSED)
   810  	}
   811  
   812  	peer := &netFD{
   813  		family:      ln.family,
   814  		sotype:      ln.sotype,
   815  		net:         ln.net,
   816  		laddr:       ln.laddr,
   817  		raddr:       fd.laddr,
   818  		isConnected: true,
   819  	}
   820  	peer.fakeNetFD = newFakeNetFD(fd)
   821  	peer.fakeNetFD.queue = newPacketQueue(defaultBuffer)
   822  	defer func() {
   823  		if fd.peer != peer {
   824  			// Failed to connect; clean up.
   825  			peer.Close()
   826  		}
   827  	}()
   828  
   829  	var incoming []*netFD
   830  	select {
   831  	case <-ctx.Done():
   832  		return wrapErr(syscall.ETIMEDOUT)
   833  	case ok = <-ln.incomingEmpty:
   834  	case incoming, ok = <-ln.incoming:
   835  	}
   836  	if !ok {
   837  		return wrapErr(syscall.ECONNREFUSED)
   838  	}
   839  
   840  	fd.isConnected = true
   841  	fd.peer = peer
   842  	peer.peer = fd
   843  
   844  	incoming = append(incoming, peer)
   845  	if len(incoming) >= listenerBacklog() {
   846  		ln.incomingFull <- incoming
   847  	} else {
   848  		ln.incoming <- incoming
   849  	}
   850  	return nil
   851  }
   852  
   853  func (ffd *fakeNetFD) assignFakeAddr(addr sockaddr) error {
   854  	validate := func(sa sockaddr) error {
   855  		if err := validateResolvedAddr(ffd.fd.net, ffd.fd.family, sa); err != nil {
   856  			return err
   857  		}
   858  		ffd.fd.laddr = sa
   859  		return nil
   860  	}
   861  
   862  	assignIP := func(addr sockaddr) error {
   863  		var (
   864  			ip   IP
   865  			port int
   866  			zone string
   867  		)
   868  		switch addr := addr.(type) {
   869  		case *TCPAddr:
   870  			if addr != nil {
   871  				ip = addr.IP
   872  				port = addr.Port
   873  				zone = addr.Zone
   874  			}
   875  		case *UDPAddr:
   876  			if addr != nil {
   877  				ip = addr.IP
   878  				port = addr.Port
   879  				zone = addr.Zone
   880  			}
   881  		default:
   882  			return validate(addr)
   883  		}
   884  
   885  		if ip == nil {
   886  			ip = IPv4(127, 0, 0, 1)
   887  		}
   888  		switch ffd.fd.family {
   889  		case syscall.AF_INET:
   890  			if ip4 := ip.To4(); ip4 != nil {
   891  				ip = ip4
   892  			}
   893  		case syscall.AF_INET6:
   894  			if ip16 := ip.To16(); ip16 != nil {
   895  				ip = ip16
   896  			}
   897  		}
   898  		if ip == nil {
   899  			return syscall.EINVAL
   900  		}
   901  
   902  		if port == 0 {
   903  			var prevPort int32
   904  			portWrapped := false
   905  			nextPort := func() (int, bool) {
   906  				for {
   907  					port := nextPortCounter.Add(1)
   908  					if port <= 0 || port >= 1<<16 {
   909  						// nextPortCounter ran off the end of the port space.
   910  						// Bump it back into range.
   911  						for {
   912  							if nextPortCounter.CompareAndSwap(port, 0) {
   913  								break
   914  							}
   915  							if port = nextPortCounter.Load(); port >= 0 && port < 1<<16 {
   916  								break
   917  							}
   918  						}
   919  						if portWrapped {
   920  							// This is the second wraparound, so we've scanned the whole port space
   921  							// at least once already and it's time to give up.
   922  							return 0, false
   923  						}
   924  						portWrapped = true
   925  						prevPort = 0
   926  						continue
   927  					}
   928  
   929  					if port <= prevPort {
   930  						// nextPortCounter has wrapped around since the last time we read it.
   931  						if portWrapped {
   932  							// This is the second wraparound, so we've scanned the whole port space
   933  							// at least once already and it's time to give up.
   934  							return 0, false
   935  						} else {
   936  							portWrapped = true
   937  						}
   938  					}
   939  
   940  					prevPort = port
   941  					return int(port), true
   942  				}
   943  			}
   944  
   945  			for {
   946  				var ok bool
   947  				port, ok = nextPort()
   948  				if !ok {
   949  					ffd.assignedPort = 0
   950  					return syscall.EADDRINUSE
   951  				}
   952  
   953  				ffd.assignedPort = int(port)
   954  				if _, dup := fakePorts.LoadOrStore(ffd.assignedPort, ffd.fd); !dup {
   955  					break
   956  				}
   957  			}
   958  		}
   959  
   960  		switch addr.(type) {
   961  		case *TCPAddr:
   962  			return validate(&TCPAddr{IP: ip, Port: port, Zone: zone})
   963  		case *UDPAddr:
   964  			return validate(&UDPAddr{IP: ip, Port: port, Zone: zone})
   965  		default:
   966  			panic("unreachable")
   967  		}
   968  	}
   969  
   970  	switch ffd.fd.net {
   971  	case "tcp", "tcp4", "tcp6":
   972  		if addr == nil {
   973  			return assignIP(new(TCPAddr))
   974  		}
   975  		return assignIP(addr)
   976  
   977  	case "udp", "udp4", "udp6":
   978  		if addr == nil {
   979  			return assignIP(new(UDPAddr))
   980  		}
   981  		return assignIP(addr)
   982  
   983  	case "unix", "unixgram", "unixpacket":
   984  		uaddr, ok := addr.(*UnixAddr)
   985  		if !ok && addr != nil {
   986  			return &AddrError{
   987  				Err:  "non-Unix address for " + ffd.fd.net + " network",
   988  				Addr: addr.String(),
   989  			}
   990  		}
   991  		if uaddr == nil {
   992  			return validate(&UnixAddr{Net: ffd.fd.net})
   993  		}
   994  		return validate(&UnixAddr{Net: ffd.fd.net, Name: uaddr.Name})
   995  
   996  	default:
   997  		return &AddrError{
   998  			Err:  syscall.EAFNOSUPPORT.Error(),
   999  			Addr: addr.String(),
  1000  		}
  1001  	}
  1002  }
  1003  
  1004  func (ffd *fakeNetFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
  1005  	if ffd.queue == nil {
  1006  		return 0, nil, os.NewSyscallError("readFrom", syscall.EINVAL)
  1007  	}
  1008  
  1009  	n, from, err := ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, nil)
  1010  
  1011  	if from != nil {
  1012  		// Convert the net.sockaddr to a syscall.Sockaddr type.
  1013  		var saErr error
  1014  		sa, saErr = from.sockaddr(ffd.fd.family)
  1015  		if err == nil {
  1016  			err = saErr
  1017  		}
  1018  	}
  1019  
  1020  	return n, sa, err
  1021  }
  1022  
  1023  func (ffd *fakeNetFD) readFromInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
  1024  	n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, func(from sockaddr) error {
  1025  		fromSA, err := from.sockaddr(syscall.AF_INET)
  1026  		if err != nil {
  1027  			return err
  1028  		}
  1029  		if fromSA == nil {
  1030  			return os.NewSyscallError("readFromInet4", syscall.EINVAL)
  1031  		}
  1032  		*sa = *(fromSA.(*syscall.SockaddrInet4))
  1033  		return nil
  1034  	})
  1035  	return n, err
  1036  }
  1037  
  1038  func (ffd *fakeNetFD) readFromInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
  1039  	n, _, err = ffd.queue.recvfrom(ffd.readDeadline.Load(), p, true, func(from sockaddr) error {
  1040  		fromSA, err := from.sockaddr(syscall.AF_INET6)
  1041  		if err != nil {
  1042  			return err
  1043  		}
  1044  		if fromSA == nil {
  1045  			return os.NewSyscallError("readFromInet6", syscall.EINVAL)
  1046  		}
  1047  		*sa = *(fromSA.(*syscall.SockaddrInet6))
  1048  		return nil
  1049  	})
  1050  	return n, err
  1051  }
  1052  
  1053  func (ffd *fakeNetFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
  1054  	if flags != 0 {
  1055  		return 0, 0, 0, nil, os.NewSyscallError("readMsg", syscall.ENOTSUP)
  1056  	}
  1057  	n, sa, err = ffd.readFrom(p)
  1058  	return n, 0, 0, sa, err
  1059  }
  1060  
  1061  func (ffd *fakeNetFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
  1062  	if flags != 0 {
  1063  		return 0, 0, 0, os.NewSyscallError("readMsgInet4", syscall.ENOTSUP)
  1064  	}
  1065  	n, err = ffd.readFromInet4(p, sa)
  1066  	return n, 0, 0, err
  1067  }
  1068  
  1069  func (ffd *fakeNetFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
  1070  	if flags != 0 {
  1071  		return 0, 0, 0, os.NewSyscallError("readMsgInet6", syscall.ENOTSUP)
  1072  	}
  1073  	n, err = ffd.readFromInet6(p, sa)
  1074  	return n, 0, 0, err
  1075  }
  1076  
  1077  func (ffd *fakeNetFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
  1078  	if len(oob) > 0 {
  1079  		return 0, 0, os.NewSyscallError("writeMsg", syscall.ENOTSUP)
  1080  	}
  1081  	n, err = ffd.writeTo(p, sa)
  1082  	return n, 0, err
  1083  }
  1084  
  1085  func (ffd *fakeNetFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
  1086  	return ffd.writeMsg(p, oob, sa)
  1087  }
  1088  
  1089  func (ffd *fakeNetFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
  1090  	return ffd.writeMsg(p, oob, sa)
  1091  }
  1092  
  1093  func (ffd *fakeNetFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
  1094  	raddr := ffd.fd.raddr
  1095  	if sa != nil {
  1096  		if ffd.fd.isConnected {
  1097  			return 0, os.NewSyscallError("writeTo", syscall.EISCONN)
  1098  		}
  1099  		raddr = ffd.fd.addrFunc()(sa)
  1100  	}
  1101  	if raddr == nil {
  1102  		return 0, os.NewSyscallError("writeTo", syscall.EINVAL)
  1103  	}
  1104  
  1105  	peeri, _ := sockets.Load(fakeAddr(raddr.(sockaddr)))
  1106  	if peeri == nil {
  1107  		if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" {
  1108  			return len(p), nil
  1109  		}
  1110  		return 0, os.NewSyscallError("writeTo", syscall.ECONNRESET)
  1111  	}
  1112  	peer := peeri.(*netFD)
  1113  	if peer.queue == nil {
  1114  		if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" {
  1115  			return len(p), nil
  1116  		}
  1117  		return 0, os.NewSyscallError("writeTo", syscall.ECONNRESET)
  1118  	}
  1119  
  1120  	block := true
  1121  	if len(ffd.fd.net) >= 3 && ffd.fd.net[:3] == "udp" {
  1122  		block = false
  1123  	}
  1124  	return peer.queue.send(ffd.writeDeadline.Load(), p, ffd.fd.laddr.(sockaddr), block)
  1125  }
  1126  
  1127  func (ffd *fakeNetFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
  1128  	return ffd.writeTo(p, sa)
  1129  }
  1130  
  1131  func (ffd *fakeNetFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
  1132  	return ffd.writeTo(p, sa)
  1133  }
  1134  
  1135  func (ffd *fakeNetFD) dup() (f *os.File, err error) {
  1136  	return nil, os.NewSyscallError("dup", syscall.ENOSYS)
  1137  }
  1138  
  1139  func (ffd *fakeNetFD) setReadBuffer(bytes int) error {
  1140  	if ffd.queue == nil {
  1141  		return os.NewSyscallError("setReadBuffer", syscall.EINVAL)
  1142  	}
  1143  	ffd.queue.setReadBuffer(bytes)
  1144  	return nil
  1145  }
  1146  
  1147  func (ffd *fakeNetFD) setWriteBuffer(bytes int) error {
  1148  	return os.NewSyscallError("setWriteBuffer", syscall.ENOTSUP)
  1149  }
  1150  
  1151  func (ffd *fakeNetFD) setLinger(sec int) error {
  1152  	if sec < 0 || ffd.peer == nil {
  1153  		return os.NewSyscallError("setLinger", syscall.EINVAL)
  1154  	}
  1155  	ffd.peer.queue.setLinger(sec > 0)
  1156  	return nil
  1157  }
  1158  

View as plain text