Source file
src/net/resolverdialfunc_test.go
1
2
3
4
5
6
7
8 package net
9
10 import (
11 "bytes"
12 "context"
13 "errors"
14 "fmt"
15 "reflect"
16 "slices"
17 "testing"
18 "time"
19
20 "golang.org/x/net/dns/dnsmessage"
21 )
22
23 func TestResolverDialFunc(t *testing.T) {
24 r := &Resolver{
25 PreferGo: true,
26 Dial: newResolverDialFunc(&resolverDialHandler{
27 StartDial: func(network, address string) error {
28 t.Logf("StartDial(%q, %q) ...", network, address)
29 return nil
30 },
31 Question: func(h dnsmessage.Header, q dnsmessage.Question) {
32 t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
33 q.Name.String(), q.Type, q.Class)
34 },
35
36
37 HandleA: func(w AWriter, name string) error {
38 w.AddIP([4]byte{1, 2, 3, 4})
39 w.AddIP([4]byte{5, 6, 7, 8})
40 return nil
41 },
42 HandleAAAA: func(w AAAAWriter, name string) error {
43 w.AddIP([16]byte{1: 1, 15: 15})
44 w.AddIP([16]byte{2: 2, 14: 14})
45 return nil
46 },
47 HandleSRV: func(w SRVWriter, name string) error {
48 w.AddSRV(1, 2, 80, "foo.bar.")
49 w.AddSRV(2, 3, 81, "bar.baz.")
50 return nil
51 },
52 }),
53 }
54 ctx := context.Background()
55 const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."
56
57 t.Run("LookupIP", func(t *testing.T) {
58 ips, err := r.LookupIP(ctx, "ip", fakeDomain)
59 if err != nil {
60 t.Fatal(err)
61 }
62 if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) {
63 t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want)
64 }
65 })
66
67 t.Run("LookupSRV", func(t *testing.T) {
68 _, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain)
69 if err != nil {
70 t.Fatal(err)
71 }
72 want := []*SRV{
73 {
74 Target: "foo.bar.",
75 Port: 80,
76 Priority: 1,
77 Weight: 2,
78 },
79 {
80 Target: "bar.baz.",
81 Port: 81,
82 Priority: 2,
83 Weight: 3,
84 },
85 }
86 if !reflect.DeepEqual(got, want) {
87 t.Errorf("wrong result. got:")
88 for _, r := range got {
89 t.Logf(" - %+v", r)
90 }
91 }
92 })
93 }
94
95 func sortedIPStrings(ips []IP) []string {
96 ret := make([]string, len(ips))
97 for i, ip := range ips {
98 ret[i] = ip.String()
99 }
100 slices.Sort(ret)
101 return ret
102 }
103
104 func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
105 return func(ctx context.Context, network, address string) (Conn, error) {
106 a := &resolverFuncConn{
107 h: h,
108 network: network,
109 address: address,
110 ttl: 10,
111 }
112 if h.StartDial != nil {
113 if err := h.StartDial(network, address); err != nil {
114 return nil, err
115 }
116 }
117 return a, nil
118 }
119 }
120
121 type resolverDialHandler struct {
122
123
124 StartDial func(network, address string) error
125
126 Question func(dnsmessage.Header, dnsmessage.Question)
127
128
129
130 HandleA func(w AWriter, name string) error
131 HandleAAAA func(w AAAAWriter, name string) error
132 HandleSRV func(w SRVWriter, name string) error
133 }
134
135 type ResponseWriter struct{ a *resolverFuncConn }
136
137 func (w ResponseWriter) header() dnsmessage.ResourceHeader {
138 q := w.a.q
139 return dnsmessage.ResourceHeader{
140 Name: q.Name,
141 Type: q.Type,
142 Class: q.Class,
143 TTL: w.a.ttl,
144 }
145 }
146
147
148
149
150
151 func (w ResponseWriter) SetTTL(seconds uint32) {
152
153
154
155
156
157
158 if w.a.wrote {
159 return
160 }
161 w.a.ttl = seconds
162
163 }
164
165 type AWriter struct{ ResponseWriter }
166
167 func (w AWriter) AddIP(v4 [4]byte) {
168 w.a.wrote = true
169 err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
170 if err != nil {
171 panic(err)
172 }
173 }
174
175 type AAAAWriter struct{ ResponseWriter }
176
177 func (w AAAAWriter) AddIP(v6 [16]byte) {
178 w.a.wrote = true
179 err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
180 if err != nil {
181 panic(err)
182 }
183 }
184
185 type SRVWriter struct{ ResponseWriter }
186
187
188
189 func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
190 targetName, err := dnsmessage.NewName(target)
191 if err != nil {
192 return err
193 }
194 w.a.wrote = true
195 err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
196 Priority: priority,
197 Weight: weight,
198 Port: port,
199 Target: targetName,
200 })
201 if err != nil {
202 panic(err)
203 }
204 return nil
205 }
206
207 var (
208 ErrNotExist = errors.New("name does not exist")
209 ErrRefused = errors.New("refused")
210 )
211
212 type resolverFuncConn struct {
213 h *resolverDialHandler
214 network string
215 address string
216 builder *dnsmessage.Builder
217 q dnsmessage.Question
218 ttl uint32
219 wrote bool
220
221 rbuf bytes.Buffer
222 }
223
224 func (*resolverFuncConn) Close() error { return nil }
225 func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} }
226 func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} }
227 func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil }
228 func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil }
229 func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }
230
231 func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
232 return a.rbuf.Read(p)
233 }
234
235 func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
236 if len(packet) < 2 {
237 return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
238 }
239 reqLen := int(packet[0])<<8 | int(packet[1])
240 req := packet[2:]
241 if len(req) != reqLen {
242 return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
243 }
244
245 var parser dnsmessage.Parser
246 h, err := parser.Start(req)
247 if err != nil {
248
249 return 0, err
250 }
251 q, err := parser.Question()
252 hadQ := (err == nil)
253 if err == nil && a.h.Question != nil {
254 a.h.Question(h, q)
255 }
256 if err != nil && err != dnsmessage.ErrSectionDone {
257 return 0, err
258 }
259
260 resh := h
261 resh.Response = true
262 resh.Authoritative = true
263 if hadQ {
264 resh.RCode = dnsmessage.RCodeSuccess
265 } else {
266 resh.RCode = dnsmessage.RCodeNotImplemented
267 }
268 a.rbuf.Grow(514)
269 a.rbuf.WriteByte('X')
270 a.rbuf.WriteByte('Y')
271 builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
272 a.builder = &builder
273 if hadQ {
274 a.q = q
275 a.builder.StartQuestions()
276 err := a.builder.Question(q)
277 if err != nil {
278 return 0, fmt.Errorf("Question: %w", err)
279 }
280 a.builder.StartAnswers()
281 switch q.Type {
282 case dnsmessage.TypeA:
283 if a.h.HandleA != nil {
284 resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
285 }
286 case dnsmessage.TypeAAAA:
287 if a.h.HandleAAAA != nil {
288 resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
289 }
290 case dnsmessage.TypeSRV:
291 if a.h.HandleSRV != nil {
292 resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
293 }
294 }
295 }
296 tcpRes, err := builder.Finish()
297 if err != nil {
298 return 0, fmt.Errorf("Finish: %w", err)
299 }
300
301 n = len(tcpRes) - 2
302 tcpRes[0] = byte(n >> 8)
303 tcpRes[1] = byte(n)
304 a.rbuf.Write(tcpRes[2:])
305
306 return len(packet), nil
307 }
308
309 type someaddr struct{}
310
311 func (someaddr) Network() string { return "unused" }
312 func (someaddr) String() string { return "unused-someaddr" }
313
314 func mapRCode(err error) dnsmessage.RCode {
315 switch err {
316 case nil:
317 return dnsmessage.RCodeSuccess
318 case ErrNotExist:
319 return dnsmessage.RCodeNameError
320 case ErrRefused:
321 return dnsmessage.RCodeRefused
322 default:
323 return dnsmessage.RCodeServerFailure
324 }
325 }
326
View as plain text