Source file src/cmd/vendor/golang.org/x/tools/go/analysis/passes/modernize/minmax.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package modernize
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/token"
    11  	"go/types"
    12  	"strings"
    13  
    14  	"golang.org/x/tools/go/analysis"
    15  	"golang.org/x/tools/go/analysis/passes/inspect"
    16  	"golang.org/x/tools/go/ast/edge"
    17  	"golang.org/x/tools/go/ast/inspector"
    18  	"golang.org/x/tools/internal/analysis/analyzerutil"
    19  	typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
    20  	"golang.org/x/tools/internal/astutil"
    21  	"golang.org/x/tools/internal/typeparams"
    22  	"golang.org/x/tools/internal/typesinternal/typeindex"
    23  	"golang.org/x/tools/internal/versions"
    24  )
    25  
    26  var MinMaxAnalyzer = &analysis.Analyzer{
    27  	Name: "minmax",
    28  	Doc:  analyzerutil.MustExtractDoc(doc, "minmax"),
    29  	Requires: []*analysis.Analyzer{
    30  		inspect.Analyzer,
    31  		typeindexanalyzer.Analyzer,
    32  	},
    33  	Run: minmax,
    34  	URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#minmax",
    35  }
    36  
    37  // The minmax pass replaces if/else statements with calls to min or max,
    38  // and removes user-defined min/max functions that are equivalent to built-ins.
    39  //
    40  // If/else replacement patterns:
    41  //
    42  //  1. if a < b { x = a } else { x = b }        =>      x = min(a, b)
    43  //  2. x = a; if a < b { x = b }                =>      x = max(a, b)
    44  //
    45  // Pattern 1 requires that a is not NaN, and pattern 2 requires that b
    46  // is not Nan. Since this is hard to prove, we reject floating-point
    47  // numbers.
    48  //
    49  // Function removal:
    50  // User-defined min/max functions are suggested for removal if they may
    51  // be safely replaced by their built-in namesake.
    52  //
    53  // Variants:
    54  // - all four ordered comparisons
    55  // - "x := a" or "x = a" or "var x = a" in pattern 2
    56  // - "x < b" or "a < b" in pattern 2
    57  func minmax(pass *analysis.Pass) (any, error) {
    58  	// Check for user-defined min/max functions that can be removed
    59  	checkUserDefinedMinMax(pass)
    60  
    61  	// check is called for all statements of this form:
    62  	//   if a < b { lhs = rhs }
    63  	check := func(file *ast.File, curIfStmt inspector.Cursor, compare *ast.BinaryExpr) {
    64  		var (
    65  			ifStmt  = curIfStmt.Node().(*ast.IfStmt)
    66  			tassign = ifStmt.Body.List[0].(*ast.AssignStmt)
    67  			a       = compare.X
    68  			b       = compare.Y
    69  			lhs     = tassign.Lhs[0]
    70  			rhs     = tassign.Rhs[0]
    71  			sign    = isInequality(compare.Op)
    72  
    73  			// callArg formats a call argument, preserving comments from [start-end).
    74  			callArg = func(arg ast.Expr, start, end token.Pos) string {
    75  				comments := allComments(file, start, end)
    76  				return cond(arg == b, ", ", "") + // second argument needs a comma
    77  					cond(comments != "", "\n", "") + // comments need their own line
    78  					comments +
    79  					astutil.Format(pass.Fset, arg)
    80  			}
    81  		)
    82  
    83  		if fblock, ok := ifStmt.Else.(*ast.BlockStmt); ok && isAssignBlock(fblock) {
    84  			fassign := fblock.List[0].(*ast.AssignStmt)
    85  
    86  			// Have: if a < b { lhs = rhs } else { lhs2 = rhs2 }
    87  			lhs2 := fassign.Lhs[0]
    88  			rhs2 := fassign.Rhs[0]
    89  
    90  			// For pattern 1, check that:
    91  			// - lhs = lhs2
    92  			// - {rhs,rhs2} = {a,b}
    93  			if astutil.EqualSyntax(lhs, lhs2) {
    94  				if astutil.EqualSyntax(rhs, a) && astutil.EqualSyntax(rhs2, b) {
    95  					sign = +sign
    96  				} else if astutil.EqualSyntax(rhs2, a) && astutil.EqualSyntax(rhs, b) {
    97  					sign = -sign
    98  				} else {
    99  					return
   100  				}
   101  
   102  				sym := cond(sign < 0, "min", "max")
   103  
   104  				if !is[*types.Builtin](lookup(pass.TypesInfo, curIfStmt, sym)) {
   105  					return // min/max function is shadowed
   106  				}
   107  
   108  				// pattern 1
   109  				//
   110  				// TODO(adonovan): if lhs is declared "var lhs T" on preceding line,
   111  				// simplify the whole thing to "lhs := min(a, b)".
   112  				pass.Report(analysis.Diagnostic{
   113  					// Highlight the condition a < b.
   114  					Pos:     compare.Pos(),
   115  					End:     compare.End(),
   116  					Message: fmt.Sprintf("if/else statement can be modernized using %s", sym),
   117  					SuggestedFixes: []analysis.SuggestedFix{{
   118  						Message: fmt.Sprintf("Replace if statement with %s", sym),
   119  						TextEdits: []analysis.TextEdit{{
   120  							// Replace IfStmt with lhs = min(a, b).
   121  							Pos: ifStmt.Pos(),
   122  							End: ifStmt.End(),
   123  							NewText: fmt.Appendf(nil, "%s = %s(%s%s)",
   124  								astutil.Format(pass.Fset, lhs),
   125  								sym,
   126  								callArg(a, ifStmt.Pos(), ifStmt.Else.Pos()),
   127  								callArg(b, ifStmt.Else.Pos(), ifStmt.End()),
   128  							),
   129  						}},
   130  					}},
   131  				})
   132  			}
   133  
   134  		} else if prev, ok := curIfStmt.PrevSibling(); ok && isSimpleAssign(prev.Node()) && ifStmt.Else == nil {
   135  			fassign := prev.Node().(*ast.AssignStmt)
   136  
   137  			// Have: lhs0 = rhs0; if a < b { lhs = rhs }
   138  			//
   139  			// For pattern 2, check that
   140  			// - lhs = lhs0
   141  			// - {a,b} = {rhs,rhs0} or {rhs,lhs0}
   142  			//   The replacement must use rhs0 not lhs0 though.
   143  			//   For example, we accept this variant:
   144  			//     lhs = x; if lhs < y { lhs = y }   =>   lhs = min(x, y), not min(lhs, y)
   145  			//
   146  			// TODO(adonovan): accept "var lhs0 = rhs0" form too.
   147  			lhs0 := fassign.Lhs[0]
   148  			rhs0 := fassign.Rhs[0]
   149  
   150  			// If the assignment occurs within a select
   151  			// comms clause (like "case lhs0 := <-rhs0:"),
   152  			// there's no way of rewriting it into a min/max call.
   153  			if ek, _ := prev.ParentEdge(); ek == edge.CommClause_Comm {
   154  				return
   155  			}
   156  
   157  			if astutil.EqualSyntax(lhs, lhs0) {
   158  				if astutil.EqualSyntax(rhs, a) && (astutil.EqualSyntax(rhs0, b) || astutil.EqualSyntax(lhs0, b)) {
   159  					sign = +sign
   160  				} else if (astutil.EqualSyntax(rhs0, a) || astutil.EqualSyntax(lhs0, a)) && astutil.EqualSyntax(rhs, b) {
   161  					sign = -sign
   162  				} else {
   163  					return
   164  				}
   165  				sym := cond(sign < 0, "min", "max")
   166  
   167  				if !is[*types.Builtin](lookup(pass.TypesInfo, curIfStmt, sym)) {
   168  					return // min/max function is shadowed
   169  				}
   170  
   171  				// Permit lhs0 to stand for rhs0 in the matching,
   172  				// but don't actually reduce to lhs0 = min(lhs0, rhs)
   173  				// since the "=" could be a ":=". Use min(rhs0, rhs).
   174  				if astutil.EqualSyntax(lhs0, a) {
   175  					a = rhs0
   176  				} else if astutil.EqualSyntax(lhs0, b) {
   177  					b = rhs0
   178  				}
   179  
   180  				// pattern 2
   181  				pass.Report(analysis.Diagnostic{
   182  					// Highlight the condition a < b.
   183  					Pos:     compare.Pos(),
   184  					End:     compare.End(),
   185  					Message: fmt.Sprintf("if statement can be modernized using %s", sym),
   186  					SuggestedFixes: []analysis.SuggestedFix{{
   187  						Message: fmt.Sprintf("Replace if/else with %s", sym),
   188  						TextEdits: []analysis.TextEdit{{
   189  							Pos: fassign.Pos(),
   190  							End: ifStmt.End(),
   191  							// Replace "x := a; if ... {}" with "x = min(...)", preserving comments.
   192  							NewText: fmt.Appendf(nil, "%s %s %s(%s%s)",
   193  								astutil.Format(pass.Fset, lhs),
   194  								fassign.Tok.String(),
   195  								sym,
   196  								callArg(a, fassign.Pos(), ifStmt.Pos()),
   197  								callArg(b, ifStmt.Pos(), ifStmt.End()),
   198  							),
   199  						}},
   200  					}},
   201  				})
   202  			}
   203  		}
   204  	}
   205  
   206  	// Find all "if a < b { lhs = rhs }" statements.
   207  	info := pass.TypesInfo
   208  	for curFile := range filesUsingGoVersion(pass, versions.Go1_21) {
   209  		astFile := curFile.Node().(*ast.File)
   210  		for curIfStmt := range curFile.Preorder((*ast.IfStmt)(nil)) {
   211  			ifStmt := curIfStmt.Node().(*ast.IfStmt)
   212  
   213  			// Don't bother handling "if a < b { lhs = rhs }" when it appears
   214  			// as the "else" branch of another if-statement.
   215  			//    if cond { ... } else if a < b { lhs = rhs }
   216  			// (This case would require introducing another block
   217  			//    if cond { ... } else { if a < b { lhs = rhs } }
   218  			// and checking that there is no following "else".)
   219  			if astutil.IsChildOf(curIfStmt, edge.IfStmt_Else) {
   220  				continue
   221  			}
   222  
   223  			if compare, ok := ifStmt.Cond.(*ast.BinaryExpr); ok &&
   224  				ifStmt.Init == nil &&
   225  				isInequality(compare.Op) != 0 &&
   226  				isAssignBlock(ifStmt.Body) {
   227  				// a blank var has no type.
   228  				if tLHS := info.TypeOf(ifStmt.Body.List[0].(*ast.AssignStmt).Lhs[0]); tLHS != nil && !maybeNaN(tLHS) {
   229  					// Have: if a < b { lhs = rhs }
   230  					check(astFile, curIfStmt, compare)
   231  				}
   232  			}
   233  		}
   234  	}
   235  	return nil, nil
   236  }
   237  
   238  // allComments collects all the comments from start to end.
   239  func allComments(file *ast.File, start, end token.Pos) string {
   240  	var buf strings.Builder
   241  	for co := range astutil.Comments(file, start, end) {
   242  		_, _ = fmt.Fprintf(&buf, "%s\n", co.Text)
   243  	}
   244  	return buf.String()
   245  }
   246  
   247  // isInequality reports non-zero if tok is one of < <= => >:
   248  // +1 for > and -1 for <.
   249  func isInequality(tok token.Token) int {
   250  	switch tok {
   251  	case token.LEQ, token.LSS:
   252  		return -1
   253  	case token.GEQ, token.GTR:
   254  		return +1
   255  	}
   256  	return 0
   257  }
   258  
   259  // isAssignBlock reports whether b is a block of the form { lhs = rhs }.
   260  func isAssignBlock(b *ast.BlockStmt) bool {
   261  	if len(b.List) != 1 {
   262  		return false
   263  	}
   264  	// Inv: the sole statement cannot be { lhs := rhs }.
   265  	return isSimpleAssign(b.List[0])
   266  }
   267  
   268  // isSimpleAssign reports whether n has the form "lhs = rhs" or "lhs := rhs".
   269  func isSimpleAssign(n ast.Node) bool {
   270  	assign, ok := n.(*ast.AssignStmt)
   271  	return ok &&
   272  		(assign.Tok == token.ASSIGN || assign.Tok == token.DEFINE) &&
   273  		len(assign.Lhs) == 1 &&
   274  		len(assign.Rhs) == 1
   275  }
   276  
   277  // maybeNaN reports whether t is (or may be) a floating-point type.
   278  func maybeNaN(t types.Type) bool {
   279  	// For now, we rely on core types.
   280  	// TODO(adonovan): In the post-core-types future,
   281  	// follow the approach of types.Checker.applyTypeFunc.
   282  	t = typeparams.CoreType(t)
   283  	if t == nil {
   284  		return true // fail safe
   285  	}
   286  	if basic, ok := t.(*types.Basic); ok && basic.Info()&types.IsFloat != 0 {
   287  		return true
   288  	}
   289  	return false
   290  }
   291  
   292  // checkUserDefinedMinMax looks for user-defined min/max functions that are
   293  // equivalent to the built-in functions and suggests removing them.
   294  func checkUserDefinedMinMax(pass *analysis.Pass) {
   295  	index := pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
   296  
   297  	// Look up min and max functions by name in package scope
   298  	for _, funcName := range []string{"min", "max"} {
   299  		if fn, ok := pass.Pkg.Scope().Lookup(funcName).(*types.Func); ok {
   300  			// Use typeindex to get the FuncDecl directly
   301  			if def, ok := index.Def(fn); ok {
   302  				decl := def.Parent().Node().(*ast.FuncDecl)
   303  				// Check if this function matches the built-in min/max signature and behavior
   304  				if canUseBuiltinMinMax(fn, decl.Body) {
   305  					// Expand to include leading doc comment
   306  					pos := decl.Pos()
   307  					if docs := astutil.DocComment(decl); docs != nil {
   308  						pos = docs.Pos()
   309  					}
   310  
   311  					pass.Report(analysis.Diagnostic{
   312  						Pos:     decl.Pos(),
   313  						End:     decl.End(),
   314  						Message: fmt.Sprintf("user-defined %s function is equivalent to built-in %s and can be removed", funcName, funcName),
   315  						SuggestedFixes: []analysis.SuggestedFix{{
   316  							Message: fmt.Sprintf("Remove user-defined %s function", funcName),
   317  							TextEdits: []analysis.TextEdit{{
   318  								Pos: pos,
   319  								End: decl.End(),
   320  							}},
   321  						}},
   322  					})
   323  				}
   324  			}
   325  		}
   326  	}
   327  }
   328  
   329  // canUseBuiltinMinMax reports whether it is safe to replace a call
   330  // to this min or max function by its built-in namesake.
   331  func canUseBuiltinMinMax(fn *types.Func, body *ast.BlockStmt) bool {
   332  	sig := fn.Type().(*types.Signature)
   333  
   334  	// Only consider the most common case: exactly 2 parameters
   335  	if sig.Params().Len() != 2 {
   336  		return false
   337  	}
   338  
   339  	// Check if any parameter might be floating-point
   340  	for param := range sig.Params().Variables() {
   341  		if maybeNaN(param.Type()) {
   342  			return false // Don't suggest removal for float types due to NaN handling
   343  		}
   344  	}
   345  
   346  	// Must have exactly one return value
   347  	if sig.Results().Len() != 1 {
   348  		return false
   349  	}
   350  
   351  	// Check that the function body implements the expected min/max logic
   352  	if body == nil {
   353  		return false
   354  	}
   355  
   356  	return hasMinMaxLogic(body, fn.Name())
   357  }
   358  
   359  // hasMinMaxLogic checks if the function body implements simple min/max logic.
   360  func hasMinMaxLogic(body *ast.BlockStmt, funcName string) bool {
   361  	// Pattern 1: Single if/else statement
   362  	if len(body.List) == 1 {
   363  		if ifStmt, ok := body.List[0].(*ast.IfStmt); ok {
   364  			// Get the "false" result from the else block
   365  			if elseBlock, ok := ifStmt.Else.(*ast.BlockStmt); ok && len(elseBlock.List) == 1 {
   366  				if elseRet, ok := elseBlock.List[0].(*ast.ReturnStmt); ok && len(elseRet.Results) == 1 {
   367  					return checkMinMaxPattern(ifStmt, elseRet.Results[0], funcName)
   368  				}
   369  			}
   370  		}
   371  	}
   372  
   373  	// Pattern 2: if statement followed by return
   374  	if len(body.List) == 2 {
   375  		if ifStmt, ok := body.List[0].(*ast.IfStmt); ok && ifStmt.Else == nil {
   376  			if retStmt, ok := body.List[1].(*ast.ReturnStmt); ok && len(retStmt.Results) == 1 {
   377  				return checkMinMaxPattern(ifStmt, retStmt.Results[0], funcName)
   378  			}
   379  		}
   380  	}
   381  
   382  	return false
   383  }
   384  
   385  // checkMinMaxPattern checks if an if statement implements min/max logic.
   386  // ifStmt: the if statement to check
   387  // falseResult: the expression returned when the condition is false
   388  // funcName: "min" or "max"
   389  func checkMinMaxPattern(ifStmt *ast.IfStmt, falseResult ast.Expr, funcName string) bool {
   390  	// Must have condition with comparison
   391  	cmp, ok := ifStmt.Cond.(*ast.BinaryExpr)
   392  	if !ok {
   393  		return false
   394  	}
   395  
   396  	// Check if then branch returns one of the compared values
   397  	if len(ifStmt.Body.List) != 1 {
   398  		return false
   399  	}
   400  
   401  	thenRet, ok := ifStmt.Body.List[0].(*ast.ReturnStmt)
   402  	if !ok || len(thenRet.Results) != 1 {
   403  		return false
   404  	}
   405  
   406  	// Use the same logic as the existing minmax analyzer
   407  	sign := isInequality(cmp.Op)
   408  	if sign == 0 {
   409  		return false // Not a comparison operator
   410  	}
   411  
   412  	t := thenRet.Results[0] // "true" result
   413  	f := falseResult        // "false" result
   414  	x := cmp.X              // left operand
   415  	y := cmp.Y              // right operand
   416  
   417  	// Check operand order and adjust sign accordingly
   418  	if astutil.EqualSyntax(t, x) && astutil.EqualSyntax(f, y) {
   419  		sign = +sign
   420  	} else if astutil.EqualSyntax(t, y) && astutil.EqualSyntax(f, x) {
   421  		sign = -sign
   422  	} else {
   423  		return false
   424  	}
   425  
   426  	// Check if the sign matches the function name
   427  	return cond(sign < 0, "min", "max") == funcName
   428  }
   429  
   430  // -- utils --
   431  
   432  func is[T any](x any) bool {
   433  	_, ok := x.(T)
   434  	return ok
   435  }
   436  
   437  func cond[T any](cond bool, t, f T) T {
   438  	if cond {
   439  		return t
   440  	} else {
   441  		return f
   442  	}
   443  }
   444  

View as plain text