Source file src/os/readfrom_linux_test.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package os_test
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"internal/poll"
    11  	"internal/testpty"
    12  	"io"
    13  	"math/rand"
    14  	"net"
    15  	. "os"
    16  	"path/filepath"
    17  	"strconv"
    18  	"strings"
    19  	"sync"
    20  	"syscall"
    21  	"testing"
    22  	"time"
    23  )
    24  
    25  func TestCopyFileRange(t *testing.T) {
    26  	sizes := []int{
    27  		1,
    28  		42,
    29  		1025,
    30  		syscall.Getpagesize() + 1,
    31  		32769,
    32  	}
    33  	t.Run("Basic", func(t *testing.T) {
    34  		for _, size := range sizes {
    35  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    36  				testCopyFileRange(t, int64(size), -1)
    37  			})
    38  		}
    39  	})
    40  	t.Run("Limited", func(t *testing.T) {
    41  		t.Run("OneLess", func(t *testing.T) {
    42  			for _, size := range sizes {
    43  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    44  					testCopyFileRange(t, int64(size), int64(size)-1)
    45  				})
    46  			}
    47  		})
    48  		t.Run("Half", func(t *testing.T) {
    49  			for _, size := range sizes {
    50  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    51  					testCopyFileRange(t, int64(size), int64(size)/2)
    52  				})
    53  			}
    54  		})
    55  		t.Run("More", func(t *testing.T) {
    56  			for _, size := range sizes {
    57  				t.Run(strconv.Itoa(size), func(t *testing.T) {
    58  					testCopyFileRange(t, int64(size), int64(size)+7)
    59  				})
    60  			}
    61  		})
    62  	})
    63  	t.Run("DoesntTryInAppendMode", func(t *testing.T) {
    64  		dst, src, data, hook := newCopyFileRangeTest(t, 42)
    65  
    66  		dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755)
    67  		if err != nil {
    68  			t.Fatal(err)
    69  		}
    70  		defer dst2.Close()
    71  
    72  		if _, err := io.Copy(dst2, src); err != nil {
    73  			t.Fatal(err)
    74  		}
    75  		if hook.called {
    76  			t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode")
    77  		}
    78  		mustSeekStart(t, dst2)
    79  		mustContainData(t, dst2, data) // through traditional means
    80  	})
    81  	t.Run("CopyFileItself", func(t *testing.T) {
    82  		hook := hookCopyFileRange(t)
    83  
    84  		f, err := CreateTemp("", "file-readfrom-itself-test")
    85  		if err != nil {
    86  			t.Fatalf("failed to create tmp file: %v", err)
    87  		}
    88  		t.Cleanup(func() {
    89  			f.Close()
    90  			Remove(f.Name())
    91  		})
    92  
    93  		data := []byte("hello world!")
    94  		if _, err := f.Write(data); err != nil {
    95  			t.Fatalf("failed to create and feed the file: %v", err)
    96  		}
    97  
    98  		if err := f.Sync(); err != nil {
    99  			t.Fatalf("failed to save the file: %v", err)
   100  		}
   101  
   102  		// Rewind it.
   103  		if _, err := f.Seek(0, io.SeekStart); err != nil {
   104  			t.Fatalf("failed to rewind the file: %v", err)
   105  		}
   106  
   107  		// Read data from the file itself.
   108  		if _, err := io.Copy(f, f); err != nil {
   109  			t.Fatalf("failed to read from the file: %v", err)
   110  		}
   111  
   112  		if !hook.called || hook.written != 0 || hook.handled || hook.err != nil {
   113  			t.Fatalf("poll.CopyFileRange should be called and return the EINVAL error, but got hook.called=%t, hook.err=%v", hook.called, hook.err)
   114  		}
   115  
   116  		// Rewind it.
   117  		if _, err := f.Seek(0, io.SeekStart); err != nil {
   118  			t.Fatalf("failed to rewind the file: %v", err)
   119  		}
   120  
   121  		data2, err := io.ReadAll(f)
   122  		if err != nil {
   123  			t.Fatalf("failed to read from the file: %v", err)
   124  		}
   125  
   126  		// It should wind up a double of the original data.
   127  		if strings.Repeat(string(data), 2) != string(data2) {
   128  			t.Fatalf("data mismatch: %s != %s", string(data), string(data2))
   129  		}
   130  	})
   131  	t.Run("NotRegular", func(t *testing.T) {
   132  		t.Run("BothPipes", func(t *testing.T) {
   133  			hook := hookCopyFileRange(t)
   134  
   135  			pr1, pw1, err := Pipe()
   136  			if err != nil {
   137  				t.Fatal(err)
   138  			}
   139  			defer pr1.Close()
   140  			defer pw1.Close()
   141  
   142  			pr2, pw2, err := Pipe()
   143  			if err != nil {
   144  				t.Fatal(err)
   145  			}
   146  			defer pr2.Close()
   147  			defer pw2.Close()
   148  
   149  			// The pipe is empty, and PIPE_BUF is large enough
   150  			// for this, by (POSIX) definition, so there is no
   151  			// need for an additional goroutine.
   152  			data := []byte("hello")
   153  			if _, err := pw1.Write(data); err != nil {
   154  				t.Fatal(err)
   155  			}
   156  			pw1.Close()
   157  
   158  			n, err := io.Copy(pw2, pr1)
   159  			if err != nil {
   160  				t.Fatal(err)
   161  			}
   162  			if n != int64(len(data)) {
   163  				t.Fatalf("transferred %d, want %d", n, len(data))
   164  			}
   165  			if !hook.called {
   166  				t.Fatalf("should have called poll.CopyFileRange")
   167  			}
   168  			pw2.Close()
   169  			mustContainData(t, pr2, data)
   170  		})
   171  		t.Run("DstPipe", func(t *testing.T) {
   172  			dst, src, data, hook := newCopyFileRangeTest(t, 255)
   173  			dst.Close()
   174  
   175  			pr, pw, err := Pipe()
   176  			if err != nil {
   177  				t.Fatal(err)
   178  			}
   179  			defer pr.Close()
   180  			defer pw.Close()
   181  
   182  			n, err := io.Copy(pw, src)
   183  			if err != nil {
   184  				t.Fatal(err)
   185  			}
   186  			if n != int64(len(data)) {
   187  				t.Fatalf("transferred %d, want %d", n, len(data))
   188  			}
   189  			if !hook.called {
   190  				t.Fatalf("should have called poll.CopyFileRange")
   191  			}
   192  			pw.Close()
   193  			mustContainData(t, pr, data)
   194  		})
   195  		t.Run("SrcPipe", func(t *testing.T) {
   196  			dst, src, data, hook := newCopyFileRangeTest(t, 255)
   197  			src.Close()
   198  
   199  			pr, pw, err := Pipe()
   200  			if err != nil {
   201  				t.Fatal(err)
   202  			}
   203  			defer pr.Close()
   204  			defer pw.Close()
   205  
   206  			// The pipe is empty, and PIPE_BUF is large enough
   207  			// for this, by (POSIX) definition, so there is no
   208  			// need for an additional goroutine.
   209  			if _, err := pw.Write(data); err != nil {
   210  				t.Fatal(err)
   211  			}
   212  			pw.Close()
   213  
   214  			n, err := io.Copy(dst, pr)
   215  			if err != nil {
   216  				t.Fatal(err)
   217  			}
   218  			if n != int64(len(data)) {
   219  				t.Fatalf("transferred %d, want %d", n, len(data))
   220  			}
   221  			if !hook.called {
   222  				t.Fatalf("should have called poll.CopyFileRange")
   223  			}
   224  			mustSeekStart(t, dst)
   225  			mustContainData(t, dst, data)
   226  		})
   227  	})
   228  	t.Run("Nil", func(t *testing.T) {
   229  		var nilFile *File
   230  		anyFile, err := CreateTemp("", "")
   231  		if err != nil {
   232  			t.Fatal(err)
   233  		}
   234  		defer Remove(anyFile.Name())
   235  		defer anyFile.Close()
   236  
   237  		if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid {
   238  			t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid)
   239  		}
   240  		if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid {
   241  			t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid)
   242  		}
   243  		if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid {
   244  			t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid)
   245  		}
   246  
   247  		if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid {
   248  			t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
   249  		}
   250  		if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid {
   251  			t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
   252  		}
   253  		if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid {
   254  			t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid)
   255  		}
   256  	})
   257  }
   258  
   259  func TestSpliceFile(t *testing.T) {
   260  	sizes := []int{
   261  		1,
   262  		42,
   263  		1025,
   264  		syscall.Getpagesize() + 1,
   265  		32769,
   266  	}
   267  	t.Run("Basic-TCP", func(t *testing.T) {
   268  		for _, size := range sizes {
   269  			t.Run(strconv.Itoa(size), func(t *testing.T) {
   270  				testSpliceFile(t, "tcp", int64(size), -1)
   271  			})
   272  		}
   273  	})
   274  	t.Run("Basic-Unix", func(t *testing.T) {
   275  		for _, size := range sizes {
   276  			t.Run(strconv.Itoa(size), func(t *testing.T) {
   277  				testSpliceFile(t, "unix", int64(size), -1)
   278  			})
   279  		}
   280  	})
   281  	t.Run("TCP-To-TTY", func(t *testing.T) {
   282  		testSpliceToTTY(t, "tcp", 32768)
   283  	})
   284  	t.Run("Unix-To-TTY", func(t *testing.T) {
   285  		testSpliceToTTY(t, "unix", 32768)
   286  	})
   287  	t.Run("Limited", func(t *testing.T) {
   288  		t.Run("OneLess-TCP", func(t *testing.T) {
   289  			for _, size := range sizes {
   290  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   291  					testSpliceFile(t, "tcp", int64(size), int64(size)-1)
   292  				})
   293  			}
   294  		})
   295  		t.Run("OneLess-Unix", func(t *testing.T) {
   296  			for _, size := range sizes {
   297  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   298  					testSpliceFile(t, "unix", int64(size), int64(size)-1)
   299  				})
   300  			}
   301  		})
   302  		t.Run("Half-TCP", func(t *testing.T) {
   303  			for _, size := range sizes {
   304  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   305  					testSpliceFile(t, "tcp", int64(size), int64(size)/2)
   306  				})
   307  			}
   308  		})
   309  		t.Run("Half-Unix", func(t *testing.T) {
   310  			for _, size := range sizes {
   311  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   312  					testSpliceFile(t, "unix", int64(size), int64(size)/2)
   313  				})
   314  			}
   315  		})
   316  		t.Run("More-TCP", func(t *testing.T) {
   317  			for _, size := range sizes {
   318  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   319  					testSpliceFile(t, "tcp", int64(size), int64(size)+1)
   320  				})
   321  			}
   322  		})
   323  		t.Run("More-Unix", func(t *testing.T) {
   324  			for _, size := range sizes {
   325  				t.Run(strconv.Itoa(size), func(t *testing.T) {
   326  					testSpliceFile(t, "unix", int64(size), int64(size)+1)
   327  				})
   328  			}
   329  		})
   330  	})
   331  }
   332  
   333  func testSpliceFile(t *testing.T, proto string, size, limit int64) {
   334  	dst, src, data, hook, cleanup := newSpliceFileTest(t, proto, size)
   335  	defer cleanup()
   336  
   337  	// If we have a limit, wrap the reader.
   338  	var (
   339  		r  io.Reader
   340  		lr *io.LimitedReader
   341  	)
   342  	if limit >= 0 {
   343  		lr = &io.LimitedReader{N: limit, R: src}
   344  		r = lr
   345  		if limit < int64(len(data)) {
   346  			data = data[:limit]
   347  		}
   348  	} else {
   349  		r = src
   350  	}
   351  	// Now call ReadFrom (through io.Copy), which will hopefully call poll.Splice
   352  	n, err := io.Copy(dst, r)
   353  	if err != nil {
   354  		t.Fatal(err)
   355  	}
   356  
   357  	// We should have called poll.Splice with the right file descriptor arguments.
   358  	if n > 0 && !hook.called {
   359  		t.Fatal("expected to called poll.Splice")
   360  	}
   361  	if hook.called && hook.dstfd != int(dst.Fd()) {
   362  		t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
   363  	}
   364  	sc, ok := src.(syscall.Conn)
   365  	if !ok {
   366  		t.Fatalf("server Conn is not a syscall.Conn")
   367  	}
   368  	rc, err := sc.SyscallConn()
   369  	if err != nil {
   370  		t.Fatalf("server Conn SyscallConn error: %v", err)
   371  	}
   372  	if err = rc.Control(func(fd uintptr) {
   373  		if hook.called && hook.srcfd != int(fd) {
   374  			t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, int(fd))
   375  		}
   376  	}); err != nil {
   377  		t.Fatalf("server Conn Control error: %v", err)
   378  	}
   379  
   380  	// Check that the offsets after the transfer make sense, that the size
   381  	// of the transfer was reported correctly, and that the destination
   382  	// file contains exactly the bytes we expect it to contain.
   383  	dstoff, err := dst.Seek(0, io.SeekCurrent)
   384  	if err != nil {
   385  		t.Fatal(err)
   386  	}
   387  	if dstoff != int64(len(data)) {
   388  		t.Errorf("dstoff = %d, want %d", dstoff, len(data))
   389  	}
   390  	if n != int64(len(data)) {
   391  		t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
   392  	}
   393  	mustSeekStart(t, dst)
   394  	mustContainData(t, dst, data)
   395  
   396  	// If we had a limit, check that it was updated.
   397  	if lr != nil {
   398  		if want := limit - n; lr.N != want {
   399  			t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
   400  		}
   401  	}
   402  }
   403  
   404  // Issue #59041.
   405  func testSpliceToTTY(t *testing.T, proto string, size int64) {
   406  	var wg sync.WaitGroup
   407  
   408  	// Call wg.Wait as the final deferred function,
   409  	// because the goroutines may block until some of
   410  	// the deferred Close calls.
   411  	defer wg.Wait()
   412  
   413  	pty, ttyName, err := testpty.Open()
   414  	if err != nil {
   415  		t.Skipf("skipping test because pty open failed: %v", err)
   416  	}
   417  	defer pty.Close()
   418  
   419  	// Open the tty directly, rather than via OpenFile.
   420  	// This bypasses the non-blocking support and is required
   421  	// to recreate the problem in the issue (#59041).
   422  	ttyFD, err := syscall.Open(ttyName, syscall.O_RDWR, 0)
   423  	if err != nil {
   424  		t.Skipf("skipping test because failed to open tty: %v", err)
   425  	}
   426  	defer syscall.Close(ttyFD)
   427  
   428  	tty := NewFile(uintptr(ttyFD), "tty")
   429  	defer tty.Close()
   430  
   431  	client, server := createSocketPair(t, proto)
   432  
   433  	data := bytes.Repeat([]byte{'a'}, int(size))
   434  
   435  	wg.Add(1)
   436  	go func() {
   437  		defer wg.Done()
   438  		// The problem (issue #59041) occurs when writing
   439  		// a series of blocks of data. It does not occur
   440  		// when all the data is written at once.
   441  		for i := 0; i < len(data); i += 1024 {
   442  			if _, err := client.Write(data[i : i+1024]); err != nil {
   443  				// If we get here because the client was
   444  				// closed, skip the error.
   445  				if !errors.Is(err, net.ErrClosed) {
   446  					t.Errorf("error writing to socket: %v", err)
   447  				}
   448  				return
   449  			}
   450  		}
   451  		client.Close()
   452  	}()
   453  
   454  	wg.Add(1)
   455  	go func() {
   456  		defer wg.Done()
   457  		buf := make([]byte, 32)
   458  		for {
   459  			if _, err := pty.Read(buf); err != nil {
   460  				if err != io.EOF && !errors.Is(err, ErrClosed) {
   461  					// An error here doesn't matter for
   462  					// our test.
   463  					t.Logf("error reading from pty: %v", err)
   464  				}
   465  				return
   466  			}
   467  		}
   468  	}()
   469  
   470  	// Close Client to wake up the writing goroutine if necessary.
   471  	defer client.Close()
   472  
   473  	_, err = io.Copy(tty, server)
   474  	if err != nil {
   475  		t.Fatal(err)
   476  	}
   477  }
   478  
   479  func testCopyFileRange(t *testing.T, size int64, limit int64) {
   480  	dst, src, data, hook := newCopyFileRangeTest(t, size)
   481  
   482  	// If we have a limit, wrap the reader.
   483  	var (
   484  		realsrc io.Reader
   485  		lr      *io.LimitedReader
   486  	)
   487  	if limit >= 0 {
   488  		lr = &io.LimitedReader{N: limit, R: src}
   489  		realsrc = lr
   490  		if limit < int64(len(data)) {
   491  			data = data[:limit]
   492  		}
   493  	} else {
   494  		realsrc = src
   495  	}
   496  
   497  	// Now call ReadFrom (through io.Copy), which will hopefully call
   498  	// poll.CopyFileRange.
   499  	n, err := io.Copy(dst, realsrc)
   500  	if err != nil {
   501  		t.Fatal(err)
   502  	}
   503  
   504  	// If we didn't have a limit, we should have called poll.CopyFileRange
   505  	// with the right file descriptor arguments.
   506  	if limit > 0 && !hook.called {
   507  		t.Fatal("never called poll.CopyFileRange")
   508  	}
   509  	if hook.called && hook.dstfd != int(dst.Fd()) {
   510  		t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
   511  	}
   512  	if hook.called && hook.srcfd != int(src.Fd()) {
   513  		t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
   514  	}
   515  
   516  	// Check that the offsets after the transfer make sense, that the size
   517  	// of the transfer was reported correctly, and that the destination
   518  	// file contains exactly the bytes we expect it to contain.
   519  	dstoff, err := dst.Seek(0, io.SeekCurrent)
   520  	if err != nil {
   521  		t.Fatal(err)
   522  	}
   523  	srcoff, err := src.Seek(0, io.SeekCurrent)
   524  	if err != nil {
   525  		t.Fatal(err)
   526  	}
   527  	if dstoff != srcoff {
   528  		t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff)
   529  	}
   530  	if dstoff != int64(len(data)) {
   531  		t.Errorf("dstoff = %d, want %d", dstoff, len(data))
   532  	}
   533  	if n != int64(len(data)) {
   534  		t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
   535  	}
   536  	mustSeekStart(t, dst)
   537  	mustContainData(t, dst, data)
   538  
   539  	// If we had a limit, check that it was updated.
   540  	if lr != nil {
   541  		if want := limit - n; lr.N != want {
   542  			t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
   543  		}
   544  	}
   545  }
   546  
   547  // newCopyFileRangeTest initializes a new test for copy_file_range.
   548  //
   549  // It creates source and destination files, and populates the source file
   550  // with random data of the specified size. It also hooks package os' call
   551  // to poll.CopyFileRange and returns the hook so it can be inspected.
   552  func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) {
   553  	t.Helper()
   554  
   555  	hook = hookCopyFileRange(t)
   556  	tmp := t.TempDir()
   557  
   558  	src, err := Create(filepath.Join(tmp, "src"))
   559  	if err != nil {
   560  		t.Fatal(err)
   561  	}
   562  	t.Cleanup(func() { src.Close() })
   563  
   564  	dst, err = Create(filepath.Join(tmp, "dst"))
   565  	if err != nil {
   566  		t.Fatal(err)
   567  	}
   568  	t.Cleanup(func() { dst.Close() })
   569  
   570  	// Populate the source file with data, then rewind it, so it can be
   571  	// consumed by copy_file_range(2).
   572  	prng := rand.New(rand.NewSource(time.Now().Unix()))
   573  	data = make([]byte, size)
   574  	prng.Read(data)
   575  	if _, err := src.Write(data); err != nil {
   576  		t.Fatal(err)
   577  	}
   578  	if _, err := src.Seek(0, io.SeekStart); err != nil {
   579  		t.Fatal(err)
   580  	}
   581  
   582  	return dst, src, data, hook
   583  }
   584  
   585  // newSpliceFileTest initializes a new test for splice.
   586  //
   587  // It creates source sockets and destination file, and populates the source sockets
   588  // with random data of the specified size. It also hooks package os' call
   589  // to poll.Splice and returns the hook so it can be inspected.
   590  func newSpliceFileTest(t *testing.T, proto string, size int64) (*File, net.Conn, []byte, *spliceFileHook, func()) {
   591  	t.Helper()
   592  
   593  	hook := hookSpliceFile(t)
   594  
   595  	client, server := createSocketPair(t, proto)
   596  
   597  	dst, err := CreateTemp(t.TempDir(), "dst-splice-file-test")
   598  	if err != nil {
   599  		t.Fatal(err)
   600  	}
   601  	t.Cleanup(func() { dst.Close() })
   602  
   603  	randSeed := time.Now().Unix()
   604  	t.Logf("random data seed: %d\n", randSeed)
   605  	prng := rand.New(rand.NewSource(randSeed))
   606  	data := make([]byte, size)
   607  	prng.Read(data)
   608  
   609  	done := make(chan struct{})
   610  	go func() {
   611  		client.Write(data)
   612  		client.Close()
   613  		close(done)
   614  	}()
   615  
   616  	return dst, server, data, hook, func() { <-done }
   617  }
   618  
   619  // mustContainData ensures that the specified file contains exactly the
   620  // specified data.
   621  func mustContainData(t *testing.T, f *File, data []byte) {
   622  	t.Helper()
   623  
   624  	got := make([]byte, len(data))
   625  	if _, err := io.ReadFull(f, got); err != nil {
   626  		t.Fatal(err)
   627  	}
   628  	if !bytes.Equal(got, data) {
   629  		t.Fatalf("didn't get the same data back from %s", f.Name())
   630  	}
   631  	if _, err := f.Read(make([]byte, 1)); err != io.EOF {
   632  		t.Fatalf("not at EOF")
   633  	}
   634  }
   635  
   636  func mustSeekStart(t *testing.T, f *File) {
   637  	if _, err := f.Seek(0, io.SeekStart); err != nil {
   638  		t.Fatal(err)
   639  	}
   640  }
   641  
   642  func hookCopyFileRange(t *testing.T) *copyFileRangeHook {
   643  	h := new(copyFileRangeHook)
   644  	h.install()
   645  	t.Cleanup(h.uninstall)
   646  	return h
   647  }
   648  
   649  type copyFileRangeHook struct {
   650  	called bool
   651  	dstfd  int
   652  	srcfd  int
   653  	remain int64
   654  
   655  	written int64
   656  	handled bool
   657  	err     error
   658  
   659  	original func(dst, src *poll.FD, remain int64) (int64, bool, error)
   660  }
   661  
   662  func (h *copyFileRangeHook) install() {
   663  	h.original = *PollCopyFileRangeP
   664  	*PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
   665  		h.called = true
   666  		h.dstfd = dst.Sysfd
   667  		h.srcfd = src.Sysfd
   668  		h.remain = remain
   669  		h.written, h.handled, h.err = h.original(dst, src, remain)
   670  		return h.written, h.handled, h.err
   671  	}
   672  }
   673  
   674  func (h *copyFileRangeHook) uninstall() {
   675  	*PollCopyFileRangeP = h.original
   676  }
   677  
   678  func hookSpliceFile(t *testing.T) *spliceFileHook {
   679  	h := new(spliceFileHook)
   680  	h.install()
   681  	t.Cleanup(h.uninstall)
   682  	return h
   683  }
   684  
   685  type spliceFileHook struct {
   686  	called bool
   687  	dstfd  int
   688  	srcfd  int
   689  	remain int64
   690  
   691  	written int64
   692  	handled bool
   693  	err     error
   694  
   695  	original func(dst, src *poll.FD, remain int64) (int64, bool, error)
   696  }
   697  
   698  func (h *spliceFileHook) install() {
   699  	h.original = *PollSpliceFile
   700  	*PollSpliceFile = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
   701  		h.called = true
   702  		h.dstfd = dst.Sysfd
   703  		h.srcfd = src.Sysfd
   704  		h.remain = remain
   705  		h.written, h.handled, h.err = h.original(dst, src, remain)
   706  		return h.written, h.handled, h.err
   707  	}
   708  }
   709  
   710  func (h *spliceFileHook) uninstall() {
   711  	*PollSpliceFile = h.original
   712  }
   713  
   714  // On some kernels copy_file_range fails on files in /proc.
   715  func TestProcCopy(t *testing.T) {
   716  	t.Parallel()
   717  
   718  	const cmdlineFile = "/proc/self/cmdline"
   719  	cmdline, err := ReadFile(cmdlineFile)
   720  	if err != nil {
   721  		t.Skipf("can't read /proc file: %v", err)
   722  	}
   723  	in, err := Open(cmdlineFile)
   724  	if err != nil {
   725  		t.Fatal(err)
   726  	}
   727  	defer in.Close()
   728  	outFile := filepath.Join(t.TempDir(), "cmdline")
   729  	out, err := Create(outFile)
   730  	if err != nil {
   731  		t.Fatal(err)
   732  	}
   733  	if _, err := io.Copy(out, in); err != nil {
   734  		t.Fatal(err)
   735  	}
   736  	if err := out.Close(); err != nil {
   737  		t.Fatal(err)
   738  	}
   739  	copy, err := ReadFile(outFile)
   740  	if err != nil {
   741  		t.Fatal(err)
   742  	}
   743  	if !bytes.Equal(cmdline, copy) {
   744  		t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
   745  	}
   746  }
   747  
   748  func TestGetPollFDAndNetwork(t *testing.T) {
   749  	t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
   750  	t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
   751  }
   752  
   753  func testGetPollFDAndNetwork(t *testing.T, proto string) {
   754  	_, server := createSocketPair(t, proto)
   755  	sc, ok := server.(syscall.Conn)
   756  	if !ok {
   757  		t.Fatalf("server Conn is not a syscall.Conn")
   758  	}
   759  	rc, err := sc.SyscallConn()
   760  	if err != nil {
   761  		t.Fatalf("server SyscallConn error: %v", err)
   762  	}
   763  	if err = rc.Control(func(fd uintptr) {
   764  		pfd, network := GetPollFDAndNetwork(server)
   765  		if pfd == nil {
   766  			t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
   767  		}
   768  		if string(network) != proto {
   769  			t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
   770  		}
   771  		if pfd.Sysfd != int(fd) {
   772  			t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
   773  		}
   774  		if !pfd.IsStream {
   775  			t.Fatalf("expected IsStream to be true")
   776  		}
   777  		if err = pfd.Init(proto, true); err == nil {
   778  			t.Fatalf("Init should have failed with the initialized poll.FD and return EEXIST error")
   779  		}
   780  	}); err != nil {
   781  		t.Fatalf("server Control error: %v", err)
   782  	}
   783  }
   784  

View as plain text