Source file src/os/copy_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package os_test
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"io"
    11  	"math/rand/v2"
    12  	"net"
    13  	"os"
    14  	"runtime"
    15  	"sync"
    16  	"testing"
    17  
    18  	"golang.org/x/net/nettest"
    19  )
    20  
    21  // Exercise sendfile/splice fast paths with a moderately large file.
    22  //
    23  // https://go.dev/issue/70000
    24  
    25  func TestLargeCopyViaNetwork(t *testing.T) {
    26  	const size = 10 * 1024 * 1024
    27  	dir := t.TempDir()
    28  
    29  	src, err := os.Create(dir + "/src")
    30  	if err != nil {
    31  		t.Fatal(err)
    32  	}
    33  	defer src.Close()
    34  	if _, err := io.CopyN(src, newRandReader(), size); err != nil {
    35  		t.Fatal(err)
    36  	}
    37  	if _, err := src.Seek(0, 0); err != nil {
    38  		t.Fatal(err)
    39  	}
    40  
    41  	dst, err := os.Create(dir + "/dst")
    42  	if err != nil {
    43  		t.Fatal(err)
    44  	}
    45  	defer dst.Close()
    46  
    47  	client, server := createSocketPair(t, "tcp")
    48  	var wg sync.WaitGroup
    49  	wg.Add(2)
    50  	go func() {
    51  		defer wg.Done()
    52  		if n, err := io.Copy(dst, server); n != size || err != nil {
    53  			t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size)
    54  		}
    55  	}()
    56  	go func() {
    57  		defer wg.Done()
    58  		defer client.Close()
    59  		if n, err := io.Copy(client, src); n != size || err != nil {
    60  			t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size)
    61  		}
    62  	}()
    63  	wg.Wait()
    64  
    65  	if _, err := dst.Seek(0, 0); err != nil {
    66  		t.Fatal(err)
    67  	}
    68  	if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil {
    69  		t.Fatal(err)
    70  	}
    71  }
    72  
    73  func compareReaders(a, b io.Reader) error {
    74  	bufa := make([]byte, 4096)
    75  	bufb := make([]byte, 4096)
    76  	for {
    77  		na, erra := io.ReadFull(a, bufa)
    78  		if erra != nil && erra != io.EOF {
    79  			return erra
    80  		}
    81  		nb, errb := io.ReadFull(b, bufb)
    82  		if errb != nil && errb != io.EOF {
    83  			return errb
    84  		}
    85  		if !bytes.Equal(bufa[:na], bufb[:nb]) {
    86  			return errors.New("contents mismatch")
    87  		}
    88  		if erra == io.EOF && errb == io.EOF {
    89  			break
    90  		}
    91  	}
    92  	return nil
    93  }
    94  
    95  type randReader struct {
    96  	rand *rand.Rand
    97  }
    98  
    99  func newRandReader() *randReader {
   100  	return &randReader{rand.New(rand.NewPCG(0, 0))}
   101  }
   102  
   103  func (r *randReader) Read(p []byte) (int, error) {
   104  	var v uint64
   105  	var n int
   106  	for i := range p {
   107  		if n == 0 {
   108  			v = r.rand.Uint64()
   109  			n = 8
   110  		}
   111  		p[i] = byte(v & 0xff)
   112  		v >>= 8
   113  		n--
   114  	}
   115  	return len(p), nil
   116  }
   117  
   118  func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
   119  	t.Helper()
   120  	if !nettest.TestableNetwork(proto) {
   121  		t.Skipf("%s does not support %q", runtime.GOOS, proto)
   122  	}
   123  
   124  	ln, err := nettest.NewLocalListener(proto)
   125  	if err != nil {
   126  		t.Fatalf("NewLocalListener error: %v", err)
   127  	}
   128  	t.Cleanup(func() {
   129  		if ln != nil {
   130  			ln.Close()
   131  		}
   132  		if client != nil {
   133  			client.Close()
   134  		}
   135  		if server != nil {
   136  			server.Close()
   137  		}
   138  	})
   139  	ch := make(chan struct{})
   140  	go func() {
   141  		var err error
   142  		server, err = ln.Accept()
   143  		if err != nil {
   144  			t.Errorf("Accept new connection error: %v", err)
   145  		}
   146  		ch <- struct{}{}
   147  	}()
   148  	client, err = net.Dial(proto, ln.Addr().String())
   149  	<-ch
   150  	if err != nil {
   151  		t.Fatalf("Dial new connection error: %v", err)
   152  	}
   153  	return client, server
   154  }
   155  

View as plain text