1
2
3
4
5 package modernize
6
7 import (
8 "bytes"
9 "fmt"
10 "go/ast"
11 "go/printer"
12 "slices"
13
14 "golang.org/x/tools/go/analysis"
15 "golang.org/x/tools/go/analysis/passes/inspect"
16 "golang.org/x/tools/go/types/typeutil"
17 "golang.org/x/tools/internal/analysis/analyzerutil"
18 typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
19 "golang.org/x/tools/internal/astutil"
20 "golang.org/x/tools/internal/refactor"
21 "golang.org/x/tools/internal/typesinternal/typeindex"
22 "golang.org/x/tools/internal/versions"
23 )
24
25 var WaitGroupAnalyzer = &analysis.Analyzer{
26 Name: "waitgroup",
27 Doc: analyzerutil.MustExtractDoc(doc, "waitgroup"),
28 Requires: []*analysis.Analyzer{
29 inspect.Analyzer,
30 typeindexanalyzer.Analyzer,
31 },
32 Run: waitgroup,
33 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#waitgroup",
34 }
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62 func waitgroup(pass *analysis.Pass) (any, error) {
63 var (
64 index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
65 info = pass.TypesInfo
66 syncWaitGroupAdd = index.Selection("sync", "WaitGroup", "Add")
67 syncWaitGroupDone = index.Selection("sync", "WaitGroup", "Done")
68 )
69 if !index.Used(syncWaitGroupDone) {
70 return nil, nil
71 }
72
73 for curAddCall := range index.Calls(syncWaitGroupAdd) {
74
75 addCall := curAddCall.Node().(*ast.CallExpr)
76 if !isIntLiteral(info, addCall.Args[0], 1) {
77 continue
78 }
79
80
81 addCallRecv := ast.Unparen(addCall.Fun).(*ast.SelectorExpr).X
82
83
84 curAddStmt := curAddCall.Parent()
85 if !is[*ast.ExprStmt](curAddStmt.Node()) {
86 continue
87 }
88 curNext, ok := curAddCall.Parent().NextSibling()
89 if !ok {
90 continue
91 }
92 goStmt, ok := curNext.Node().(*ast.GoStmt)
93 if !ok {
94 continue
95 }
96 lit, ok := goStmt.Call.Fun.(*ast.FuncLit)
97 if !ok || len(goStmt.Call.Args) != 0 {
98 continue
99 }
100 if lit.Type.Results != nil && len(lit.Type.Results.List) > 0 {
101 continue
102 }
103 list := lit.Body.List
104 if len(list) == 0 {
105 continue
106 }
107
108
109 var doneStmt ast.Stmt
110 if deferStmt, ok := list[0].(*ast.DeferStmt); ok &&
111 typeutil.Callee(info, deferStmt.Call) == syncWaitGroupDone &&
112 astutil.EqualSyntax(ast.Unparen(deferStmt.Call.Fun).(*ast.SelectorExpr).X, addCallRecv) {
113 doneStmt = deferStmt
114
115 } else if lastStmt, ok := list[len(list)-1].(*ast.ExprStmt); ok {
116 if doneCall, ok := lastStmt.X.(*ast.CallExpr); ok &&
117 typeutil.Callee(info, doneCall) == syncWaitGroupDone &&
118 astutil.EqualSyntax(ast.Unparen(doneCall.Fun).(*ast.SelectorExpr).X, addCallRecv) {
119 doneStmt = lastStmt
120 }
121 }
122 if doneStmt == nil {
123 continue
124 }
125 curDoneStmt, ok := curNext.FindNode(doneStmt)
126 if !ok {
127 panic("can't find Cursor for 'done' statement")
128 }
129
130 file := astutil.EnclosingFile(curAddCall)
131 if !analyzerutil.FileUsesGoVersion(pass, file, versions.Go1_25) {
132 continue
133 }
134 tokFile := pass.Fset.File(file.Pos())
135
136 var addCallRecvText bytes.Buffer
137 err := printer.Fprint(&addCallRecvText, pass.Fset, addCallRecv)
138 if err != nil {
139 continue
140 }
141
142 pass.Report(analysis.Diagnostic{
143
144
145 Pos: goStmt.Pos(),
146 End: lit.Type.End(),
147 Message: "Goroutine creation can be simplified using WaitGroup.Go",
148 SuggestedFixes: []analysis.SuggestedFix{{
149 Message: "Simplify by using WaitGroup.Go",
150 TextEdits: slices.Concat(
151
152 refactor.DeleteStmt(tokFile, curAddStmt),
153
154 refactor.DeleteStmt(tokFile, curDoneStmt),
155 []analysis.TextEdit{
156
157
158
159 {
160 Pos: goStmt.Pos(),
161 End: goStmt.Call.Pos(),
162 NewText: fmt.Appendf(nil, "%s.Go(", addCallRecvText.String()),
163 },
164
165
166
167 {
168 Pos: goStmt.Call.Lparen,
169 End: goStmt.Call.Rparen,
170 },
171 },
172 ),
173 }},
174 })
175 }
176 return nil, nil
177 }
178
View as plain text