Source file src/os/copy_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package os_test
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"math/rand/v2"
    13  	"net"
    14  	"os"
    15  	"runtime"
    16  	"sync"
    17  	"testing"
    18  
    19  	"golang.org/x/net/nettest"
    20  )
    21  
    22  // Exercise sendfile/splice fast paths with a moderately large file.
    23  //
    24  // https://go.dev/issue/70000
    25  
    26  func TestLargeCopyViaNetwork(t *testing.T) {
    27  	const size = 10 * 1024 * 1024
    28  	dir := t.TempDir()
    29  
    30  	src, err := os.Create(dir + "/src")
    31  	if err != nil {
    32  		t.Fatal(err)
    33  	}
    34  	defer src.Close()
    35  	if _, err := io.CopyN(src, newRandReader(), size); err != nil {
    36  		t.Fatal(err)
    37  	}
    38  	if _, err := src.Seek(0, 0); err != nil {
    39  		t.Fatal(err)
    40  	}
    41  
    42  	dst, err := os.Create(dir + "/dst")
    43  	if err != nil {
    44  		t.Fatal(err)
    45  	}
    46  	defer dst.Close()
    47  
    48  	client, server := createSocketPair(t, "tcp")
    49  	var wg sync.WaitGroup
    50  	wg.Add(2)
    51  	go func() {
    52  		defer wg.Done()
    53  		if n, err := io.Copy(dst, server); n != size || err != nil {
    54  			t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size)
    55  		}
    56  	}()
    57  	go func() {
    58  		defer wg.Done()
    59  		defer client.Close()
    60  		if n, err := io.Copy(client, src); n != size || err != nil {
    61  			t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size)
    62  		}
    63  	}()
    64  	wg.Wait()
    65  
    66  	if _, err := dst.Seek(0, 0); err != nil {
    67  		t.Fatal(err)
    68  	}
    69  	if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil {
    70  		t.Fatal(err)
    71  	}
    72  }
    73  
    74  func TestCopyFileToFile(t *testing.T) {
    75  	const size = 1 * 1024 * 1024
    76  	dir := t.TempDir()
    77  
    78  	src, err := os.Create(dir + "/src")
    79  	if err != nil {
    80  		t.Fatal(err)
    81  	}
    82  	defer src.Close()
    83  	if _, err := io.CopyN(src, newRandReader(), size); err != nil {
    84  		t.Fatal(err)
    85  	}
    86  	if _, err := src.Seek(0, 0); err != nil {
    87  		t.Fatal(err)
    88  	}
    89  
    90  	mustSeek := func(f *os.File, offset int64, whence int) int64 {
    91  		ret, err := f.Seek(offset, whence)
    92  		if err != nil {
    93  			t.Fatal(err)
    94  		}
    95  		return ret
    96  	}
    97  
    98  	for _, srcStart := range []int64{0, 100, size} {
    99  		remaining := size - srcStart
   100  		for _, dstStart := range []int64{0, 200} {
   101  			for _, limit := range []int64{remaining, remaining - 100, size * 2, 0} {
   102  				if limit < 0 {
   103  					continue
   104  				}
   105  				name := fmt.Sprintf("srcStart=%v/dstStart=%v/limit=%v", srcStart, dstStart, limit)
   106  				t.Run(name, func(t *testing.T) {
   107  					dst, err := os.CreateTemp(dir, "dst")
   108  					if err != nil {
   109  						t.Fatal(err)
   110  					}
   111  					defer dst.Close()
   112  					defer os.Remove(dst.Name())
   113  
   114  					mustSeek(src, srcStart, io.SeekStart)
   115  					if _, err := io.CopyN(dst, zeroReader{}, dstStart); err != nil {
   116  						t.Fatal(err)
   117  					}
   118  
   119  					var copied int64
   120  					if limit == 0 {
   121  						copied, err = io.Copy(dst, src)
   122  					} else {
   123  						copied, err = io.CopyN(dst, src, limit)
   124  					}
   125  					if limit > remaining {
   126  						if err != io.EOF {
   127  							t.Errorf("Copy: %v; want io.EOF", err)
   128  						}
   129  					} else {
   130  						if err != nil {
   131  							t.Errorf("Copy: %v; want nil", err)
   132  						}
   133  					}
   134  
   135  					wantCopied := remaining
   136  					if limit != 0 {
   137  						wantCopied = min(limit, wantCopied)
   138  					}
   139  					if copied != wantCopied {
   140  						t.Errorf("copied %v bytes, want %v", copied, wantCopied)
   141  					}
   142  
   143  					srcPos := mustSeek(src, 0, io.SeekCurrent)
   144  					wantSrcPos := srcStart + wantCopied
   145  					if srcPos != wantSrcPos {
   146  						t.Errorf("source position = %v, want %v", srcPos, wantSrcPos)
   147  					}
   148  
   149  					dstPos := mustSeek(dst, 0, io.SeekCurrent)
   150  					wantDstPos := dstStart + wantCopied
   151  					if dstPos != wantDstPos {
   152  						t.Errorf("destination position = %v, want %v", dstPos, wantDstPos)
   153  					}
   154  
   155  					mustSeek(dst, 0, io.SeekStart)
   156  					rr := newRandReader()
   157  					io.CopyN(io.Discard, rr, srcStart)
   158  					wantReader := io.MultiReader(
   159  						io.LimitReader(zeroReader{}, dstStart),
   160  						io.LimitReader(rr, wantCopied),
   161  					)
   162  					if err := compareReaders(dst, wantReader); err != nil {
   163  						t.Fatal(err)
   164  					}
   165  				})
   166  
   167  			}
   168  		}
   169  	}
   170  }
   171  
   172  func compareReaders(a, b io.Reader) error {
   173  	bufa := make([]byte, 4096)
   174  	bufb := make([]byte, 4096)
   175  	off := 0
   176  	for {
   177  		na, erra := io.ReadFull(a, bufa)
   178  		if erra != nil && erra != io.EOF && erra != io.ErrUnexpectedEOF {
   179  			return erra
   180  		}
   181  		nb, errb := io.ReadFull(b, bufb)
   182  		if errb != nil && errb != io.EOF && errb != io.ErrUnexpectedEOF {
   183  			return errb
   184  		}
   185  		if !bytes.Equal(bufa[:na], bufb[:nb]) {
   186  			return errors.New("contents mismatch")
   187  		}
   188  		if erra != nil && errb != nil {
   189  			break
   190  		}
   191  		off += len(bufa)
   192  	}
   193  	return nil
   194  }
   195  
   196  type zeroReader struct{}
   197  
   198  func (r zeroReader) Read(p []byte) (int, error) {
   199  	clear(p)
   200  	return len(p), nil
   201  }
   202  
   203  type randReader struct {
   204  	rand *rand.Rand
   205  }
   206  
   207  func newRandReader() *randReader {
   208  	return &randReader{rand.New(rand.NewPCG(0, 0))}
   209  }
   210  
   211  func (r *randReader) Read(p []byte) (int, error) {
   212  	for i := range p {
   213  		p[i] = byte(r.rand.Uint32() & 0xff)
   214  	}
   215  	return len(p), nil
   216  }
   217  
   218  func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
   219  	t.Helper()
   220  	if !nettest.TestableNetwork(proto) {
   221  		t.Skipf("%s does not support %q", runtime.GOOS, proto)
   222  	}
   223  
   224  	ln, err := nettest.NewLocalListener(proto)
   225  	if err != nil {
   226  		t.Fatalf("NewLocalListener error: %v", err)
   227  	}
   228  	t.Cleanup(func() {
   229  		if ln != nil {
   230  			ln.Close()
   231  		}
   232  		if client != nil {
   233  			client.Close()
   234  		}
   235  		if server != nil {
   236  			server.Close()
   237  		}
   238  	})
   239  	ch := make(chan struct{})
   240  	go func() {
   241  		var err error
   242  		server, err = ln.Accept()
   243  		if err != nil {
   244  			t.Errorf("Accept new connection error: %v", err)
   245  		}
   246  		ch <- struct{}{}
   247  	}()
   248  	client, err = net.Dial(proto, ln.Addr().String())
   249  	<-ch
   250  	if err != nil {
   251  		t.Fatalf("Dial new connection error: %v", err)
   252  	}
   253  	return client, server
   254  }
   255  

View as plain text