1  
     2  
     3  
     4  
     5  package lostcancel
     6  
     7  import (
     8  	_ "embed"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/types"
    12  
    13  	"golang.org/x/tools/go/analysis"
    14  	"golang.org/x/tools/go/analysis/passes/ctrlflow"
    15  	"golang.org/x/tools/go/analysis/passes/inspect"
    16  	"golang.org/x/tools/go/analysis/passes/internal/analysisutil"
    17  	"golang.org/x/tools/go/ast/inspector"
    18  	"golang.org/x/tools/go/cfg"
    19  	"golang.org/x/tools/internal/analysisinternal"
    20  	"golang.org/x/tools/internal/astutil"
    21  )
    22  
    23  
    24  var doc string
    25  
    26  var Analyzer = &analysis.Analyzer{
    27  	Name: "lostcancel",
    28  	Doc:  analysisutil.MustExtractDoc(doc, "lostcancel"),
    29  	URL:  "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/lostcancel",
    30  	Run:  run,
    31  	Requires: []*analysis.Analyzer{
    32  		inspect.Analyzer,
    33  		ctrlflow.Analyzer,
    34  	},
    35  }
    36  
    37  const debug = false
    38  
    39  var contextPackage = "context"
    40  
    41  
    42  
    43  
    44  
    45  
    46  
    47  
    48  
    49  
    50  
    51  func run(pass *analysis.Pass) (any, error) {
    52  	
    53  	if !analysisinternal.Imports(pass.Pkg, contextPackage) {
    54  		return nil, nil
    55  	}
    56  
    57  	
    58  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    59  	nodeTypes := []ast.Node{
    60  		(*ast.FuncLit)(nil),
    61  		(*ast.FuncDecl)(nil),
    62  	}
    63  	inspect.Preorder(nodeTypes, func(n ast.Node) {
    64  		runFunc(pass, n)
    65  	})
    66  	return nil, nil
    67  }
    68  
    69  func runFunc(pass *analysis.Pass, node ast.Node) {
    70  	
    71  	var funcScope *types.Scope
    72  	switch v := node.(type) {
    73  	case *ast.FuncLit:
    74  		funcScope = pass.TypesInfo.Scopes[v.Type]
    75  	case *ast.FuncDecl:
    76  		funcScope = pass.TypesInfo.Scopes[v.Type]
    77  	}
    78  
    79  	
    80  	cancelvars := make(map[*types.Var]ast.Node)
    81  
    82  	
    83  	
    84  	
    85  
    86  	
    87  	astutil.PreorderStack(node, nil, func(n ast.Node, stack []ast.Node) bool {
    88  		if _, ok := n.(*ast.FuncLit); ok && len(stack) > 0 {
    89  			return false 
    90  		}
    91  
    92  		
    93  		
    94  		
    95  		
    96  		
    97  		
    98  		if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-1]) {
    99  			return true
   100  		}
   101  		var id *ast.Ident 
   102  		stmt := stack[len(stack)-2]
   103  		switch stmt := stmt.(type) {
   104  		case *ast.ValueSpec:
   105  			if len(stmt.Names) > 1 {
   106  				id = stmt.Names[1]
   107  			}
   108  		case *ast.AssignStmt:
   109  			if len(stmt.Lhs) > 1 {
   110  				id, _ = stmt.Lhs[1].(*ast.Ident)
   111  			}
   112  		}
   113  		if id != nil {
   114  			if id.Name == "_" {
   115  				pass.ReportRangef(id,
   116  					"the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
   117  					n.(*ast.SelectorExpr).Sel.Name)
   118  			} else if v, ok := pass.TypesInfo.Uses[id].(*types.Var); ok {
   119  				
   120  				
   121  				if funcScope.Contains(v.Pos()) {
   122  					cancelvars[v] = stmt
   123  				}
   124  			} else if v, ok := pass.TypesInfo.Defs[id].(*types.Var); ok {
   125  				cancelvars[v] = stmt
   126  			}
   127  		}
   128  		return true
   129  	})
   130  
   131  	if len(cancelvars) == 0 {
   132  		return 
   133  	}
   134  
   135  	
   136  	cfgs := pass.ResultOf[ctrlflow.Analyzer].(*ctrlflow.CFGs)
   137  	var g *cfg.CFG
   138  	var sig *types.Signature
   139  	switch node := node.(type) {
   140  	case *ast.FuncDecl:
   141  		sig, _ = pass.TypesInfo.Defs[node.Name].Type().(*types.Signature)
   142  		if node.Name.Name == "main" && sig.Recv() == nil && pass.Pkg.Name() == "main" {
   143  			
   144  			
   145  			return
   146  		}
   147  		g = cfgs.FuncDecl(node)
   148  
   149  	case *ast.FuncLit:
   150  		sig, _ = pass.TypesInfo.Types[node.Type].Type.(*types.Signature)
   151  		g = cfgs.FuncLit(node)
   152  	}
   153  	if sig == nil {
   154  		return 
   155  	}
   156  
   157  	
   158  	if debug {
   159  		fmt.Println(g.Format(pass.Fset))
   160  	}
   161  
   162  	
   163  	
   164  	
   165  	for v, stmt := range cancelvars {
   166  		if ret := lostCancelPath(pass, g, v, stmt, sig); ret != nil {
   167  			lineno := pass.Fset.Position(stmt.Pos()).Line
   168  			pass.ReportRangef(stmt, "the %s function is not used on all paths (possible context leak)", v.Name())
   169  
   170  			pos, end := ret.Pos(), ret.End()
   171  			
   172  			
   173  			if pass.Fset.File(pos) != pass.Fset.File(end) {
   174  				end = pos
   175  			}
   176  			pass.Report(analysis.Diagnostic{
   177  				Pos:     pos,
   178  				End:     end,
   179  				Message: fmt.Sprintf("this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno),
   180  			})
   181  		}
   182  	}
   183  }
   184  
   185  func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
   186  
   187  
   188  
   189  func isContextWithCancel(info *types.Info, n ast.Node) bool {
   190  	sel, ok := n.(*ast.SelectorExpr)
   191  	if !ok {
   192  		return false
   193  	}
   194  	switch sel.Sel.Name {
   195  	case "WithCancel", "WithCancelCause",
   196  		"WithTimeout", "WithTimeoutCause",
   197  		"WithDeadline", "WithDeadlineCause":
   198  	default:
   199  		return false
   200  	}
   201  	if x, ok := sel.X.(*ast.Ident); ok {
   202  		if pkgname, ok := info.Uses[x].(*types.PkgName); ok {
   203  			return pkgname.Imported().Path() == contextPackage
   204  		}
   205  		
   206  		
   207  		return x.Name == "context"
   208  	}
   209  	return false
   210  }
   211  
   212  
   213  
   214  
   215  
   216  func lostCancelPath(pass *analysis.Pass, g *cfg.CFG, v *types.Var, stmt ast.Node, sig *types.Signature) *ast.ReturnStmt {
   217  	vIsNamedResult := sig != nil && tupleContains(sig.Results(), v)
   218  
   219  	
   220  	uses := func(pass *analysis.Pass, v *types.Var, stmts []ast.Node) bool {
   221  		found := false
   222  		for _, stmt := range stmts {
   223  			ast.Inspect(stmt, func(n ast.Node) bool {
   224  				switch n := n.(type) {
   225  				case *ast.Ident:
   226  					if pass.TypesInfo.Uses[n] == v {
   227  						found = true
   228  					}
   229  				case *ast.ReturnStmt:
   230  					
   231  					
   232  					if n.Results == nil && vIsNamedResult {
   233  						found = true
   234  					}
   235  				}
   236  				return !found
   237  			})
   238  		}
   239  		return found
   240  	}
   241  
   242  	
   243  	memo := make(map[*cfg.Block]bool)
   244  	blockUses := func(pass *analysis.Pass, v *types.Var, b *cfg.Block) bool {
   245  		res, ok := memo[b]
   246  		if !ok {
   247  			res = uses(pass, v, b.Nodes)
   248  			memo[b] = res
   249  		}
   250  		return res
   251  	}
   252  
   253  	
   254  	
   255  	var defblock *cfg.Block
   256  	var rest []ast.Node
   257  outer:
   258  	for _, b := range g.Blocks {
   259  		for i, n := range b.Nodes {
   260  			if n == stmt {
   261  				defblock = b
   262  				rest = b.Nodes[i+1:]
   263  				break outer
   264  			}
   265  		}
   266  	}
   267  	if defblock == nil {
   268  		panic("internal error: can't find defining block for cancel var")
   269  	}
   270  
   271  	
   272  	if uses(pass, v, rest) {
   273  		return nil
   274  	}
   275  
   276  	
   277  	if ret := defblock.Return(); ret != nil {
   278  		return ret
   279  	}
   280  
   281  	
   282  	
   283  	seen := make(map[*cfg.Block]bool)
   284  	var search func(blocks []*cfg.Block) *ast.ReturnStmt
   285  	search = func(blocks []*cfg.Block) *ast.ReturnStmt {
   286  		for _, b := range blocks {
   287  			if seen[b] {
   288  				continue
   289  			}
   290  			seen[b] = true
   291  
   292  			
   293  			if blockUses(pass, v, b) {
   294  				continue
   295  			}
   296  
   297  			
   298  			if ret := b.Return(); ret != nil {
   299  				if debug {
   300  					fmt.Printf("found path to return in block %s\n", b)
   301  				}
   302  				return ret 
   303  			}
   304  
   305  			
   306  			if ret := search(b.Succs); ret != nil {
   307  				if debug {
   308  					fmt.Printf(" from block %s\n", b)
   309  				}
   310  				return ret
   311  			}
   312  		}
   313  		return nil
   314  	}
   315  	return search(defblock.Succs)
   316  }
   317  
   318  func tupleContains(tuple *types.Tuple, v *types.Var) bool {
   319  	for i := 0; i < tuple.Len(); i++ {
   320  		if tuple.At(i) == v {
   321  			return true
   322  		}
   323  	}
   324  	return false
   325  }
   326  
View as plain text