Source file
src/net/dial.go
1
2
3
4
5 package net
6
7 import (
8 "context"
9 "internal/bytealg"
10 "internal/godebug"
11 "internal/nettrace"
12 "syscall"
13 "time"
14 )
15
16 const (
17
18
19 defaultTCPKeepAliveIdle = 15 * time.Second
20
21
22
23 defaultTCPKeepAliveInterval = 15 * time.Second
24
25
26 defaultTCPKeepAliveCount = 9
27
28
29
30
31 defaultMPTCPEnabledListen = true
32 defaultMPTCPEnabledDial = false
33 )
34
35
36
37
38
39
40
41 var multipathtcp = godebug.New("multipathtcp")
42
43
44
45 type mptcpStatusDial uint8
46
47 const (
48
49 mptcpUseDefaultDial mptcpStatusDial = iota
50 mptcpEnabledDial
51 mptcpDisabledDial
52 )
53
54 func (m *mptcpStatusDial) get() bool {
55 switch *m {
56 case mptcpEnabledDial:
57 return true
58 case mptcpDisabledDial:
59 return false
60 }
61
62
63 if multipathtcp.Value() == "1" || multipathtcp.Value() == "3" {
64 multipathtcp.IncNonDefault()
65
66 return true
67 }
68
69 return defaultMPTCPEnabledDial
70 }
71
72 func (m *mptcpStatusDial) set(use bool) {
73 if use {
74 *m = mptcpEnabledDial
75 } else {
76 *m = mptcpDisabledDial
77 }
78 }
79
80
81
82 type mptcpStatusListen uint8
83
84 const (
85
86 mptcpUseDefaultListen mptcpStatusListen = iota
87 mptcpEnabledListen
88 mptcpDisabledListen
89 )
90
91 func (m *mptcpStatusListen) get() bool {
92 switch *m {
93 case mptcpEnabledListen:
94 return true
95 case mptcpDisabledListen:
96 return false
97 }
98
99
100
101 if multipathtcp.Value() == "0" || multipathtcp.Value() == "3" {
102 multipathtcp.IncNonDefault()
103
104 return false
105 }
106
107 return defaultMPTCPEnabledListen
108 }
109
110 func (m *mptcpStatusListen) set(use bool) {
111 if use {
112 *m = mptcpEnabledListen
113 } else {
114 *m = mptcpDisabledListen
115 }
116 }
117
118
119
120
121
122
123
124
125 type Dialer struct {
126
127
128
129
130
131
132
133
134
135
136
137
138 Timeout time.Duration
139
140
141
142
143
144 Deadline time.Time
145
146
147
148
149
150 LocalAddr Addr
151
152
153
154
155
156
157
158
159 DualStack bool
160
161
162
163
164
165
166
167
168
169 FallbackDelay time.Duration
170
171
172
173
174
175
176
177
178
179
180
181 KeepAlive time.Duration
182
183
184
185
186
187
188
189
190 KeepAliveConfig KeepAliveConfig
191
192
193 Resolver *Resolver
194
195
196
197
198
199
200 Cancel <-chan struct{}
201
202
203
204
205
206
207
208
209
210
211
212 Control func(network, address string, c syscall.RawConn) error
213
214
215
216
217
218
219
220
221
222
223
224 ControlContext func(ctx context.Context, network, address string, c syscall.RawConn) error
225
226
227
228
229 mptcpStatus mptcpStatusDial
230 }
231
232 func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }
233
234 func minNonzeroTime(a, b time.Time) time.Time {
235 if a.IsZero() {
236 return b
237 }
238 if b.IsZero() || a.Before(b) {
239 return a
240 }
241 return b
242 }
243
244
245
246
247
248
249
250 func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
251 if d.Timeout != 0 {
252 earliest = now.Add(d.Timeout)
253 }
254 if d, ok := ctx.Deadline(); ok {
255 earliest = minNonzeroTime(earliest, d)
256 }
257 return minNonzeroTime(earliest, d.Deadline)
258 }
259
260 func (d *Dialer) resolver() *Resolver {
261 if d.Resolver != nil {
262 return d.Resolver
263 }
264 return DefaultResolver
265 }
266
267
268
269 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
270 if deadline.IsZero() {
271 return deadline, nil
272 }
273 timeRemaining := deadline.Sub(now)
274 if timeRemaining <= 0 {
275 return time.Time{}, errTimeout
276 }
277
278 timeout := timeRemaining / time.Duration(addrsRemaining)
279
280 const saneMinimum = 2 * time.Second
281 if timeout < saneMinimum {
282 if timeRemaining < saneMinimum {
283 timeout = timeRemaining
284 } else {
285 timeout = saneMinimum
286 }
287 }
288 return now.Add(timeout), nil
289 }
290
291 func (d *Dialer) fallbackDelay() time.Duration {
292 if d.FallbackDelay > 0 {
293 return d.FallbackDelay
294 } else {
295 return 300 * time.Millisecond
296 }
297 }
298
299 func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) {
300 i := bytealg.LastIndexByteString(network, ':')
301 if i < 0 {
302 switch network {
303 case "tcp", "tcp4", "tcp6":
304 case "udp", "udp4", "udp6":
305 case "ip", "ip4", "ip6":
306 if needsProto {
307 return "", 0, UnknownNetworkError(network)
308 }
309 case "unix", "unixgram", "unixpacket":
310 default:
311 return "", 0, UnknownNetworkError(network)
312 }
313 return network, 0, nil
314 }
315 afnet = network[:i]
316 switch afnet {
317 case "ip", "ip4", "ip6":
318 protostr := network[i+1:]
319 proto, i, ok := dtoi(protostr)
320 if !ok || i != len(protostr) {
321 proto, err = lookupProtocol(ctx, protostr)
322 if err != nil {
323 return "", 0, err
324 }
325 }
326 return afnet, proto, nil
327 }
328 return "", 0, UnknownNetworkError(network)
329 }
330
331
332
333
334 func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
335 afnet, _, err := parseNetwork(ctx, network, true)
336 if err != nil {
337 return nil, err
338 }
339 if op == "dial" && addr == "" {
340 return nil, errMissingAddress
341 }
342 switch afnet {
343 case "unix", "unixgram", "unixpacket":
344 addr, err := ResolveUnixAddr(afnet, addr)
345 if err != nil {
346 return nil, err
347 }
348 if op == "dial" && hint != nil && addr.Network() != hint.Network() {
349 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
350 }
351 return addrList{addr}, nil
352 }
353 addrs, err := r.internetAddrList(ctx, afnet, addr)
354 if err != nil || op != "dial" || hint == nil {
355 return addrs, err
356 }
357 var (
358 tcp *TCPAddr
359 udp *UDPAddr
360 ip *IPAddr
361 wildcard bool
362 )
363 switch hint := hint.(type) {
364 case *TCPAddr:
365 tcp = hint
366 wildcard = tcp.isWildcard()
367 case *UDPAddr:
368 udp = hint
369 wildcard = udp.isWildcard()
370 case *IPAddr:
371 ip = hint
372 wildcard = ip.isWildcard()
373 }
374 naddrs := addrs[:0]
375 for _, addr := range addrs {
376 if addr.Network() != hint.Network() {
377 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
378 }
379 switch addr := addr.(type) {
380 case *TCPAddr:
381 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
382 continue
383 }
384 naddrs = append(naddrs, addr)
385 case *UDPAddr:
386 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
387 continue
388 }
389 naddrs = append(naddrs, addr)
390 case *IPAddr:
391 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
392 continue
393 }
394 naddrs = append(naddrs, addr)
395 }
396 }
397 if len(naddrs) == 0 {
398 return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()}
399 }
400 return naddrs, nil
401 }
402
403
404
405
406
407 func (d *Dialer) MultipathTCP() bool {
408 return d.mptcpStatus.get()
409 }
410
411
412
413
414
415
416
417 func (d *Dialer) SetMultipathTCP(use bool) {
418 d.mptcpStatus.set(use)
419 }
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469 func Dial(network, address string) (Conn, error) {
470 var d Dialer
471 return d.Dial(network, address)
472 }
473
474
475
476
477
478
479
480
481
482
483
484 func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
485 d := Dialer{Timeout: timeout}
486 return d.Dial(network, address)
487 }
488
489
490 type sysDialer struct {
491 Dialer
492 network, address string
493 testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
494 }
495
496
497
498
499
500
501
502
503 func (d *Dialer) Dial(network, address string) (Conn, error) {
504 return d.DialContext(context.Background(), network, address)
505 }
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525 func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
526 if ctx == nil {
527 panic("nil context")
528 }
529 deadline := d.deadline(ctx, time.Now())
530 if !deadline.IsZero() {
531 testHookStepTime()
532 if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
533 subCtx, cancel := context.WithDeadline(ctx, deadline)
534 defer cancel()
535 ctx = subCtx
536 }
537 }
538 if oldCancel := d.Cancel; oldCancel != nil {
539 subCtx, cancel := context.WithCancel(ctx)
540 defer cancel()
541 go func() {
542 select {
543 case <-oldCancel:
544 cancel()
545 case <-subCtx.Done():
546 }
547 }()
548 ctx = subCtx
549 }
550
551
552 resolveCtx := ctx
553 if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil {
554 shadow := *trace
555 shadow.ConnectStart = nil
556 shadow.ConnectDone = nil
557 resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
558 }
559
560 addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
561 if err != nil {
562 return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
563 }
564
565 sd := &sysDialer{
566 Dialer: *d,
567 network: network,
568 address: address,
569 }
570
571 var primaries, fallbacks addrList
572 if d.dualStack() && network == "tcp" {
573 primaries, fallbacks = addrs.partition(isIPv4)
574 } else {
575 primaries = addrs
576 }
577
578 return sd.dialParallel(ctx, primaries, fallbacks)
579 }
580
581
582
583
584
585 func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addrList) (Conn, error) {
586 if len(fallbacks) == 0 {
587 return sd.dialSerial(ctx, primaries)
588 }
589
590 returned := make(chan struct{})
591 defer close(returned)
592
593 type dialResult struct {
594 Conn
595 error
596 primary bool
597 done bool
598 }
599 results := make(chan dialResult)
600
601 startRacer := func(ctx context.Context, primary bool) {
602 ras := primaries
603 if !primary {
604 ras = fallbacks
605 }
606 c, err := sd.dialSerial(ctx, ras)
607 select {
608 case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
609 case <-returned:
610 if c != nil {
611 c.Close()
612 }
613 }
614 }
615
616 var primary, fallback dialResult
617
618
619 primaryCtx, primaryCancel := context.WithCancel(ctx)
620 defer primaryCancel()
621 go startRacer(primaryCtx, true)
622
623
624 fallbackTimer := time.NewTimer(sd.fallbackDelay())
625 defer fallbackTimer.Stop()
626
627 for {
628 select {
629 case <-fallbackTimer.C:
630 fallbackCtx, fallbackCancel := context.WithCancel(ctx)
631 defer fallbackCancel()
632 go startRacer(fallbackCtx, false)
633
634 case res := <-results:
635 if res.error == nil {
636 return res.Conn, nil
637 }
638 if res.primary {
639 primary = res
640 } else {
641 fallback = res
642 }
643 if primary.done && fallback.done {
644 return nil, primary.error
645 }
646 if res.primary && fallbackTimer.Stop() {
647
648
649
650
651 fallbackTimer.Reset(0)
652 }
653 }
654 }
655 }
656
657
658
659 func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
660 var firstErr error
661
662 for i, ra := range ras {
663 select {
664 case <-ctx.Done():
665 return nil, &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
666 default:
667 }
668
669 dialCtx := ctx
670 if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
671 partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
672 if err != nil {
673
674 if firstErr == nil {
675 firstErr = &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: err}
676 }
677 break
678 }
679 if partialDeadline.Before(deadline) {
680 var cancel context.CancelFunc
681 dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
682 defer cancel()
683 }
684 }
685
686 c, err := sd.dialSingle(dialCtx, ra)
687 if err == nil {
688 return c, nil
689 }
690 if firstErr == nil {
691 firstErr = err
692 }
693 }
694
695 if firstErr == nil {
696 firstErr = &OpError{Op: "dial", Net: sd.network, Source: nil, Addr: nil, Err: errMissingAddress}
697 }
698 return nil, firstErr
699 }
700
701
702
703 func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error) {
704 trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
705 if trace != nil {
706 raStr := ra.String()
707 if trace.ConnectStart != nil {
708 trace.ConnectStart(sd.network, raStr)
709 }
710 if trace.ConnectDone != nil {
711 defer func() { trace.ConnectDone(sd.network, raStr, err) }()
712 }
713 }
714 la := sd.LocalAddr
715 switch ra := ra.(type) {
716 case *TCPAddr:
717 la, _ := la.(*TCPAddr)
718 if sd.MultipathTCP() {
719 c, err = sd.dialMPTCP(ctx, la, ra)
720 } else {
721 c, err = sd.dialTCP(ctx, la, ra)
722 }
723 case *UDPAddr:
724 la, _ := la.(*UDPAddr)
725 c, err = sd.dialUDP(ctx, la, ra)
726 case *IPAddr:
727 la, _ := la.(*IPAddr)
728 c, err = sd.dialIP(ctx, la, ra)
729 case *UnixAddr:
730 la, _ := la.(*UnixAddr)
731 c, err = sd.dialUnix(ctx, la, ra)
732 default:
733 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: sd.address}}
734 }
735 if err != nil {
736 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: err}
737 }
738 return c, nil
739 }
740
741
742 type ListenConfig struct {
743
744
745
746
747
748
749
750
751 Control func(network, address string, c syscall.RawConn) error
752
753
754
755
756
757
758
759
760
761
762 KeepAlive time.Duration
763
764
765
766
767
768
769
770
771 KeepAliveConfig KeepAliveConfig
772
773
774
775
776 mptcpStatus mptcpStatusListen
777 }
778
779
780
781
782
783 func (lc *ListenConfig) MultipathTCP() bool {
784 return lc.mptcpStatus.get()
785 }
786
787
788
789
790
791
792
793 func (lc *ListenConfig) SetMultipathTCP(use bool) {
794 lc.mptcpStatus.set(use)
795 }
796
797
798
799
800
801
802
803
804 func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
805 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
806 if err != nil {
807 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
808 }
809 sl := &sysListener{
810 ListenConfig: *lc,
811 network: network,
812 address: address,
813 }
814 var l Listener
815 la := addrs.first(isIPv4)
816 switch la := la.(type) {
817 case *TCPAddr:
818 if sl.MultipathTCP() {
819 l, err = sl.listenMPTCP(ctx, la)
820 } else {
821 l, err = sl.listenTCP(ctx, la)
822 }
823 case *UnixAddr:
824 l, err = sl.listenUnix(ctx, la)
825 default:
826 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
827 }
828 if err != nil {
829 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
830 }
831 return l, nil
832 }
833
834
835
836
837
838
839
840
841 func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
842 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
843 if err != nil {
844 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
845 }
846 sl := &sysListener{
847 ListenConfig: *lc,
848 network: network,
849 address: address,
850 }
851 var c PacketConn
852 la := addrs.first(isIPv4)
853 switch la := la.(type) {
854 case *UDPAddr:
855 c, err = sl.listenUDP(ctx, la)
856 case *IPAddr:
857 c, err = sl.listenIP(ctx, la)
858 case *UnixAddr:
859 c, err = sl.listenUnixgram(ctx, la)
860 default:
861 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
862 }
863 if err != nil {
864 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
865 }
866 return c, nil
867 }
868
869
870 type sysListener struct {
871 ListenConfig
872 network, address string
873 }
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896 func Listen(network, address string) (Listener, error) {
897 var lc ListenConfig
898 return lc.Listen(context.Background(), network, address)
899 }
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926 func ListenPacket(network, address string) (PacketConn, error) {
927 var lc ListenConfig
928 return lc.ListenPacket(context.Background(), network, address)
929 }
930
View as plain text