Source file
src/net/splice_linux_test.go
1
2
3
4
5
6
7 package net
8
9 import (
10 "internal/poll"
11 "io"
12 "os"
13 "strconv"
14 "sync"
15 "syscall"
16 "testing"
17 )
18
19 func TestSplice(t *testing.T) {
20 t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
21 if !testableNetwork("unixgram") {
22 t.Skip("skipping unix-to-tcp tests")
23 }
24 t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
25 t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") })
26 t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
27 t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
28 t.Run("no-unixpacket", testSpliceNoUnixpacket)
29 t.Run("no-unixgram", testSpliceNoUnixgram)
30 }
31
32 func testSpliceToFile(t *testing.T, upNet, downNet string) {
33 t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile)
34 t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile)
35 t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile)
36 t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile)
37 t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile)
38 t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile)
39 }
40
41 func testSplice(t *testing.T, upNet, downNet string) {
42 t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
43 t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
44 t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
45 t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
46 t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
47 t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
48 t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
49 t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
50 }
51
52 type spliceTestCase struct {
53 upNet, downNet string
54
55 chunkSize, totalSize int
56 limitReadSize int
57 }
58
59 func (tc spliceTestCase) test(t *testing.T) {
60 hook := hookSplice(t)
61
62
63
64
65 size := tc.totalSize
66 if tc.limitReadSize > 0 {
67 if tc.limitReadSize < size {
68 size = tc.limitReadSize
69 }
70 }
71
72 clientUp, serverUp := spawnTestSocketPair(t, tc.upNet)
73 defer serverUp.Close()
74 cleanup, err := startTestSocketPeer(t, clientUp, "w", tc.chunkSize, size)
75 if err != nil {
76 t.Fatal(err)
77 }
78 defer cleanup(t)
79 clientDown, serverDown := spawnTestSocketPair(t, tc.downNet)
80 defer serverDown.Close()
81 cleanup, err = startTestSocketPeer(t, clientDown, "r", tc.chunkSize, size)
82 if err != nil {
83 t.Fatal(err)
84 }
85 defer cleanup(t)
86
87 var r io.Reader = serverUp
88 if tc.limitReadSize > 0 {
89 r = &io.LimitedReader{
90 N: int64(tc.limitReadSize),
91 R: serverUp,
92 }
93 defer serverUp.Close()
94 }
95 n, err := io.Copy(serverDown, r)
96 if err != nil {
97 t.Fatal(err)
98 }
99
100 if want := int64(size); want != n {
101 t.Errorf("want %d bytes spliced, got %d", want, n)
102 }
103
104 if tc.limitReadSize > 0 {
105 wantN := 0
106 if tc.limitReadSize > size {
107 wantN = tc.limitReadSize - size
108 }
109
110 if n := r.(*io.LimitedReader).N; n != int64(wantN) {
111 t.Errorf("r.N = %d, want %d", n, wantN)
112 }
113 }
114
115
116
117 if tc.limitReadSize == 0 || tc.downNet == "tcp" {
118
119 if n > 0 && !hook.called {
120 t.Fatal("expected poll.Splice to be called")
121 }
122
123 verifySpliceFds(t, serverDown, hook, "dst")
124 verifySpliceFds(t, serverUp, hook, "src")
125
126
127 if !hook.handled || hook.written != int64(size) || hook.err != nil {
128 t.Errorf("expected handled = true, written = %d, err = nil, but got handled = %t, written = %d, err = %v",
129 size, hook.handled, hook.written, hook.err)
130 }
131 } else if hook.called {
132
133
134 t.Errorf("expected poll.Splice not be called")
135 }
136 }
137
138 func verifySpliceFds(t *testing.T, c Conn, hook *spliceHook, fdType string) {
139 t.Helper()
140
141 sc, ok := c.(syscall.Conn)
142 if !ok {
143 t.Fatalf("expected syscall.Conn")
144 }
145 rc, err := sc.SyscallConn()
146 if err != nil {
147 t.Fatalf("syscall.Conn.SyscallConn error: %v", err)
148 }
149 var hookFd int
150 switch fdType {
151 case "src":
152 hookFd = hook.srcfd
153 case "dst":
154 hookFd = hook.dstfd
155 default:
156 t.Fatalf("unknown fdType %q", fdType)
157 }
158 if err := rc.Control(func(fd uintptr) {
159 if hook.called && hookFd != int(fd) {
160 t.Fatalf("wrong %s file descriptor: got %d, want %d", fdType, hook.dstfd, int(fd))
161 }
162 }); err != nil {
163 t.Fatalf("syscall.RawConn.Control error: %v", err)
164 }
165 }
166
167 func (tc spliceTestCase) testFile(t *testing.T) {
168 hook := hookSplice(t)
169
170
171
172
173 actualSize := tc.totalSize
174 if tc.limitReadSize > 0 {
175 if tc.limitReadSize < actualSize {
176 actualSize = tc.limitReadSize
177 }
178 }
179
180 f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
181 if err != nil {
182 t.Fatal(err)
183 }
184 defer f.Close()
185
186 client, server := spawnTestSocketPair(t, tc.upNet)
187 defer server.Close()
188
189 cleanup, err := startTestSocketPeer(t, client, "w", tc.chunkSize, actualSize)
190 if err != nil {
191 client.Close()
192 t.Fatal("failed to start splice client:", err)
193 }
194 defer cleanup(t)
195
196 var r io.Reader = server
197 if tc.limitReadSize > 0 {
198 r = &io.LimitedReader{
199 N: int64(tc.limitReadSize),
200 R: r,
201 }
202 }
203
204 got, err := io.Copy(f, r)
205 if err != nil {
206 t.Fatalf("failed to ReadFrom with error: %v", err)
207 }
208
209
210
211 if got > 0 && hook.called {
212 t.Error("expected not poll.Splice to be called")
213 }
214
215 if want := int64(actualSize); got != want {
216 t.Errorf("got %d bytes, want %d", got, want)
217 }
218 if tc.limitReadSize > 0 {
219 wantN := 0
220 if tc.limitReadSize > actualSize {
221 wantN = tc.limitReadSize - actualSize
222 }
223
224 if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) {
225 t.Errorf("r.N = %d, want %d", gotN, wantN)
226 }
227 }
228 }
229
230 func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
231
232
233
234 if downNet == "unix" {
235 t.Skip("skipping test on unix socket")
236 }
237
238 hook := hookSplice(t)
239
240 clientUp, serverUp := spawnTestSocketPair(t, upNet)
241 defer clientUp.Close()
242 clientDown, serverDown := spawnTestSocketPair(t, downNet)
243 defer clientDown.Close()
244 defer serverDown.Close()
245
246 serverUp.Close()
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264 msg := "bye"
265 go func() {
266 serverDown.(io.ReaderFrom).ReadFrom(serverUp)
267 io.WriteString(serverDown, msg)
268 }()
269
270 buf := make([]byte, 3)
271 n, err := io.ReadFull(clientDown, buf)
272 if err != nil {
273 t.Errorf("clientDown: %v", err)
274 }
275 if string(buf) != msg {
276 t.Errorf("clientDown got %q, want %q", buf, msg)
277 }
278
279
280 if n > 0 && !hook.called {
281 t.Fatal("expected poll.Splice to be called")
282 }
283
284 verifySpliceFds(t, serverDown, hook, "dst")
285
286
287
288 if !hook.handled || hook.written > 0 || hook.err == nil {
289 t.Errorf("expected handled = true, written = 0, err != nil, but got handled = %t, written = %d, err = %v",
290 hook.handled, hook.written, hook.err)
291 }
292 }
293
294 func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
295 front := newLocalListener(t, upNet)
296 defer front.Close()
297 back := newLocalListener(t, downNet)
298 defer back.Close()
299
300 var wg sync.WaitGroup
301 wg.Add(2)
302
303 proxy := func() {
304 src, err := front.Accept()
305 if err != nil {
306 return
307 }
308 dst, err := Dial(downNet, back.Addr().String())
309 if err != nil {
310 return
311 }
312 defer dst.Close()
313 defer src.Close()
314 go func() {
315 io.Copy(src, dst)
316 wg.Done()
317 }()
318 go func() {
319 io.Copy(dst, src)
320 wg.Done()
321 }()
322 }
323
324 go proxy()
325
326 toFront, err := Dial(upNet, front.Addr().String())
327 if err != nil {
328 t.Fatal(err)
329 }
330
331 io.WriteString(toFront, "foo")
332 toFront.Close()
333
334 fromProxy, err := back.Accept()
335 if err != nil {
336 t.Fatal(err)
337 }
338 defer fromProxy.Close()
339
340 _, err = io.ReadAll(fromProxy)
341 if err != nil {
342 t.Fatal(err)
343 }
344
345 wg.Wait()
346 }
347
348 func testSpliceNoUnixpacket(t *testing.T) {
349 clientUp, serverUp := spawnTestSocketPair(t, "unixpacket")
350 defer clientUp.Close()
351 defer serverUp.Close()
352 clientDown, serverDown := spawnTestSocketPair(t, "tcp")
353 defer clientDown.Close()
354 defer serverDown.Close()
355
356
357
358
359
360
361
362
363 _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp)
364 if err != nil || handled != false {
365 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
366 }
367 }
368
369 func testSpliceNoUnixgram(t *testing.T) {
370 addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
371 if err != nil {
372 t.Fatal(err)
373 }
374 defer os.Remove(addr.Name)
375 up, err := ListenUnixgram("unixgram", addr)
376 if err != nil {
377 t.Fatal(err)
378 }
379 defer up.Close()
380 clientDown, serverDown := spawnTestSocketPair(t, "tcp")
381 defer clientDown.Close()
382 defer serverDown.Close()
383
384 _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up)
385 if err != nil || handled != false {
386 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
387 }
388 }
389
390 func BenchmarkSplice(b *testing.B) {
391 testHookUninstaller.Do(uninstallTestHooks)
392
393 b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
394 b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
395 b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") })
396 }
397
398 func benchSplice(b *testing.B, upNet, downNet string) {
399 for i := 0; i <= 10; i++ {
400 chunkSize := 1 << uint(i+10)
401 tc := spliceTestCase{
402 upNet: upNet,
403 downNet: downNet,
404 chunkSize: chunkSize,
405 }
406
407 b.Run(strconv.Itoa(chunkSize), tc.bench)
408 }
409 }
410
411 func (tc spliceTestCase) bench(b *testing.B) {
412
413 useSplice := true
414
415 clientUp, serverUp := spawnTestSocketPair(b, tc.upNet)
416 defer serverUp.Close()
417
418 cleanup, err := startTestSocketPeer(b, clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
419 if err != nil {
420 b.Fatal(err)
421 }
422 defer cleanup(b)
423
424 clientDown, serverDown := spawnTestSocketPair(b, tc.downNet)
425 defer serverDown.Close()
426
427 cleanup, err = startTestSocketPeer(b, clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
428 if err != nil {
429 b.Fatal(err)
430 }
431 defer cleanup(b)
432
433 b.SetBytes(int64(tc.chunkSize))
434 b.ResetTimer()
435
436 if useSplice {
437 _, err := io.Copy(serverDown, serverUp)
438 if err != nil {
439 b.Fatal(err)
440 }
441 } else {
442 type onlyReader struct {
443 io.Reader
444 }
445 _, err := io.Copy(serverDown, onlyReader{serverUp})
446 if err != nil {
447 b.Fatal(err)
448 }
449 }
450 }
451
452 func BenchmarkSpliceFile(b *testing.B) {
453 b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") })
454 b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") })
455 }
456
457 func benchmarkSpliceFile(b *testing.B, proto string) {
458 for i := 0; i <= 10; i++ {
459 size := 1 << (i + 10)
460 bench := spliceFileBench{
461 proto: proto,
462 chunkSize: size,
463 }
464 b.Run(strconv.Itoa(size), bench.benchSpliceFile)
465 }
466 }
467
468 type spliceFileBench struct {
469 proto string
470 chunkSize int
471 }
472
473 func (bench spliceFileBench) benchSpliceFile(b *testing.B) {
474 f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
475 if err != nil {
476 b.Fatal(err)
477 }
478 defer f.Close()
479
480 totalSize := b.N * bench.chunkSize
481
482 client, server := spawnTestSocketPair(b, bench.proto)
483 defer server.Close()
484
485 cleanup, err := startTestSocketPeer(b, client, "w", bench.chunkSize, totalSize)
486 if err != nil {
487 client.Close()
488 b.Fatalf("failed to start splice client: %v", err)
489 }
490 defer cleanup(b)
491
492 b.ReportAllocs()
493 b.SetBytes(int64(bench.chunkSize))
494 b.ResetTimer()
495
496 got, err := io.Copy(f, server)
497 if err != nil {
498 b.Fatalf("failed to ReadFrom with error: %v", err)
499 }
500 if want := int64(totalSize); got != want {
501 b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
502 }
503 }
504
505 func hookSplice(t *testing.T) *spliceHook {
506 t.Helper()
507
508 h := new(spliceHook)
509 h.install()
510 t.Cleanup(h.uninstall)
511 return h
512 }
513
514 type spliceHook struct {
515 called bool
516 dstfd int
517 srcfd int
518 remain int64
519
520 written int64
521 handled bool
522 err error
523
524 original func(dst, src *poll.FD, remain int64) (int64, bool, error)
525 }
526
527 func (h *spliceHook) install() {
528 h.original = pollSplice
529 pollSplice = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
530 h.called = true
531 h.dstfd = dst.Sysfd
532 h.srcfd = src.Sysfd
533 h.remain = remain
534 h.written, h.handled, h.err = h.original(dst, src, remain)
535 return h.written, h.handled, h.err
536 }
537 }
538
539 func (h *spliceHook) uninstall() {
540 pollSplice = h.original
541 }
542
View as plain text