Source file
src/os/readfrom_linux_test.go
1
2
3
4
5 package os_test
6
7 import (
8 "bytes"
9 "errors"
10 "internal/poll"
11 "internal/testpty"
12 "io"
13 "math/rand"
14 "net"
15 . "os"
16 "path/filepath"
17 "strconv"
18 "sync"
19 "syscall"
20 "testing"
21 "time"
22 )
23
24 func TestSpliceFile(t *testing.T) {
25 sizes := []int{
26 1,
27 42,
28 1025,
29 syscall.Getpagesize() + 1,
30 32769,
31 }
32 t.Run("Basic-TCP", func(t *testing.T) {
33 for _, size := range sizes {
34 t.Run(strconv.Itoa(size), func(t *testing.T) {
35 testSpliceFile(t, "tcp", int64(size), -1)
36 })
37 }
38 })
39 t.Run("Basic-Unix", func(t *testing.T) {
40 for _, size := range sizes {
41 t.Run(strconv.Itoa(size), func(t *testing.T) {
42 testSpliceFile(t, "unix", int64(size), -1)
43 })
44 }
45 })
46 t.Run("TCP-To-TTY", func(t *testing.T) {
47 testSpliceToTTY(t, "tcp", 32768)
48 })
49 t.Run("Unix-To-TTY", func(t *testing.T) {
50 testSpliceToTTY(t, "unix", 32768)
51 })
52 t.Run("Limited", func(t *testing.T) {
53 t.Run("OneLess-TCP", func(t *testing.T) {
54 for _, size := range sizes {
55 t.Run(strconv.Itoa(size), func(t *testing.T) {
56 testSpliceFile(t, "tcp", int64(size), int64(size)-1)
57 })
58 }
59 })
60 t.Run("OneLess-Unix", func(t *testing.T) {
61 for _, size := range sizes {
62 t.Run(strconv.Itoa(size), func(t *testing.T) {
63 testSpliceFile(t, "unix", int64(size), int64(size)-1)
64 })
65 }
66 })
67 t.Run("Half-TCP", func(t *testing.T) {
68 for _, size := range sizes {
69 t.Run(strconv.Itoa(size), func(t *testing.T) {
70 testSpliceFile(t, "tcp", int64(size), int64(size)/2)
71 })
72 }
73 })
74 t.Run("Half-Unix", func(t *testing.T) {
75 for _, size := range sizes {
76 t.Run(strconv.Itoa(size), func(t *testing.T) {
77 testSpliceFile(t, "unix", int64(size), int64(size)/2)
78 })
79 }
80 })
81 t.Run("More-TCP", func(t *testing.T) {
82 for _, size := range sizes {
83 t.Run(strconv.Itoa(size), func(t *testing.T) {
84 testSpliceFile(t, "tcp", int64(size), int64(size)+1)
85 })
86 }
87 })
88 t.Run("More-Unix", func(t *testing.T) {
89 for _, size := range sizes {
90 t.Run(strconv.Itoa(size), func(t *testing.T) {
91 testSpliceFile(t, "unix", int64(size), int64(size)+1)
92 })
93 }
94 })
95 })
96 }
97
98 func testSpliceFile(t *testing.T, proto string, size, limit int64) {
99 dst, src, data, hook, cleanup := newSpliceFileTest(t, proto, size)
100 defer cleanup()
101
102
103 var (
104 r io.Reader
105 lr *io.LimitedReader
106 )
107 if limit >= 0 {
108 lr = &io.LimitedReader{N: limit, R: src}
109 r = lr
110 if limit < int64(len(data)) {
111 data = data[:limit]
112 }
113 } else {
114 r = src
115 }
116
117 n, err := io.Copy(dst, r)
118 if err != nil {
119 t.Fatal(err)
120 }
121
122
123 if n > 0 && !hook.called {
124 t.Fatal("expected to called poll.Splice")
125 }
126 if hook.called && hook.dstfd != int(dst.Fd()) {
127 t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
128 }
129 sc, ok := src.(syscall.Conn)
130 if !ok {
131 t.Fatalf("server Conn is not a syscall.Conn")
132 }
133 rc, err := sc.SyscallConn()
134 if err != nil {
135 t.Fatalf("server Conn SyscallConn error: %v", err)
136 }
137 if err = rc.Control(func(fd uintptr) {
138 if hook.called && hook.srcfd != int(fd) {
139 t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, int(fd))
140 }
141 }); err != nil {
142 t.Fatalf("server Conn Control error: %v", err)
143 }
144
145
146
147
148 dstoff, err := dst.Seek(0, io.SeekCurrent)
149 if err != nil {
150 t.Fatal(err)
151 }
152 if dstoff != int64(len(data)) {
153 t.Errorf("dstoff = %d, want %d", dstoff, len(data))
154 }
155 if n != int64(len(data)) {
156 t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
157 }
158 mustSeekStart(t, dst)
159 mustContainData(t, dst, data)
160
161
162 if lr != nil {
163 if want := limit - n; lr.N != want {
164 t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
165 }
166 }
167 }
168
169
170 func testSpliceToTTY(t *testing.T, proto string, size int64) {
171 var wg sync.WaitGroup
172
173
174
175
176 defer wg.Wait()
177
178 pty, ttyName, err := testpty.Open()
179 if err != nil {
180 t.Skipf("skipping test because pty open failed: %v", err)
181 }
182 defer pty.Close()
183
184
185
186
187 ttyFD, err := syscall.Open(ttyName, syscall.O_RDWR, 0)
188 if err != nil {
189 t.Skipf("skipping test because failed to open tty: %v", err)
190 }
191 defer syscall.Close(ttyFD)
192
193 tty := NewFile(uintptr(ttyFD), "tty")
194 defer tty.Close()
195
196 client, server := createSocketPair(t, proto)
197
198 data := bytes.Repeat([]byte{'a'}, int(size))
199
200 wg.Add(1)
201 go func() {
202 defer wg.Done()
203
204
205
206 for i := 0; i < len(data); i += 1024 {
207 if _, err := client.Write(data[i : i+1024]); err != nil {
208
209
210 if !errors.Is(err, net.ErrClosed) {
211 t.Errorf("error writing to socket: %v", err)
212 }
213 return
214 }
215 }
216 client.Close()
217 }()
218
219 wg.Add(1)
220 go func() {
221 defer wg.Done()
222 buf := make([]byte, 32)
223 for {
224 if _, err := pty.Read(buf); err != nil {
225 if err != io.EOF && !errors.Is(err, ErrClosed) {
226
227
228 t.Logf("error reading from pty: %v", err)
229 }
230 return
231 }
232 }
233 }()
234
235
236 defer client.Close()
237
238 _, err = io.Copy(tty, server)
239 if err != nil {
240 t.Fatal(err)
241 }
242 }
243
244 var (
245 copyFileTests = []copyFileTestFunc{newCopyFileRangeTest, newSendfileOverCopyFileRangeTest}
246 copyFileHooks = []copyFileTestHook{hookCopyFileRange, hookSendFileOverCopyFileRange}
247 )
248
249 func testCopyFiles(t *testing.T, size, limit int64) {
250 testCopyFileRange(t, size, limit)
251 testSendfileOverCopyFileRange(t, size, limit)
252 }
253
254 func testCopyFileRange(t *testing.T, size int64, limit int64) {
255 dst, src, data, hook, name := newCopyFileRangeTest(t, size)
256 testCopyFile(t, dst, src, data, hook, limit, name)
257 }
258
259 func testSendfileOverCopyFileRange(t *testing.T, size int64, limit int64) {
260 dst, src, data, hook, name := newSendfileOverCopyFileRangeTest(t, size)
261 testCopyFile(t, dst, src, data, hook, limit, name)
262 }
263
264
265
266
267
268 func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileHook, name string) {
269 t.Helper()
270
271 name = "newCopyFileRangeTest"
272
273 dst, src, data = newCopyFileTest(t, size)
274 hook, _ = hookCopyFileRange(t)
275
276 return
277 }
278
279
280
281
282 func newSendfileOverCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileHook, name string) {
283 t.Helper()
284
285 name = "newSendfileOverCopyFileRangeTest"
286
287 dst, src, data = newCopyFileTest(t, size)
288 hook, _ = hookSendFileOverCopyFileRange(t)
289
290 return
291 }
292
293
294
295
296
297
298 func newSpliceFileTest(t *testing.T, proto string, size int64) (*File, net.Conn, []byte, *spliceFileHook, func()) {
299 t.Helper()
300
301 hook := hookSpliceFile(t)
302
303 client, server := createSocketPair(t, proto)
304
305 dst, err := CreateTemp(t.TempDir(), "dst-splice-file-test")
306 if err != nil {
307 t.Fatal(err)
308 }
309 t.Cleanup(func() { dst.Close() })
310
311 randSeed := time.Now().Unix()
312 t.Logf("random data seed: %d\n", randSeed)
313 prng := rand.New(rand.NewSource(randSeed))
314 data := make([]byte, size)
315 prng.Read(data)
316
317 done := make(chan struct{})
318 go func() {
319 client.Write(data)
320 client.Close()
321 close(done)
322 }()
323
324 return dst, server, data, hook, func() { <-done }
325 }
326
327 func hookCopyFileRange(t *testing.T) (hook *copyFileHook, name string) {
328 name = "hookCopyFileRange"
329
330 hook = new(copyFileHook)
331 orig := *PollCopyFileRangeP
332 t.Cleanup(func() {
333 *PollCopyFileRangeP = orig
334 })
335 *PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
336 hook.called = true
337 hook.dstfd = dst.Sysfd
338 hook.srcfd = src.Sysfd
339 hook.written, hook.handled, hook.err = orig(dst, src, remain)
340 return hook.written, hook.handled, hook.err
341 }
342 return
343 }
344
345 func hookSendFileOverCopyFileRange(t *testing.T) (*copyFileHook, string) {
346 return hookSendFileTB(t), "hookSendFileOverCopyFileRange"
347 }
348
349 func hookSendFileTB(tb testing.TB) *copyFileHook {
350
351 originalCopyFileRange := *PollCopyFileRangeP
352 *PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (written int64, handled bool, err error) {
353 return 0, false, nil
354 }
355
356 hook := new(copyFileHook)
357 orig := poll.TestHookDidSendFile
358 tb.Cleanup(func() {
359 *PollCopyFileRangeP = originalCopyFileRange
360 poll.TestHookDidSendFile = orig
361 })
362 poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) {
363 hook.called = true
364 hook.dstfd = dstFD.Sysfd
365 hook.srcfd = src
366 hook.written = written
367 hook.err = err
368 hook.handled = handled
369 }
370 return hook
371 }
372
373 func hookSpliceFile(t *testing.T) *spliceFileHook {
374 h := new(spliceFileHook)
375 h.install()
376 t.Cleanup(h.uninstall)
377 return h
378 }
379
380 type spliceFileHook struct {
381 called bool
382 dstfd int
383 srcfd int
384 remain int64
385
386 written int64
387 handled bool
388 err error
389
390 original func(dst, src *poll.FD, remain int64) (int64, bool, error)
391 }
392
393 func (h *spliceFileHook) install() {
394 h.original = *PollSpliceFile
395 *PollSpliceFile = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
396 h.called = true
397 h.dstfd = dst.Sysfd
398 h.srcfd = src.Sysfd
399 h.remain = remain
400 h.written, h.handled, h.err = h.original(dst, src, remain)
401 return h.written, h.handled, h.err
402 }
403 }
404
405 func (h *spliceFileHook) uninstall() {
406 *PollSpliceFile = h.original
407 }
408
409
410 func TestProcCopy(t *testing.T) {
411 t.Parallel()
412
413 const cmdlineFile = "/proc/self/cmdline"
414 cmdline, err := ReadFile(cmdlineFile)
415 if err != nil {
416 t.Skipf("can't read /proc file: %v", err)
417 }
418 in, err := Open(cmdlineFile)
419 if err != nil {
420 t.Fatal(err)
421 }
422 defer in.Close()
423 outFile := filepath.Join(t.TempDir(), "cmdline")
424 out, err := Create(outFile)
425 if err != nil {
426 t.Fatal(err)
427 }
428 if _, err := io.Copy(out, in); err != nil {
429 t.Fatal(err)
430 }
431 if err := out.Close(); err != nil {
432 t.Fatal(err)
433 }
434 copy, err := ReadFile(outFile)
435 if err != nil {
436 t.Fatal(err)
437 }
438 if !bytes.Equal(cmdline, copy) {
439 t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
440 }
441 }
442
443 func TestGetPollFDAndNetwork(t *testing.T) {
444 t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
445 t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
446 }
447
448 func testGetPollFDAndNetwork(t *testing.T, proto string) {
449 _, server := createSocketPair(t, proto)
450 sc, ok := server.(syscall.Conn)
451 if !ok {
452 t.Fatalf("server Conn is not a syscall.Conn")
453 }
454 rc, err := sc.SyscallConn()
455 if err != nil {
456 t.Fatalf("server SyscallConn error: %v", err)
457 }
458 if err = rc.Control(func(fd uintptr) {
459 pfd, network := GetPollFDAndNetwork(server)
460 if pfd == nil {
461 t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
462 }
463 if string(network) != proto {
464 t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
465 }
466 if pfd.Sysfd != int(fd) {
467 t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
468 }
469 if !pfd.IsStream {
470 t.Fatalf("expected IsStream to be true")
471 }
472 if err = pfd.Init(proto, true); err == nil {
473 t.Fatalf("Init should have failed with the initialized poll.FD and return EEXIST error")
474 }
475 }); err != nil {
476 t.Fatalf("server Control error: %v", err)
477 }
478 }
479
View as plain text