Source file src/os/writeto_linux_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package os_test
     6  
     7  import (
     8  	"bytes"
     9  	"internal/poll"
    10  	"io"
    11  	"math/rand"
    12  	"net"
    13  	. "os"
    14  	"strconv"
    15  	"syscall"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  func TestSendFile(t *testing.T) {
    21  	sizes := []int{
    22  		1,
    23  		42,
    24  		1025,
    25  		syscall.Getpagesize() + 1,
    26  		32769,
    27  	}
    28  	t.Run("sendfile-to-unix", func(t *testing.T) {
    29  		for _, size := range sizes {
    30  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    31  				testSendFile(t, "unix", int64(size))
    32  			})
    33  		}
    34  	})
    35  	t.Run("sendfile-to-tcp", func(t *testing.T) {
    36  		for _, size := range sizes {
    37  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    38  				testSendFile(t, "tcp", int64(size))
    39  			})
    40  		}
    41  	})
    42  }
    43  
    44  func testSendFile(t *testing.T, proto string, size int64) {
    45  	dst, src, recv, data, hook := newSendFileTest(t, proto, size)
    46  
    47  	// Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile
    48  	n, err := io.Copy(dst, src)
    49  	if err != nil {
    50  		t.Fatalf("io.Copy error: %v", err)
    51  	}
    52  
    53  	// We should have called poll.Splice with the right file descriptor arguments.
    54  	if n > 0 && !hook.called {
    55  		t.Fatal("expected to called poll.SendFile")
    56  	}
    57  	if hook.called && hook.srcfd != int(src.Fd()) {
    58  		t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
    59  	}
    60  	sc, ok := dst.(syscall.Conn)
    61  	if !ok {
    62  		t.Fatalf("destination is not a syscall.Conn")
    63  	}
    64  	rc, err := sc.SyscallConn()
    65  	if err != nil {
    66  		t.Fatalf("destination SyscallConn error: %v", err)
    67  	}
    68  	if err = rc.Control(func(fd uintptr) {
    69  		if hook.called && hook.dstfd != int(fd) {
    70  			t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
    71  		}
    72  	}); err != nil {
    73  		t.Fatalf("destination Conn Control error: %v", err)
    74  	}
    75  
    76  	// Verify the data size and content.
    77  	dataSize := len(data)
    78  	dstData := make([]byte, dataSize)
    79  	m, err := io.ReadFull(recv, dstData)
    80  	if err != nil {
    81  		t.Fatalf("server Conn Read error: %v", err)
    82  	}
    83  	if n != int64(dataSize) {
    84  		t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
    85  	}
    86  	if m != dataSize {
    87  		t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
    88  	}
    89  	if !bytes.Equal(dstData, data) {
    90  		t.Errorf("data mismatch, got %s, want %s", dstData, data)
    91  	}
    92  }
    93  
    94  // newSendFileTest initializes a new test for sendfile.
    95  //
    96  // It creates source file and destination sockets, and populates the source file
    97  // with random data of the specified size. It also hooks package os' call
    98  // to poll.Sendfile and returns the hook so it can be inspected.
    99  func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
   100  	t.Helper()
   101  
   102  	hook := hookSendFile(t)
   103  
   104  	client, server := createSocketPair(t, proto)
   105  	tempFile, data := createTempFile(t, size)
   106  
   107  	return client, tempFile, server, data, hook
   108  }
   109  
   110  func hookSendFile(t *testing.T) *sendFileHook {
   111  	h := new(sendFileHook)
   112  	orig := poll.TestHookDidSendFile
   113  	t.Cleanup(func() {
   114  		poll.TestHookDidSendFile = orig
   115  	})
   116  	poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) {
   117  		h.called = true
   118  		h.dstfd = dstFD.Sysfd
   119  		h.srcfd = src
   120  		h.written = written
   121  		h.err = err
   122  		h.handled = handled
   123  	}
   124  	return h
   125  }
   126  
   127  type sendFileHook struct {
   128  	called bool
   129  	dstfd  int
   130  	srcfd  int
   131  
   132  	written int64
   133  	handled bool
   134  	err     error
   135  }
   136  
   137  func createTempFile(t *testing.T, size int64) (*File, []byte) {
   138  	f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
   139  	if err != nil {
   140  		t.Fatalf("failed to create temporary file: %v", err)
   141  	}
   142  	t.Cleanup(func() {
   143  		f.Close()
   144  	})
   145  
   146  	randSeed := time.Now().Unix()
   147  	t.Logf("random data seed: %d\n", randSeed)
   148  	prng := rand.New(rand.NewSource(randSeed))
   149  	data := make([]byte, size)
   150  	prng.Read(data)
   151  	if _, err := f.Write(data); err != nil {
   152  		t.Fatalf("failed to create and feed the file: %v", err)
   153  	}
   154  	if err := f.Sync(); err != nil {
   155  		t.Fatalf("failed to save the file: %v", err)
   156  	}
   157  	if _, err := f.Seek(0, io.SeekStart); err != nil {
   158  		t.Fatalf("failed to rewind the file: %v", err)
   159  	}
   160  
   161  	return f, data
   162  }
   163  

View as plain text