// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package main import ( "bytes" "flag" "fmt" "go/ast" "go/format" "go/parser" "go/scanner" "go/token" "go/version" "internal/diff" "io" "io/fs" "os" "path/filepath" "slices" "strings" "cmd/internal/telemetry/counter" ) var ( fset = token.NewFileSet() exitCode = 0 ) var allowedRewrites = flag.String("r", "", "restrict the rewrites to this comma-separated list") var forceRewrites = flag.String("force", "", "force these fixes to run even if the code looks updated") var allowed, force map[string]bool var ( doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files") goVersion = flag.String("go", "", "go language version for files") ) // enable for debugging fix failures const debug = false // display incorrectly reformatted source and exit func usage() { fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n") flag.PrintDefaults() fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n") slices.SortFunc(fixes, func(a, b fix) int { return strings.Compare(a.name, b.name) }) for _, f := range fixes { if f.disabled { fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name) } else { fmt.Fprintf(os.Stderr, "\n%s\n", f.name) } desc := strings.TrimSpace(f.desc) desc = strings.ReplaceAll(desc, "\n", "\n\t") fmt.Fprintf(os.Stderr, "\t%s\n", desc) } os.Exit(2) } func main() { counter.Open() flag.Usage = usage flag.Parse() counter.Inc("fix/invocations") counter.CountFlags("fix/flag:", *flag.CommandLine) if !version.IsValid(*goVersion) { report(fmt.Errorf("invalid -go=%s", *goVersion)) os.Exit(exitCode) } slices.SortFunc(fixes, func(a, b fix) int { return strings.Compare(a.date, b.date) }) if *allowedRewrites != "" { allowed = make(map[string]bool) for _, f := range strings.Split(*allowedRewrites, ",") { allowed[f] = true } } if *forceRewrites != "" { force = make(map[string]bool) for _, f := range strings.Split(*forceRewrites, ",") { force[f] = true } } if flag.NArg() == 0 { if err := processFile("standard input", true); err != nil { report(err) } os.Exit(exitCode) } for i := 0; i < flag.NArg(); i++ { path := flag.Arg(i) switch dir, err := os.Stat(path); { case err != nil: report(err) case dir.IsDir(): walkDir(path) default: if err := processFile(path, false); err != nil { report(err) } } } os.Exit(exitCode) } const parserMode = parser.ParseComments func gofmtFile(f *ast.File) ([]byte, error) { var buf bytes.Buffer if err := format.Node(&buf, fset, f); err != nil { return nil, err } return buf.Bytes(), nil } func processFile(filename string, useStdin bool) error { var f *os.File var err error var fixlog strings.Builder if useStdin { f = os.Stdin } else { f, err = os.Open(filename) if err != nil { return err } defer f.Close() } src, err := io.ReadAll(f) if err != nil { return err } file, err := parser.ParseFile(fset, filename, src, parserMode) if err != nil { return err } // Make sure file is in canonical format. // This "fmt" pseudo-fix cannot be disabled. newSrc, err := gofmtFile(file) if err != nil { return err } if !bytes.Equal(newSrc, src) { newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode) if err != nil { return err } file = newFile fmt.Fprintf(&fixlog, " fmt") } // Apply all fixes to file. newFile := file fixed := false for _, fix := range fixes { if allowed != nil && !allowed[fix.name] { continue } if fix.disabled && !force[fix.name] { continue } if fix.f(newFile) { fixed = true fmt.Fprintf(&fixlog, " %s", fix.name) // AST changed. // Print and parse, to update any missing scoping // or position information for subsequent fixers. newSrc, err := gofmtFile(newFile) if err != nil { return err } newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode) if err != nil { if debug { fmt.Printf("%s", newSrc) report(err) os.Exit(exitCode) } return err } } } if !fixed { return nil } fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:]) // Print AST. We did that after each fix, so this appears // redundant, but it is necessary to generate gofmt-compatible // source code in a few cases. The official gofmt style is the // output of the printer run on a standard AST generated by the parser, // but the source we generated inside the loop above is the // output of the printer run on a mangled AST generated by a fixer. newSrc, err = gofmtFile(newFile) if err != nil { return err } if *doDiff { os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc)) return nil } if useStdin { os.Stdout.Write(newSrc) return nil } return os.WriteFile(f.Name(), newSrc, 0) } func gofmt(n any) string { var gofmtBuf strings.Builder if err := format.Node(&gofmtBuf, fset, n); err != nil { return "<" + err.Error() + ">" } return gofmtBuf.String() } func report(err error) { scanner.PrintError(os.Stderr, err) exitCode = 2 } func walkDir(path string) { filepath.WalkDir(path, visitFile) } func visitFile(path string, f fs.DirEntry, err error) error { if err == nil && isGoFile(f) { err = processFile(path, false) } if err != nil { report(err) } return nil } func isGoFile(f fs.DirEntry) bool { // ignore non-Go files name := f.Name() return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") }