Source file src/os/readfrom_linux_test.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package os_test
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"internal/poll"
    11  	"internal/testpty"
    12  	"io"
    13  	"math/rand"
    14  	"net"
    15  	. "os"
    16  	"path/filepath"
    17  	"strconv"
    18  	"sync"
    19  	"syscall"
    20  	"testing"
    21  	"time"
    22  )
    23  
    24  func TestSpliceFile(t *testing.T) {
    25  	sizes := []int{
    26  		1,
    27  		42,
    28  		1025,
    29  		syscall.Getpagesize() + 1,
    30  		32769,
    31  	}
    32  	t.Run("Basic-TCP", func(t *testing.T) {
    33  		for _, size := range sizes {
    34  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    35  				testSpliceFile(t, "tcp", int64(size), -1)
    36  			})
    37  		}
    38  	})
    39  	t.Run("Basic-Unix", func(t *testing.T) {
    40  		for _, size := range sizes {
    41  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    42  				testSpliceFile(t, "unix", int64(size), -1)
    43  			})
    44  		}
    45  	})
    46  	t.Run("TCP-To-TTY", func(t *testing.T) {
    47  		testSpliceToTTY(t, "tcp", 32768)
    48  	})
    49  	t.Run("Unix-To-TTY", func(t *testing.T) {
    50  		testSpliceToTTY(t, "unix", 32768)
    51  	})
    52  	t.Run("Limited", func(t *testing.T) {
    53  		t.Run("OneLess-TCP", func(t *testing.T) {
    54  			for _, size := range sizes {
    55  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    56  					testSpliceFile(t, "tcp", int64(size), int64(size)-1)
    57  				})
    58  			}
    59  		})
    60  		t.Run("OneLess-Unix", func(t *testing.T) {
    61  			for _, size := range sizes {
    62  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    63  					testSpliceFile(t, "unix", int64(size), int64(size)-1)
    64  				})
    65  			}
    66  		})
    67  		t.Run("Half-TCP", func(t *testing.T) {
    68  			for _, size := range sizes {
    69  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    70  					testSpliceFile(t, "tcp", int64(size), int64(size)/2)
    71  				})
    72  			}
    73  		})
    74  		t.Run("Half-Unix", func(t *testing.T) {
    75  			for _, size := range sizes {
    76  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    77  					testSpliceFile(t, "unix", int64(size), int64(size)/2)
    78  				})
    79  			}
    80  		})
    81  		t.Run("More-TCP", func(t *testing.T) {
    82  			for _, size := range sizes {
    83  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    84  					testSpliceFile(t, "tcp", int64(size), int64(size)+1)
    85  				})
    86  			}
    87  		})
    88  		t.Run("More-Unix", func(t *testing.T) {
    89  			for _, size := range sizes {
    90  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    91  					testSpliceFile(t, "unix", int64(size), int64(size)+1)
    92  				})
    93  			}
    94  		})
    95  	})
    96  }
    97  
    98  func testSpliceFile(t *testing.T, proto string, size, limit int64) {
    99  	dst, src, data, hook, cleanup := newSpliceFileTest(t, proto, size)
   100  	defer cleanup()
   101  
   102  	// If we have a limit, wrap the reader.
   103  	var (
   104  		r  io.Reader
   105  		lr *io.LimitedReader
   106  	)
   107  	if limit >= 0 {
   108  		lr = &io.LimitedReader{N: limit, R: src}
   109  		r = lr
   110  		if limit < int64(len(data)) {
   111  			data = data[:limit]
   112  		}
   113  	} else {
   114  		r = src
   115  	}
   116  	// Now call ReadFrom (through io.Copy), which will hopefully call poll.Splice
   117  	n, err := io.Copy(dst, r)
   118  	if err != nil {
   119  		t.Fatal(err)
   120  	}
   121  
   122  	// We should have called poll.Splice with the right file descriptor arguments.
   123  	if n > 0 && !hook.called {
   124  		t.Fatal("expected to called poll.Splice")
   125  	}
   126  	if hook.called && hook.dstfd != int(dst.Fd()) {
   127  		t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
   128  	}
   129  	sc, ok := src.(syscall.Conn)
   130  	if !ok {
   131  		t.Fatalf("server Conn is not a syscall.Conn")
   132  	}
   133  	rc, err := sc.SyscallConn()
   134  	if err != nil {
   135  		t.Fatalf("server Conn SyscallConn error: %v", err)
   136  	}
   137  	if err = rc.Control(func(fd uintptr) {
   138  		if hook.called && hook.srcfd != int(fd) {
   139  			t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, int(fd))
   140  		}
   141  	}); err != nil {
   142  		t.Fatalf("server Conn Control error: %v", err)
   143  	}
   144  
   145  	// Check that the offsets after the transfer make sense, that the size
   146  	// of the transfer was reported correctly, and that the destination
   147  	// file contains exactly the bytes we expect it to contain.
   148  	dstoff, err := dst.Seek(0, io.SeekCurrent)
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  	if dstoff != int64(len(data)) {
   153  		t.Errorf("dstoff = %d, want %d", dstoff, len(data))
   154  	}
   155  	if n != int64(len(data)) {
   156  		t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
   157  	}
   158  	mustSeekStart(t, dst)
   159  	mustContainData(t, dst, data)
   160  
   161  	// If we had a limit, check that it was updated.
   162  	if lr != nil {
   163  		if want := limit - n; lr.N != want {
   164  			t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
   165  		}
   166  	}
   167  }
   168  
   169  // Issue #59041.
   170  func testSpliceToTTY(t *testing.T, proto string, size int64) {
   171  	var wg sync.WaitGroup
   172  
   173  	// Call wg.Wait as the final deferred function,
   174  	// because the goroutines may block until some of
   175  	// the deferred Close calls.
   176  	defer wg.Wait()
   177  
   178  	pty, ttyName, err := testpty.Open()
   179  	if err != nil {
   180  		t.Skipf("skipping test because pty open failed: %v", err)
   181  	}
   182  	defer pty.Close()
   183  
   184  	// Open the tty directly, rather than via OpenFile.
   185  	// This bypasses the non-blocking support and is required
   186  	// to recreate the problem in the issue (#59041).
   187  	ttyFD, err := syscall.Open(ttyName, syscall.O_RDWR, 0)
   188  	if err != nil {
   189  		t.Skipf("skipping test because failed to open tty: %v", err)
   190  	}
   191  	defer syscall.Close(ttyFD)
   192  
   193  	tty := NewFile(uintptr(ttyFD), "tty")
   194  	defer tty.Close()
   195  
   196  	client, server := createSocketPair(t, proto)
   197  
   198  	data := bytes.Repeat([]byte{'a'}, int(size))
   199  
   200  	wg.Add(1)
   201  	go func() {
   202  		defer wg.Done()
   203  		// The problem (issue #59041) occurs when writing
   204  		// a series of blocks of data. It does not occur
   205  		// when all the data is written at once.
   206  		for i := 0; i < len(data); i += 1024 {
   207  			if _, err := client.Write(data[i : i+1024]); err != nil {
   208  				// If we get here because the client was
   209  				// closed, skip the error.
   210  				if !errors.Is(err, net.ErrClosed) {
   211  					t.Errorf("error writing to socket: %v", err)
   212  				}
   213  				return
   214  			}
   215  		}
   216  		client.Close()
   217  	}()
   218  
   219  	wg.Add(1)
   220  	go func() {
   221  		defer wg.Done()
   222  		buf := make([]byte, 32)
   223  		for {
   224  			if _, err := pty.Read(buf); err != nil {
   225  				if err != io.EOF && !errors.Is(err, ErrClosed) {
   226  					// An error here doesn't matter for
   227  					// our test.
   228  					t.Logf("error reading from pty: %v", err)
   229  				}
   230  				return
   231  			}
   232  		}
   233  	}()
   234  
   235  	// Close Client to wake up the writing goroutine if necessary.
   236  	defer client.Close()
   237  
   238  	_, err = io.Copy(tty, server)
   239  	if err != nil {
   240  		t.Fatal(err)
   241  	}
   242  }
   243  
   244  var (
   245  	copyFileTests = []copyFileTestFunc{newCopyFileRangeTest, newSendfileOverCopyFileRangeTest}
   246  	copyFileHooks = []copyFileTestHook{hookCopyFileRange, hookSendFileOverCopyFileRange}
   247  )
   248  
   249  func testCopyFiles(t *testing.T, size, limit int64) {
   250  	testCopyFileRange(t, size, limit)
   251  	testSendfileOverCopyFileRange(t, size, limit)
   252  }
   253  
   254  func testCopyFileRange(t *testing.T, size int64, limit int64) {
   255  	dst, src, data, hook, name := newCopyFileRangeTest(t, size)
   256  	testCopyFile(t, dst, src, data, hook, limit, name)
   257  }
   258  
   259  func testSendfileOverCopyFileRange(t *testing.T, size int64, limit int64) {
   260  	dst, src, data, hook, name := newSendfileOverCopyFileRangeTest(t, size)
   261  	testCopyFile(t, dst, src, data, hook, limit, name)
   262  }
   263  
   264  // newCopyFileRangeTest initializes a new test for copy_file_range.
   265  //
   266  // It hooks package os' call to poll.CopyFileRange and returns the hook,
   267  // so it can be inspected.
   268  func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileHook, name string) {
   269  	t.Helper()
   270  
   271  	name = "newCopyFileRangeTest"
   272  
   273  	dst, src, data = newCopyFileTest(t, size)
   274  	hook, _ = hookCopyFileRange(t)
   275  
   276  	return
   277  }
   278  
   279  // newSendfileOverCopyFileRangeTest initializes a new test for sendfile over copy_file_range.
   280  // It hooks package os' call to poll.SendFile and returns the hook,
   281  // so it can be inspected.
   282  func newSendfileOverCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileHook, name string) {
   283  	t.Helper()
   284  
   285  	name = "newSendfileOverCopyFileRangeTest"
   286  
   287  	dst, src, data = newCopyFileTest(t, size)
   288  	hook, _ = hookSendFileOverCopyFileRange(t)
   289  
   290  	return
   291  }
   292  
   293  // newSpliceFileTest initializes a new test for splice.
   294  //
   295  // It creates source sockets and destination file, and populates the source sockets
   296  // with random data of the specified size. It also hooks package os' call
   297  // to poll.Splice and returns the hook so it can be inspected.
   298  func newSpliceFileTest(t *testing.T, proto string, size int64) (*File, net.Conn, []byte, *spliceFileHook, func()) {
   299  	t.Helper()
   300  
   301  	hook := hookSpliceFile(t)
   302  
   303  	client, server := createSocketPair(t, proto)
   304  
   305  	dst, err := CreateTemp(t.TempDir(), "dst-splice-file-test")
   306  	if err != nil {
   307  		t.Fatal(err)
   308  	}
   309  	t.Cleanup(func() { dst.Close() })
   310  
   311  	randSeed := time.Now().Unix()
   312  	t.Logf("random data seed: %d\n", randSeed)
   313  	prng := rand.New(rand.NewSource(randSeed))
   314  	data := make([]byte, size)
   315  	prng.Read(data)
   316  
   317  	done := make(chan struct{})
   318  	go func() {
   319  		client.Write(data)
   320  		client.Close()
   321  		close(done)
   322  	}()
   323  
   324  	return dst, server, data, hook, func() { <-done }
   325  }
   326  
   327  func hookCopyFileRange(t *testing.T) (hook *copyFileHook, name string) {
   328  	name = "hookCopyFileRange"
   329  
   330  	hook = new(copyFileHook)
   331  	orig := *PollCopyFileRangeP
   332  	t.Cleanup(func() {
   333  		*PollCopyFileRangeP = orig
   334  	})
   335  	*PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
   336  		hook.called = true
   337  		hook.dstfd = dst.Sysfd
   338  		hook.srcfd = src.Sysfd
   339  		hook.written, hook.handled, hook.err = orig(dst, src, remain)
   340  		return hook.written, hook.handled, hook.err
   341  	}
   342  	return
   343  }
   344  
   345  func hookSendFileOverCopyFileRange(t *testing.T) (*copyFileHook, string) {
   346  	return hookSendFileTB(t), "hookSendFileOverCopyFileRange"
   347  }
   348  
   349  func hookSendFileTB(tb testing.TB) *copyFileHook {
   350  	// Disable poll.CopyFileRange to force the fallback to poll.SendFile.
   351  	originalCopyFileRange := *PollCopyFileRangeP
   352  	*PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (written int64, handled bool, err error) {
   353  		return 0, false, nil
   354  	}
   355  
   356  	hook := new(copyFileHook)
   357  	orig := poll.TestHookDidSendFile
   358  	tb.Cleanup(func() {
   359  		*PollCopyFileRangeP = originalCopyFileRange
   360  		poll.TestHookDidSendFile = orig
   361  	})
   362  	poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) {
   363  		hook.called = true
   364  		hook.dstfd = dstFD.Sysfd
   365  		hook.srcfd = src
   366  		hook.written = written
   367  		hook.err = err
   368  		hook.handled = handled
   369  	}
   370  	return hook
   371  }
   372  
   373  func hookSpliceFile(t *testing.T) *spliceFileHook {
   374  	h := new(spliceFileHook)
   375  	h.install()
   376  	t.Cleanup(h.uninstall)
   377  	return h
   378  }
   379  
   380  type spliceFileHook struct {
   381  	called bool
   382  	dstfd  int
   383  	srcfd  int
   384  	remain int64
   385  
   386  	written int64
   387  	handled bool
   388  	err     error
   389  
   390  	original func(dst, src *poll.FD, remain int64) (int64, bool, error)
   391  }
   392  
   393  func (h *spliceFileHook) install() {
   394  	h.original = *PollSpliceFile
   395  	*PollSpliceFile = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
   396  		h.called = true
   397  		h.dstfd = dst.Sysfd
   398  		h.srcfd = src.Sysfd
   399  		h.remain = remain
   400  		h.written, h.handled, h.err = h.original(dst, src, remain)
   401  		return h.written, h.handled, h.err
   402  	}
   403  }
   404  
   405  func (h *spliceFileHook) uninstall() {
   406  	*PollSpliceFile = h.original
   407  }
   408  
   409  // On some kernels copy_file_range fails on files in /proc.
   410  func TestProcCopy(t *testing.T) {
   411  	t.Parallel()
   412  
   413  	const cmdlineFile = "/proc/self/cmdline"
   414  	cmdline, err := ReadFile(cmdlineFile)
   415  	if err != nil {
   416  		t.Skipf("can't read /proc file: %v", err)
   417  	}
   418  	in, err := Open(cmdlineFile)
   419  	if err != nil {
   420  		t.Fatal(err)
   421  	}
   422  	defer in.Close()
   423  	outFile := filepath.Join(t.TempDir(), "cmdline")
   424  	out, err := Create(outFile)
   425  	if err != nil {
   426  		t.Fatal(err)
   427  	}
   428  	if _, err := io.Copy(out, in); err != nil {
   429  		t.Fatal(err)
   430  	}
   431  	if err := out.Close(); err != nil {
   432  		t.Fatal(err)
   433  	}
   434  	copy, err := ReadFile(outFile)
   435  	if err != nil {
   436  		t.Fatal(err)
   437  	}
   438  	if !bytes.Equal(cmdline, copy) {
   439  		t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
   440  	}
   441  }
   442  
   443  func TestGetPollFDAndNetwork(t *testing.T) {
   444  	t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
   445  	t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
   446  }
   447  
   448  func testGetPollFDAndNetwork(t *testing.T, proto string) {
   449  	_, server := createSocketPair(t, proto)
   450  	sc, ok := server.(syscall.Conn)
   451  	if !ok {
   452  		t.Fatalf("server Conn is not a syscall.Conn")
   453  	}
   454  	rc, err := sc.SyscallConn()
   455  	if err != nil {
   456  		t.Fatalf("server SyscallConn error: %v", err)
   457  	}
   458  	if err = rc.Control(func(fd uintptr) {
   459  		pfd, network := GetPollFDAndNetwork(server)
   460  		if pfd == nil {
   461  			t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
   462  		}
   463  		if string(network) != proto {
   464  			t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
   465  		}
   466  		if pfd.Sysfd != int(fd) {
   467  			t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
   468  		}
   469  		if !pfd.IsStream {
   470  			t.Fatalf("expected IsStream to be true")
   471  		}
   472  		if err = pfd.Init(proto, true); err == nil {
   473  			t.Fatalf("Init should have failed with the initialized poll.FD and return EEXIST error")
   474  		}
   475  	}); err != nil {
   476  		t.Fatalf("server Control error: %v", err)
   477  	}
   478  }
   479  

View as plain text