Source file src/net/writev_test.go

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package net
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"internal/poll"
    11  	"io"
    12  	"reflect"
    13  	"runtime"
    14  	"sync"
    15  	"testing"
    16  )
    17  
    18  func TestBuffers_read(t *testing.T) {
    19  	const story = "once upon a time in Gopherland ... "
    20  	buffers := Buffers{
    21  		[]byte("once "),
    22  		[]byte("upon "),
    23  		[]byte("a "),
    24  		[]byte("time "),
    25  		[]byte("in "),
    26  		[]byte("Gopherland ... "),
    27  	}
    28  	got, err := io.ReadAll(&buffers)
    29  	if err != nil {
    30  		t.Fatal(err)
    31  	}
    32  	if string(got) != story {
    33  		t.Errorf("read %q; want %q", got, story)
    34  	}
    35  	if len(buffers) != 0 {
    36  		t.Errorf("len(buffers) = %d; want 0", len(buffers))
    37  	}
    38  }
    39  
    40  func TestBuffers_consume(t *testing.T) {
    41  	tests := []struct {
    42  		in      Buffers
    43  		consume int64
    44  		want    Buffers
    45  	}{
    46  		{
    47  			in:      Buffers{[]byte("foo"), []byte("bar")},
    48  			consume: 0,
    49  			want:    Buffers{[]byte("foo"), []byte("bar")},
    50  		},
    51  		{
    52  			in:      Buffers{[]byte("foo"), []byte("bar")},
    53  			consume: 2,
    54  			want:    Buffers{[]byte("o"), []byte("bar")},
    55  		},
    56  		{
    57  			in:      Buffers{[]byte("foo"), []byte("bar")},
    58  			consume: 3,
    59  			want:    Buffers{[]byte("bar")},
    60  		},
    61  		{
    62  			in:      Buffers{[]byte("foo"), []byte("bar")},
    63  			consume: 4,
    64  			want:    Buffers{[]byte("ar")},
    65  		},
    66  		{
    67  			in:      Buffers{nil, nil, nil, []byte("bar")},
    68  			consume: 1,
    69  			want:    Buffers{[]byte("ar")},
    70  		},
    71  		{
    72  			in:      Buffers{nil, nil, nil, []byte("foo")},
    73  			consume: 0,
    74  			want:    Buffers{[]byte("foo")},
    75  		},
    76  		{
    77  			in:      Buffers{nil, nil, nil},
    78  			consume: 0,
    79  			want:    Buffers{},
    80  		},
    81  	}
    82  	for i, tt := range tests {
    83  		in := tt.in
    84  		in.consume(tt.consume)
    85  		if !reflect.DeepEqual(in, tt.want) {
    86  			t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want)
    87  		}
    88  	}
    89  }
    90  
    91  func TestBuffers_WriteTo(t *testing.T) {
    92  	for _, name := range []string{"WriteTo", "Copy"} {
    93  		for _, size := range []int{0, 10, 1023, 1024, 1025} {
    94  			t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) {
    95  				testBuffer_writeTo(t, size, name == "Copy")
    96  			})
    97  		}
    98  	}
    99  }
   100  
   101  func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) {
   102  	oldHook := poll.TestHookDidWritev
   103  	defer func() { poll.TestHookDidWritev = oldHook }()
   104  	var writeLog struct {
   105  		sync.Mutex
   106  		log []int
   107  	}
   108  	poll.TestHookDidWritev = func(size int) {
   109  		writeLog.Lock()
   110  		writeLog.log = append(writeLog.log, size)
   111  		writeLog.Unlock()
   112  	}
   113  	var want bytes.Buffer
   114  	for i := 0; i < chunks; i++ {
   115  		want.WriteByte(byte(i))
   116  	}
   117  
   118  	withTCPConnPair(t, func(c *TCPConn) error {
   119  		buffers := make(Buffers, chunks)
   120  		for i := range buffers {
   121  			buffers[i] = want.Bytes()[i : i+1]
   122  		}
   123  		var n int64
   124  		var err error
   125  		if useCopy {
   126  			n, err = io.Copy(c, &buffers)
   127  		} else {
   128  			n, err = buffers.WriteTo(c)
   129  		}
   130  		if err != nil {
   131  			return err
   132  		}
   133  		if len(buffers) != 0 {
   134  			return fmt.Errorf("len(buffers) = %d; want 0", len(buffers))
   135  		}
   136  		if n != int64(want.Len()) {
   137  			return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len())
   138  		}
   139  		return nil
   140  	}, func(c *TCPConn) error {
   141  		all, err := io.ReadAll(c)
   142  		if !bytes.Equal(all, want.Bytes()) || err != nil {
   143  			return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes())
   144  		}
   145  
   146  		writeLog.Lock() // no need to unlock
   147  		var gotSum int
   148  		for _, v := range writeLog.log {
   149  			gotSum += v
   150  		}
   151  
   152  		var wantSum int
   153  		switch runtime.GOOS {
   154  		case "aix", "android", "darwin", "ios", "dragonfly", "freebsd", "illumos", "linux", "netbsd", "openbsd", "solaris":
   155  			var wantMinCalls int
   156  			wantSum = want.Len()
   157  			v := chunks
   158  			for v > 0 {
   159  				wantMinCalls++
   160  				v -= 1024
   161  			}
   162  			if len(writeLog.log) < wantMinCalls {
   163  				t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls)
   164  			}
   165  		case "windows":
   166  			var wantCalls int
   167  			wantSum = want.Len()
   168  			if wantSum > 0 {
   169  				wantCalls = 1 // windows will always do 1 syscall, unless sending empty buffer
   170  			}
   171  			if len(writeLog.log) != wantCalls {
   172  				t.Errorf("write calls = %v; want %v", len(writeLog.log), wantCalls)
   173  			}
   174  		}
   175  		if gotSum != wantSum {
   176  			t.Errorf("writev call sum  = %v; want %v", gotSum, wantSum)
   177  		}
   178  		return nil
   179  	})
   180  }
   181  
   182  func TestWritevError(t *testing.T) {
   183  	if runtime.GOOS == "windows" {
   184  		t.Skipf("skipping the test: windows does not have problem sending large chunks of data")
   185  	}
   186  
   187  	ln := newLocalListener(t, "tcp")
   188  
   189  	ch := make(chan Conn, 1)
   190  	defer func() {
   191  		ln.Close()
   192  		for c := range ch {
   193  			c.Close()
   194  		}
   195  	}()
   196  
   197  	go func() {
   198  		defer close(ch)
   199  		c, err := ln.Accept()
   200  		if err != nil {
   201  			t.Error(err)
   202  			return
   203  		}
   204  		ch <- c
   205  	}()
   206  	c1, err := Dial("tcp", ln.Addr().String())
   207  	if err != nil {
   208  		t.Fatal(err)
   209  	}
   210  	defer c1.Close()
   211  	c2 := <-ch
   212  	if c2 == nil {
   213  		t.Fatal("no server side connection")
   214  	}
   215  	c2.Close()
   216  
   217  	// 1 GB of data should be enough to notice the connection is gone.
   218  	// Just a few bytes is not enough.
   219  	// Arrange to reuse the same 1 MB buffer so that we don't allocate much.
   220  	buf := make([]byte, 1<<20)
   221  	buffers := make(Buffers, 1<<10)
   222  	for i := range buffers {
   223  		buffers[i] = buf
   224  	}
   225  	if _, err := buffers.WriteTo(c1); err == nil {
   226  		t.Fatal("Buffers.WriteTo(closed conn) succeeded, want error")
   227  	}
   228  }
   229  

View as plain text