Source file src/net/mockserver_test.go

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package net
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"internal/testenv"
    12  	"log"
    13  	"os"
    14  	"path/filepath"
    15  	"runtime"
    16  	"strconv"
    17  	"sync"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  // testUnixAddr uses os.MkdirTemp to get a name that is unique.
    23  func testUnixAddr(t testing.TB) string {
    24  	// Pass an empty pattern to get a directory name that is as short as possible.
    25  	// If we end up with a name longer than the sun_path field in the sockaddr_un
    26  	// struct, we won't be able to make the syscall to open the socket.
    27  	d, err := os.MkdirTemp("", "")
    28  	if err != nil {
    29  		t.Fatal(err)
    30  	}
    31  	t.Cleanup(func() {
    32  		if err := os.RemoveAll(d); err != nil {
    33  			t.Error(err)
    34  		}
    35  	})
    36  	return filepath.Join(d, "sock")
    37  }
    38  
    39  func newLocalListener(t testing.TB, network string, lcOpt ...*ListenConfig) Listener {
    40  	var lc *ListenConfig
    41  	switch len(lcOpt) {
    42  	case 0:
    43  		lc = new(ListenConfig)
    44  	case 1:
    45  		lc = lcOpt[0]
    46  	default:
    47  		t.Helper()
    48  		t.Fatal("too many ListenConfigs passed to newLocalListener: want 0 or 1")
    49  	}
    50  
    51  	listen := func(net, addr string) Listener {
    52  		ln, err := lc.Listen(context.Background(), net, addr)
    53  		if err != nil {
    54  			t.Helper()
    55  			t.Fatal(err)
    56  		}
    57  		return ln
    58  	}
    59  
    60  	switch network {
    61  	case "tcp":
    62  		if supportsIPv4() {
    63  			return listen("tcp4", "127.0.0.1:0")
    64  		}
    65  		if supportsIPv6() {
    66  			return listen("tcp6", "[::1]:0")
    67  		}
    68  	case "tcp4":
    69  		if supportsIPv4() {
    70  			return listen("tcp4", "127.0.0.1:0")
    71  		}
    72  	case "tcp6":
    73  		if supportsIPv6() {
    74  			return listen("tcp6", "[::1]:0")
    75  		}
    76  	case "unix", "unixpacket":
    77  		return listen(network, testUnixAddr(t))
    78  	}
    79  
    80  	t.Helper()
    81  	t.Fatalf("%s is not supported", network)
    82  	return nil
    83  }
    84  
    85  func newDualStackListener() (lns []*TCPListener, err error) {
    86  	var args = []struct {
    87  		network string
    88  		TCPAddr
    89  	}{
    90  		{"tcp4", TCPAddr{IP: IPv4(127, 0, 0, 1)}},
    91  		{"tcp6", TCPAddr{IP: IPv6loopback}},
    92  	}
    93  	for i := 0; i < 64; i++ {
    94  		var port int
    95  		var lns []*TCPListener
    96  		for _, arg := range args {
    97  			arg.TCPAddr.Port = port
    98  			ln, err := ListenTCP(arg.network, &arg.TCPAddr)
    99  			if err != nil {
   100  				continue
   101  			}
   102  			port = ln.Addr().(*TCPAddr).Port
   103  			lns = append(lns, ln)
   104  		}
   105  		if len(lns) != len(args) {
   106  			for _, ln := range lns {
   107  				ln.Close()
   108  			}
   109  			continue
   110  		}
   111  		return lns, nil
   112  	}
   113  	return nil, errors.New("no dualstack port available")
   114  }
   115  
   116  type localServer struct {
   117  	lnmu sync.RWMutex
   118  	Listener
   119  	done chan bool // signal that indicates server stopped
   120  	cl   []Conn    // accepted connection list
   121  }
   122  
   123  func (ls *localServer) buildup(handler func(*localServer, Listener)) error {
   124  	go func() {
   125  		handler(ls, ls.Listener)
   126  		close(ls.done)
   127  	}()
   128  	return nil
   129  }
   130  
   131  func (ls *localServer) teardown() error {
   132  	ls.lnmu.Lock()
   133  	defer ls.lnmu.Unlock()
   134  	if ls.Listener != nil {
   135  		network := ls.Listener.Addr().Network()
   136  		address := ls.Listener.Addr().String()
   137  		ls.Listener.Close()
   138  		for _, c := range ls.cl {
   139  			if err := c.Close(); err != nil {
   140  				return err
   141  			}
   142  		}
   143  		<-ls.done
   144  		ls.Listener = nil
   145  		switch network {
   146  		case "unix", "unixpacket":
   147  			os.Remove(address)
   148  		}
   149  	}
   150  	return nil
   151  }
   152  
   153  func newLocalServer(t testing.TB, network string) *localServer {
   154  	t.Helper()
   155  	ln := newLocalListener(t, network)
   156  	return &localServer{Listener: ln, done: make(chan bool)}
   157  }
   158  
   159  type streamListener struct {
   160  	network, address string
   161  	Listener
   162  	done chan bool // signal that indicates server stopped
   163  }
   164  
   165  func (sl *streamListener) newLocalServer() *localServer {
   166  	return &localServer{Listener: sl.Listener, done: make(chan bool)}
   167  }
   168  
   169  type dualStackServer struct {
   170  	lnmu sync.RWMutex
   171  	lns  []streamListener
   172  	port string
   173  
   174  	cmu sync.RWMutex
   175  	cs  []Conn // established connections at the passive open side
   176  }
   177  
   178  func (dss *dualStackServer) buildup(handler func(*dualStackServer, Listener)) error {
   179  	for i := range dss.lns {
   180  		go func(i int) {
   181  			handler(dss, dss.lns[i].Listener)
   182  			close(dss.lns[i].done)
   183  		}(i)
   184  	}
   185  	return nil
   186  }
   187  
   188  func (dss *dualStackServer) teardownNetwork(network string) error {
   189  	dss.lnmu.Lock()
   190  	for i := range dss.lns {
   191  		if network == dss.lns[i].network && dss.lns[i].Listener != nil {
   192  			dss.lns[i].Listener.Close()
   193  			<-dss.lns[i].done
   194  			dss.lns[i].Listener = nil
   195  		}
   196  	}
   197  	dss.lnmu.Unlock()
   198  	return nil
   199  }
   200  
   201  func (dss *dualStackServer) teardown() error {
   202  	dss.lnmu.Lock()
   203  	for i := range dss.lns {
   204  		if dss.lns[i].Listener != nil {
   205  			dss.lns[i].Listener.Close()
   206  			<-dss.lns[i].done
   207  		}
   208  	}
   209  	dss.lns = dss.lns[:0]
   210  	dss.lnmu.Unlock()
   211  	dss.cmu.Lock()
   212  	for _, c := range dss.cs {
   213  		c.Close()
   214  	}
   215  	dss.cs = dss.cs[:0]
   216  	dss.cmu.Unlock()
   217  	return nil
   218  }
   219  
   220  func newDualStackServer() (*dualStackServer, error) {
   221  	lns, err := newDualStackListener()
   222  	if err != nil {
   223  		return nil, err
   224  	}
   225  	_, port, err := SplitHostPort(lns[0].Addr().String())
   226  	if err != nil {
   227  		lns[0].Close()
   228  		lns[1].Close()
   229  		return nil, err
   230  	}
   231  	return &dualStackServer{
   232  		lns: []streamListener{
   233  			{network: "tcp4", address: lns[0].Addr().String(), Listener: lns[0], done: make(chan bool)},
   234  			{network: "tcp6", address: lns[1].Addr().String(), Listener: lns[1], done: make(chan bool)},
   235  		},
   236  		port: port,
   237  	}, nil
   238  }
   239  
   240  func (ls *localServer) transponder(ln Listener, ch chan<- error) {
   241  	defer close(ch)
   242  
   243  	switch ln := ln.(type) {
   244  	case *TCPListener:
   245  		ln.SetDeadline(time.Now().Add(someTimeout))
   246  	case *UnixListener:
   247  		ln.SetDeadline(time.Now().Add(someTimeout))
   248  	}
   249  	c, err := ln.Accept()
   250  	if err != nil {
   251  		if perr := parseAcceptError(err); perr != nil {
   252  			ch <- perr
   253  		}
   254  		ch <- err
   255  		return
   256  	}
   257  	ls.cl = append(ls.cl, c)
   258  
   259  	network := ln.Addr().Network()
   260  	if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network {
   261  		ch <- fmt.Errorf("got %v->%v; expected %v->%v", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
   262  		return
   263  	}
   264  	c.SetDeadline(time.Now().Add(someTimeout))
   265  	c.SetReadDeadline(time.Now().Add(someTimeout))
   266  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   267  
   268  	b := make([]byte, 256)
   269  	n, err := c.Read(b)
   270  	if err != nil {
   271  		if perr := parseReadError(err); perr != nil {
   272  			ch <- perr
   273  		}
   274  		ch <- err
   275  		return
   276  	}
   277  	if _, err := c.Write(b[:n]); err != nil {
   278  		if perr := parseWriteError(err); perr != nil {
   279  			ch <- perr
   280  		}
   281  		ch <- err
   282  		return
   283  	}
   284  }
   285  
   286  func transceiver(c Conn, wb []byte, ch chan<- error) {
   287  	defer close(ch)
   288  
   289  	c.SetDeadline(time.Now().Add(someTimeout))
   290  	c.SetReadDeadline(time.Now().Add(someTimeout))
   291  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   292  
   293  	n, err := c.Write(wb)
   294  	if err != nil {
   295  		if perr := parseWriteError(err); perr != nil {
   296  			ch <- perr
   297  		}
   298  		ch <- err
   299  		return
   300  	}
   301  	if n != len(wb) {
   302  		ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
   303  	}
   304  	rb := make([]byte, len(wb))
   305  	n, err = c.Read(rb)
   306  	if err != nil {
   307  		if perr := parseReadError(err); perr != nil {
   308  			ch <- perr
   309  		}
   310  		ch <- err
   311  		return
   312  	}
   313  	if n != len(wb) {
   314  		ch <- fmt.Errorf("read %d; want %d", n, len(wb))
   315  	}
   316  }
   317  
   318  func newLocalPacketListener(t testing.TB, network string, lcOpt ...*ListenConfig) PacketConn {
   319  	var lc *ListenConfig
   320  	switch len(lcOpt) {
   321  	case 0:
   322  		lc = new(ListenConfig)
   323  	case 1:
   324  		lc = lcOpt[0]
   325  	default:
   326  		t.Helper()
   327  		t.Fatal("too many ListenConfigs passed to newLocalListener: want 0 or 1")
   328  	}
   329  
   330  	listenPacket := func(net, addr string) PacketConn {
   331  		c, err := lc.ListenPacket(context.Background(), net, addr)
   332  		if err != nil {
   333  			t.Helper()
   334  			t.Fatal(err)
   335  		}
   336  		return c
   337  	}
   338  
   339  	t.Helper()
   340  	switch network {
   341  	case "udp":
   342  		if supportsIPv4() {
   343  			return listenPacket("udp4", "127.0.0.1:0")
   344  		}
   345  		if supportsIPv6() {
   346  			return listenPacket("udp6", "[::1]:0")
   347  		}
   348  	case "udp4":
   349  		if supportsIPv4() {
   350  			return listenPacket("udp4", "127.0.0.1:0")
   351  		}
   352  	case "udp6":
   353  		if supportsIPv6() {
   354  			return listenPacket("udp6", "[::1]:0")
   355  		}
   356  	case "unixgram":
   357  		return listenPacket(network, testUnixAddr(t))
   358  	}
   359  
   360  	t.Fatalf("%s is not supported", network)
   361  	return nil
   362  }
   363  
   364  func newDualStackPacketListener() (cs []*UDPConn, err error) {
   365  	var args = []struct {
   366  		network string
   367  		UDPAddr
   368  	}{
   369  		{"udp4", UDPAddr{IP: IPv4(127, 0, 0, 1)}},
   370  		{"udp6", UDPAddr{IP: IPv6loopback}},
   371  	}
   372  	for i := 0; i < 64; i++ {
   373  		var port int
   374  		var cs []*UDPConn
   375  		for _, arg := range args {
   376  			arg.UDPAddr.Port = port
   377  			c, err := ListenUDP(arg.network, &arg.UDPAddr)
   378  			if err != nil {
   379  				continue
   380  			}
   381  			port = c.LocalAddr().(*UDPAddr).Port
   382  			cs = append(cs, c)
   383  		}
   384  		if len(cs) != len(args) {
   385  			for _, c := range cs {
   386  				c.Close()
   387  			}
   388  			continue
   389  		}
   390  		return cs, nil
   391  	}
   392  	return nil, errors.New("no dualstack port available")
   393  }
   394  
   395  type localPacketServer struct {
   396  	pcmu sync.RWMutex
   397  	PacketConn
   398  	done chan bool // signal that indicates server stopped
   399  }
   400  
   401  func (ls *localPacketServer) buildup(handler func(*localPacketServer, PacketConn)) error {
   402  	go func() {
   403  		handler(ls, ls.PacketConn)
   404  		close(ls.done)
   405  	}()
   406  	return nil
   407  }
   408  
   409  func (ls *localPacketServer) teardown() error {
   410  	ls.pcmu.Lock()
   411  	if ls.PacketConn != nil {
   412  		network := ls.PacketConn.LocalAddr().Network()
   413  		address := ls.PacketConn.LocalAddr().String()
   414  		ls.PacketConn.Close()
   415  		<-ls.done
   416  		ls.PacketConn = nil
   417  		switch network {
   418  		case "unixgram":
   419  			os.Remove(address)
   420  		}
   421  	}
   422  	ls.pcmu.Unlock()
   423  	return nil
   424  }
   425  
   426  func newLocalPacketServer(t testing.TB, network string) *localPacketServer {
   427  	t.Helper()
   428  	c := newLocalPacketListener(t, network)
   429  	return &localPacketServer{PacketConn: c, done: make(chan bool)}
   430  }
   431  
   432  type packetListener struct {
   433  	PacketConn
   434  }
   435  
   436  func (pl *packetListener) newLocalServer() *localPacketServer {
   437  	return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)}
   438  }
   439  
   440  func packetTransponder(c PacketConn, ch chan<- error) {
   441  	defer close(ch)
   442  
   443  	c.SetDeadline(time.Now().Add(someTimeout))
   444  	c.SetReadDeadline(time.Now().Add(someTimeout))
   445  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   446  
   447  	b := make([]byte, 256)
   448  	n, peer, err := c.ReadFrom(b)
   449  	if err != nil {
   450  		if perr := parseReadError(err); perr != nil {
   451  			ch <- perr
   452  		}
   453  		ch <- err
   454  		return
   455  	}
   456  	if peer == nil { // for connected-mode sockets
   457  		switch c.LocalAddr().Network() {
   458  		case "udp":
   459  			peer, err = ResolveUDPAddr("udp", string(b[:n]))
   460  		case "unixgram":
   461  			peer, err = ResolveUnixAddr("unixgram", string(b[:n]))
   462  		}
   463  		if err != nil {
   464  			ch <- err
   465  			return
   466  		}
   467  	}
   468  	if _, err := c.WriteTo(b[:n], peer); err != nil {
   469  		if perr := parseWriteError(err); perr != nil {
   470  			ch <- perr
   471  		}
   472  		ch <- err
   473  		return
   474  	}
   475  }
   476  
   477  func packetTransceiver(c PacketConn, wb []byte, dst Addr, ch chan<- error) {
   478  	defer close(ch)
   479  
   480  	c.SetDeadline(time.Now().Add(someTimeout))
   481  	c.SetReadDeadline(time.Now().Add(someTimeout))
   482  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   483  
   484  	n, err := c.WriteTo(wb, dst)
   485  	if err != nil {
   486  		if perr := parseWriteError(err); perr != nil {
   487  			ch <- perr
   488  		}
   489  		ch <- err
   490  		return
   491  	}
   492  	if n != len(wb) {
   493  		ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
   494  	}
   495  	rb := make([]byte, len(wb))
   496  	n, _, err = c.ReadFrom(rb)
   497  	if err != nil {
   498  		if perr := parseReadError(err); perr != nil {
   499  			ch <- perr
   500  		}
   501  		ch <- err
   502  		return
   503  	}
   504  	if n != len(wb) {
   505  		ch <- fmt.Errorf("read %d; want %d", n, len(wb))
   506  	}
   507  }
   508  
   509  func spawnTestSocketPair(t testing.TB, net string) (client, server Conn) {
   510  	t.Helper()
   511  
   512  	ln := newLocalListener(t, net)
   513  	defer ln.Close()
   514  	var cerr, serr error
   515  	acceptDone := make(chan struct{})
   516  	go func() {
   517  		server, serr = ln.Accept()
   518  		acceptDone <- struct{}{}
   519  	}()
   520  	client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
   521  	<-acceptDone
   522  	if cerr != nil {
   523  		if server != nil {
   524  			server.Close()
   525  		}
   526  		t.Fatal(cerr)
   527  	}
   528  	if serr != nil {
   529  		if client != nil {
   530  			client.Close()
   531  		}
   532  		t.Fatal(serr)
   533  	}
   534  	return client, server
   535  }
   536  
   537  func startTestSocketPeer(t testing.TB, conn Conn, op string, chunkSize, totalSize int) (func(t testing.TB), error) {
   538  	t.Helper()
   539  
   540  	if runtime.GOOS == "windows" {
   541  		// TODO(panjf2000): Windows has not yet implemented FileConn,
   542  		//		remove this when it's implemented in https://go.dev/issues/9503.
   543  		t.Fatalf("startTestSocketPeer is not supported on %s", runtime.GOOS)
   544  	}
   545  
   546  	f, err := conn.(interface{ File() (*os.File, error) }).File()
   547  	if err != nil {
   548  		return nil, err
   549  	}
   550  
   551  	cmd := testenv.Command(t, os.Args[0])
   552  	cmd.Env = []string{
   553  		"GO_NET_TEST_TRANSFER=1",
   554  		"GO_NET_TEST_TRANSFER_OP=" + op,
   555  		"GO_NET_TEST_TRANSFER_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
   556  		"GO_NET_TEST_TRANSFER_TOTAL_SIZE=" + strconv.Itoa(totalSize),
   557  		"TMPDIR=" + os.Getenv("TMPDIR"),
   558  	}
   559  	cmd.ExtraFiles = append(cmd.ExtraFiles, f)
   560  	cmd.Stdout = os.Stdout
   561  	cmd.Stderr = os.Stderr
   562  
   563  	if err := cmd.Start(); err != nil {
   564  		return nil, err
   565  	}
   566  
   567  	cmdCh := make(chan error, 1)
   568  	go func() {
   569  		err := cmd.Wait()
   570  		conn.Close()
   571  		f.Close()
   572  		cmdCh <- err
   573  	}()
   574  
   575  	return func(tb testing.TB) {
   576  		err := <-cmdCh
   577  		if err != nil {
   578  			tb.Errorf("process exited with error: %v", err)
   579  		}
   580  	}, nil
   581  }
   582  
   583  func init() {
   584  	if os.Getenv("GO_NET_TEST_TRANSFER") == "" {
   585  		return
   586  	}
   587  	defer os.Exit(0)
   588  
   589  	f := os.NewFile(uintptr(3), "splice-test-conn")
   590  	defer f.Close()
   591  
   592  	conn, err := FileConn(f)
   593  	if err != nil {
   594  		log.Fatal(err)
   595  	}
   596  
   597  	var chunkSize int
   598  	if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_TRANSFER_CHUNK_SIZE")); err != nil {
   599  		log.Fatal(err)
   600  	}
   601  	buf := make([]byte, chunkSize)
   602  
   603  	var totalSize int
   604  	if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_TRANSFER_TOTAL_SIZE")); err != nil {
   605  		log.Fatal(err)
   606  	}
   607  
   608  	var fn func([]byte) (int, error)
   609  	switch op := os.Getenv("GO_NET_TEST_TRANSFER_OP"); op {
   610  	case "r":
   611  		fn = conn.Read
   612  	case "w":
   613  		defer conn.Close()
   614  
   615  		fn = conn.Write
   616  	default:
   617  		log.Fatalf("unknown op %q", op)
   618  	}
   619  
   620  	var n int
   621  	for count := 0; count < totalSize; count += n {
   622  		if count+chunkSize > totalSize {
   623  			buf = buf[:totalSize-count]
   624  		}
   625  
   626  		var err error
   627  		if n, err = fn(buf); err != nil {
   628  			return
   629  		}
   630  	}
   631  }
   632  

View as plain text