Source file src/vendor/golang.org/x/net/quic/udp_msg.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build !quicbasicnet && (darwin || linux)
     6  
     7  package quic
     8  
     9  import (
    10  	"encoding/binary"
    11  	"net"
    12  	"net/netip"
    13  	"sync"
    14  	"syscall"
    15  	"unsafe"
    16  )
    17  
    18  // Network interface for platforms using sendmsg/recvmsg with cmsgs.
    19  
    20  type netUDPConn struct {
    21  	c         *net.UDPConn
    22  	localAddr netip.AddrPort
    23  }
    24  
    25  func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) {
    26  	a, _ := uc.LocalAddr().(*net.UDPAddr)
    27  	localAddr := a.AddrPort()
    28  	if localAddr.Addr().IsUnspecified() {
    29  		// If the conn is not bound to a specified (non-wildcard) address,
    30  		// then set localAddr.Addr to an invalid netip.Addr.
    31  		// This better conveys that this is not an address we should be using,
    32  		// and is a bit more efficient to test against.
    33  		localAddr = netip.AddrPortFrom(netip.Addr{}, localAddr.Port())
    34  	}
    35  
    36  	sc, err := uc.SyscallConn()
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	sc.Control(func(fd uintptr) {
    41  		// Ask for ECN info and (when we aren't bound to a fixed local address)
    42  		// destination info.
    43  		//
    44  		// If any of these calls fail, we won't get the requested information.
    45  		// That's fine, we'll gracefully handle the lack.
    46  		syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, ip_recvtos, 1)
    47  		syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_RECVTCLASS, 1)
    48  		if !localAddr.IsValid() {
    49  			syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1)
    50  			syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, ipv6_recvpktinfo, 1)
    51  		}
    52  	})
    53  
    54  	return &netUDPConn{
    55  		c:         uc,
    56  		localAddr: localAddr,
    57  	}, nil
    58  }
    59  
    60  func (c *netUDPConn) Close() error { return c.c.Close() }
    61  
    62  func (c *netUDPConn) LocalAddr() netip.AddrPort {
    63  	a, _ := c.c.LocalAddr().(*net.UDPAddr)
    64  	return a.AddrPort()
    65  }
    66  
    67  func (c *netUDPConn) Read(f func(*datagram)) {
    68  	// We shouldn't ever see all of these messages at the same time,
    69  	// but the total is small so just allocate enough space for everything we use.
    70  	const (
    71  		inPktinfoSize  = 12 // int + in_addr + in_addr
    72  		in6PktinfoSize = 20 // in6_addr + int
    73  		ipTOSSize      = 4
    74  		ipv6TclassSize = 4
    75  	)
    76  	control := make([]byte, 0+
    77  		syscall.CmsgSpace(inPktinfoSize)+
    78  		syscall.CmsgSpace(in6PktinfoSize)+
    79  		syscall.CmsgSpace(ipTOSSize)+
    80  		syscall.CmsgSpace(ipv6TclassSize))
    81  
    82  	for {
    83  		d := newDatagram()
    84  		n, controlLen, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(d.b, control)
    85  		if err != nil {
    86  			return
    87  		}
    88  		if n == 0 {
    89  			continue
    90  		}
    91  		d.localAddr = c.localAddr
    92  		d.peerAddr = unmapAddrPort(peerAddr)
    93  		d.b = d.b[:n]
    94  		parseControl(d, control[:controlLen])
    95  		f(d)
    96  	}
    97  }
    98  
    99  var cmsgPool = sync.Pool{
   100  	New: func() any {
   101  		return new([]byte)
   102  	},
   103  }
   104  
   105  func (c *netUDPConn) Write(dgram datagram) error {
   106  	controlp := cmsgPool.Get().(*[]byte)
   107  	control := *controlp
   108  	defer func() {
   109  		*controlp = control[:0]
   110  		cmsgPool.Put(controlp)
   111  	}()
   112  
   113  	localIP := dgram.localAddr.Addr()
   114  	if localIP.IsValid() {
   115  		if localIP.Is4() {
   116  			control = appendCmsgIPSourceAddrV4(control, localIP)
   117  		} else {
   118  			control = appendCmsgIPSourceAddrV6(control, localIP)
   119  		}
   120  	}
   121  	if dgram.ecn != ecnNotECT {
   122  		if dgram.peerAddr.Addr().Is4() {
   123  			control = appendCmsgECNv4(control, dgram.ecn)
   124  		} else {
   125  			control = appendCmsgECNv6(control, dgram.ecn)
   126  		}
   127  	}
   128  
   129  	_, _, err := c.c.WriteMsgUDPAddrPort(dgram.b, control, dgram.peerAddr)
   130  	return err
   131  }
   132  
   133  func parseControl(d *datagram, control []byte) {
   134  	msgs, err := syscall.ParseSocketControlMessage(control)
   135  	if err != nil {
   136  		return
   137  	}
   138  	for _, m := range msgs {
   139  		switch m.Header.Level {
   140  		case syscall.IPPROTO_IP:
   141  			switch m.Header.Type {
   142  			case syscall.IP_TOS, ip_recvtos:
   143  				// (Linux sets the type to IP_TOS, Darwin to IP_RECVTOS,
   144  				// just check for both.)
   145  				if ecn, ok := parseIPTOS(m.Data); ok {
   146  					d.ecn = ecn
   147  				}
   148  			case syscall.IP_PKTINFO:
   149  				if a, ok := parseInPktinfo(m.Data); ok {
   150  					d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
   151  				}
   152  			}
   153  		case syscall.IPPROTO_IPV6:
   154  			switch m.Header.Type {
   155  			case syscall.IPV6_TCLASS:
   156  				// 32-bit integer containing the traffic class field.
   157  				// The low two bits are the ECN field.
   158  				if ecn, ok := parseIPv6TCLASS(m.Data); ok {
   159  					d.ecn = ecn
   160  				}
   161  			case ipv6_pktinfo:
   162  				if a, ok := parseIn6Pktinfo(m.Data); ok {
   163  					d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
   164  				}
   165  			}
   166  		}
   167  	}
   168  }
   169  
   170  // IPV6_TCLASS is specified by RFC 3542 as an int.
   171  
   172  func parseIPv6TCLASS(b []byte) (ecnBits, bool) {
   173  	if len(b) != 4 {
   174  		return 0, false
   175  	}
   176  	return ecnBits(binary.NativeEndian.Uint32(b) & ecnMask), true
   177  }
   178  
   179  func appendCmsgECNv6(b []byte, ecn ecnBits) []byte {
   180  	b, data := appendCmsg(b, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, 4)
   181  	binary.NativeEndian.PutUint32(data, uint32(ecn))
   182  	return b
   183  }
   184  
   185  // struct in_pktinfo {
   186  //   unsigned int   ipi_ifindex;  /* send/recv interface index */
   187  //   struct in_addr ipi_spec_dst; /* Local address */
   188  //   struct in_addr ipi_addr;     /* IP Header dst address */
   189  // };
   190  
   191  // parseInPktinfo returns the destination address from an IP_PKTINFO.
   192  func parseInPktinfo(b []byte) (dst netip.Addr, ok bool) {
   193  	if len(b) != 12 {
   194  		return netip.Addr{}, false
   195  	}
   196  	return netip.AddrFrom4([4]byte(b[8:][:4])), true
   197  }
   198  
   199  // appendCmsgIPSourceAddrV4 appends an IP_PKTINFO setting the source address
   200  // for an outbound datagram.
   201  func appendCmsgIPSourceAddrV4(b []byte, src netip.Addr) []byte {
   202  	// struct in_pktinfo {
   203  	//   unsigned int   ipi_ifindex;  /* send/recv interface index */
   204  	//   struct in_addr ipi_spec_dst; /* Local address */
   205  	//   struct in_addr ipi_addr;     /* IP Header dst address */
   206  	// };
   207  	b, data := appendCmsg(b, syscall.IPPROTO_IP, syscall.IP_PKTINFO, 12)
   208  	ip := src.As4()
   209  	copy(data[4:], ip[:])
   210  	return b
   211  }
   212  
   213  // struct in6_pktinfo {
   214  //   struct in6_addr  ipi6_addr;    /* src/dst IPv6 address */
   215  //   unsigned int     ipi6_ifindex; /* send/recv interface index */
   216  // };
   217  
   218  // parseIn6Pktinfo returns the destination address from an IPV6_PKTINFO.
   219  func parseIn6Pktinfo(b []byte) (netip.Addr, bool) {
   220  	if len(b) != 20 {
   221  		return netip.Addr{}, false
   222  	}
   223  	return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true
   224  }
   225  
   226  // appendCmsgIPSourceAddrV6 appends an IPV6_PKTINFO setting the source address
   227  // for an outbound datagram.
   228  func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte {
   229  	b, data := appendCmsg(b, syscall.IPPROTO_IPV6, ipv6_pktinfo, 20)
   230  	ip := src.As16()
   231  	copy(data[0:], ip[:])
   232  	return b
   233  }
   234  
   235  // appendCmsg appends a cmsg with the given level, type, and size to b.
   236  // It returns the new buffer, and the data section of the cmsg.
   237  func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) {
   238  	off := len(b)
   239  	b = append(b, make([]byte, syscall.CmsgSpace(size))...)
   240  	h := (*syscall.Cmsghdr)(unsafe.Pointer(&b[off]))
   241  	h.Level = level
   242  	h.Type = typ
   243  	h.SetLen(syscall.CmsgLen(size))
   244  	return b, b[off+syscall.CmsgSpace(0):][:size]
   245  }
   246  

View as plain text