Source file
src/os/copy_test.go
1
2
3
4
5 package os_test
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "math/rand/v2"
13 "net"
14 "os"
15 "runtime"
16 "sync"
17 "testing"
18
19 "golang.org/x/net/nettest"
20 )
21
22
23
24
25
26 func TestLargeCopyViaNetwork(t *testing.T) {
27 const size = 10 * 1024 * 1024
28 dir := t.TempDir()
29
30 src, err := os.Create(dir + "/src")
31 if err != nil {
32 t.Fatal(err)
33 }
34 defer src.Close()
35 if _, err := io.CopyN(src, newRandReader(), size); err != nil {
36 t.Fatal(err)
37 }
38 if _, err := src.Seek(0, 0); err != nil {
39 t.Fatal(err)
40 }
41
42 dst, err := os.Create(dir + "/dst")
43 if err != nil {
44 t.Fatal(err)
45 }
46 defer dst.Close()
47
48 client, server := createSocketPair(t, "tcp")
49 var wg sync.WaitGroup
50 wg.Add(2)
51 go func() {
52 defer wg.Done()
53 if n, err := io.Copy(dst, server); n != size || err != nil {
54 t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size)
55 }
56 }()
57 go func() {
58 defer wg.Done()
59 defer client.Close()
60 if n, err := io.Copy(client, src); n != size || err != nil {
61 t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size)
62 }
63 }()
64 wg.Wait()
65
66 if _, err := dst.Seek(0, 0); err != nil {
67 t.Fatal(err)
68 }
69 if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil {
70 t.Fatal(err)
71 }
72 }
73
74 func TestCopyFileToFile(t *testing.T) {
75 const size = 1 * 1024 * 1024
76 dir := t.TempDir()
77
78 src, err := os.Create(dir + "/src")
79 if err != nil {
80 t.Fatal(err)
81 }
82 defer src.Close()
83 if _, err := io.CopyN(src, newRandReader(), size); err != nil {
84 t.Fatal(err)
85 }
86 if _, err := src.Seek(0, 0); err != nil {
87 t.Fatal(err)
88 }
89
90 mustSeek := func(f *os.File, offset int64, whence int) int64 {
91 ret, err := f.Seek(offset, whence)
92 if err != nil {
93 t.Fatal(err)
94 }
95 return ret
96 }
97
98 for _, srcStart := range []int64{0, 100, size} {
99 remaining := size - srcStart
100 for _, dstStart := range []int64{0, 200} {
101 for _, limit := range []int64{remaining, remaining - 100, size * 2, 0} {
102 if limit < 0 {
103 continue
104 }
105 name := fmt.Sprintf("srcStart=%v/dstStart=%v/limit=%v", srcStart, dstStart, limit)
106 t.Run(name, func(t *testing.T) {
107 dst, err := os.CreateTemp(dir, "dst")
108 if err != nil {
109 t.Fatal(err)
110 }
111 defer dst.Close()
112 defer os.Remove(dst.Name())
113
114 mustSeek(src, srcStart, io.SeekStart)
115 if _, err := io.CopyN(dst, zeroReader{}, dstStart); err != nil {
116 t.Fatal(err)
117 }
118
119 var copied int64
120 if limit == 0 {
121 copied, err = io.Copy(dst, src)
122 } else {
123 copied, err = io.CopyN(dst, src, limit)
124 }
125 if limit > remaining {
126 if err != io.EOF {
127 t.Errorf("Copy: %v; want io.EOF", err)
128 }
129 } else {
130 if err != nil {
131 t.Errorf("Copy: %v; want nil", err)
132 }
133 }
134
135 wantCopied := remaining
136 if limit != 0 {
137 wantCopied = min(limit, wantCopied)
138 }
139 if copied != wantCopied {
140 t.Errorf("copied %v bytes, want %v", copied, wantCopied)
141 }
142
143 srcPos := mustSeek(src, 0, io.SeekCurrent)
144 wantSrcPos := srcStart + wantCopied
145 if srcPos != wantSrcPos {
146 t.Errorf("source position = %v, want %v", srcPos, wantSrcPos)
147 }
148
149 dstPos := mustSeek(dst, 0, io.SeekCurrent)
150 wantDstPos := dstStart + wantCopied
151 if dstPos != wantDstPos {
152 t.Errorf("destination position = %v, want %v", dstPos, wantDstPos)
153 }
154
155 mustSeek(dst, 0, io.SeekStart)
156 rr := newRandReader()
157 io.CopyN(io.Discard, rr, srcStart)
158 wantReader := io.MultiReader(
159 io.LimitReader(zeroReader{}, dstStart),
160 io.LimitReader(rr, wantCopied),
161 )
162 if err := compareReaders(dst, wantReader); err != nil {
163 t.Fatal(err)
164 }
165 })
166
167 }
168 }
169 }
170 }
171
172 func compareReaders(a, b io.Reader) error {
173 bufa := make([]byte, 4096)
174 bufb := make([]byte, 4096)
175 off := 0
176 for {
177 na, erra := io.ReadFull(a, bufa)
178 if erra != nil && erra != io.EOF && erra != io.ErrUnexpectedEOF {
179 return erra
180 }
181 nb, errb := io.ReadFull(b, bufb)
182 if errb != nil && errb != io.EOF && errb != io.ErrUnexpectedEOF {
183 return errb
184 }
185 if !bytes.Equal(bufa[:na], bufb[:nb]) {
186 return errors.New("contents mismatch")
187 }
188 if erra != nil && errb != nil {
189 break
190 }
191 off += len(bufa)
192 }
193 return nil
194 }
195
196 type zeroReader struct{}
197
198 func (r zeroReader) Read(p []byte) (int, error) {
199 clear(p)
200 return len(p), nil
201 }
202
203 type randReader struct {
204 rand *rand.Rand
205 }
206
207 func newRandReader() *randReader {
208 return &randReader{rand.New(rand.NewPCG(0, 0))}
209 }
210
211 func (r *randReader) Read(p []byte) (int, error) {
212 for i := range p {
213 p[i] = byte(r.rand.Uint32() & 0xff)
214 }
215 return len(p), nil
216 }
217
218 func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
219 t.Helper()
220 if !nettest.TestableNetwork(proto) {
221 t.Skipf("%s does not support %q", runtime.GOOS, proto)
222 }
223
224 ln, err := nettest.NewLocalListener(proto)
225 if err != nil {
226 t.Fatalf("NewLocalListener error: %v", err)
227 }
228 t.Cleanup(func() {
229 if ln != nil {
230 ln.Close()
231 }
232 if client != nil {
233 client.Close()
234 }
235 if server != nil {
236 server.Close()
237 }
238 })
239 ch := make(chan struct{})
240 go func() {
241 var err error
242 server, err = ln.Accept()
243 if err != nil {
244 t.Errorf("Accept new connection error: %v", err)
245 }
246 ch <- struct{}{}
247 }()
248 client, err = net.Dial(proto, ln.Addr().String())
249 <-ch
250 if err != nil {
251 t.Fatalf("Dial new connection error: %v", err)
252 }
253 return client, server
254 }
255
View as plain text