1
2
3
4
5
6
7 package analysisinternal
8
9 import (
10 "bytes"
11 "cmp"
12 "fmt"
13 "go/ast"
14 "go/printer"
15 "go/scanner"
16 "go/token"
17 "go/types"
18 "iter"
19 pathpkg "path"
20 "slices"
21 "strings"
22
23 "golang.org/x/tools/go/analysis"
24 "golang.org/x/tools/go/ast/inspector"
25 "golang.org/x/tools/internal/typesinternal"
26 )
27
28
29
30 func TypeErrorEndPos(fset *token.FileSet, src []byte, start token.Pos) token.Pos {
31
32 file := fset.File(start)
33 if file == nil {
34 return start
35 }
36 if offset := file.PositionFor(start, false).Offset; offset > len(src) {
37 return start
38 } else {
39 src = src[offset:]
40 }
41
42
43
44
45
46
47
48
49
50
51
52
53
54 end := start
55 {
56 var s scanner.Scanner
57 fset := token.NewFileSet()
58 f := fset.AddFile("", fset.Base(), len(src))
59 s.Init(f, src, nil , scanner.ScanComments)
60 pos, tok, lit := s.Scan()
61 if tok != token.SEMICOLON && token.Pos(f.Base()) <= pos && pos <= token.Pos(f.Base()+f.Size()) {
62 off := file.Offset(pos) + len(lit)
63 src = src[off:]
64 end += token.Pos(off)
65 }
66 }
67
68
69
70 if width := bytes.IndexAny(src, " \n,():;[]+-*/"); width > 0 {
71 end += token.Pos(width)
72 }
73 return end
74 }
75
76
77
78 func WalkASTWithParent(n ast.Node, f func(n ast.Node, parent ast.Node) bool) {
79 var ancestors []ast.Node
80 ast.Inspect(n, func(n ast.Node) (recurse bool) {
81 if n == nil {
82 ancestors = ancestors[:len(ancestors)-1]
83 return false
84 }
85
86 var parent ast.Node
87 if len(ancestors) > 0 {
88 parent = ancestors[len(ancestors)-1]
89 }
90 ancestors = append(ancestors, n)
91 return f(n, parent)
92 })
93 }
94
95
96
97
98
99 func MatchingIdents(typs []types.Type, node ast.Node, pos token.Pos, info *types.Info, pkg *types.Package) map[types.Type][]string {
100
101
102 matches := make(map[types.Type][]string)
103 for _, typ := range typs {
104 if typ == nil {
105 continue
106 }
107 matches[typ] = nil
108 }
109
110 seen := map[types.Object]struct{}{}
111 ast.Inspect(node, func(n ast.Node) bool {
112 if n == nil {
113 return false
114 }
115
116
117
118
119
120
121 if assign, ok := n.(*ast.AssignStmt); ok && pos > assign.Pos() && pos <= assign.End() {
122 return false
123 }
124 if n.End() > pos {
125 return n.Pos() <= pos
126 }
127 ident, ok := n.(*ast.Ident)
128 if !ok || ident.Name == "_" {
129 return true
130 }
131 obj := info.Defs[ident]
132 if obj == nil || obj.Type() == nil {
133 return true
134 }
135 if _, ok := obj.(*types.TypeName); ok {
136 return true
137 }
138
139 if _, ok = seen[obj]; ok {
140 return true
141 }
142 seen[obj] = struct{}{}
143
144
145 innerScope := pkg.Scope().Innermost(pos)
146 if innerScope == nil {
147 return true
148 }
149 _, foundObj := innerScope.LookupParent(ident.Name, pos)
150 if foundObj != obj {
151 return true
152 }
153
154
155 if names, ok := matches[obj.Type()]; ok {
156 matches[obj.Type()] = append(names, ident.Name)
157 } else {
158
159
160
161 for typ := range matches {
162 if equivalentTypes(obj.Type(), typ) {
163 matches[typ] = append(matches[typ], ident.Name)
164 }
165 }
166 }
167 return true
168 })
169 return matches
170 }
171
172 func equivalentTypes(want, got types.Type) bool {
173 if types.Identical(want, got) {
174 return true
175 }
176
177 if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
178 if lhs, ok := got.Underlying().(*types.Basic); ok {
179 return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
180 }
181 }
182 return types.AssignableTo(want, got)
183 }
184
185
186
187 type ReadFileFunc = func(filename string) ([]byte, error)
188
189
190
191 func CheckedReadFile(pass *analysis.Pass, readFile ReadFileFunc) ReadFileFunc {
192 return func(filename string) ([]byte, error) {
193 if err := CheckReadable(pass, filename); err != nil {
194 return nil, err
195 }
196 return readFile(filename)
197 }
198 }
199
200
201 func CheckReadable(pass *analysis.Pass, filename string) error {
202 if slices.Contains(pass.OtherFiles, filename) ||
203 slices.Contains(pass.IgnoredFiles, filename) {
204 return nil
205 }
206 for _, f := range pass.Files {
207 if pass.Fset.File(f.FileStart).Name() == filename {
208 return nil
209 }
210 }
211 return fmt.Errorf("Pass.ReadFile: %s is not among OtherFiles, IgnoredFiles, or names of Files", filename)
212 }
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228 func AddImport(info *types.Info, file *ast.File, preferredName, pkgpath, member string, pos token.Pos) (name, prefix string, newImport []analysis.TextEdit) {
229
230 scope := info.Scopes[file].Innermost(pos)
231 if scope == nil {
232 panic("no enclosing lexical block")
233 }
234
235
236
237 for _, spec := range file.Imports {
238 pkgname := info.PkgNameOf(spec)
239 if pkgname != nil && pkgname.Imported().Path() == pkgpath {
240 name = pkgname.Name()
241 if name == "." {
242
243 if s, _ := scope.LookupParent(member, pos); s == info.Scopes[file] {
244 return name, "", nil
245 }
246 } else if _, obj := scope.LookupParent(name, pos); obj == pkgname {
247 return name, name + ".", nil
248 }
249 }
250 }
251
252
253
254 newName := FreshName(scope, pos, preferredName)
255
256
257
258
259
260
261
262
263 newText := fmt.Sprintf("%q", pkgpath)
264 if newName != preferredName || newName != pathpkg.Base(pkgpath) {
265 newText = fmt.Sprintf("%s %q", newName, pkgpath)
266 }
267 decl0 := file.Decls[0]
268 var before ast.Node = decl0
269 switch decl0 := decl0.(type) {
270 case *ast.GenDecl:
271 if decl0.Doc != nil {
272 before = decl0.Doc
273 }
274 case *ast.FuncDecl:
275 if decl0.Doc != nil {
276 before = decl0.Doc
277 }
278 }
279
280 if gd, ok := before.(*ast.GenDecl); ok && gd.Tok == token.IMPORT && gd.Rparen.IsValid() {
281 pos = gd.Rparen
282
283
284
285 if IsStdPackage(pkgpath) && len(gd.Specs) != 0 {
286 pos = gd.Specs[0].Pos()
287 newText += "\n\t"
288 } else {
289 newText = "\t" + newText + "\n"
290 }
291 } else {
292 pos = before.Pos()
293 newText = "import " + newText + "\n\n"
294 }
295 return newName, newName + ".", []analysis.TextEdit{{
296 Pos: pos,
297 End: pos,
298 NewText: []byte(newText),
299 }}
300 }
301
302
303
304 func FreshName(scope *types.Scope, pos token.Pos, preferred string) string {
305 newName := preferred
306 for i := 0; ; i++ {
307 if _, obj := scope.LookupParent(newName, pos); obj == nil {
308 break
309 }
310 newName = fmt.Sprintf("%s%d", preferred, i)
311 }
312 return newName
313 }
314
315
316 func Format(fset *token.FileSet, n ast.Node) string {
317 var buf strings.Builder
318 printer.Fprint(&buf, fset, n)
319 return buf.String()
320 }
321
322
323 func Imports(pkg *types.Package, path string) bool {
324 for _, imp := range pkg.Imports() {
325 if imp.Path() == path {
326 return true
327 }
328 }
329 return false
330 }
331
332
333
334
335
336
337
338 func IsTypeNamed(t types.Type, pkgPath string, names ...string) bool {
339 if named, ok := types.Unalias(t).(*types.Named); ok {
340 tname := named.Obj()
341 return tname != nil &&
342 typesinternal.IsPackageLevel(tname) &&
343 tname.Pkg().Path() == pkgPath &&
344 slices.Contains(names, tname.Name())
345 }
346 return false
347 }
348
349
350
351
352 func IsPointerToNamed(t types.Type, pkgPath string, names ...string) bool {
353 r := typesinternal.Unpointer(t)
354 if r == t {
355 return false
356 }
357 return IsTypeNamed(r, pkgPath, names...)
358 }
359
360
361
362
363
364
365
366 func IsFunctionNamed(obj types.Object, pkgPath string, names ...string) bool {
367 f, ok := obj.(*types.Func)
368 return ok &&
369 typesinternal.IsPackageLevel(obj) &&
370 f.Pkg().Path() == pkgPath &&
371 f.Type().(*types.Signature).Recv() == nil &&
372 slices.Contains(names, f.Name())
373 }
374
375
376
377
378
379
380
381 func IsMethodNamed(obj types.Object, pkgPath string, typeName string, names ...string) bool {
382 if fn, ok := obj.(*types.Func); ok {
383 if recv := fn.Type().(*types.Signature).Recv(); recv != nil {
384 _, T := typesinternal.ReceiverNamed(recv)
385 return T != nil &&
386 IsTypeNamed(T, pkgPath, typeName) &&
387 slices.Contains(names, fn.Name())
388 }
389 }
390 return false
391 }
392
393
394
395
396
397
398
399 func ValidateFixes(fset *token.FileSet, a *analysis.Analyzer, fixes []analysis.SuggestedFix) error {
400 fixMessages := make(map[string]bool)
401 for i := range fixes {
402 fix := &fixes[i]
403 if fixMessages[fix.Message] {
404 return fmt.Errorf("analyzer %q suggests two fixes with same Message (%s)", a.Name, fix.Message)
405 }
406 fixMessages[fix.Message] = true
407 if err := validateFix(fset, fix); err != nil {
408 return fmt.Errorf("analyzer %q suggests invalid fix (%s): %v", a.Name, fix.Message, err)
409 }
410 }
411 return nil
412 }
413
414
415
416
417
418 func validateFix(fset *token.FileSet, fix *analysis.SuggestedFix) error {
419
420
421
422
423
424 slices.SortStableFunc(fix.TextEdits, func(x, y analysis.TextEdit) int {
425 if sign := cmp.Compare(x.Pos, y.Pos); sign != 0 {
426 return sign
427 }
428 return cmp.Compare(x.End, y.End)
429 })
430
431 var prev *analysis.TextEdit
432 for i := range fix.TextEdits {
433 edit := &fix.TextEdits[i]
434
435
436 start := edit.Pos
437 file := fset.File(start)
438 if file == nil {
439 return fmt.Errorf("no token.File for TextEdit.Pos (%v)", edit.Pos)
440 }
441 fileEnd := token.Pos(file.Base() + file.Size())
442 if end := edit.End; end.IsValid() {
443 if end < start {
444 return fmt.Errorf("TextEdit.Pos (%v) > TextEdit.End (%v)", edit.Pos, edit.End)
445 }
446 endFile := fset.File(end)
447 if endFile != file && end < fileEnd+10 {
448
449
450
451
452
453
454
455
456
457 edit.End = fileEnd
458 continue
459 }
460 if endFile == nil {
461 return fmt.Errorf("no token.File for TextEdit.End (%v; File(start).FileEnd is %d)", end, file.Base()+file.Size())
462 }
463 if endFile != file {
464 return fmt.Errorf("edit #%d spans files (%v and %v)",
465 i, file.Position(edit.Pos), endFile.Position(edit.End))
466 }
467 } else {
468 edit.End = start
469 }
470 if eof := fileEnd; edit.End > eof {
471 return fmt.Errorf("end is (%v) beyond end of file (%v)", edit.End, eof)
472 }
473
474
475
476 if prev != nil && edit.Pos < prev.End {
477 xpos := fset.Position(prev.Pos)
478 xend := fset.Position(prev.End)
479 ypos := fset.Position(edit.Pos)
480 yend := fset.Position(edit.End)
481 return fmt.Errorf("overlapping edits to %s (%d:%d-%d:%d and %d:%d-%d:%d)",
482 xpos.Filename,
483 xpos.Line, xpos.Column,
484 xend.Line, xend.Column,
485 ypos.Line, ypos.Column,
486 yend.Line, yend.Column,
487 )
488 }
489 prev = edit
490 }
491
492 return nil
493 }
494
495
496
497
498
499 func CanImport(from, to string) bool {
500
501 if to == "internal" || strings.HasPrefix(to, "internal/") {
502
503
504
505 first, _, _ := strings.Cut(from, "/")
506 if strings.Contains(first, ".") {
507 return false
508 }
509 if first == "testdata" {
510 return false
511 }
512 }
513 if strings.HasSuffix(to, "/internal") {
514 return strings.HasPrefix(from, to[:len(to)-len("/internal")])
515 }
516 if i := strings.LastIndex(to, "/internal/"); i >= 0 {
517 return strings.HasPrefix(from, to[:i])
518 }
519 return true
520 }
521
522
523
524
525 func DeleteStmt(fset *token.FileSet, astFile *ast.File, stmt ast.Stmt, report func(string, ...any)) []analysis.TextEdit {
526
527 insp := inspector.New([]*ast.File{astFile})
528 root := insp.Root()
529 cstmt, ok := root.FindNode(stmt)
530 if !ok {
531 report("%s not found in file", stmt.Pos())
532 return nil
533 }
534
535 if !stmt.Pos().IsValid() || !stmt.End().IsValid() {
536 report("%s: stmt has invalid position", stmt.Pos())
537 return nil
538 }
539
540
541
542
543
544
545
546 tokFile := fset.File(stmt.Pos())
547 lineOf := tokFile.Line
548 stmtStartLine, stmtEndLine := lineOf(stmt.Pos()), lineOf(stmt.End())
549
550 var from, to token.Pos
551
552 limits := func(left, right token.Pos) {
553 if lineOf(left) == stmtStartLine {
554 from = left
555 }
556 if lineOf(right) == stmtEndLine {
557 to = right
558 }
559 }
560
561
562
563
564
565 switch parent := cstmt.Parent().Node().(type) {
566 case *ast.SwitchStmt:
567 limits(parent.Switch, parent.Body.Lbrace)
568 case *ast.TypeSwitchStmt:
569 limits(parent.Switch, parent.Body.Lbrace)
570 if parent.Assign == stmt {
571 return nil
572 }
573 case *ast.BlockStmt:
574 limits(parent.Lbrace, parent.Rbrace)
575 case *ast.CommClause:
576 limits(parent.Colon, cstmt.Parent().Parent().Node().(*ast.BlockStmt).Rbrace)
577 if parent.Comm == stmt {
578 return nil
579 }
580 case *ast.CaseClause:
581 limits(parent.Colon, cstmt.Parent().Parent().Node().(*ast.BlockStmt).Rbrace)
582 case *ast.ForStmt:
583 limits(parent.For, parent.Body.Lbrace)
584
585 default:
586 return nil
587 }
588
589 if prev, found := cstmt.PrevSibling(); found && lineOf(prev.Node().End()) == stmtStartLine {
590 from = prev.Node().End()
591 }
592 if next, found := cstmt.NextSibling(); found && lineOf(next.Node().Pos()) == stmtEndLine {
593 to = next.Node().Pos()
594 }
595
596 Outer:
597 for _, cg := range astFile.Comments {
598 for _, co := range cg.List {
599 if lineOf(co.End()) < stmtStartLine {
600 continue
601 } else if lineOf(co.Pos()) > stmtEndLine {
602 break Outer
603 }
604 if lineOf(co.End()) == stmtStartLine && co.End() < stmt.Pos() {
605 if !from.IsValid() || co.End() > from {
606 from = co.End()
607 continue
608 }
609 }
610 if lineOf(co.Pos()) == stmtEndLine && co.Pos() > stmt.End() {
611 if !to.IsValid() || co.Pos() < to {
612 to = co.Pos()
613 continue
614 }
615 }
616 }
617 }
618
619
620 edit := analysis.TextEdit{Pos: stmt.Pos(), End: stmt.End()}
621 if from.IsValid() || to.IsValid() {
622
623
624
625
626
627
628 return []analysis.TextEdit{edit}
629 }
630
631 for lineOf(edit.Pos) == stmtStartLine {
632 edit.Pos--
633 }
634 edit.Pos++
635 for lineOf(edit.End) == stmtEndLine {
636 edit.End++
637 }
638 return []analysis.TextEdit{edit}
639 }
640
641
642 func Comments(file *ast.File, start, end token.Pos) iter.Seq[*ast.Comment] {
643
644 return func(yield func(*ast.Comment) bool) {
645 for _, cg := range file.Comments {
646 for _, co := range cg.List {
647 if co.Pos() > end {
648 return
649 }
650 if co.End() < start {
651 continue
652 }
653
654 if !yield(co) {
655 return
656 }
657 }
658 }
659 }
660 }
661
662
663
664 func IsStdPackage(path string) bool {
665
666
667 slash := strings.IndexByte(path, '/')
668 if slash < 0 {
669 slash = len(path)
670 }
671 return !strings.Contains(path[:slash], ".") && path != "testdata"
672 }
673
674
675 func Range(pos, end token.Pos) analysis.Range {
676 return tokenRange{pos, end}
677 }
678
679
680 type tokenRange struct{ StartPos, EndPos token.Pos }
681
682 func (r tokenRange) Pos() token.Pos { return r.StartPos }
683 func (r tokenRange) End() token.Pos { return r.EndPos }
684
View as plain text