Source file src/cmd/vendor/golang.org/x/tools/go/analysis/passes/sigchanyzer/sigchanyzer.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package sigchanyzer defines an Analyzer that detects
     6  // misuse of unbuffered signal as argument to signal.Notify.
     7  package sigchanyzer
     8  
     9  import (
    10  	"bytes"
    11  	_ "embed"
    12  	"go/ast"
    13  	"go/format"
    14  	"go/token"
    15  	"go/types"
    16  
    17  	"golang.org/x/tools/go/analysis"
    18  	"golang.org/x/tools/go/analysis/passes/inspect"
    19  	"golang.org/x/tools/go/analysis/passes/internal/analysisutil"
    20  	"golang.org/x/tools/go/ast/inspector"
    21  )
    22  
    23  //go:embed doc.go
    24  var doc string
    25  
    26  // Analyzer describes sigchanyzer analysis function detector.
    27  var Analyzer = &analysis.Analyzer{
    28  	Name:     "sigchanyzer",
    29  	Doc:      analysisutil.MustExtractDoc(doc, "sigchanyzer"),
    30  	URL:      "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/sigchanyzer",
    31  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    32  	Run:      run,
    33  }
    34  
    35  func run(pass *analysis.Pass) (interface{}, error) {
    36  	if !analysisutil.Imports(pass.Pkg, "os/signal") {
    37  		return nil, nil // doesn't directly import signal
    38  	}
    39  
    40  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    41  
    42  	nodeFilter := []ast.Node{
    43  		(*ast.CallExpr)(nil),
    44  	}
    45  	inspect.Preorder(nodeFilter, func(n ast.Node) {
    46  		call := n.(*ast.CallExpr)
    47  		if !isSignalNotify(pass.TypesInfo, call) {
    48  			return
    49  		}
    50  		var chanDecl *ast.CallExpr
    51  		switch arg := call.Args[0].(type) {
    52  		case *ast.Ident:
    53  			if decl, ok := findDecl(arg).(*ast.CallExpr); ok {
    54  				chanDecl = decl
    55  			}
    56  		case *ast.CallExpr:
    57  			// Only signal.Notify(make(chan os.Signal), os.Interrupt) is safe,
    58  			// conservatively treat others as not safe, see golang/go#45043
    59  			if isBuiltinMake(pass.TypesInfo, arg) {
    60  				return
    61  			}
    62  			chanDecl = arg
    63  		}
    64  		if chanDecl == nil || len(chanDecl.Args) != 1 {
    65  			return
    66  		}
    67  
    68  		// Make a copy of the channel's declaration to avoid
    69  		// mutating the AST. See https://golang.org/issue/46129.
    70  		chanDeclCopy := &ast.CallExpr{}
    71  		*chanDeclCopy = *chanDecl
    72  		chanDeclCopy.Args = append([]ast.Expr(nil), chanDecl.Args...)
    73  		chanDeclCopy.Args = append(chanDeclCopy.Args, &ast.BasicLit{
    74  			Kind:  token.INT,
    75  			Value: "1",
    76  		})
    77  
    78  		var buf bytes.Buffer
    79  		if err := format.Node(&buf, token.NewFileSet(), chanDeclCopy); err != nil {
    80  			return
    81  		}
    82  		pass.Report(analysis.Diagnostic{
    83  			Pos:     call.Pos(),
    84  			End:     call.End(),
    85  			Message: "misuse of unbuffered os.Signal channel as argument to signal.Notify",
    86  			SuggestedFixes: []analysis.SuggestedFix{{
    87  				Message: "Change to buffer channel",
    88  				TextEdits: []analysis.TextEdit{{
    89  					Pos:     chanDecl.Pos(),
    90  					End:     chanDecl.End(),
    91  					NewText: buf.Bytes(),
    92  				}},
    93  			}},
    94  		})
    95  	})
    96  	return nil, nil
    97  }
    98  
    99  func isSignalNotify(info *types.Info, call *ast.CallExpr) bool {
   100  	check := func(id *ast.Ident) bool {
   101  		obj := info.ObjectOf(id)
   102  		return obj.Name() == "Notify" && obj.Pkg().Path() == "os/signal"
   103  	}
   104  	switch fun := call.Fun.(type) {
   105  	case *ast.SelectorExpr:
   106  		return check(fun.Sel)
   107  	case *ast.Ident:
   108  		if fun, ok := findDecl(fun).(*ast.SelectorExpr); ok {
   109  			return check(fun.Sel)
   110  		}
   111  		return false
   112  	default:
   113  		return false
   114  	}
   115  }
   116  
   117  func findDecl(arg *ast.Ident) ast.Node {
   118  	if arg.Obj == nil {
   119  		return nil
   120  	}
   121  	switch as := arg.Obj.Decl.(type) {
   122  	case *ast.AssignStmt:
   123  		if len(as.Lhs) != len(as.Rhs) {
   124  			return nil
   125  		}
   126  		for i, lhs := range as.Lhs {
   127  			lid, ok := lhs.(*ast.Ident)
   128  			if !ok {
   129  				continue
   130  			}
   131  			if lid.Obj == arg.Obj {
   132  				return as.Rhs[i]
   133  			}
   134  		}
   135  	case *ast.ValueSpec:
   136  		if len(as.Names) != len(as.Values) {
   137  			return nil
   138  		}
   139  		for i, name := range as.Names {
   140  			if name.Obj == arg.Obj {
   141  				return as.Values[i]
   142  			}
   143  		}
   144  	}
   145  	return nil
   146  }
   147  
   148  func isBuiltinMake(info *types.Info, call *ast.CallExpr) bool {
   149  	typVal := info.Types[call.Fun]
   150  	if !typVal.IsBuiltin() {
   151  		return false
   152  	}
   153  	switch fun := call.Fun.(type) {
   154  	case *ast.Ident:
   155  		return info.ObjectOf(fun).Name() == "make"
   156  	default:
   157  		return false
   158  	}
   159  }
   160  

View as plain text