1
2
3
4
5 package modernize
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/constant"
11 "go/token"
12 "go/types"
13 "iter"
14 "strconv"
15
16 "golang.org/x/tools/go/analysis"
17 "golang.org/x/tools/go/analysis/passes/inspect"
18 "golang.org/x/tools/go/ast/edge"
19 "golang.org/x/tools/go/ast/inspector"
20 "golang.org/x/tools/go/types/typeutil"
21 "golang.org/x/tools/internal/analysis/analyzerutil"
22 typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
23 "golang.org/x/tools/internal/astutil"
24 "golang.org/x/tools/internal/goplsexport"
25 "golang.org/x/tools/internal/refactor"
26 "golang.org/x/tools/internal/typesinternal"
27 "golang.org/x/tools/internal/typesinternal/typeindex"
28 "golang.org/x/tools/internal/versions"
29 )
30
31 var stringscutAnalyzer = &analysis.Analyzer{
32 Name: "stringscut",
33 Doc: analyzerutil.MustExtractDoc(doc, "stringscut"),
34 Requires: []*analysis.Analyzer{
35 inspect.Analyzer,
36 typeindexanalyzer.Analyzer,
37 },
38 Run: stringscut,
39 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#stringscut",
40 }
41
42 func init() {
43
44 goplsexport.StringsCutModernizer = stringscutAnalyzer
45 }
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116 func stringscut(pass *analysis.Pass) (any, error) {
117 var (
118 index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
119 info = pass.TypesInfo
120
121 stringsIndex = index.Object("strings", "Index")
122 stringsIndexByte = index.Object("strings", "IndexByte")
123 bytesIndex = index.Object("bytes", "Index")
124 bytesIndexByte = index.Object("bytes", "IndexByte")
125 )
126
127 for _, obj := range []types.Object{
128 stringsIndex,
129 stringsIndexByte,
130 bytesIndex,
131 bytesIndexByte,
132 } {
133
134 nextcall:
135 for curCall := range index.Calls(obj) {
136
137 if !analyzerutil.FileUsesGoVersion(pass, astutil.EnclosingFile(curCall), versions.Go1_18) {
138 continue
139 }
140 indexCall := curCall.Node().(*ast.CallExpr)
141 obj := typeutil.Callee(info, indexCall)
142 if obj == nil {
143 continue
144 }
145
146 var iIdent *ast.Ident
147 switch ek, idx := curCall.ParentEdge(); ek {
148 case edge.ValueSpec_Values:
149
150 curName := curCall.Parent().ChildAt(edge.ValueSpec_Names, idx)
151 iIdent = curName.Node().(*ast.Ident)
152 case edge.AssignStmt_Rhs:
153
154
155 curLhs := curCall.Parent().ChildAt(edge.AssignStmt_Lhs, idx)
156 iIdent, _ = curLhs.Node().(*ast.Ident)
157 }
158
159 if iIdent == nil {
160 continue
161 }
162
163
164 iObj := info.ObjectOf(iIdent)
165 if iObj == nil {
166 continue
167 }
168
169 var (
170 s = indexCall.Args[0]
171 substr = indexCall.Args[1]
172 )
173
174
175
176 if !indexArgValid(info, index, s, indexCall.Pos()) ||
177 !indexArgValid(info, index, substr, indexCall.Pos()) {
178 continue nextcall
179 }
180
181
182
183
184
185
186 negative, nonnegative, beforeSlice, afterSlice := checkIdxUses(pass.TypesInfo, index.Uses(iObj), s, substr, iObj)
187
188
189
190 if negative == nil && nonnegative == nil && beforeSlice == nil && afterSlice == nil {
191 continue
192 }
193
194
195 isContains := (len(negative) > 0 || len(nonnegative) > 0) && len(beforeSlice) == 0 && len(afterSlice) == 0
196
197 scope := iObj.Parent()
198 var (
199
200 okVarName = refactor.FreshName(scope, iIdent.Pos(), "ok")
201 beforeVarName = refactor.FreshName(scope, iIdent.Pos(), "before")
202 afterVarName = refactor.FreshName(scope, iIdent.Pos(), "after")
203 foundVarName = refactor.FreshName(scope, iIdent.Pos(), "found")
204 )
205
206
207
208 if len(negative) == 0 && len(nonnegative) == 0 {
209 okVarName = "_"
210 }
211 if len(beforeSlice) == 0 {
212 beforeVarName = "_"
213 }
214 if len(afterSlice) == 0 {
215 afterVarName = "_"
216 }
217
218 var edits []analysis.TextEdit
219 replace := func(exprs []ast.Expr, new string) {
220 for _, expr := range exprs {
221 edits = append(edits, analysis.TextEdit{
222 Pos: expr.Pos(),
223 End: expr.End(),
224 NewText: []byte(new),
225 })
226 }
227 }
228
229
230 indexCallId := typesinternal.UsedIdent(info, indexCall.Fun)
231 replacedFunc := "Cut"
232 if isContains {
233 replacedFunc = "Contains"
234 replace(negative, "!"+foundVarName)
235 replace(nonnegative, foundVarName)
236
237
238
239
240
241
242 edits = append(edits, analysis.TextEdit{
243 Pos: iIdent.Pos(),
244 End: iIdent.End(),
245 NewText: []byte(foundVarName),
246 }, analysis.TextEdit{
247 Pos: indexCallId.Pos(),
248 End: indexCallId.End(),
249 NewText: []byte("Contains"),
250 })
251 } else {
252 replace(negative, "!"+okVarName)
253 replace(nonnegative, okVarName)
254 replace(beforeSlice, beforeVarName)
255 replace(afterSlice, afterVarName)
256
257
258
259
260
261
262 edits = append(edits, analysis.TextEdit{
263 Pos: iIdent.Pos(),
264 End: iIdent.End(),
265 NewText: fmt.Appendf(nil, "%s, %s, %s", beforeVarName, afterVarName, okVarName),
266 }, analysis.TextEdit{
267 Pos: indexCallId.Pos(),
268 End: indexCallId.End(),
269 NewText: []byte("Cut"),
270 })
271 }
272
273
274
275 if obj.Name() == "IndexByte" {
276 switch obj.Pkg().Name() {
277 case "strings":
278 searchByteVal := info.Types[substr].Value
279 if searchByteVal == nil {
280
281
282 edits = append(edits, []analysis.TextEdit{
283 {
284 Pos: substr.Pos(),
285 NewText: []byte("string("),
286 },
287 {
288 Pos: substr.End(),
289 NewText: []byte(")"),
290 },
291 }...)
292 } else {
293
294 val, _ := constant.Int64Val(searchByteVal)
295
296 edits = append(edits, analysis.TextEdit{
297 Pos: substr.Pos(),
298 End: substr.End(),
299 NewText: strconv.AppendQuote(nil, string(byte(val))),
300 })
301 }
302 case "bytes":
303
304 edits = append(edits, []analysis.TextEdit{
305 {
306 Pos: substr.Pos(),
307 NewText: []byte("[]byte{"),
308 },
309 {
310 Pos: substr.End(),
311 NewText: []byte("}"),
312 },
313 }...)
314 }
315 }
316 pass.Report(analysis.Diagnostic{
317 Pos: indexCall.Fun.Pos(),
318 End: indexCall.Fun.End(),
319 Message: fmt.Sprintf("%s.%s can be simplified using %s.%s",
320 obj.Pkg().Name(), obj.Name(), obj.Pkg().Name(), replacedFunc),
321 Category: "stringscut",
322 SuggestedFixes: []analysis.SuggestedFix{{
323 Message: fmt.Sprintf("Simplify %s.%s call using %s.%s", obj.Pkg().Name(), obj.Name(), obj.Pkg().Name(), replacedFunc),
324 TextEdits: edits,
325 }},
326 })
327 }
328 }
329
330 return nil, nil
331 }
332
333
334
335
336
337
338
339
340 func indexArgValid(info *types.Info, index *typeindex.Index, expr ast.Expr, afterPos token.Pos) bool {
341 tv := info.Types[expr]
342 if tv.Value != nil {
343 return true
344 }
345 switch expr := expr.(type) {
346 case *ast.CallExpr:
347 return types.Identical(tv.Type, byteSliceType) &&
348 info.Types[expr.Fun].IsType() &&
349 indexArgValid(info, index, expr.Args[0], afterPos)
350 case *ast.Ident:
351 sObj := info.Uses[expr]
352 sUses := index.Uses(sObj)
353 return !hasModifyingUses(info, sUses, afterPos)
354 default:
355
356
357
358
359
360
361
362
363
364 return false
365 }
366 }
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389 func checkIdxUses(info *types.Info, uses iter.Seq[inspector.Cursor], s, substr ast.Expr, iObj types.Object) (negative, nonnegative, beforeSlice, afterSlice []ast.Expr) {
390 requireGuard := true
391 if l := constSubstrLen(info, substr); l != -1 && l != 1 {
392 requireGuard = false
393 }
394
395 use := func(cur inspector.Cursor) bool {
396 ek, _ := cur.ParentEdge()
397 n := cur.Parent().Node()
398 switch ek {
399 case edge.BinaryExpr_X, edge.BinaryExpr_Y:
400 check := n.(*ast.BinaryExpr)
401 switch checkIdxComparison(info, check, iObj) {
402 case -1:
403 negative = append(negative, check)
404 return true
405 case 1:
406 nonnegative = append(nonnegative, check)
407 return true
408 }
409
410
411
412
413
414 if slice, ok := cur.Parent().Parent().Node().(*ast.SliceExpr); ok &&
415 sameObject(info, s, slice.X) &&
416 slice.Max == nil {
417 if isBeforeSlice(info, ek, slice) && (!requireGuard || isSliceIndexGuarded(info, cur, iObj)) {
418 beforeSlice = append(beforeSlice, slice)
419 return true
420 } else if isAfterSlice(info, ek, slice, substr) && (!requireGuard || isSliceIndexGuarded(info, cur, iObj)) {
421 afterSlice = append(afterSlice, slice)
422 return true
423 }
424 }
425 case edge.SliceExpr_Low, edge.SliceExpr_High:
426 slice := n.(*ast.SliceExpr)
427
428
429 if sameObject(info, s, slice.X) && slice.Max == nil {
430 if isBeforeSlice(info, ek, slice) && (!requireGuard || isSliceIndexGuarded(info, cur, iObj)) {
431 beforeSlice = append(beforeSlice, slice)
432 return true
433 } else if isAfterSlice(info, ek, slice, substr) && (!requireGuard || isSliceIndexGuarded(info, cur, iObj)) {
434 afterSlice = append(afterSlice, slice)
435 return true
436 }
437 }
438 }
439 return false
440 }
441
442 for curIdent := range uses {
443 if !use(curIdent) {
444 return nil, nil, nil, nil
445 }
446 }
447 return negative, nonnegative, beforeSlice, afterSlice
448 }
449
450
451
452
453 func hasModifyingUses(info *types.Info, uses iter.Seq[inspector.Cursor], afterPos token.Pos) bool {
454 for curUse := range uses {
455 ek, _ := curUse.ParentEdge()
456 if ek == edge.AssignStmt_Lhs {
457 if curUse.Node().Pos() <= afterPos {
458 continue
459 }
460 assign := curUse.Parent().Node().(*ast.AssignStmt)
461 if sameObject(info, assign.Lhs[0], curUse.Node().(*ast.Ident)) {
462
463 return true
464 }
465 } else if ek == edge.UnaryExpr_X &&
466 curUse.Parent().Node().(*ast.UnaryExpr).Op == token.AND {
467
468
469
470
471 return true
472 }
473 }
474 return false
475 }
476
477
478
479
480
481
482
483
484
485 func checkIdxComparison(info *types.Info, check *ast.BinaryExpr, iObj types.Object) int {
486 isI := func(e ast.Expr) bool {
487 id, ok := e.(*ast.Ident)
488 return ok && info.Uses[id] == iObj
489 }
490 if !isI(check.X) && !isI(check.Y) {
491 return 0
492 }
493
494
495 x, op, y := check.X, check.Op, check.Y
496 if info.Types[x].Value != nil {
497 x, op, y = y, flip(op), x
498 }
499
500 yIsInt := func(k int64) bool {
501 return isIntLiteral(info, y, k)
502 }
503
504 if op == token.LSS && yIsInt(0) ||
505 op == token.EQL && yIsInt(-1) ||
506 op == token.LEQ && yIsInt(-1) {
507 return -1
508 }
509
510 if op == token.GEQ && yIsInt(0) ||
511 op == token.NEQ && yIsInt(-1) ||
512 op == token.GTR && yIsInt(-1) {
513 return +1
514 }
515
516 return 0
517 }
518
519
520
521 func flip(op token.Token) token.Token {
522 switch op {
523 case token.EQL:
524 return token.EQL
525 case token.GEQ:
526 return token.LEQ
527 case token.GTR:
528 return token.LSS
529 case token.LEQ:
530 return token.GEQ
531 case token.LSS:
532 return token.GTR
533 }
534 return op
535 }
536
537
538 func isBeforeSlice(info *types.Info, ek edge.Kind, slice *ast.SliceExpr) bool {
539 return ek == edge.SliceExpr_High && (slice.Low == nil || isZeroIntConst(info, slice.Low))
540 }
541
542
543 func constSubstrLen(info *types.Info, substr ast.Expr) int {
544
545 if call, ok := substr.(*ast.CallExpr); ok {
546 tv := info.Types[call.Fun]
547 if tv.IsType() && types.Identical(tv.Type, byteSliceType) {
548
549 substr = call.Args[0]
550 }
551 }
552 substrVal := info.Types[substr].Value
553 if substrVal != nil {
554 switch substrVal.Kind() {
555 case constant.String:
556 return len(constant.StringVal(substrVal))
557 case constant.Int:
558
559
560
561 return 1
562 }
563 }
564 return -1
565 }
566
567
568
569 func isAfterSlice(info *types.Info, ek edge.Kind, slice *ast.SliceExpr, substr ast.Expr) bool {
570 lowExpr, ok := slice.Low.(*ast.BinaryExpr)
571 if !ok || slice.High != nil {
572 return false
573 }
574
575 isLenCall := func(expr ast.Expr) bool {
576 call, ok := expr.(*ast.CallExpr)
577 if !ok || len(call.Args) != 1 {
578 return false
579 }
580 return sameObject(info, substr, call.Args[0]) && typeutil.Callee(info, call) == builtinLen
581 }
582
583 substrLen := constSubstrLen(info, substr)
584
585 switch ek {
586 case edge.BinaryExpr_X:
587 kVal := info.Types[lowExpr.Y].Value
588 if kVal == nil {
589
590 return lowExpr.Op == token.ADD && isLenCall(lowExpr.Y)
591 } else {
592
593 kInt, ok := constant.Int64Val(kVal)
594 return ok && substrLen == int(kInt)
595 }
596 case edge.BinaryExpr_Y:
597 kVal := info.Types[lowExpr.X].Value
598 if kVal == nil {
599
600 return lowExpr.Op == token.ADD && isLenCall(lowExpr.X)
601 } else {
602
603 kInt, ok := constant.Int64Val(kVal)
604 return ok && substrLen == int(kInt)
605 }
606 }
607 return false
608 }
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624 func isSliceIndexGuarded(info *types.Info, cur inspector.Cursor, iObj types.Object) bool {
625 for anc := range cur.Enclosing() {
626 switch ek, _ := anc.ParentEdge(); ek {
627 case edge.IfStmt_Body, edge.IfStmt_Else:
628 ifStmt := anc.Parent().Node().(*ast.IfStmt)
629 check := condChecksIdx(info, ifStmt.Cond, iObj)
630 if ek == edge.IfStmt_Else {
631 check = -check
632 }
633 if check > 0 {
634 return true
635 }
636 if check < 0 {
637 return false
638 }
639 case edge.BlockStmt_List:
640
641 for sib, ok := anc.PrevSibling(); ok; sib, ok = sib.PrevSibling() {
642 ifStmt, ok := sib.Node().(*ast.IfStmt)
643 if ok && condChecksIdx(info, ifStmt.Cond, iObj) < 0 && bodyTerminates(ifStmt.Body) {
644 return true
645 }
646 }
647 case edge.FuncDecl_Body, edge.FuncLit_Body:
648 return false
649 }
650 }
651 return false
652 }
653
654
655
656
657 func condChecksIdx(info *types.Info, cond ast.Expr, iObj types.Object) int {
658 binExpr, ok := cond.(*ast.BinaryExpr)
659 if !ok {
660 return 0
661 }
662 return checkIdxComparison(info, binExpr, iObj)
663 }
664
665
666
667 func bodyTerminates(block *ast.BlockStmt) bool {
668 if len(block.List) == 0 {
669 return false
670 }
671 last := block.List[len(block.List)-1]
672 switch last.(type) {
673 case *ast.ReturnStmt, *ast.BranchStmt:
674 return true
675 }
676 return false
677 }
678
679
680 func sameObject(info *types.Info, expr1, expr2 ast.Expr) bool {
681 if ident1, ok := expr1.(*ast.Ident); ok {
682 if ident2, ok := expr2.(*ast.Ident); ok {
683 uses1, ok1 := info.Uses[ident1]
684 uses2, ok2 := info.Uses[ident2]
685 return ok1 && ok2 && uses1 == uses2
686 }
687 }
688 return false
689 }
690
View as plain text