Source file src/os/writeto_linux_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package os_test
     6  
     7  import (
     8  	"bytes"
     9  	"internal/poll"
    10  	"io"
    11  	"net"
    12  	. "os"
    13  	"strconv"
    14  	"syscall"
    15  	"testing"
    16  )
    17  
    18  func TestSendFile(t *testing.T) {
    19  	sizes := []int{
    20  		1,
    21  		42,
    22  		1025,
    23  		syscall.Getpagesize() + 1,
    24  		32769,
    25  	}
    26  	t.Run("sendfile-to-unix", func(t *testing.T) {
    27  		for _, size := range sizes {
    28  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    29  				testSendFile(t, "unix", int64(size))
    30  			})
    31  		}
    32  	})
    33  	t.Run("sendfile-to-tcp", func(t *testing.T) {
    34  		for _, size := range sizes {
    35  			t.Run(strconv.Itoa(size), func(t *testing.T) {
    36  				testSendFile(t, "tcp", int64(size))
    37  			})
    38  		}
    39  	})
    40  }
    41  
    42  func testSendFile(t *testing.T, proto string, size int64) {
    43  	dst, src, recv, data, hook := newSendFileTest(t, proto, size)
    44  
    45  	// Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile
    46  	n, err := io.Copy(dst, src)
    47  	if err != nil {
    48  		t.Fatalf("io.Copy error: %v", err)
    49  	}
    50  
    51  	// We should have called poll.Splice with the right file descriptor arguments.
    52  	if n > 0 && !hook.called {
    53  		t.Fatal("expected to called poll.SendFile")
    54  	}
    55  	if hook.called && hook.srcfd != int(src.Fd()) {
    56  		t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
    57  	}
    58  	sc, ok := dst.(syscall.Conn)
    59  	if !ok {
    60  		t.Fatalf("destination is not a syscall.Conn")
    61  	}
    62  	rc, err := sc.SyscallConn()
    63  	if err != nil {
    64  		t.Fatalf("destination SyscallConn error: %v", err)
    65  	}
    66  	if err = rc.Control(func(fd uintptr) {
    67  		if hook.called && hook.dstfd != int(fd) {
    68  			t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
    69  		}
    70  	}); err != nil {
    71  		t.Fatalf("destination Conn Control error: %v", err)
    72  	}
    73  
    74  	// Verify the data size and content.
    75  	dataSize := len(data)
    76  	dstData := make([]byte, dataSize)
    77  	m, err := io.ReadFull(recv, dstData)
    78  	if err != nil {
    79  		t.Fatalf("server Conn Read error: %v", err)
    80  	}
    81  	if n != int64(dataSize) {
    82  		t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
    83  	}
    84  	if m != dataSize {
    85  		t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
    86  	}
    87  	if !bytes.Equal(dstData, data) {
    88  		t.Errorf("data mismatch, got %s, want %s", dstData, data)
    89  	}
    90  }
    91  
    92  // newSendFileTest initializes a new test for sendfile.
    93  //
    94  // It creates source file and destination sockets, and populates the source file
    95  // with random data of the specified size. It also hooks package os' call
    96  // to poll.Sendfile and returns the hook so it can be inspected.
    97  func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
    98  	t.Helper()
    99  
   100  	hook := hookSendFile(t)
   101  
   102  	client, server := createSocketPair(t, proto)
   103  	tempFile, data := createTempFile(t, "writeto-sendfile-to-socket", size)
   104  
   105  	return client, tempFile, server, data, hook
   106  }
   107  
   108  func hookSendFile(t *testing.T) *sendFileHook {
   109  	h := new(sendFileHook)
   110  	orig := poll.TestHookDidSendFile
   111  	t.Cleanup(func() {
   112  		poll.TestHookDidSendFile = orig
   113  	})
   114  	poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) {
   115  		h.called = true
   116  		h.dstfd = dstFD.Sysfd
   117  		h.srcfd = src
   118  		h.written = written
   119  		h.err = err
   120  		h.handled = handled
   121  	}
   122  	return h
   123  }
   124  
   125  type sendFileHook struct {
   126  	called bool
   127  	dstfd  int
   128  	srcfd  int
   129  
   130  	written int64
   131  	handled bool
   132  	err     error
   133  }
   134  

View as plain text