Source file
src/os/writeto_linux_test.go
1
2
3
4
5 package os_test
6
7 import (
8 "bytes"
9 "internal/poll"
10 "io"
11 "math/rand"
12 "net"
13 . "os"
14 "strconv"
15 "syscall"
16 "testing"
17 "time"
18 )
19
20 func TestSendFile(t *testing.T) {
21 sizes := []int{
22 1,
23 42,
24 1025,
25 syscall.Getpagesize() + 1,
26 32769,
27 }
28 t.Run("sendfile-to-unix", func(t *testing.T) {
29 for _, size := range sizes {
30 t.Run(strconv.Itoa(size), func(t *testing.T) {
31 testSendFile(t, "unix", int64(size))
32 })
33 }
34 })
35 t.Run("sendfile-to-tcp", func(t *testing.T) {
36 for _, size := range sizes {
37 t.Run(strconv.Itoa(size), func(t *testing.T) {
38 testSendFile(t, "tcp", int64(size))
39 })
40 }
41 })
42 }
43
44 func testSendFile(t *testing.T, proto string, size int64) {
45 dst, src, recv, data, hook := newSendFileTest(t, proto, size)
46
47
48 n, err := io.Copy(dst, src)
49 if err != nil {
50 t.Fatalf("io.Copy error: %v", err)
51 }
52
53
54 if n > 0 && !hook.called {
55 t.Fatal("expected to called poll.SendFile")
56 }
57 if hook.called && hook.srcfd != int(src.Fd()) {
58 t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
59 }
60 sc, ok := dst.(syscall.Conn)
61 if !ok {
62 t.Fatalf("destination is not a syscall.Conn")
63 }
64 rc, err := sc.SyscallConn()
65 if err != nil {
66 t.Fatalf("destination SyscallConn error: %v", err)
67 }
68 if err = rc.Control(func(fd uintptr) {
69 if hook.called && hook.dstfd != int(fd) {
70 t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
71 }
72 }); err != nil {
73 t.Fatalf("destination Conn Control error: %v", err)
74 }
75
76
77 dataSize := len(data)
78 dstData := make([]byte, dataSize)
79 m, err := io.ReadFull(recv, dstData)
80 if err != nil {
81 t.Fatalf("server Conn Read error: %v", err)
82 }
83 if n != int64(dataSize) {
84 t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
85 }
86 if m != dataSize {
87 t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
88 }
89 if !bytes.Equal(dstData, data) {
90 t.Errorf("data mismatch, got %s, want %s", dstData, data)
91 }
92 }
93
94
95
96
97
98
99 func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
100 t.Helper()
101
102 hook := hookSendFile(t)
103
104 client, server := createSocketPair(t, proto)
105 tempFile, data := createTempFile(t, size)
106
107 return client, tempFile, server, data, hook
108 }
109
110 func hookSendFile(t *testing.T) *sendFileHook {
111 h := new(sendFileHook)
112 orig := poll.TestHookDidSendFile
113 t.Cleanup(func() {
114 poll.TestHookDidSendFile = orig
115 })
116 poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) {
117 h.called = true
118 h.dstfd = dstFD.Sysfd
119 h.srcfd = src
120 h.written = written
121 h.err = err
122 h.handled = handled
123 }
124 return h
125 }
126
127 type sendFileHook struct {
128 called bool
129 dstfd int
130 srcfd int
131
132 written int64
133 handled bool
134 err error
135 }
136
137 func createTempFile(t *testing.T, size int64) (*File, []byte) {
138 f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
139 if err != nil {
140 t.Fatalf("failed to create temporary file: %v", err)
141 }
142 t.Cleanup(func() {
143 f.Close()
144 })
145
146 randSeed := time.Now().Unix()
147 t.Logf("random data seed: %d\n", randSeed)
148 prng := rand.New(rand.NewSource(randSeed))
149 data := make([]byte, size)
150 prng.Read(data)
151 if _, err := f.Write(data); err != nil {
152 t.Fatalf("failed to create and feed the file: %v", err)
153 }
154 if err := f.Sync(); err != nil {
155 t.Fatalf("failed to save the file: %v", err)
156 }
157 if _, err := f.Seek(0, io.SeekStart); err != nil {
158 t.Fatalf("failed to rewind the file: %v", err)
159 }
160
161 return f, data
162 }
163
View as plain text