1
2
3
4
5 package lostcancel
6
7 import (
8 _ "embed"
9 "fmt"
10 "go/ast"
11 "go/types"
12
13 "golang.org/x/tools/go/analysis"
14 "golang.org/x/tools/go/analysis/passes/ctrlflow"
15 "golang.org/x/tools/go/analysis/passes/inspect"
16 "golang.org/x/tools/go/analysis/passes/internal/analysisutil"
17 "golang.org/x/tools/go/ast/inspector"
18 "golang.org/x/tools/go/cfg"
19 )
20
21
22 var doc string
23
24 var Analyzer = &analysis.Analyzer{
25 Name: "lostcancel",
26 Doc: analysisutil.MustExtractDoc(doc, "lostcancel"),
27 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/lostcancel",
28 Run: run,
29 Requires: []*analysis.Analyzer{
30 inspect.Analyzer,
31 ctrlflow.Analyzer,
32 },
33 }
34
35 const debug = false
36
37 var contextPackage = "context"
38
39
40
41
42
43
44
45
46
47
48
49 func run(pass *analysis.Pass) (interface{}, error) {
50
51 if !analysisutil.Imports(pass.Pkg, contextPackage) {
52 return nil, nil
53 }
54
55
56 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
57 nodeTypes := []ast.Node{
58 (*ast.FuncLit)(nil),
59 (*ast.FuncDecl)(nil),
60 }
61 inspect.Preorder(nodeTypes, func(n ast.Node) {
62 runFunc(pass, n)
63 })
64 return nil, nil
65 }
66
67 func runFunc(pass *analysis.Pass, node ast.Node) {
68
69 var funcScope *types.Scope
70 switch v := node.(type) {
71 case *ast.FuncLit:
72 funcScope = pass.TypesInfo.Scopes[v.Type]
73 case *ast.FuncDecl:
74 funcScope = pass.TypesInfo.Scopes[v.Type]
75 }
76
77
78 cancelvars := make(map[*types.Var]ast.Node)
79
80
81
82
83
84
85 stack := make([]ast.Node, 0, 32)
86 ast.Inspect(node, func(n ast.Node) bool {
87 switch n.(type) {
88 case *ast.FuncLit:
89 if len(stack) > 0 {
90 return false
91 }
92 case nil:
93 stack = stack[:len(stack)-1]
94 return true
95 }
96 stack = append(stack, n)
97
98
99
100
101
102
103
104 if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-2]) {
105 return true
106 }
107 var id *ast.Ident
108 stmt := stack[len(stack)-3]
109 switch stmt := stmt.(type) {
110 case *ast.ValueSpec:
111 if len(stmt.Names) > 1 {
112 id = stmt.Names[1]
113 }
114 case *ast.AssignStmt:
115 if len(stmt.Lhs) > 1 {
116 id, _ = stmt.Lhs[1].(*ast.Ident)
117 }
118 }
119 if id != nil {
120 if id.Name == "_" {
121 pass.ReportRangef(id,
122 "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
123 n.(*ast.SelectorExpr).Sel.Name)
124 } else if v, ok := pass.TypesInfo.Uses[id].(*types.Var); ok {
125
126
127 if funcScope.Contains(v.Pos()) {
128 cancelvars[v] = stmt
129 }
130 } else if v, ok := pass.TypesInfo.Defs[id].(*types.Var); ok {
131 cancelvars[v] = stmt
132 }
133 }
134 return true
135 })
136
137 if len(cancelvars) == 0 {
138 return
139 }
140
141
142 cfgs := pass.ResultOf[ctrlflow.Analyzer].(*ctrlflow.CFGs)
143 var g *cfg.CFG
144 var sig *types.Signature
145 switch node := node.(type) {
146 case *ast.FuncDecl:
147 sig, _ = pass.TypesInfo.Defs[node.Name].Type().(*types.Signature)
148 if node.Name.Name == "main" && sig.Recv() == nil && pass.Pkg.Name() == "main" {
149
150
151 return
152 }
153 g = cfgs.FuncDecl(node)
154
155 case *ast.FuncLit:
156 sig, _ = pass.TypesInfo.Types[node.Type].Type.(*types.Signature)
157 g = cfgs.FuncLit(node)
158 }
159 if sig == nil {
160 return
161 }
162
163
164 if debug {
165 fmt.Println(g.Format(pass.Fset))
166 }
167
168
169
170
171 for v, stmt := range cancelvars {
172 if ret := lostCancelPath(pass, g, v, stmt, sig); ret != nil {
173 lineno := pass.Fset.Position(stmt.Pos()).Line
174 pass.ReportRangef(stmt, "the %s function is not used on all paths (possible context leak)", v.Name())
175
176 pos, end := ret.Pos(), ret.End()
177
178
179 if pass.Fset.File(pos) != pass.Fset.File(end) {
180 end = pos
181 }
182 pass.Report(analysis.Diagnostic{
183 Pos: pos,
184 End: end,
185 Message: fmt.Sprintf("this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno),
186 })
187 }
188 }
189 }
190
191 func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
192
193
194
195 func isContextWithCancel(info *types.Info, n ast.Node) bool {
196 sel, ok := n.(*ast.SelectorExpr)
197 if !ok {
198 return false
199 }
200 switch sel.Sel.Name {
201 case "WithCancel", "WithTimeout", "WithDeadline":
202 default:
203 return false
204 }
205 if x, ok := sel.X.(*ast.Ident); ok {
206 if pkgname, ok := info.Uses[x].(*types.PkgName); ok {
207 return pkgname.Imported().Path() == contextPackage
208 }
209
210
211 return x.Name == "context"
212 }
213 return false
214 }
215
216
217
218
219
220 func lostCancelPath(pass *analysis.Pass, g *cfg.CFG, v *types.Var, stmt ast.Node, sig *types.Signature) *ast.ReturnStmt {
221 vIsNamedResult := sig != nil && tupleContains(sig.Results(), v)
222
223
224 uses := func(pass *analysis.Pass, v *types.Var, stmts []ast.Node) bool {
225 found := false
226 for _, stmt := range stmts {
227 ast.Inspect(stmt, func(n ast.Node) bool {
228 switch n := n.(type) {
229 case *ast.Ident:
230 if pass.TypesInfo.Uses[n] == v {
231 found = true
232 }
233 case *ast.ReturnStmt:
234
235
236 if n.Results == nil && vIsNamedResult {
237 found = true
238 }
239 }
240 return !found
241 })
242 }
243 return found
244 }
245
246
247 memo := make(map[*cfg.Block]bool)
248 blockUses := func(pass *analysis.Pass, v *types.Var, b *cfg.Block) bool {
249 res, ok := memo[b]
250 if !ok {
251 res = uses(pass, v, b.Nodes)
252 memo[b] = res
253 }
254 return res
255 }
256
257
258
259 var defblock *cfg.Block
260 var rest []ast.Node
261 outer:
262 for _, b := range g.Blocks {
263 for i, n := range b.Nodes {
264 if n == stmt {
265 defblock = b
266 rest = b.Nodes[i+1:]
267 break outer
268 }
269 }
270 }
271 if defblock == nil {
272 panic("internal error: can't find defining block for cancel var")
273 }
274
275
276 if uses(pass, v, rest) {
277 return nil
278 }
279
280
281 if ret := defblock.Return(); ret != nil {
282 return ret
283 }
284
285
286
287 seen := make(map[*cfg.Block]bool)
288 var search func(blocks []*cfg.Block) *ast.ReturnStmt
289 search = func(blocks []*cfg.Block) *ast.ReturnStmt {
290 for _, b := range blocks {
291 if seen[b] {
292 continue
293 }
294 seen[b] = true
295
296
297 if blockUses(pass, v, b) {
298 continue
299 }
300
301
302 if ret := b.Return(); ret != nil {
303 if debug {
304 fmt.Printf("found path to return in block %s\n", b)
305 }
306 return ret
307 }
308
309
310 if ret := search(b.Succs); ret != nil {
311 if debug {
312 fmt.Printf(" from block %s\n", b)
313 }
314 return ret
315 }
316 }
317 return nil
318 }
319 return search(defblock.Succs)
320 }
321
322 func tupleContains(tuple *types.Tuple, v *types.Var) bool {
323 for i := 0; i < tuple.Len(); i++ {
324 if tuple.At(i) == v {
325 return true
326 }
327 }
328 return false
329 }
330
View as plain text