Source file src/cmd/fix/fix.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/token"
    11  	"path"
    12  	"strconv"
    13  )
    14  
    15  type fix struct {
    16  	name     string
    17  	date     string // date that fix was introduced, in YYYY-MM-DD format
    18  	f        func(*ast.File) bool
    19  	desc     string
    20  	disabled bool // whether this fix should be disabled by default
    21  }
    22  
    23  var fixes []fix
    24  
    25  func register(f fix) {
    26  	fixes = append(fixes, f)
    27  }
    28  
    29  // walk traverses the AST x, calling visit(y) for each node y in the tree but
    30  // also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt,
    31  // in a bottom-up traversal.
    32  func walk(x any, visit func(any)) {
    33  	walkBeforeAfter(x, nop, visit)
    34  }
    35  
    36  func nop(any) {}
    37  
    38  // walkBeforeAfter is like walk but calls before(x) before traversing
    39  // x's children and after(x) afterward.
    40  func walkBeforeAfter(x any, before, after func(any)) {
    41  	before(x)
    42  
    43  	switch n := x.(type) {
    44  	default:
    45  		panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x))
    46  
    47  	case nil:
    48  
    49  	// pointers to interfaces
    50  	case *ast.Decl:
    51  		walkBeforeAfter(*n, before, after)
    52  	case *ast.Expr:
    53  		walkBeforeAfter(*n, before, after)
    54  	case *ast.Spec:
    55  		walkBeforeAfter(*n, before, after)
    56  	case *ast.Stmt:
    57  		walkBeforeAfter(*n, before, after)
    58  
    59  	// pointers to struct pointers
    60  	case **ast.BlockStmt:
    61  		walkBeforeAfter(*n, before, after)
    62  	case **ast.CallExpr:
    63  		walkBeforeAfter(*n, before, after)
    64  	case **ast.FieldList:
    65  		walkBeforeAfter(*n, before, after)
    66  	case **ast.FuncType:
    67  		walkBeforeAfter(*n, before, after)
    68  	case **ast.Ident:
    69  		walkBeforeAfter(*n, before, after)
    70  	case **ast.BasicLit:
    71  		walkBeforeAfter(*n, before, after)
    72  
    73  	// pointers to slices
    74  	case *[]ast.Decl:
    75  		walkBeforeAfter(*n, before, after)
    76  	case *[]ast.Expr:
    77  		walkBeforeAfter(*n, before, after)
    78  	case *[]*ast.File:
    79  		walkBeforeAfter(*n, before, after)
    80  	case *[]*ast.Ident:
    81  		walkBeforeAfter(*n, before, after)
    82  	case *[]ast.Spec:
    83  		walkBeforeAfter(*n, before, after)
    84  	case *[]ast.Stmt:
    85  		walkBeforeAfter(*n, before, after)
    86  
    87  	// These are ordered and grouped to match ../../go/ast/ast.go
    88  	case *ast.Field:
    89  		walkBeforeAfter(&n.Names, before, after)
    90  		walkBeforeAfter(&n.Type, before, after)
    91  		walkBeforeAfter(&n.Tag, before, after)
    92  	case *ast.FieldList:
    93  		for _, field := range n.List {
    94  			walkBeforeAfter(field, before, after)
    95  		}
    96  	case *ast.BadExpr:
    97  	case *ast.Ident:
    98  	case *ast.Ellipsis:
    99  		walkBeforeAfter(&n.Elt, before, after)
   100  	case *ast.BasicLit:
   101  	case *ast.FuncLit:
   102  		walkBeforeAfter(&n.Type, before, after)
   103  		walkBeforeAfter(&n.Body, before, after)
   104  	case *ast.CompositeLit:
   105  		walkBeforeAfter(&n.Type, before, after)
   106  		walkBeforeAfter(&n.Elts, before, after)
   107  	case *ast.ParenExpr:
   108  		walkBeforeAfter(&n.X, before, after)
   109  	case *ast.SelectorExpr:
   110  		walkBeforeAfter(&n.X, before, after)
   111  	case *ast.IndexExpr:
   112  		walkBeforeAfter(&n.X, before, after)
   113  		walkBeforeAfter(&n.Index, before, after)
   114  	case *ast.IndexListExpr:
   115  		walkBeforeAfter(&n.X, before, after)
   116  		walkBeforeAfter(&n.Indices, before, after)
   117  	case *ast.SliceExpr:
   118  		walkBeforeAfter(&n.X, before, after)
   119  		if n.Low != nil {
   120  			walkBeforeAfter(&n.Low, before, after)
   121  		}
   122  		if n.High != nil {
   123  			walkBeforeAfter(&n.High, before, after)
   124  		}
   125  	case *ast.TypeAssertExpr:
   126  		walkBeforeAfter(&n.X, before, after)
   127  		walkBeforeAfter(&n.Type, before, after)
   128  	case *ast.CallExpr:
   129  		walkBeforeAfter(&n.Fun, before, after)
   130  		walkBeforeAfter(&n.Args, before, after)
   131  	case *ast.StarExpr:
   132  		walkBeforeAfter(&n.X, before, after)
   133  	case *ast.UnaryExpr:
   134  		walkBeforeAfter(&n.X, before, after)
   135  	case *ast.BinaryExpr:
   136  		walkBeforeAfter(&n.X, before, after)
   137  		walkBeforeAfter(&n.Y, before, after)
   138  	case *ast.KeyValueExpr:
   139  		walkBeforeAfter(&n.Key, before, after)
   140  		walkBeforeAfter(&n.Value, before, after)
   141  
   142  	case *ast.ArrayType:
   143  		walkBeforeAfter(&n.Len, before, after)
   144  		walkBeforeAfter(&n.Elt, before, after)
   145  	case *ast.StructType:
   146  		walkBeforeAfter(&n.Fields, before, after)
   147  	case *ast.FuncType:
   148  		if n.TypeParams != nil {
   149  			walkBeforeAfter(&n.TypeParams, before, after)
   150  		}
   151  		walkBeforeAfter(&n.Params, before, after)
   152  		if n.Results != nil {
   153  			walkBeforeAfter(&n.Results, before, after)
   154  		}
   155  	case *ast.InterfaceType:
   156  		walkBeforeAfter(&n.Methods, before, after)
   157  	case *ast.MapType:
   158  		walkBeforeAfter(&n.Key, before, after)
   159  		walkBeforeAfter(&n.Value, before, after)
   160  	case *ast.ChanType:
   161  		walkBeforeAfter(&n.Value, before, after)
   162  
   163  	case *ast.BadStmt:
   164  	case *ast.DeclStmt:
   165  		walkBeforeAfter(&n.Decl, before, after)
   166  	case *ast.EmptyStmt:
   167  	case *ast.LabeledStmt:
   168  		walkBeforeAfter(&n.Stmt, before, after)
   169  	case *ast.ExprStmt:
   170  		walkBeforeAfter(&n.X, before, after)
   171  	case *ast.SendStmt:
   172  		walkBeforeAfter(&n.Chan, before, after)
   173  		walkBeforeAfter(&n.Value, before, after)
   174  	case *ast.IncDecStmt:
   175  		walkBeforeAfter(&n.X, before, after)
   176  	case *ast.AssignStmt:
   177  		walkBeforeAfter(&n.Lhs, before, after)
   178  		walkBeforeAfter(&n.Rhs, before, after)
   179  	case *ast.GoStmt:
   180  		walkBeforeAfter(&n.Call, before, after)
   181  	case *ast.DeferStmt:
   182  		walkBeforeAfter(&n.Call, before, after)
   183  	case *ast.ReturnStmt:
   184  		walkBeforeAfter(&n.Results, before, after)
   185  	case *ast.BranchStmt:
   186  	case *ast.BlockStmt:
   187  		walkBeforeAfter(&n.List, before, after)
   188  	case *ast.IfStmt:
   189  		walkBeforeAfter(&n.Init, before, after)
   190  		walkBeforeAfter(&n.Cond, before, after)
   191  		walkBeforeAfter(&n.Body, before, after)
   192  		walkBeforeAfter(&n.Else, before, after)
   193  	case *ast.CaseClause:
   194  		walkBeforeAfter(&n.List, before, after)
   195  		walkBeforeAfter(&n.Body, before, after)
   196  	case *ast.SwitchStmt:
   197  		walkBeforeAfter(&n.Init, before, after)
   198  		walkBeforeAfter(&n.Tag, before, after)
   199  		walkBeforeAfter(&n.Body, before, after)
   200  	case *ast.TypeSwitchStmt:
   201  		walkBeforeAfter(&n.Init, before, after)
   202  		walkBeforeAfter(&n.Assign, before, after)
   203  		walkBeforeAfter(&n.Body, before, after)
   204  	case *ast.CommClause:
   205  		walkBeforeAfter(&n.Comm, before, after)
   206  		walkBeforeAfter(&n.Body, before, after)
   207  	case *ast.SelectStmt:
   208  		walkBeforeAfter(&n.Body, before, after)
   209  	case *ast.ForStmt:
   210  		walkBeforeAfter(&n.Init, before, after)
   211  		walkBeforeAfter(&n.Cond, before, after)
   212  		walkBeforeAfter(&n.Post, before, after)
   213  		walkBeforeAfter(&n.Body, before, after)
   214  	case *ast.RangeStmt:
   215  		walkBeforeAfter(&n.Key, before, after)
   216  		walkBeforeAfter(&n.Value, before, after)
   217  		walkBeforeAfter(&n.X, before, after)
   218  		walkBeforeAfter(&n.Body, before, after)
   219  
   220  	case *ast.ImportSpec:
   221  	case *ast.ValueSpec:
   222  		walkBeforeAfter(&n.Type, before, after)
   223  		walkBeforeAfter(&n.Values, before, after)
   224  		walkBeforeAfter(&n.Names, before, after)
   225  	case *ast.TypeSpec:
   226  		if n.TypeParams != nil {
   227  			walkBeforeAfter(&n.TypeParams, before, after)
   228  		}
   229  		walkBeforeAfter(&n.Type, before, after)
   230  
   231  	case *ast.BadDecl:
   232  	case *ast.GenDecl:
   233  		walkBeforeAfter(&n.Specs, before, after)
   234  	case *ast.FuncDecl:
   235  		if n.Recv != nil {
   236  			walkBeforeAfter(&n.Recv, before, after)
   237  		}
   238  		walkBeforeAfter(&n.Type, before, after)
   239  		if n.Body != nil {
   240  			walkBeforeAfter(&n.Body, before, after)
   241  		}
   242  
   243  	case *ast.File:
   244  		walkBeforeAfter(&n.Decls, before, after)
   245  
   246  	case *ast.Package:
   247  		walkBeforeAfter(&n.Files, before, after)
   248  
   249  	case []*ast.File:
   250  		for i := range n {
   251  			walkBeforeAfter(&n[i], before, after)
   252  		}
   253  	case []ast.Decl:
   254  		for i := range n {
   255  			walkBeforeAfter(&n[i], before, after)
   256  		}
   257  	case []ast.Expr:
   258  		for i := range n {
   259  			walkBeforeAfter(&n[i], before, after)
   260  		}
   261  	case []*ast.Ident:
   262  		for i := range n {
   263  			walkBeforeAfter(&n[i], before, after)
   264  		}
   265  	case []ast.Stmt:
   266  		for i := range n {
   267  			walkBeforeAfter(&n[i], before, after)
   268  		}
   269  	case []ast.Spec:
   270  		for i := range n {
   271  			walkBeforeAfter(&n[i], before, after)
   272  		}
   273  	}
   274  	after(x)
   275  }
   276  
   277  // imports reports whether f imports path.
   278  func imports(f *ast.File, path string) bool {
   279  	return importSpec(f, path) != nil
   280  }
   281  
   282  // importSpec returns the import spec if f imports path,
   283  // or nil otherwise.
   284  func importSpec(f *ast.File, path string) *ast.ImportSpec {
   285  	for _, s := range f.Imports {
   286  		if importPath(s) == path {
   287  			return s
   288  		}
   289  	}
   290  	return nil
   291  }
   292  
   293  // importPath returns the unquoted import path of s,
   294  // or "" if the path is not properly quoted.
   295  func importPath(s *ast.ImportSpec) string {
   296  	t, err := strconv.Unquote(s.Path.Value)
   297  	if err == nil {
   298  		return t
   299  	}
   300  	return ""
   301  }
   302  
   303  // declImports reports whether gen contains an import of path.
   304  func declImports(gen *ast.GenDecl, path string) bool {
   305  	if gen.Tok != token.IMPORT {
   306  		return false
   307  	}
   308  	for _, spec := range gen.Specs {
   309  		impspec := spec.(*ast.ImportSpec)
   310  		if importPath(impspec) == path {
   311  			return true
   312  		}
   313  	}
   314  	return false
   315  }
   316  
   317  // isTopName reports whether n is a top-level unresolved identifier with the given name.
   318  func isTopName(n ast.Expr, name string) bool {
   319  	id, ok := n.(*ast.Ident)
   320  	return ok && id.Name == name && id.Obj == nil
   321  }
   322  
   323  // renameTop renames all references to the top-level name old.
   324  // It reports whether it makes any changes.
   325  func renameTop(f *ast.File, old, new string) bool {
   326  	var fixed bool
   327  
   328  	// Rename any conflicting imports
   329  	// (assuming package name is last element of path).
   330  	for _, s := range f.Imports {
   331  		if s.Name != nil {
   332  			if s.Name.Name == old {
   333  				s.Name.Name = new
   334  				fixed = true
   335  			}
   336  		} else {
   337  			_, thisName := path.Split(importPath(s))
   338  			if thisName == old {
   339  				s.Name = ast.NewIdent(new)
   340  				fixed = true
   341  			}
   342  		}
   343  	}
   344  
   345  	// Rename any top-level declarations.
   346  	for _, d := range f.Decls {
   347  		switch d := d.(type) {
   348  		case *ast.FuncDecl:
   349  			if d.Recv == nil && d.Name.Name == old {
   350  				d.Name.Name = new
   351  				d.Name.Obj.Name = new
   352  				fixed = true
   353  			}
   354  		case *ast.GenDecl:
   355  			for _, s := range d.Specs {
   356  				switch s := s.(type) {
   357  				case *ast.TypeSpec:
   358  					if s.Name.Name == old {
   359  						s.Name.Name = new
   360  						s.Name.Obj.Name = new
   361  						fixed = true
   362  					}
   363  				case *ast.ValueSpec:
   364  					for _, n := range s.Names {
   365  						if n.Name == old {
   366  							n.Name = new
   367  							n.Obj.Name = new
   368  							fixed = true
   369  						}
   370  					}
   371  				}
   372  			}
   373  		}
   374  	}
   375  
   376  	// Rename top-level old to new, both unresolved names
   377  	// (probably defined in another file) and names that resolve
   378  	// to a declaration we renamed.
   379  	walk(f, func(n any) {
   380  		id, ok := n.(*ast.Ident)
   381  		if ok && isTopName(id, old) {
   382  			id.Name = new
   383  			fixed = true
   384  		}
   385  		if ok && id.Obj != nil && id.Name == old && id.Obj.Name == new {
   386  			id.Name = id.Obj.Name
   387  			fixed = true
   388  		}
   389  	})
   390  
   391  	return fixed
   392  }
   393  
   394  // matchLen returns the length of the longest prefix shared by x and y.
   395  func matchLen(x, y string) int {
   396  	i := 0
   397  	for i < len(x) && i < len(y) && x[i] == y[i] {
   398  		i++
   399  	}
   400  	return i
   401  }
   402  
   403  // addImport adds the import path to the file f, if absent.
   404  func addImport(f *ast.File, ipath string) (added bool) {
   405  	if imports(f, ipath) {
   406  		return false
   407  	}
   408  
   409  	// Determine name of import.
   410  	// Assume added imports follow convention of using last element.
   411  	_, name := path.Split(ipath)
   412  
   413  	// Rename any conflicting top-level references from name to name_.
   414  	renameTop(f, name, name+"_")
   415  
   416  	newImport := &ast.ImportSpec{
   417  		Path: &ast.BasicLit{
   418  			Kind:  token.STRING,
   419  			Value: strconv.Quote(ipath),
   420  		},
   421  	}
   422  
   423  	// Find an import decl to add to.
   424  	var (
   425  		bestMatch  = -1
   426  		lastImport = -1
   427  		impDecl    *ast.GenDecl
   428  		impIndex   = -1
   429  	)
   430  	for i, decl := range f.Decls {
   431  		gen, ok := decl.(*ast.GenDecl)
   432  		if ok && gen.Tok == token.IMPORT {
   433  			lastImport = i
   434  			// Do not add to import "C", to avoid disrupting the
   435  			// association with its doc comment, breaking cgo.
   436  			if declImports(gen, "C") {
   437  				continue
   438  			}
   439  
   440  			// Compute longest shared prefix with imports in this block.
   441  			for j, spec := range gen.Specs {
   442  				impspec := spec.(*ast.ImportSpec)
   443  				n := matchLen(importPath(impspec), ipath)
   444  				if n > bestMatch {
   445  					bestMatch = n
   446  					impDecl = gen
   447  					impIndex = j
   448  				}
   449  			}
   450  		}
   451  	}
   452  
   453  	// If no import decl found, add one after the last import.
   454  	if impDecl == nil {
   455  		impDecl = &ast.GenDecl{
   456  			Tok: token.IMPORT,
   457  		}
   458  		f.Decls = append(f.Decls, nil)
   459  		copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:])
   460  		f.Decls[lastImport+1] = impDecl
   461  	}
   462  
   463  	// Ensure the import decl has parentheses, if needed.
   464  	if len(impDecl.Specs) > 0 && !impDecl.Lparen.IsValid() {
   465  		impDecl.Lparen = impDecl.Pos()
   466  	}
   467  
   468  	insertAt := impIndex + 1
   469  	if insertAt == 0 {
   470  		insertAt = len(impDecl.Specs)
   471  	}
   472  	impDecl.Specs = append(impDecl.Specs, nil)
   473  	copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:])
   474  	impDecl.Specs[insertAt] = newImport
   475  	if insertAt > 0 {
   476  		// Assign same position as the previous import,
   477  		// so that the sorter sees it as being in the same block.
   478  		prev := impDecl.Specs[insertAt-1]
   479  		newImport.Path.ValuePos = prev.Pos()
   480  		newImport.EndPos = prev.Pos()
   481  	}
   482  
   483  	f.Imports = append(f.Imports, newImport)
   484  	return true
   485  }
   486  
   487  // deleteImport deletes the import path from the file f, if present.
   488  func deleteImport(f *ast.File, path string) (deleted bool) {
   489  	oldImport := importSpec(f, path)
   490  
   491  	// Find the import node that imports path, if any.
   492  	for i, decl := range f.Decls {
   493  		gen, ok := decl.(*ast.GenDecl)
   494  		if !ok || gen.Tok != token.IMPORT {
   495  			continue
   496  		}
   497  		for j, spec := range gen.Specs {
   498  			impspec := spec.(*ast.ImportSpec)
   499  			if oldImport != impspec {
   500  				continue
   501  			}
   502  
   503  			// We found an import spec that imports path.
   504  			// Delete it.
   505  			deleted = true
   506  			copy(gen.Specs[j:], gen.Specs[j+1:])
   507  			gen.Specs = gen.Specs[:len(gen.Specs)-1]
   508  
   509  			// If this was the last import spec in this decl,
   510  			// delete the decl, too.
   511  			if len(gen.Specs) == 0 {
   512  				copy(f.Decls[i:], f.Decls[i+1:])
   513  				f.Decls = f.Decls[:len(f.Decls)-1]
   514  			} else if len(gen.Specs) == 1 {
   515  				gen.Lparen = token.NoPos // drop parens
   516  			}
   517  			if j > 0 {
   518  				// We deleted an entry but now there will be
   519  				// a blank line-sized hole where the import was.
   520  				// Close the hole by making the previous
   521  				// import appear to "end" where this one did.
   522  				gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End()
   523  			}
   524  			break
   525  		}
   526  	}
   527  
   528  	// Delete it from f.Imports.
   529  	for i, imp := range f.Imports {
   530  		if imp == oldImport {
   531  			copy(f.Imports[i:], f.Imports[i+1:])
   532  			f.Imports = f.Imports[:len(f.Imports)-1]
   533  			break
   534  		}
   535  	}
   536  
   537  	return
   538  }
   539  
   540  // rewriteImport rewrites any import of path oldPath to path newPath.
   541  func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) {
   542  	for _, imp := range f.Imports {
   543  		if importPath(imp) == oldPath {
   544  			rewrote = true
   545  			// record old End, because the default is to compute
   546  			// it using the length of imp.Path.Value.
   547  			imp.EndPos = imp.End()
   548  			imp.Path.Value = strconv.Quote(newPath)
   549  		}
   550  	}
   551  	return
   552  }
   553  

View as plain text