Source file
src/io/multi_test.go
1
2
3
4
5 package io_test
6
7 import (
8 "bytes"
9 "crypto/sha1"
10 "errors"
11 "fmt"
12 . "io"
13 "runtime"
14 "strings"
15 "testing"
16 "time"
17 )
18
19 func TestMultiReader(t *testing.T) {
20 var mr Reader
21 var buf []byte
22 nread := 0
23 withFooBar := func(tests func()) {
24 r1 := strings.NewReader("foo ")
25 r2 := strings.NewReader("")
26 r3 := strings.NewReader("bar")
27 mr = MultiReader(r1, r2, r3)
28 buf = make([]byte, 20)
29 tests()
30 }
31 expectRead := func(size int, expected string, eerr error) {
32 nread++
33 n, gerr := mr.Read(buf[0:size])
34 if n != len(expected) {
35 t.Errorf("#%d, expected %d bytes; got %d",
36 nread, len(expected), n)
37 }
38 got := string(buf[0:n])
39 if got != expected {
40 t.Errorf("#%d, expected %q; got %q",
41 nread, expected, got)
42 }
43 if gerr != eerr {
44 t.Errorf("#%d, expected error %v; got %v",
45 nread, eerr, gerr)
46 }
47 buf = buf[n:]
48 }
49 withFooBar(func() {
50 expectRead(2, "fo", nil)
51 expectRead(5, "o ", nil)
52 expectRead(5, "bar", nil)
53 expectRead(5, "", EOF)
54 })
55 withFooBar(func() {
56 expectRead(4, "foo ", nil)
57 expectRead(1, "b", nil)
58 expectRead(3, "ar", nil)
59 expectRead(1, "", EOF)
60 })
61 withFooBar(func() {
62 expectRead(5, "foo ", nil)
63 })
64 }
65
66 func TestMultiReaderAsWriterTo(t *testing.T) {
67 mr := MultiReader(
68 strings.NewReader("foo "),
69 MultiReader(
70 strings.NewReader(""),
71 strings.NewReader("bar"),
72 ),
73 )
74 mrAsWriterTo, ok := mr.(WriterTo)
75 if !ok {
76 t.Fatalf("expected cast to WriterTo to succeed")
77 }
78 sink := &strings.Builder{}
79 n, err := mrAsWriterTo.WriteTo(sink)
80 if err != nil {
81 t.Fatalf("expected no error; got %v", err)
82 }
83 if n != 7 {
84 t.Errorf("expected read 7 bytes; got %d", n)
85 }
86 if result := sink.String(); result != "foo bar" {
87 t.Errorf(`expected "foo bar"; got %q`, result)
88 }
89 }
90
91 func TestMultiWriter(t *testing.T) {
92 sink := new(bytes.Buffer)
93
94 testMultiWriter(t, struct {
95 Writer
96 fmt.Stringer
97 }{sink, sink})
98 }
99
100 func TestMultiWriter_String(t *testing.T) {
101 testMultiWriter(t, new(bytes.Buffer))
102 }
103
104
105
106 func TestMultiWriter_WriteStringSingleAlloc(t *testing.T) {
107 var sink1, sink2 bytes.Buffer
108 type simpleWriter struct {
109 Writer
110 }
111 mw := MultiWriter(simpleWriter{&sink1}, simpleWriter{&sink2})
112 allocs := int(testing.AllocsPerRun(1000, func() {
113 WriteString(mw, "foo")
114 }))
115 if allocs != 1 {
116 t.Errorf("num allocations = %d; want 1", allocs)
117 }
118 }
119
120 type writeStringChecker struct{ called bool }
121
122 func (c *writeStringChecker) WriteString(s string) (n int, err error) {
123 c.called = true
124 return len(s), nil
125 }
126
127 func (c *writeStringChecker) Write(p []byte) (n int, err error) {
128 return len(p), nil
129 }
130
131 func TestMultiWriter_StringCheckCall(t *testing.T) {
132 var c writeStringChecker
133 mw := MultiWriter(&c)
134 WriteString(mw, "foo")
135 if !c.called {
136 t.Error("did not see WriteString call to writeStringChecker")
137 }
138 }
139
140 func testMultiWriter(t *testing.T, sink interface {
141 Writer
142 fmt.Stringer
143 }) {
144 sha1 := sha1.New()
145 mw := MultiWriter(sha1, sink)
146
147 sourceString := "My input text."
148 source := strings.NewReader(sourceString)
149 written, err := Copy(mw, source)
150
151 if written != int64(len(sourceString)) {
152 t.Errorf("short write of %d, not %d", written, len(sourceString))
153 }
154
155 if err != nil {
156 t.Errorf("unexpected error: %v", err)
157 }
158
159 sha1hex := fmt.Sprintf("%x", sha1.Sum(nil))
160 if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" {
161 t.Error("incorrect sha1 value")
162 }
163
164 if sink.String() != sourceString {
165 t.Errorf("expected %q; got %q", sourceString, sink.String())
166 }
167 }
168
169
170 type writerFunc func(p []byte) (int, error)
171
172 func (f writerFunc) Write(p []byte) (int, error) {
173 return f(p)
174 }
175
176
177 func TestMultiWriterSingleChainFlatten(t *testing.T) {
178 pc := make([]uintptr, 1000)
179 n := runtime.Callers(0, pc)
180 var myDepth = callDepth(pc[:n])
181 var writeDepth int
182 var w Writer = MultiWriter(writerFunc(func(p []byte) (int, error) {
183 n := runtime.Callers(1, pc)
184 writeDepth += callDepth(pc[:n])
185 return 0, nil
186 }))
187
188 mw := w
189
190 for i := 0; i < 100; i++ {
191 mw = MultiWriter(w)
192 }
193
194 mw = MultiWriter(w, mw, w, mw)
195 mw.Write(nil)
196
197 if writeDepth != 4*(myDepth+2) {
198 t.Errorf("multiWriter did not flatten chained multiWriters: expected writeDepth %d, got %d",
199 4*(myDepth+2), writeDepth)
200 }
201 }
202
203 func TestMultiWriterError(t *testing.T) {
204 f1 := writerFunc(func(p []byte) (int, error) {
205 return len(p) / 2, ErrShortWrite
206 })
207 f2 := writerFunc(func(p []byte) (int, error) {
208 t.Errorf("MultiWriter called f2.Write")
209 return len(p), nil
210 })
211 w := MultiWriter(f1, f2)
212 n, err := w.Write(make([]byte, 100))
213 if n != 50 || err != ErrShortWrite {
214 t.Errorf("Write = %d, %v, want 50, ErrShortWrite", n, err)
215 }
216 }
217
218
219 func TestMultiReaderCopy(t *testing.T) {
220 slice := []Reader{strings.NewReader("hello world")}
221 r := MultiReader(slice...)
222 slice[0] = nil
223 data, err := ReadAll(r)
224 if err != nil || string(data) != "hello world" {
225 t.Errorf("ReadAll() = %q, %v, want %q, nil", data, err, "hello world")
226 }
227 }
228
229
230 func TestMultiWriterCopy(t *testing.T) {
231 var buf strings.Builder
232 slice := []Writer{&buf}
233 w := MultiWriter(slice...)
234 slice[0] = nil
235 n, err := w.Write([]byte("hello world"))
236 if err != nil || n != 11 {
237 t.Errorf("Write(`hello world`) = %d, %v, want 11, nil", n, err)
238 }
239 if buf.String() != "hello world" {
240 t.Errorf("buf.String() = %q, want %q", buf.String(), "hello world")
241 }
242 }
243
244
245 type readerFunc func(p []byte) (int, error)
246
247 func (f readerFunc) Read(p []byte) (int, error) {
248 return f(p)
249 }
250
251
252 func callDepth(callers []uintptr) (depth int) {
253 frames := runtime.CallersFrames(callers)
254 more := true
255 for more {
256 _, more = frames.Next()
257 depth++
258 }
259 return
260 }
261
262
263 func TestMultiReaderFlatten(t *testing.T) {
264 pc := make([]uintptr, 1000)
265 n := runtime.Callers(0, pc)
266 var myDepth = callDepth(pc[:n])
267 var readDepth int
268 var r Reader = MultiReader(readerFunc(func(p []byte) (int, error) {
269 n := runtime.Callers(1, pc)
270 readDepth = callDepth(pc[:n])
271 return 0, errors.New("irrelevant")
272 }))
273
274
275 for i := 0; i < 100; i++ {
276 r = MultiReader(r)
277 }
278
279 r.Read(nil)
280
281 if readDepth != myDepth+2 {
282 t.Errorf("multiReader did not flatten chained multiReaders: expected readDepth %d, got %d",
283 myDepth+2, readDepth)
284 }
285 }
286
287
288
289 type byteAndEOFReader byte
290
291 func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
292 if len(p) == 0 {
293
294
295 panic("unexpected call")
296 }
297 p[0] = byte(b)
298 return 1, EOF
299 }
300
301
302 func TestMultiReaderSingleByteWithEOF(t *testing.T) {
303 got, err := ReadAll(LimitReader(MultiReader(byteAndEOFReader('a'), byteAndEOFReader('b')), 10))
304 if err != nil {
305 t.Fatal(err)
306 }
307 const want = "ab"
308 if string(got) != want {
309 t.Errorf("got %q; want %q", got, want)
310 }
311 }
312
313
314
315
316 func TestMultiReaderFinalEOF(t *testing.T) {
317 r := MultiReader(bytes.NewReader(nil), byteAndEOFReader('a'))
318 buf := make([]byte, 2)
319 n, err := r.Read(buf)
320 if n != 1 || err != EOF {
321 t.Errorf("got %v, %v; want 1, EOF", n, err)
322 }
323 }
324
325 func TestMultiReaderFreesExhaustedReaders(t *testing.T) {
326 var mr Reader
327 closed := make(chan struct{})
328
329
330
331 func() {
332 buf1 := bytes.NewReader([]byte("foo"))
333 buf2 := bytes.NewReader([]byte("bar"))
334 mr = MultiReader(buf1, buf2)
335 runtime.SetFinalizer(buf1, func(*bytes.Reader) {
336 close(closed)
337 })
338 }()
339
340 buf := make([]byte, 4)
341 if n, err := ReadFull(mr, buf); err != nil || string(buf) != "foob" {
342 t.Fatalf(`ReadFull = %d (%q), %v; want 3, "foo", nil`, n, buf[:n], err)
343 }
344
345 runtime.GC()
346 select {
347 case <-closed:
348 case <-time.After(5 * time.Second):
349 t.Fatal("timeout waiting for collection of buf1")
350 }
351
352 if n, err := ReadFull(mr, buf[:2]); err != nil || string(buf[:2]) != "ar" {
353 t.Fatalf(`ReadFull = %d (%q), %v; want 2, "ar", nil`, n, buf[:n], err)
354 }
355 }
356
357 func TestInterleavedMultiReader(t *testing.T) {
358 r1 := strings.NewReader("123")
359 r2 := strings.NewReader("45678")
360
361 mr1 := MultiReader(r1, r2)
362 mr2 := MultiReader(mr1)
363
364 buf := make([]byte, 4)
365
366
367
368 n, err := ReadFull(mr2, buf)
369 if got := string(buf[:n]); got != "1234" || err != nil {
370 t.Errorf(`ReadFull(mr2) = (%q, %v), want ("1234", nil)`, got, err)
371 }
372
373
374
375 n, err = ReadFull(mr1, buf)
376 if got := string(buf[:n]); got != "5678" || err != nil {
377 t.Errorf(`ReadFull(mr1) = (%q, %v), want ("5678", nil)`, got, err)
378 }
379 }
380
View as plain text