Source file src/cmd/fix/main.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/parser"
    14  	"go/scanner"
    15  	"go/token"
    16  	"go/version"
    17  	"internal/diff"
    18  	"io"
    19  	"io/fs"
    20  	"os"
    21  	"path/filepath"
    22  	"slices"
    23  	"strings"
    24  
    25  	"cmd/internal/telemetry/counter"
    26  )
    27  
    28  var (
    29  	fset     = token.NewFileSet()
    30  	exitCode = 0
    31  )
    32  
    33  var allowedRewrites = flag.String("r", "",
    34  	"restrict the rewrites to this comma-separated list")
    35  
    36  var forceRewrites = flag.String("force", "",
    37  	"force these fixes to run even if the code looks updated")
    38  
    39  var allowed, force map[string]bool
    40  
    41  var (
    42  	doDiff    = flag.Bool("diff", false, "display diffs instead of rewriting files")
    43  	goVersion = flag.String("go", "", "go language version for files")
    44  )
    45  
    46  // enable for debugging fix failures
    47  const debug = false // display incorrectly reformatted source and exit
    48  
    49  func usage() {
    50  	fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
    51  	flag.PrintDefaults()
    52  	fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
    53  	slices.SortFunc(fixes, func(a, b fix) int {
    54  		return strings.Compare(a.name, b.name)
    55  	})
    56  	for _, f := range fixes {
    57  		if f.disabled {
    58  			fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
    59  		} else {
    60  			fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
    61  		}
    62  		desc := strings.TrimSpace(f.desc)
    63  		desc = strings.ReplaceAll(desc, "\n", "\n\t")
    64  		fmt.Fprintf(os.Stderr, "\t%s\n", desc)
    65  	}
    66  	os.Exit(2)
    67  }
    68  
    69  func main() {
    70  	counter.Open()
    71  	flag.Usage = usage
    72  	flag.Parse()
    73  	counter.Inc("fix/invocations")
    74  	counter.CountFlags("fix/flag:", *flag.CommandLine)
    75  
    76  	if !version.IsValid(*goVersion) {
    77  		report(fmt.Errorf("invalid -go=%s", *goVersion))
    78  		os.Exit(exitCode)
    79  	}
    80  
    81  	slices.SortFunc(fixes, func(a, b fix) int {
    82  		return strings.Compare(a.date, b.date)
    83  	})
    84  
    85  	if *allowedRewrites != "" {
    86  		allowed = make(map[string]bool)
    87  		for _, f := range strings.Split(*allowedRewrites, ",") {
    88  			allowed[f] = true
    89  		}
    90  	}
    91  
    92  	if *forceRewrites != "" {
    93  		force = make(map[string]bool)
    94  		for _, f := range strings.Split(*forceRewrites, ",") {
    95  			force[f] = true
    96  		}
    97  	}
    98  
    99  	if flag.NArg() == 0 {
   100  		if err := processFile("standard input", true); err != nil {
   101  			report(err)
   102  		}
   103  		os.Exit(exitCode)
   104  	}
   105  
   106  	for i := 0; i < flag.NArg(); i++ {
   107  		path := flag.Arg(i)
   108  		switch dir, err := os.Stat(path); {
   109  		case err != nil:
   110  			report(err)
   111  		case dir.IsDir():
   112  			walkDir(path)
   113  		default:
   114  			if err := processFile(path, false); err != nil {
   115  				report(err)
   116  			}
   117  		}
   118  	}
   119  
   120  	os.Exit(exitCode)
   121  }
   122  
   123  const parserMode = parser.ParseComments
   124  
   125  func gofmtFile(f *ast.File) ([]byte, error) {
   126  	var buf bytes.Buffer
   127  	if err := format.Node(&buf, fset, f); err != nil {
   128  		return nil, err
   129  	}
   130  	return buf.Bytes(), nil
   131  }
   132  
   133  func processFile(filename string, useStdin bool) error {
   134  	var f *os.File
   135  	var err error
   136  	var fixlog strings.Builder
   137  
   138  	if useStdin {
   139  		f = os.Stdin
   140  	} else {
   141  		f, err = os.Open(filename)
   142  		if err != nil {
   143  			return err
   144  		}
   145  		defer f.Close()
   146  	}
   147  
   148  	src, err := io.ReadAll(f)
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   154  	if err != nil {
   155  		return err
   156  	}
   157  
   158  	// Make sure file is in canonical format.
   159  	// This "fmt" pseudo-fix cannot be disabled.
   160  	newSrc, err := gofmtFile(file)
   161  	if err != nil {
   162  		return err
   163  	}
   164  	if !bytes.Equal(newSrc, src) {
   165  		newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
   166  		if err != nil {
   167  			return err
   168  		}
   169  		file = newFile
   170  		fmt.Fprintf(&fixlog, " fmt")
   171  	}
   172  
   173  	// Apply all fixes to file.
   174  	newFile := file
   175  	fixed := false
   176  	for _, fix := range fixes {
   177  		if allowed != nil && !allowed[fix.name] {
   178  			continue
   179  		}
   180  		if fix.disabled && !force[fix.name] {
   181  			continue
   182  		}
   183  		if fix.f(newFile) {
   184  			fixed = true
   185  			fmt.Fprintf(&fixlog, " %s", fix.name)
   186  
   187  			// AST changed.
   188  			// Print and parse, to update any missing scoping
   189  			// or position information for subsequent fixers.
   190  			newSrc, err := gofmtFile(newFile)
   191  			if err != nil {
   192  				return err
   193  			}
   194  			newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
   195  			if err != nil {
   196  				if debug {
   197  					fmt.Printf("%s", newSrc)
   198  					report(err)
   199  					os.Exit(exitCode)
   200  				}
   201  				return err
   202  			}
   203  		}
   204  	}
   205  	if !fixed {
   206  		return nil
   207  	}
   208  	fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
   209  
   210  	// Print AST.  We did that after each fix, so this appears
   211  	// redundant, but it is necessary to generate gofmt-compatible
   212  	// source code in a few cases. The official gofmt style is the
   213  	// output of the printer run on a standard AST generated by the parser,
   214  	// but the source we generated inside the loop above is the
   215  	// output of the printer run on a mangled AST generated by a fixer.
   216  	newSrc, err = gofmtFile(newFile)
   217  	if err != nil {
   218  		return err
   219  	}
   220  
   221  	if *doDiff {
   222  		os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc))
   223  		return nil
   224  	}
   225  
   226  	if useStdin {
   227  		os.Stdout.Write(newSrc)
   228  		return nil
   229  	}
   230  
   231  	return os.WriteFile(f.Name(), newSrc, 0)
   232  }
   233  
   234  func gofmt(n any) string {
   235  	var gofmtBuf strings.Builder
   236  	if err := format.Node(&gofmtBuf, fset, n); err != nil {
   237  		return "<" + err.Error() + ">"
   238  	}
   239  	return gofmtBuf.String()
   240  }
   241  
   242  func report(err error) {
   243  	scanner.PrintError(os.Stderr, err)
   244  	exitCode = 2
   245  }
   246  
   247  func walkDir(path string) {
   248  	filepath.WalkDir(path, visitFile)
   249  }
   250  
   251  func visitFile(path string, f fs.DirEntry, err error) error {
   252  	if err == nil && isGoFile(f) {
   253  		err = processFile(path, false)
   254  	}
   255  	if err != nil {
   256  		report(err)
   257  	}
   258  	return nil
   259  }
   260  
   261  func isGoFile(f fs.DirEntry) bool {
   262  	// ignore non-Go files
   263  	name := f.Name()
   264  	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
   265  }
   266  

View as plain text