1
2
3
4
5 package astutil
6
7
8
9 import (
10 "fmt"
11 "go/ast"
12 "go/token"
13 "sort"
14 )
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60 func PathEnclosingInterval(root *ast.File, start, end token.Pos) (path []ast.Node, exact bool) {
61
62
63
64 var visit func(node ast.Node) bool
65 visit = func(node ast.Node) bool {
66 path = append(path, node)
67
68 nodePos := node.Pos()
69 nodeEnd := node.End()
70
71
72
73
74 if start < nodePos {
75 start = nodePos
76 }
77 if end > nodeEnd {
78 end = nodeEnd
79 }
80
81
82 children := childrenOf(node)
83 l := len(children)
84 for i, child := range children {
85
86 childPos := child.Pos()
87 childEnd := child.End()
88
89
90 augPos := childPos
91 augEnd := childEnd
92 if i > 0 {
93 augPos = children[i-1].End()
94 }
95 if i < l-1 {
96 nextChildPos := children[i+1].Pos()
97
98 if start >= augEnd && end <= nextChildPos {
99 return false
100 }
101 augEnd = nextChildPos
102 }
103
104
105
106
107
108 if augPos <= start && end <= augEnd {
109 if is[tokenNode](child) {
110 return true
111 }
112
113
114
115
116
117 if decl, ok := node.(*ast.FuncDecl); ok {
118 if fields, ok := child.(*ast.FieldList); ok && fields != decl.Recv {
119 path = append(path, decl.Type)
120 }
121 }
122
123 return visit(child)
124 }
125
126
127
128
129 if start < childEnd && end > augEnd {
130 break
131 }
132 }
133
134
135
136
137
138
139
140
141 if start == nodePos && end == nodeEnd {
142 return true
143 }
144
145 return false
146 }
147
148
149 if start > end {
150 start, end = end, start
151 }
152
153 if start < root.End() && end > root.Pos() {
154 if start == end {
155 end = start + 1
156 }
157 exact = visit(root)
158
159
160 for i, l := 0, len(path); i < l/2; i++ {
161 path[i], path[l-1-i] = path[l-1-i], path[i]
162 }
163 } else {
164
165
166
167 path = append(path, root)
168 }
169
170 return
171 }
172
173
174
175
176 type tokenNode struct {
177 pos token.Pos
178 end token.Pos
179 }
180
181 func (n tokenNode) Pos() token.Pos {
182 return n.pos
183 }
184
185 func (n tokenNode) End() token.Pos {
186 return n.end
187 }
188
189 func tok(pos token.Pos, len int) ast.Node {
190 return tokenNode{pos, pos + token.Pos(len)}
191 }
192
193
194
195
196 func childrenOf(n ast.Node) []ast.Node {
197 var children []ast.Node
198
199
200 ast.Inspect(n, func(node ast.Node) bool {
201 if node == n {
202 return true
203 }
204 if node != nil {
205 children = append(children, node)
206 }
207 return false
208 })
209
210
211 switch n := n.(type) {
212 case *ast.ArrayType:
213 children = append(children,
214 tok(n.Lbrack, len("[")),
215 tok(n.Elt.End(), len("]")))
216
217 case *ast.AssignStmt:
218 children = append(children,
219 tok(n.TokPos, len(n.Tok.String())))
220
221 case *ast.BasicLit:
222 children = append(children,
223 tok(n.ValuePos, len(n.Value)))
224
225 case *ast.BinaryExpr:
226 children = append(children, tok(n.OpPos, len(n.Op.String())))
227
228 case *ast.BlockStmt:
229 children = append(children,
230 tok(n.Lbrace, len("{")),
231 tok(n.Rbrace, len("}")))
232
233 case *ast.BranchStmt:
234 children = append(children,
235 tok(n.TokPos, len(n.Tok.String())))
236
237 case *ast.CallExpr:
238 children = append(children,
239 tok(n.Lparen, len("(")),
240 tok(n.Rparen, len(")")))
241 if n.Ellipsis != 0 {
242 children = append(children, tok(n.Ellipsis, len("...")))
243 }
244
245 case *ast.CaseClause:
246 if n.List == nil {
247 children = append(children,
248 tok(n.Case, len("default")))
249 } else {
250 children = append(children,
251 tok(n.Case, len("case")))
252 }
253 children = append(children, tok(n.Colon, len(":")))
254
255 case *ast.ChanType:
256 switch n.Dir {
257 case ast.RECV:
258 children = append(children, tok(n.Begin, len("<-chan")))
259 case ast.SEND:
260 children = append(children, tok(n.Begin, len("chan<-")))
261 case ast.RECV | ast.SEND:
262 children = append(children, tok(n.Begin, len("chan")))
263 }
264
265 case *ast.CommClause:
266 if n.Comm == nil {
267 children = append(children,
268 tok(n.Case, len("default")))
269 } else {
270 children = append(children,
271 tok(n.Case, len("case")))
272 }
273 children = append(children, tok(n.Colon, len(":")))
274
275 case *ast.Comment:
276
277
278 case *ast.CommentGroup:
279
280
281 case *ast.CompositeLit:
282 children = append(children,
283 tok(n.Lbrace, len("{")),
284 tok(n.Rbrace, len("{")))
285
286 case *ast.DeclStmt:
287
288
289 case *ast.DeferStmt:
290 children = append(children,
291 tok(n.Defer, len("defer")))
292
293 case *ast.Ellipsis:
294 children = append(children,
295 tok(n.Ellipsis, len("...")))
296
297 case *ast.EmptyStmt:
298
299
300 case *ast.ExprStmt:
301
302
303 case *ast.Field:
304
305
306 case *ast.FieldList:
307 children = append(children,
308 tok(n.Opening, len("(")),
309 tok(n.Closing, len(")")))
310
311 case *ast.File:
312
313 children = append(children,
314 tok(n.Package, len("package")))
315
316 case *ast.ForStmt:
317 children = append(children,
318 tok(n.For, len("for")))
319
320 case *ast.FuncDecl:
321
322
323
324
325
326
327
328
329
330
331
332 children = nil
333 children = append(children, tok(n.Type.Func, len("func")))
334 if n.Recv != nil {
335 children = append(children, n.Recv)
336 }
337 children = append(children, n.Name)
338 if tparams := n.Type.TypeParams; tparams != nil {
339 children = append(children, tparams)
340 }
341 if n.Type.Params != nil {
342 children = append(children, n.Type.Params)
343 }
344 if n.Type.Results != nil {
345 children = append(children, n.Type.Results)
346 }
347 if n.Body != nil {
348 children = append(children, n.Body)
349 }
350
351 case *ast.FuncLit:
352
353
354 case *ast.FuncType:
355 if n.Func != 0 {
356 children = append(children,
357 tok(n.Func, len("func")))
358 }
359
360 case *ast.GenDecl:
361 children = append(children,
362 tok(n.TokPos, len(n.Tok.String())))
363 if n.Lparen != 0 {
364 children = append(children,
365 tok(n.Lparen, len("(")),
366 tok(n.Rparen, len(")")))
367 }
368
369 case *ast.GoStmt:
370 children = append(children,
371 tok(n.Go, len("go")))
372
373 case *ast.Ident:
374 children = append(children,
375 tok(n.NamePos, len(n.Name)))
376
377 case *ast.IfStmt:
378 children = append(children,
379 tok(n.If, len("if")))
380
381 case *ast.ImportSpec:
382
383
384 case *ast.IncDecStmt:
385 children = append(children,
386 tok(n.TokPos, len(n.Tok.String())))
387
388 case *ast.IndexExpr:
389 children = append(children,
390 tok(n.Lbrack, len("[")),
391 tok(n.Rbrack, len("]")))
392
393 case *ast.IndexListExpr:
394 children = append(children,
395 tok(n.Lbrack, len("[")),
396 tok(n.Rbrack, len("]")))
397
398 case *ast.InterfaceType:
399 children = append(children,
400 tok(n.Interface, len("interface")))
401
402 case *ast.KeyValueExpr:
403 children = append(children,
404 tok(n.Colon, len(":")))
405
406 case *ast.LabeledStmt:
407 children = append(children,
408 tok(n.Colon, len(":")))
409
410 case *ast.MapType:
411 children = append(children,
412 tok(n.Map, len("map")))
413
414 case *ast.ParenExpr:
415 children = append(children,
416 tok(n.Lparen, len("(")),
417 tok(n.Rparen, len(")")))
418
419 case *ast.RangeStmt:
420 children = append(children,
421 tok(n.For, len("for")),
422 tok(n.TokPos, len(n.Tok.String())))
423
424 case *ast.ReturnStmt:
425 children = append(children,
426 tok(n.Return, len("return")))
427
428 case *ast.SelectStmt:
429 children = append(children,
430 tok(n.Select, len("select")))
431
432 case *ast.SelectorExpr:
433
434
435 case *ast.SendStmt:
436 children = append(children,
437 tok(n.Arrow, len("<-")))
438
439 case *ast.SliceExpr:
440 children = append(children,
441 tok(n.Lbrack, len("[")),
442 tok(n.Rbrack, len("]")))
443
444 case *ast.StarExpr:
445 children = append(children, tok(n.Star, len("*")))
446
447 case *ast.StructType:
448 children = append(children, tok(n.Struct, len("struct")))
449
450 case *ast.SwitchStmt:
451 children = append(children, tok(n.Switch, len("switch")))
452
453 case *ast.TypeAssertExpr:
454 children = append(children,
455 tok(n.Lparen-1, len(".")),
456 tok(n.Lparen, len("(")),
457 tok(n.Rparen, len(")")))
458
459 case *ast.TypeSpec:
460
461
462 case *ast.TypeSwitchStmt:
463 children = append(children, tok(n.Switch, len("switch")))
464
465 case *ast.UnaryExpr:
466 children = append(children, tok(n.OpPos, len(n.Op.String())))
467
468 case *ast.ValueSpec:
469
470
471 case *ast.BadDecl, *ast.BadExpr, *ast.BadStmt:
472
473 }
474
475
476
477
478
479 sort.Sort(byPos(children))
480
481 return children
482 }
483
484 type byPos []ast.Node
485
486 func (sl byPos) Len() int {
487 return len(sl)
488 }
489 func (sl byPos) Less(i, j int) bool {
490 return sl[i].Pos() < sl[j].Pos()
491 }
492 func (sl byPos) Swap(i, j int) {
493 sl[i], sl[j] = sl[j], sl[i]
494 }
495
496
497
498
499
500
501
502 func NodeDescription(n ast.Node) string {
503 switch n := n.(type) {
504 case *ast.ArrayType:
505 return "array type"
506 case *ast.AssignStmt:
507 return "assignment"
508 case *ast.BadDecl:
509 return "bad declaration"
510 case *ast.BadExpr:
511 return "bad expression"
512 case *ast.BadStmt:
513 return "bad statement"
514 case *ast.BasicLit:
515 return "basic literal"
516 case *ast.BinaryExpr:
517 return fmt.Sprintf("binary %s operation", n.Op)
518 case *ast.BlockStmt:
519 return "block"
520 case *ast.BranchStmt:
521 switch n.Tok {
522 case token.BREAK:
523 return "break statement"
524 case token.CONTINUE:
525 return "continue statement"
526 case token.GOTO:
527 return "goto statement"
528 case token.FALLTHROUGH:
529 return "fall-through statement"
530 }
531 case *ast.CallExpr:
532 if len(n.Args) == 1 && !n.Ellipsis.IsValid() {
533 return "function call (or conversion)"
534 }
535 return "function call"
536 case *ast.CaseClause:
537 return "case clause"
538 case *ast.ChanType:
539 return "channel type"
540 case *ast.CommClause:
541 return "communication clause"
542 case *ast.Comment:
543 return "comment"
544 case *ast.CommentGroup:
545 return "comment group"
546 case *ast.CompositeLit:
547 return "composite literal"
548 case *ast.DeclStmt:
549 return NodeDescription(n.Decl) + " statement"
550 case *ast.DeferStmt:
551 return "defer statement"
552 case *ast.Ellipsis:
553 return "ellipsis"
554 case *ast.EmptyStmt:
555 return "empty statement"
556 case *ast.ExprStmt:
557 return "expression statement"
558 case *ast.Field:
559
560
561
562
563
564
565 return "field/method/parameter"
566 case *ast.FieldList:
567 return "field/method/parameter list"
568 case *ast.File:
569 return "source file"
570 case *ast.ForStmt:
571 return "for loop"
572 case *ast.FuncDecl:
573 return "function declaration"
574 case *ast.FuncLit:
575 return "function literal"
576 case *ast.FuncType:
577 return "function type"
578 case *ast.GenDecl:
579 switch n.Tok {
580 case token.IMPORT:
581 return "import declaration"
582 case token.CONST:
583 return "constant declaration"
584 case token.TYPE:
585 return "type declaration"
586 case token.VAR:
587 return "variable declaration"
588 }
589 case *ast.GoStmt:
590 return "go statement"
591 case *ast.Ident:
592 return "identifier"
593 case *ast.IfStmt:
594 return "if statement"
595 case *ast.ImportSpec:
596 return "import specification"
597 case *ast.IncDecStmt:
598 if n.Tok == token.INC {
599 return "increment statement"
600 }
601 return "decrement statement"
602 case *ast.IndexExpr:
603 return "index expression"
604 case *ast.IndexListExpr:
605 return "index list expression"
606 case *ast.InterfaceType:
607 return "interface type"
608 case *ast.KeyValueExpr:
609 return "key/value association"
610 case *ast.LabeledStmt:
611 return "statement label"
612 case *ast.MapType:
613 return "map type"
614 case *ast.Package:
615 return "package"
616 case *ast.ParenExpr:
617 return "parenthesized " + NodeDescription(n.X)
618 case *ast.RangeStmt:
619 return "range loop"
620 case *ast.ReturnStmt:
621 return "return statement"
622 case *ast.SelectStmt:
623 return "select statement"
624 case *ast.SelectorExpr:
625 return "selector"
626 case *ast.SendStmt:
627 return "channel send"
628 case *ast.SliceExpr:
629 return "slice expression"
630 case *ast.StarExpr:
631 return "*-operation"
632 case *ast.StructType:
633 return "struct type"
634 case *ast.SwitchStmt:
635 return "switch statement"
636 case *ast.TypeAssertExpr:
637 return "type assertion"
638 case *ast.TypeSpec:
639 return "type specification"
640 case *ast.TypeSwitchStmt:
641 return "type switch"
642 case *ast.UnaryExpr:
643 return fmt.Sprintf("unary %s operation", n.Op)
644 case *ast.ValueSpec:
645 return "value specification"
646
647 }
648 panic(fmt.Sprintf("unexpected node type: %T", n))
649 }
650
651 func is[T any](x any) bool {
652 _, ok := x.(T)
653 return ok
654 }
655
View as plain text