Source file
src/cmd/fix/main.go
1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "flag"
10 "fmt"
11 "go/ast"
12 "go/format"
13 "go/parser"
14 "go/scanner"
15 "go/token"
16 "go/version"
17 "internal/diff"
18 "io"
19 "io/fs"
20 "os"
21 "path/filepath"
22 "slices"
23 "strings"
24
25 "cmd/internal/telemetry/counter"
26 )
27
28 var (
29 fset = token.NewFileSet()
30 exitCode = 0
31 )
32
33 var allowedRewrites = flag.String("r", "",
34 "restrict the rewrites to this comma-separated list")
35
36 var forceRewrites = flag.String("force", "",
37 "force these fixes to run even if the code looks updated")
38
39 var allowed, force map[string]bool
40
41 var (
42 doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
43 goVersion = flag.String("go", "", "go language version for files")
44 )
45
46
47 const debug = false
48
49 func usage() {
50 fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
51 flag.PrintDefaults()
52 fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
53 slices.SortFunc(fixes, func(a, b fix) int {
54 return strings.Compare(a.name, b.name)
55 })
56 for _, f := range fixes {
57 if f.disabled {
58 fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
59 } else {
60 fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
61 }
62 desc := strings.TrimSpace(f.desc)
63 desc = strings.ReplaceAll(desc, "\n", "\n\t")
64 fmt.Fprintf(os.Stderr, "\t%s\n", desc)
65 }
66 os.Exit(2)
67 }
68
69 func main() {
70 counter.Open()
71 flag.Usage = usage
72 flag.Parse()
73 counter.Inc("fix/invocations")
74 counter.CountFlags("fix/flag:", *flag.CommandLine)
75
76 if !version.IsValid(*goVersion) {
77 report(fmt.Errorf("invalid -go=%s", *goVersion))
78 os.Exit(exitCode)
79 }
80
81 slices.SortFunc(fixes, func(a, b fix) int {
82 return strings.Compare(a.date, b.date)
83 })
84
85 if *allowedRewrites != "" {
86 allowed = make(map[string]bool)
87 for _, f := range strings.Split(*allowedRewrites, ",") {
88 allowed[f] = true
89 }
90 }
91
92 if *forceRewrites != "" {
93 force = make(map[string]bool)
94 for _, f := range strings.Split(*forceRewrites, ",") {
95 force[f] = true
96 }
97 }
98
99 if flag.NArg() == 0 {
100 if err := processFile("standard input", true); err != nil {
101 report(err)
102 }
103 os.Exit(exitCode)
104 }
105
106 for i := 0; i < flag.NArg(); i++ {
107 path := flag.Arg(i)
108 switch dir, err := os.Stat(path); {
109 case err != nil:
110 report(err)
111 case dir.IsDir():
112 walkDir(path)
113 default:
114 if err := processFile(path, false); err != nil {
115 report(err)
116 }
117 }
118 }
119
120 os.Exit(exitCode)
121 }
122
123 const parserMode = parser.ParseComments
124
125 func gofmtFile(f *ast.File) ([]byte, error) {
126 var buf bytes.Buffer
127 if err := format.Node(&buf, fset, f); err != nil {
128 return nil, err
129 }
130 return buf.Bytes(), nil
131 }
132
133 func processFile(filename string, useStdin bool) error {
134 var f *os.File
135 var err error
136 var fixlog strings.Builder
137
138 if useStdin {
139 f = os.Stdin
140 } else {
141 f, err = os.Open(filename)
142 if err != nil {
143 return err
144 }
145 defer f.Close()
146 }
147
148 src, err := io.ReadAll(f)
149 if err != nil {
150 return err
151 }
152
153 file, err := parser.ParseFile(fset, filename, src, parserMode)
154 if err != nil {
155 return err
156 }
157
158
159
160 newSrc, err := gofmtFile(file)
161 if err != nil {
162 return err
163 }
164 if !bytes.Equal(newSrc, src) {
165 newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
166 if err != nil {
167 return err
168 }
169 file = newFile
170 fmt.Fprintf(&fixlog, " fmt")
171 }
172
173
174 newFile := file
175 fixed := false
176 for _, fix := range fixes {
177 if allowed != nil && !allowed[fix.name] {
178 continue
179 }
180 if fix.disabled && !force[fix.name] {
181 continue
182 }
183 if fix.f(newFile) {
184 fixed = true
185 fmt.Fprintf(&fixlog, " %s", fix.name)
186
187
188
189
190 newSrc, err := gofmtFile(newFile)
191 if err != nil {
192 return err
193 }
194 newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
195 if err != nil {
196 if debug {
197 fmt.Printf("%s", newSrc)
198 report(err)
199 os.Exit(exitCode)
200 }
201 return err
202 }
203 }
204 }
205 if !fixed {
206 return nil
207 }
208 fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
209
210
211
212
213
214
215
216 newSrc, err = gofmtFile(newFile)
217 if err != nil {
218 return err
219 }
220
221 if *doDiff {
222 os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc))
223 return nil
224 }
225
226 if useStdin {
227 os.Stdout.Write(newSrc)
228 return nil
229 }
230
231 return os.WriteFile(f.Name(), newSrc, 0)
232 }
233
234 func gofmt(n any) string {
235 var gofmtBuf strings.Builder
236 if err := format.Node(&gofmtBuf, fset, n); err != nil {
237 return "<" + err.Error() + ">"
238 }
239 return gofmtBuf.String()
240 }
241
242 func report(err error) {
243 scanner.PrintError(os.Stderr, err)
244 exitCode = 2
245 }
246
247 func walkDir(path string) {
248 filepath.WalkDir(path, visitFile)
249 }
250
251 func visitFile(path string, f fs.DirEntry, err error) error {
252 if err == nil && isGoFile(f) {
253 err = processFile(path, false)
254 }
255 if err != nil {
256 report(err)
257 }
258 return nil
259 }
260
261 func isGoFile(f fs.DirEntry) bool {
262
263 name := f.Name()
264 return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
265 }
266
View as plain text