1
2
3
4
5 package modernize
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/token"
11 "go/types"
12 "strings"
13
14 "golang.org/x/tools/go/analysis"
15 "golang.org/x/tools/go/analysis/passes/inspect"
16 "golang.org/x/tools/go/ast/edge"
17 "golang.org/x/tools/go/ast/inspector"
18 "golang.org/x/tools/internal/analysis/analyzerutil"
19 typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
20 "golang.org/x/tools/internal/astutil"
21 "golang.org/x/tools/internal/typeparams"
22 "golang.org/x/tools/internal/typesinternal/typeindex"
23 "golang.org/x/tools/internal/versions"
24 )
25
26 var MinMaxAnalyzer = &analysis.Analyzer{
27 Name: "minmax",
28 Doc: analyzerutil.MustExtractDoc(doc, "minmax"),
29 Requires: []*analysis.Analyzer{
30 inspect.Analyzer,
31 typeindexanalyzer.Analyzer,
32 },
33 Run: minmax,
34 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#minmax",
35 }
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57 func minmax(pass *analysis.Pass) (any, error) {
58
59 checkUserDefinedMinMax(pass)
60
61
62
63 check := func(file *ast.File, curIfStmt inspector.Cursor, compare *ast.BinaryExpr) {
64 var (
65 ifStmt = curIfStmt.Node().(*ast.IfStmt)
66 tassign = ifStmt.Body.List[0].(*ast.AssignStmt)
67 a = compare.X
68 b = compare.Y
69 lhs = tassign.Lhs[0]
70 rhs = tassign.Rhs[0]
71 sign = isInequality(compare.Op)
72
73
74 callArg = func(arg ast.Expr, start, end token.Pos) string {
75 comments := allComments(file, start, end)
76 return cond(arg == b, ", ", "") +
77 cond(comments != "", "\n", "") +
78 comments +
79 astutil.Format(pass.Fset, arg)
80 }
81 )
82
83 if fblock, ok := ifStmt.Else.(*ast.BlockStmt); ok && isAssignBlock(fblock) {
84 fassign := fblock.List[0].(*ast.AssignStmt)
85
86
87 lhs2 := fassign.Lhs[0]
88 rhs2 := fassign.Rhs[0]
89
90
91
92
93 if astutil.EqualSyntax(lhs, lhs2) {
94 if astutil.EqualSyntax(rhs, a) && astutil.EqualSyntax(rhs2, b) {
95 sign = +sign
96 } else if astutil.EqualSyntax(rhs2, a) && astutil.EqualSyntax(rhs, b) {
97 sign = -sign
98 } else {
99 return
100 }
101
102 sym := cond(sign < 0, "min", "max")
103
104 if !is[*types.Builtin](lookup(pass.TypesInfo, curIfStmt, sym)) {
105 return
106 }
107
108
109
110
111
112 pass.Report(analysis.Diagnostic{
113
114 Pos: compare.Pos(),
115 End: compare.End(),
116 Message: fmt.Sprintf("if/else statement can be modernized using %s", sym),
117 SuggestedFixes: []analysis.SuggestedFix{{
118 Message: fmt.Sprintf("Replace if statement with %s", sym),
119 TextEdits: []analysis.TextEdit{{
120
121 Pos: ifStmt.Pos(),
122 End: ifStmt.End(),
123 NewText: fmt.Appendf(nil, "%s = %s(%s%s)",
124 astutil.Format(pass.Fset, lhs),
125 sym,
126 callArg(a, ifStmt.Pos(), ifStmt.Else.Pos()),
127 callArg(b, ifStmt.Else.Pos(), ifStmt.End()),
128 ),
129 }},
130 }},
131 })
132 }
133
134 } else if prev, ok := curIfStmt.PrevSibling(); ok && isSimpleAssign(prev.Node()) && ifStmt.Else == nil {
135 fassign := prev.Node().(*ast.AssignStmt)
136
137
138
139
140
141
142
143
144
145
146
147 lhs0 := fassign.Lhs[0]
148 rhs0 := fassign.Rhs[0]
149
150
151
152
153 if ek, _ := prev.ParentEdge(); ek == edge.CommClause_Comm {
154 return
155 }
156
157 if astutil.EqualSyntax(lhs, lhs0) {
158 if astutil.EqualSyntax(rhs, a) && (astutil.EqualSyntax(rhs0, b) || astutil.EqualSyntax(lhs0, b)) {
159 sign = +sign
160 } else if (astutil.EqualSyntax(rhs0, a) || astutil.EqualSyntax(lhs0, a)) && astutil.EqualSyntax(rhs, b) {
161 sign = -sign
162 } else {
163 return
164 }
165 sym := cond(sign < 0, "min", "max")
166
167 if !is[*types.Builtin](lookup(pass.TypesInfo, curIfStmt, sym)) {
168 return
169 }
170
171
172
173
174 if astutil.EqualSyntax(lhs0, a) {
175 a = rhs0
176 } else if astutil.EqualSyntax(lhs0, b) {
177 b = rhs0
178 }
179
180
181 pass.Report(analysis.Diagnostic{
182
183 Pos: compare.Pos(),
184 End: compare.End(),
185 Message: fmt.Sprintf("if statement can be modernized using %s", sym),
186 SuggestedFixes: []analysis.SuggestedFix{{
187 Message: fmt.Sprintf("Replace if/else with %s", sym),
188 TextEdits: []analysis.TextEdit{{
189 Pos: fassign.Pos(),
190 End: ifStmt.End(),
191
192 NewText: fmt.Appendf(nil, "%s %s %s(%s%s)",
193 astutil.Format(pass.Fset, lhs),
194 fassign.Tok.String(),
195 sym,
196 callArg(a, fassign.Pos(), ifStmt.Pos()),
197 callArg(b, ifStmt.Pos(), ifStmt.End()),
198 ),
199 }},
200 }},
201 })
202 }
203 }
204 }
205
206
207 info := pass.TypesInfo
208 for curFile := range filesUsingGoVersion(pass, versions.Go1_21) {
209 astFile := curFile.Node().(*ast.File)
210 for curIfStmt := range curFile.Preorder((*ast.IfStmt)(nil)) {
211 ifStmt := curIfStmt.Node().(*ast.IfStmt)
212
213
214
215
216
217
218
219 if astutil.IsChildOf(curIfStmt, edge.IfStmt_Else) {
220 continue
221 }
222
223 if compare, ok := ifStmt.Cond.(*ast.BinaryExpr); ok &&
224 ifStmt.Init == nil &&
225 isInequality(compare.Op) != 0 &&
226 isAssignBlock(ifStmt.Body) {
227
228 if tLHS := info.TypeOf(ifStmt.Body.List[0].(*ast.AssignStmt).Lhs[0]); tLHS != nil && !maybeNaN(tLHS) {
229
230 check(astFile, curIfStmt, compare)
231 }
232 }
233 }
234 }
235 return nil, nil
236 }
237
238
239 func allComments(file *ast.File, start, end token.Pos) string {
240 var buf strings.Builder
241 for co := range astutil.Comments(file, start, end) {
242 _, _ = fmt.Fprintf(&buf, "%s\n", co.Text)
243 }
244 return buf.String()
245 }
246
247
248
249 func isInequality(tok token.Token) int {
250 switch tok {
251 case token.LEQ, token.LSS:
252 return -1
253 case token.GEQ, token.GTR:
254 return +1
255 }
256 return 0
257 }
258
259
260 func isAssignBlock(b *ast.BlockStmt) bool {
261 if len(b.List) != 1 {
262 return false
263 }
264
265 return isSimpleAssign(b.List[0])
266 }
267
268
269 func isSimpleAssign(n ast.Node) bool {
270 assign, ok := n.(*ast.AssignStmt)
271 return ok &&
272 (assign.Tok == token.ASSIGN || assign.Tok == token.DEFINE) &&
273 len(assign.Lhs) == 1 &&
274 len(assign.Rhs) == 1
275 }
276
277
278 func maybeNaN(t types.Type) bool {
279
280
281
282 t = typeparams.CoreType(t)
283 if t == nil {
284 return true
285 }
286 if basic, ok := t.(*types.Basic); ok && basic.Info()&types.IsFloat != 0 {
287 return true
288 }
289 return false
290 }
291
292
293
294 func checkUserDefinedMinMax(pass *analysis.Pass) {
295 index := pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
296
297
298 for _, funcName := range []string{"min", "max"} {
299 if fn, ok := pass.Pkg.Scope().Lookup(funcName).(*types.Func); ok {
300
301 if def, ok := index.Def(fn); ok {
302 decl := def.Parent().Node().(*ast.FuncDecl)
303
304 if canUseBuiltinMinMax(fn, decl.Body) {
305
306 pos := decl.Pos()
307 if docs := astutil.DocComment(decl); docs != nil {
308 pos = docs.Pos()
309 }
310
311 pass.Report(analysis.Diagnostic{
312 Pos: decl.Pos(),
313 End: decl.End(),
314 Message: fmt.Sprintf("user-defined %s function is equivalent to built-in %s and can be removed", funcName, funcName),
315 SuggestedFixes: []analysis.SuggestedFix{{
316 Message: fmt.Sprintf("Remove user-defined %s function", funcName),
317 TextEdits: []analysis.TextEdit{{
318 Pos: pos,
319 End: decl.End(),
320 }},
321 }},
322 })
323 }
324 }
325 }
326 }
327 }
328
329
330
331 func canUseBuiltinMinMax(fn *types.Func, body *ast.BlockStmt) bool {
332 sig := fn.Type().(*types.Signature)
333
334
335 if sig.Params().Len() != 2 {
336 return false
337 }
338
339
340 for param := range sig.Params().Variables() {
341 if maybeNaN(param.Type()) {
342 return false
343 }
344 }
345
346
347 if sig.Results().Len() != 1 {
348 return false
349 }
350
351
352 if body == nil {
353 return false
354 }
355
356 return hasMinMaxLogic(body, fn.Name())
357 }
358
359
360 func hasMinMaxLogic(body *ast.BlockStmt, funcName string) bool {
361
362 if len(body.List) == 1 {
363 if ifStmt, ok := body.List[0].(*ast.IfStmt); ok {
364
365 if elseBlock, ok := ifStmt.Else.(*ast.BlockStmt); ok && len(elseBlock.List) == 1 {
366 if elseRet, ok := elseBlock.List[0].(*ast.ReturnStmt); ok && len(elseRet.Results) == 1 {
367 return checkMinMaxPattern(ifStmt, elseRet.Results[0], funcName)
368 }
369 }
370 }
371 }
372
373
374 if len(body.List) == 2 {
375 if ifStmt, ok := body.List[0].(*ast.IfStmt); ok && ifStmt.Else == nil {
376 if retStmt, ok := body.List[1].(*ast.ReturnStmt); ok && len(retStmt.Results) == 1 {
377 return checkMinMaxPattern(ifStmt, retStmt.Results[0], funcName)
378 }
379 }
380 }
381
382 return false
383 }
384
385
386
387
388
389 func checkMinMaxPattern(ifStmt *ast.IfStmt, falseResult ast.Expr, funcName string) bool {
390
391 cmp, ok := ifStmt.Cond.(*ast.BinaryExpr)
392 if !ok {
393 return false
394 }
395
396
397 if len(ifStmt.Body.List) != 1 {
398 return false
399 }
400
401 thenRet, ok := ifStmt.Body.List[0].(*ast.ReturnStmt)
402 if !ok || len(thenRet.Results) != 1 {
403 return false
404 }
405
406
407 sign := isInequality(cmp.Op)
408 if sign == 0 {
409 return false
410 }
411
412 t := thenRet.Results[0]
413 f := falseResult
414 x := cmp.X
415 y := cmp.Y
416
417
418 if astutil.EqualSyntax(t, x) && astutil.EqualSyntax(f, y) {
419 sign = +sign
420 } else if astutil.EqualSyntax(t, y) && astutil.EqualSyntax(f, x) {
421 sign = -sign
422 } else {
423 return false
424 }
425
426
427 return cond(sign < 0, "min", "max") == funcName
428 }
429
430
431
432 func is[T any](x any) bool {
433 _, ok := x.(T)
434 return ok
435 }
436
437 func cond[T any](cond bool, t, f T) T {
438 if cond {
439 return t
440 } else {
441 return f
442 }
443 }
444
View as plain text