Source file
src/os/readfrom_linux_test.go
1
2
3
4
5 package os_test
6
7 import (
8 "bytes"
9 "errors"
10 "internal/poll"
11 "internal/testpty"
12 "io"
13 "math/rand"
14 "net"
15 . "os"
16 "path/filepath"
17 "strconv"
18 "strings"
19 "sync"
20 "syscall"
21 "testing"
22 "time"
23 )
24
25 func TestCopyFileRange(t *testing.T) {
26 sizes := []int{
27 1,
28 42,
29 1025,
30 syscall.Getpagesize() + 1,
31 32769,
32 }
33 t.Run("Basic", func(t *testing.T) {
34 for _, size := range sizes {
35 t.Run(strconv.Itoa(size), func(t *testing.T) {
36 testCopyFileRange(t, int64(size), -1)
37 })
38 }
39 })
40 t.Run("Limited", func(t *testing.T) {
41 t.Run("OneLess", func(t *testing.T) {
42 for _, size := range sizes {
43 t.Run(strconv.Itoa(size), func(t *testing.T) {
44 testCopyFileRange(t, int64(size), int64(size)-1)
45 })
46 }
47 })
48 t.Run("Half", func(t *testing.T) {
49 for _, size := range sizes {
50 t.Run(strconv.Itoa(size), func(t *testing.T) {
51 testCopyFileRange(t, int64(size), int64(size)/2)
52 })
53 }
54 })
55 t.Run("More", func(t *testing.T) {
56 for _, size := range sizes {
57 t.Run(strconv.Itoa(size), func(t *testing.T) {
58 testCopyFileRange(t, int64(size), int64(size)+7)
59 })
60 }
61 })
62 })
63 t.Run("DoesntTryInAppendMode", func(t *testing.T) {
64 dst, src, data, hook := newCopyFileRangeTest(t, 42)
65
66 dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755)
67 if err != nil {
68 t.Fatal(err)
69 }
70 defer dst2.Close()
71
72 if _, err := io.Copy(dst2, src); err != nil {
73 t.Fatal(err)
74 }
75 if hook.called {
76 t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode")
77 }
78 mustSeekStart(t, dst2)
79 mustContainData(t, dst2, data)
80 })
81 t.Run("CopyFileItself", func(t *testing.T) {
82 hook := hookCopyFileRange(t)
83
84 f, err := CreateTemp("", "file-readfrom-itself-test")
85 if err != nil {
86 t.Fatalf("failed to create tmp file: %v", err)
87 }
88 t.Cleanup(func() {
89 f.Close()
90 Remove(f.Name())
91 })
92
93 data := []byte("hello world!")
94 if _, err := f.Write(data); err != nil {
95 t.Fatalf("failed to create and feed the file: %v", err)
96 }
97
98 if err := f.Sync(); err != nil {
99 t.Fatalf("failed to save the file: %v", err)
100 }
101
102
103 if _, err := f.Seek(0, io.SeekStart); err != nil {
104 t.Fatalf("failed to rewind the file: %v", err)
105 }
106
107
108 if _, err := io.Copy(f, f); err != nil {
109 t.Fatalf("failed to read from the file: %v", err)
110 }
111
112 if !hook.called || hook.written != 0 || hook.handled || hook.err != nil {
113 t.Fatalf("poll.CopyFileRange should be called and return the EINVAL error, but got hook.called=%t, hook.err=%v", hook.called, hook.err)
114 }
115
116
117 if _, err := f.Seek(0, io.SeekStart); err != nil {
118 t.Fatalf("failed to rewind the file: %v", err)
119 }
120
121 data2, err := io.ReadAll(f)
122 if err != nil {
123 t.Fatalf("failed to read from the file: %v", err)
124 }
125
126
127 if strings.Repeat(string(data), 2) != string(data2) {
128 t.Fatalf("data mismatch: %s != %s", string(data), string(data2))
129 }
130 })
131 t.Run("NotRegular", func(t *testing.T) {
132 t.Run("BothPipes", func(t *testing.T) {
133 hook := hookCopyFileRange(t)
134
135 pr1, pw1, err := Pipe()
136 if err != nil {
137 t.Fatal(err)
138 }
139 defer pr1.Close()
140 defer pw1.Close()
141
142 pr2, pw2, err := Pipe()
143 if err != nil {
144 t.Fatal(err)
145 }
146 defer pr2.Close()
147 defer pw2.Close()
148
149
150
151
152 data := []byte("hello")
153 if _, err := pw1.Write(data); err != nil {
154 t.Fatal(err)
155 }
156 pw1.Close()
157
158 n, err := io.Copy(pw2, pr1)
159 if err != nil {
160 t.Fatal(err)
161 }
162 if n != int64(len(data)) {
163 t.Fatalf("transferred %d, want %d", n, len(data))
164 }
165 if !hook.called {
166 t.Fatalf("should have called poll.CopyFileRange")
167 }
168 pw2.Close()
169 mustContainData(t, pr2, data)
170 })
171 t.Run("DstPipe", func(t *testing.T) {
172 dst, src, data, hook := newCopyFileRangeTest(t, 255)
173 dst.Close()
174
175 pr, pw, err := Pipe()
176 if err != nil {
177 t.Fatal(err)
178 }
179 defer pr.Close()
180 defer pw.Close()
181
182 n, err := io.Copy(pw, src)
183 if err != nil {
184 t.Fatal(err)
185 }
186 if n != int64(len(data)) {
187 t.Fatalf("transferred %d, want %d", n, len(data))
188 }
189 if !hook.called {
190 t.Fatalf("should have called poll.CopyFileRange")
191 }
192 pw.Close()
193 mustContainData(t, pr, data)
194 })
195 t.Run("SrcPipe", func(t *testing.T) {
196 dst, src, data, hook := newCopyFileRangeTest(t, 255)
197 src.Close()
198
199 pr, pw, err := Pipe()
200 if err != nil {
201 t.Fatal(err)
202 }
203 defer pr.Close()
204 defer pw.Close()
205
206
207
208
209 if _, err := pw.Write(data); err != nil {
210 t.Fatal(err)
211 }
212 pw.Close()
213
214 n, err := io.Copy(dst, pr)
215 if err != nil {
216 t.Fatal(err)
217 }
218 if n != int64(len(data)) {
219 t.Fatalf("transferred %d, want %d", n, len(data))
220 }
221 if !hook.called {
222 t.Fatalf("should have called poll.CopyFileRange")
223 }
224 mustSeekStart(t, dst)
225 mustContainData(t, dst, data)
226 })
227 })
228 t.Run("Nil", func(t *testing.T) {
229 var nilFile *File
230 anyFile, err := CreateTemp("", "")
231 if err != nil {
232 t.Fatal(err)
233 }
234 defer Remove(anyFile.Name())
235 defer anyFile.Close()
236
237 if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid {
238 t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid)
239 }
240 if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid {
241 t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid)
242 }
243 if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid {
244 t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid)
245 }
246
247 if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid {
248 t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
249 }
250 if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid {
251 t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
252 }
253 if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid {
254 t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid)
255 }
256 })
257 }
258
259 func TestSpliceFile(t *testing.T) {
260 sizes := []int{
261 1,
262 42,
263 1025,
264 syscall.Getpagesize() + 1,
265 32769,
266 }
267 t.Run("Basic-TCP", func(t *testing.T) {
268 for _, size := range sizes {
269 t.Run(strconv.Itoa(size), func(t *testing.T) {
270 testSpliceFile(t, "tcp", int64(size), -1)
271 })
272 }
273 })
274 t.Run("Basic-Unix", func(t *testing.T) {
275 for _, size := range sizes {
276 t.Run(strconv.Itoa(size), func(t *testing.T) {
277 testSpliceFile(t, "unix", int64(size), -1)
278 })
279 }
280 })
281 t.Run("TCP-To-TTY", func(t *testing.T) {
282 testSpliceToTTY(t, "tcp", 32768)
283 })
284 t.Run("Unix-To-TTY", func(t *testing.T) {
285 testSpliceToTTY(t, "unix", 32768)
286 })
287 t.Run("Limited", func(t *testing.T) {
288 t.Run("OneLess-TCP", func(t *testing.T) {
289 for _, size := range sizes {
290 t.Run(strconv.Itoa(size), func(t *testing.T) {
291 testSpliceFile(t, "tcp", int64(size), int64(size)-1)
292 })
293 }
294 })
295 t.Run("OneLess-Unix", func(t *testing.T) {
296 for _, size := range sizes {
297 t.Run(strconv.Itoa(size), func(t *testing.T) {
298 testSpliceFile(t, "unix", int64(size), int64(size)-1)
299 })
300 }
301 })
302 t.Run("Half-TCP", func(t *testing.T) {
303 for _, size := range sizes {
304 t.Run(strconv.Itoa(size), func(t *testing.T) {
305 testSpliceFile(t, "tcp", int64(size), int64(size)/2)
306 })
307 }
308 })
309 t.Run("Half-Unix", func(t *testing.T) {
310 for _, size := range sizes {
311 t.Run(strconv.Itoa(size), func(t *testing.T) {
312 testSpliceFile(t, "unix", int64(size), int64(size)/2)
313 })
314 }
315 })
316 t.Run("More-TCP", func(t *testing.T) {
317 for _, size := range sizes {
318 t.Run(strconv.Itoa(size), func(t *testing.T) {
319 testSpliceFile(t, "tcp", int64(size), int64(size)+1)
320 })
321 }
322 })
323 t.Run("More-Unix", func(t *testing.T) {
324 for _, size := range sizes {
325 t.Run(strconv.Itoa(size), func(t *testing.T) {
326 testSpliceFile(t, "unix", int64(size), int64(size)+1)
327 })
328 }
329 })
330 })
331 }
332
333 func testSpliceFile(t *testing.T, proto string, size, limit int64) {
334 dst, src, data, hook, cleanup := newSpliceFileTest(t, proto, size)
335 defer cleanup()
336
337
338 var (
339 r io.Reader
340 lr *io.LimitedReader
341 )
342 if limit >= 0 {
343 lr = &io.LimitedReader{N: limit, R: src}
344 r = lr
345 if limit < int64(len(data)) {
346 data = data[:limit]
347 }
348 } else {
349 r = src
350 }
351
352 n, err := io.Copy(dst, r)
353 if err != nil {
354 t.Fatal(err)
355 }
356
357
358 if n > 0 && !hook.called {
359 t.Fatal("expected to called poll.Splice")
360 }
361 if hook.called && hook.dstfd != int(dst.Fd()) {
362 t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
363 }
364 sc, ok := src.(syscall.Conn)
365 if !ok {
366 t.Fatalf("server Conn is not a syscall.Conn")
367 }
368 rc, err := sc.SyscallConn()
369 if err != nil {
370 t.Fatalf("server Conn SyscallConn error: %v", err)
371 }
372 if err = rc.Control(func(fd uintptr) {
373 if hook.called && hook.srcfd != int(fd) {
374 t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, int(fd))
375 }
376 }); err != nil {
377 t.Fatalf("server Conn Control error: %v", err)
378 }
379
380
381
382
383 dstoff, err := dst.Seek(0, io.SeekCurrent)
384 if err != nil {
385 t.Fatal(err)
386 }
387 if dstoff != int64(len(data)) {
388 t.Errorf("dstoff = %d, want %d", dstoff, len(data))
389 }
390 if n != int64(len(data)) {
391 t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
392 }
393 mustSeekStart(t, dst)
394 mustContainData(t, dst, data)
395
396
397 if lr != nil {
398 if want := limit - n; lr.N != want {
399 t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
400 }
401 }
402 }
403
404
405 func testSpliceToTTY(t *testing.T, proto string, size int64) {
406 var wg sync.WaitGroup
407
408
409
410
411 defer wg.Wait()
412
413 pty, ttyName, err := testpty.Open()
414 if err != nil {
415 t.Skipf("skipping test because pty open failed: %v", err)
416 }
417 defer pty.Close()
418
419
420
421
422 ttyFD, err := syscall.Open(ttyName, syscall.O_RDWR, 0)
423 if err != nil {
424 t.Skipf("skipping test because failed to open tty: %v", err)
425 }
426 defer syscall.Close(ttyFD)
427
428 tty := NewFile(uintptr(ttyFD), "tty")
429 defer tty.Close()
430
431 client, server := createSocketPair(t, proto)
432
433 data := bytes.Repeat([]byte{'a'}, int(size))
434
435 wg.Add(1)
436 go func() {
437 defer wg.Done()
438
439
440
441 for i := 0; i < len(data); i += 1024 {
442 if _, err := client.Write(data[i : i+1024]); err != nil {
443
444
445 if !errors.Is(err, net.ErrClosed) {
446 t.Errorf("error writing to socket: %v", err)
447 }
448 return
449 }
450 }
451 client.Close()
452 }()
453
454 wg.Add(1)
455 go func() {
456 defer wg.Done()
457 buf := make([]byte, 32)
458 for {
459 if _, err := pty.Read(buf); err != nil {
460 if err != io.EOF && !errors.Is(err, ErrClosed) {
461
462
463 t.Logf("error reading from pty: %v", err)
464 }
465 return
466 }
467 }
468 }()
469
470
471 defer client.Close()
472
473 _, err = io.Copy(tty, server)
474 if err != nil {
475 t.Fatal(err)
476 }
477 }
478
479 func testCopyFileRange(t *testing.T, size int64, limit int64) {
480 dst, src, data, hook := newCopyFileRangeTest(t, size)
481
482
483 var (
484 realsrc io.Reader
485 lr *io.LimitedReader
486 )
487 if limit >= 0 {
488 lr = &io.LimitedReader{N: limit, R: src}
489 realsrc = lr
490 if limit < int64(len(data)) {
491 data = data[:limit]
492 }
493 } else {
494 realsrc = src
495 }
496
497
498
499 n, err := io.Copy(dst, realsrc)
500 if err != nil {
501 t.Fatal(err)
502 }
503
504
505
506 if limit > 0 && !hook.called {
507 t.Fatal("never called poll.CopyFileRange")
508 }
509 if hook.called && hook.dstfd != int(dst.Fd()) {
510 t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
511 }
512 if hook.called && hook.srcfd != int(src.Fd()) {
513 t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
514 }
515
516
517
518
519 dstoff, err := dst.Seek(0, io.SeekCurrent)
520 if err != nil {
521 t.Fatal(err)
522 }
523 srcoff, err := src.Seek(0, io.SeekCurrent)
524 if err != nil {
525 t.Fatal(err)
526 }
527 if dstoff != srcoff {
528 t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff)
529 }
530 if dstoff != int64(len(data)) {
531 t.Errorf("dstoff = %d, want %d", dstoff, len(data))
532 }
533 if n != int64(len(data)) {
534 t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
535 }
536 mustSeekStart(t, dst)
537 mustContainData(t, dst, data)
538
539
540 if lr != nil {
541 if want := limit - n; lr.N != want {
542 t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
543 }
544 }
545 }
546
547
548
549
550
551
552 func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) {
553 t.Helper()
554
555 hook = hookCopyFileRange(t)
556 tmp := t.TempDir()
557
558 src, err := Create(filepath.Join(tmp, "src"))
559 if err != nil {
560 t.Fatal(err)
561 }
562 t.Cleanup(func() { src.Close() })
563
564 dst, err = Create(filepath.Join(tmp, "dst"))
565 if err != nil {
566 t.Fatal(err)
567 }
568 t.Cleanup(func() { dst.Close() })
569
570
571
572 prng := rand.New(rand.NewSource(time.Now().Unix()))
573 data = make([]byte, size)
574 prng.Read(data)
575 if _, err := src.Write(data); err != nil {
576 t.Fatal(err)
577 }
578 if _, err := src.Seek(0, io.SeekStart); err != nil {
579 t.Fatal(err)
580 }
581
582 return dst, src, data, hook
583 }
584
585
586
587
588
589
590 func newSpliceFileTest(t *testing.T, proto string, size int64) (*File, net.Conn, []byte, *spliceFileHook, func()) {
591 t.Helper()
592
593 hook := hookSpliceFile(t)
594
595 client, server := createSocketPair(t, proto)
596
597 dst, err := CreateTemp(t.TempDir(), "dst-splice-file-test")
598 if err != nil {
599 t.Fatal(err)
600 }
601 t.Cleanup(func() { dst.Close() })
602
603 randSeed := time.Now().Unix()
604 t.Logf("random data seed: %d\n", randSeed)
605 prng := rand.New(rand.NewSource(randSeed))
606 data := make([]byte, size)
607 prng.Read(data)
608
609 done := make(chan struct{})
610 go func() {
611 client.Write(data)
612 client.Close()
613 close(done)
614 }()
615
616 return dst, server, data, hook, func() { <-done }
617 }
618
619
620
621 func mustContainData(t *testing.T, f *File, data []byte) {
622 t.Helper()
623
624 got := make([]byte, len(data))
625 if _, err := io.ReadFull(f, got); err != nil {
626 t.Fatal(err)
627 }
628 if !bytes.Equal(got, data) {
629 t.Fatalf("didn't get the same data back from %s", f.Name())
630 }
631 if _, err := f.Read(make([]byte, 1)); err != io.EOF {
632 t.Fatalf("not at EOF")
633 }
634 }
635
636 func mustSeekStart(t *testing.T, f *File) {
637 if _, err := f.Seek(0, io.SeekStart); err != nil {
638 t.Fatal(err)
639 }
640 }
641
642 func hookCopyFileRange(t *testing.T) *copyFileRangeHook {
643 h := new(copyFileRangeHook)
644 h.install()
645 t.Cleanup(h.uninstall)
646 return h
647 }
648
649 type copyFileRangeHook struct {
650 called bool
651 dstfd int
652 srcfd int
653 remain int64
654
655 written int64
656 handled bool
657 err error
658
659 original func(dst, src *poll.FD, remain int64) (int64, bool, error)
660 }
661
662 func (h *copyFileRangeHook) install() {
663 h.original = *PollCopyFileRangeP
664 *PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
665 h.called = true
666 h.dstfd = dst.Sysfd
667 h.srcfd = src.Sysfd
668 h.remain = remain
669 h.written, h.handled, h.err = h.original(dst, src, remain)
670 return h.written, h.handled, h.err
671 }
672 }
673
674 func (h *copyFileRangeHook) uninstall() {
675 *PollCopyFileRangeP = h.original
676 }
677
678 func hookSpliceFile(t *testing.T) *spliceFileHook {
679 h := new(spliceFileHook)
680 h.install()
681 t.Cleanup(h.uninstall)
682 return h
683 }
684
685 type spliceFileHook struct {
686 called bool
687 dstfd int
688 srcfd int
689 remain int64
690
691 written int64
692 handled bool
693 err error
694
695 original func(dst, src *poll.FD, remain int64) (int64, bool, error)
696 }
697
698 func (h *spliceFileHook) install() {
699 h.original = *PollSpliceFile
700 *PollSpliceFile = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
701 h.called = true
702 h.dstfd = dst.Sysfd
703 h.srcfd = src.Sysfd
704 h.remain = remain
705 h.written, h.handled, h.err = h.original(dst, src, remain)
706 return h.written, h.handled, h.err
707 }
708 }
709
710 func (h *spliceFileHook) uninstall() {
711 *PollSpliceFile = h.original
712 }
713
714
715 func TestProcCopy(t *testing.T) {
716 t.Parallel()
717
718 const cmdlineFile = "/proc/self/cmdline"
719 cmdline, err := ReadFile(cmdlineFile)
720 if err != nil {
721 t.Skipf("can't read /proc file: %v", err)
722 }
723 in, err := Open(cmdlineFile)
724 if err != nil {
725 t.Fatal(err)
726 }
727 defer in.Close()
728 outFile := filepath.Join(t.TempDir(), "cmdline")
729 out, err := Create(outFile)
730 if err != nil {
731 t.Fatal(err)
732 }
733 if _, err := io.Copy(out, in); err != nil {
734 t.Fatal(err)
735 }
736 if err := out.Close(); err != nil {
737 t.Fatal(err)
738 }
739 copy, err := ReadFile(outFile)
740 if err != nil {
741 t.Fatal(err)
742 }
743 if !bytes.Equal(cmdline, copy) {
744 t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
745 }
746 }
747
748 func TestGetPollFDAndNetwork(t *testing.T) {
749 t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
750 t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
751 }
752
753 func testGetPollFDAndNetwork(t *testing.T, proto string) {
754 _, server := createSocketPair(t, proto)
755 sc, ok := server.(syscall.Conn)
756 if !ok {
757 t.Fatalf("server Conn is not a syscall.Conn")
758 }
759 rc, err := sc.SyscallConn()
760 if err != nil {
761 t.Fatalf("server SyscallConn error: %v", err)
762 }
763 if err = rc.Control(func(fd uintptr) {
764 pfd, network := GetPollFDAndNetwork(server)
765 if pfd == nil {
766 t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
767 }
768 if string(network) != proto {
769 t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
770 }
771 if pfd.Sysfd != int(fd) {
772 t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
773 }
774 if !pfd.IsStream {
775 t.Fatalf("expected IsStream to be true")
776 }
777 if err = pfd.Init(proto, true); err == nil {
778 t.Fatalf("Init should have failed with the initialized poll.FD and return EEXIST error")
779 }
780 }); err != nil {
781 t.Fatalf("server Control error: %v", err)
782 }
783 }
784
View as plain text