Source file src/net/resolverdialfunc_test.go

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Test that Resolver.Dial can be a func returning an in-memory net.Conn
     6  // speaking DNS.
     7  
     8  package net
     9  
    10  import (
    11  	"bytes"
    12  	"context"
    13  	"errors"
    14  	"fmt"
    15  	"reflect"
    16  	"slices"
    17  	"testing"
    18  	"time"
    19  
    20  	"golang.org/x/net/dns/dnsmessage"
    21  )
    22  
    23  func TestResolverDialFunc(t *testing.T) {
    24  	r := &Resolver{
    25  		PreferGo: true,
    26  		Dial: newResolverDialFunc(&resolverDialHandler{
    27  			StartDial: func(network, address string) error {
    28  				t.Logf("StartDial(%q, %q) ...", network, address)
    29  				return nil
    30  			},
    31  			Question: func(h dnsmessage.Header, q dnsmessage.Question) {
    32  				t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
    33  					q.Name.String(), q.Type, q.Class)
    34  			},
    35  			// TODO: add test without HandleA* hooks specified at all, that Go
    36  			// doesn't issue retries; map to something terminal.
    37  			HandleA: func(w AWriter, name string) error {
    38  				w.AddIP([4]byte{1, 2, 3, 4})
    39  				w.AddIP([4]byte{5, 6, 7, 8})
    40  				return nil
    41  			},
    42  			HandleAAAA: func(w AAAAWriter, name string) error {
    43  				w.AddIP([16]byte{1: 1, 15: 15})
    44  				w.AddIP([16]byte{2: 2, 14: 14})
    45  				return nil
    46  			},
    47  			HandleSRV: func(w SRVWriter, name string) error {
    48  				w.AddSRV(1, 2, 80, "foo.bar.")
    49  				w.AddSRV(2, 3, 81, "bar.baz.")
    50  				return nil
    51  			},
    52  		}),
    53  	}
    54  	ctx := context.Background()
    55  	const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."
    56  
    57  	t.Run("LookupIP", func(t *testing.T) {
    58  		ips, err := r.LookupIP(ctx, "ip", fakeDomain)
    59  		if err != nil {
    60  			t.Fatal(err)
    61  		}
    62  		if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !slices.Equal(got, want) {
    63  			t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want)
    64  		}
    65  	})
    66  
    67  	t.Run("LookupSRV", func(t *testing.T) {
    68  		_, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain)
    69  		if err != nil {
    70  			t.Fatal(err)
    71  		}
    72  		want := []*SRV{
    73  			{
    74  				Target:   "foo.bar.",
    75  				Port:     80,
    76  				Priority: 1,
    77  				Weight:   2,
    78  			},
    79  			{
    80  				Target:   "bar.baz.",
    81  				Port:     81,
    82  				Priority: 2,
    83  				Weight:   3,
    84  			},
    85  		}
    86  		if !reflect.DeepEqual(got, want) {
    87  			t.Errorf("wrong result. got:")
    88  			for _, r := range got {
    89  				t.Logf("  - %+v", r)
    90  			}
    91  		}
    92  	})
    93  }
    94  
    95  func sortedIPStrings(ips []IP) []string {
    96  	ret := make([]string, len(ips))
    97  	for i, ip := range ips {
    98  		ret[i] = ip.String()
    99  	}
   100  	slices.Sort(ret)
   101  	return ret
   102  }
   103  
   104  func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
   105  	return func(ctx context.Context, network, address string) (Conn, error) {
   106  		a := &resolverFuncConn{
   107  			h:       h,
   108  			network: network,
   109  			address: address,
   110  			ttl:     10, // 10 second default if unset
   111  		}
   112  		if h.StartDial != nil {
   113  			if err := h.StartDial(network, address); err != nil {
   114  				return nil, err
   115  			}
   116  		}
   117  		return a, nil
   118  	}
   119  }
   120  
   121  type resolverDialHandler struct {
   122  	// StartDial, if non-nil, is called when Go first calls Resolver.Dial.
   123  	// Any error returned aborts the dial and is returned unwrapped.
   124  	StartDial func(network, address string) error
   125  
   126  	Question func(dnsmessage.Header, dnsmessage.Question)
   127  
   128  	// err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2).
   129  	// A nil error means success.
   130  	HandleA    func(w AWriter, name string) error
   131  	HandleAAAA func(w AAAAWriter, name string) error
   132  	HandleSRV  func(w SRVWriter, name string) error
   133  }
   134  
   135  type ResponseWriter struct{ a *resolverFuncConn }
   136  
   137  func (w ResponseWriter) header() dnsmessage.ResourceHeader {
   138  	q := w.a.q
   139  	return dnsmessage.ResourceHeader{
   140  		Name:  q.Name,
   141  		Type:  q.Type,
   142  		Class: q.Class,
   143  		TTL:   w.a.ttl,
   144  	}
   145  }
   146  
   147  // SetTTL sets the TTL for subsequent written resources.
   148  // Once a resource has been written, SetTTL calls are no-ops.
   149  // That is, it can only be called at most once, before anything
   150  // else is written.
   151  func (w ResponseWriter) SetTTL(seconds uint32) {
   152  	// ... intention is last one wins and mutates all previously
   153  	// written records too, but that's a little annoying.
   154  	// But it's also annoying if the requirement is it needs to be set
   155  	// last.
   156  	// And it's also annoying if it's possible for users to set
   157  	// different TTLs per Answer.
   158  	if w.a.wrote {
   159  		return
   160  	}
   161  	w.a.ttl = seconds
   162  
   163  }
   164  
   165  type AWriter struct{ ResponseWriter }
   166  
   167  func (w AWriter) AddIP(v4 [4]byte) {
   168  	w.a.wrote = true
   169  	err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
   170  	if err != nil {
   171  		panic(err)
   172  	}
   173  }
   174  
   175  type AAAAWriter struct{ ResponseWriter }
   176  
   177  func (w AAAAWriter) AddIP(v6 [16]byte) {
   178  	w.a.wrote = true
   179  	err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
   180  	if err != nil {
   181  		panic(err)
   182  	}
   183  }
   184  
   185  type SRVWriter struct{ ResponseWriter }
   186  
   187  // AddSRV adds a SRV record. The target name must end in a period and
   188  // be 63 bytes or fewer.
   189  func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
   190  	targetName, err := dnsmessage.NewName(target)
   191  	if err != nil {
   192  		return err
   193  	}
   194  	w.a.wrote = true
   195  	err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
   196  		Priority: priority,
   197  		Weight:   weight,
   198  		Port:     port,
   199  		Target:   targetName,
   200  	})
   201  	if err != nil {
   202  		panic(err) // internal fault, not user
   203  	}
   204  	return nil
   205  }
   206  
   207  var (
   208  	ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN
   209  	ErrRefused  = errors.New("refused")             // maps to RCode5, REFUSED
   210  )
   211  
   212  type resolverFuncConn struct {
   213  	h       *resolverDialHandler
   214  	network string
   215  	address string
   216  	builder *dnsmessage.Builder
   217  	q       dnsmessage.Question
   218  	ttl     uint32
   219  	wrote   bool
   220  
   221  	rbuf bytes.Buffer
   222  }
   223  
   224  func (*resolverFuncConn) Close() error                       { return nil }
   225  func (*resolverFuncConn) LocalAddr() Addr                    { return someaddr{} }
   226  func (*resolverFuncConn) RemoteAddr() Addr                   { return someaddr{} }
   227  func (*resolverFuncConn) SetDeadline(t time.Time) error      { return nil }
   228  func (*resolverFuncConn) SetReadDeadline(t time.Time) error  { return nil }
   229  func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }
   230  
   231  func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
   232  	return a.rbuf.Read(p)
   233  }
   234  
   235  func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
   236  	if len(packet) < 2 {
   237  		return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
   238  	}
   239  	reqLen := int(packet[0])<<8 | int(packet[1])
   240  	req := packet[2:]
   241  	if len(req) != reqLen {
   242  		return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
   243  	}
   244  
   245  	var parser dnsmessage.Parser
   246  	h, err := parser.Start(req)
   247  	if err != nil {
   248  		// TODO: hook
   249  		return 0, err
   250  	}
   251  	q, err := parser.Question()
   252  	hadQ := (err == nil)
   253  	if err == nil && a.h.Question != nil {
   254  		a.h.Question(h, q)
   255  	}
   256  	if err != nil && err != dnsmessage.ErrSectionDone {
   257  		return 0, err
   258  	}
   259  
   260  	resh := h
   261  	resh.Response = true
   262  	resh.Authoritative = true
   263  	if hadQ {
   264  		resh.RCode = dnsmessage.RCodeSuccess
   265  	} else {
   266  		resh.RCode = dnsmessage.RCodeNotImplemented
   267  	}
   268  	a.rbuf.Grow(514)
   269  	a.rbuf.WriteByte('X') // reserved header for beu16 length
   270  	a.rbuf.WriteByte('Y') // reserved header for beu16 length
   271  	builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
   272  	a.builder = &builder
   273  	if hadQ {
   274  		a.q = q
   275  		a.builder.StartQuestions()
   276  		err := a.builder.Question(q)
   277  		if err != nil {
   278  			return 0, fmt.Errorf("Question: %w", err)
   279  		}
   280  		a.builder.StartAnswers()
   281  		switch q.Type {
   282  		case dnsmessage.TypeA:
   283  			if a.h.HandleA != nil {
   284  				resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
   285  			}
   286  		case dnsmessage.TypeAAAA:
   287  			if a.h.HandleAAAA != nil {
   288  				resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
   289  			}
   290  		case dnsmessage.TypeSRV:
   291  			if a.h.HandleSRV != nil {
   292  				resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
   293  			}
   294  		}
   295  	}
   296  	tcpRes, err := builder.Finish()
   297  	if err != nil {
   298  		return 0, fmt.Errorf("Finish: %w", err)
   299  	}
   300  
   301  	n = len(tcpRes) - 2
   302  	tcpRes[0] = byte(n >> 8)
   303  	tcpRes[1] = byte(n)
   304  	a.rbuf.Write(tcpRes[2:])
   305  
   306  	return len(packet), nil
   307  }
   308  
   309  type someaddr struct{}
   310  
   311  func (someaddr) Network() string { return "unused" }
   312  func (someaddr) String() string  { return "unused-someaddr" }
   313  
   314  func mapRCode(err error) dnsmessage.RCode {
   315  	switch err {
   316  	case nil:
   317  		return dnsmessage.RCodeSuccess
   318  	case ErrNotExist:
   319  		return dnsmessage.RCodeNameError
   320  	case ErrRefused:
   321  		return dnsmessage.RCodeRefused
   322  	default:
   323  		return dnsmessage.RCodeServerFailure
   324  	}
   325  }
   326  

View as plain text