Source file
src/os/copy_test.go
1
2
3
4
5 package os_test
6
7 import (
8 "bytes"
9 "errors"
10 "io"
11 "math/rand/v2"
12 "net"
13 "os"
14 "runtime"
15 "sync"
16 "testing"
17
18 "golang.org/x/net/nettest"
19 )
20
21
22
23
24
25 func TestLargeCopyViaNetwork(t *testing.T) {
26 const size = 10 * 1024 * 1024
27 dir := t.TempDir()
28
29 src, err := os.Create(dir + "/src")
30 if err != nil {
31 t.Fatal(err)
32 }
33 defer src.Close()
34 if _, err := io.CopyN(src, newRandReader(), size); err != nil {
35 t.Fatal(err)
36 }
37 if _, err := src.Seek(0, 0); err != nil {
38 t.Fatal(err)
39 }
40
41 dst, err := os.Create(dir + "/dst")
42 if err != nil {
43 t.Fatal(err)
44 }
45 defer dst.Close()
46
47 client, server := createSocketPair(t, "tcp")
48 var wg sync.WaitGroup
49 wg.Add(2)
50 go func() {
51 defer wg.Done()
52 if n, err := io.Copy(dst, server); n != size || err != nil {
53 t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size)
54 }
55 }()
56 go func() {
57 defer wg.Done()
58 defer client.Close()
59 if n, err := io.Copy(client, src); n != size || err != nil {
60 t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size)
61 }
62 }()
63 wg.Wait()
64
65 if _, err := dst.Seek(0, 0); err != nil {
66 t.Fatal(err)
67 }
68 if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil {
69 t.Fatal(err)
70 }
71 }
72
73 func compareReaders(a, b io.Reader) error {
74 bufa := make([]byte, 4096)
75 bufb := make([]byte, 4096)
76 for {
77 na, erra := io.ReadFull(a, bufa)
78 if erra != nil && erra != io.EOF {
79 return erra
80 }
81 nb, errb := io.ReadFull(b, bufb)
82 if errb != nil && errb != io.EOF {
83 return errb
84 }
85 if !bytes.Equal(bufa[:na], bufb[:nb]) {
86 return errors.New("contents mismatch")
87 }
88 if erra == io.EOF && errb == io.EOF {
89 break
90 }
91 }
92 return nil
93 }
94
95 type randReader struct {
96 rand *rand.Rand
97 }
98
99 func newRandReader() *randReader {
100 return &randReader{rand.New(rand.NewPCG(0, 0))}
101 }
102
103 func (r *randReader) Read(p []byte) (int, error) {
104 var v uint64
105 var n int
106 for i := range p {
107 if n == 0 {
108 v = r.rand.Uint64()
109 n = 8
110 }
111 p[i] = byte(v & 0xff)
112 v >>= 8
113 n--
114 }
115 return len(p), nil
116 }
117
118 func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
119 t.Helper()
120 if !nettest.TestableNetwork(proto) {
121 t.Skipf("%s does not support %q", runtime.GOOS, proto)
122 }
123
124 ln, err := nettest.NewLocalListener(proto)
125 if err != nil {
126 t.Fatalf("NewLocalListener error: %v", err)
127 }
128 t.Cleanup(func() {
129 if ln != nil {
130 ln.Close()
131 }
132 if client != nil {
133 client.Close()
134 }
135 if server != nil {
136 server.Close()
137 }
138 })
139 ch := make(chan struct{})
140 go func() {
141 var err error
142 server, err = ln.Accept()
143 if err != nil {
144 t.Errorf("Accept new connection error: %v", err)
145 }
146 ch <- struct{}{}
147 }()
148 client, err = net.Dial(proto, ln.Addr().String())
149 <-ch
150 if err != nil {
151 t.Fatalf("Dial new connection error: %v", err)
152 }
153 return client, server
154 }
155
View as plain text