Source file src/cmd/compile/internal/midway/rewrite.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package midway
     6  
     7  import (
     8  	"cmd/compile/internal/syntax"
     9  	"cmd/compile/internal/types2"
    10  	"fmt"
    11  	"internal/buildcfg"
    12  	"strings"
    13  )
    14  
    15  // "Midway" rewriting
    16  //
    17  // Go attempts to provide a package similar to the the "Highway" library
    18  // for C++ (https://google.github.io/highway).  The library package is "simd"
    19  // and defines vector types with unspecified widths that are bound to particular
    20  // machine dependent types as late as program execution.  This is accomplished
    21  // by rewriting code that depends on these types into code that references
    22  // architecture-specific types, perhaps more than once, and if necessary
    23  // dynamically choosing which version to execute based on hardware attributes.
    24  //
    25  // The rewriting takes place early in the compiler, after type checking but
    26  // before conversion to "unified" IR.  To ensure that types are correctly set
    27  // on the modified version of the code, type checking information is reset and
    28  // the type checking phase is re-run.  The places some limits on the shape of
    29  // the rewrites, but it also ensures that the rewritten code is well-formed.
    30  //
    31  // Rewritten code does not reference "archsimd" types directly, but instead
    32  // references types in a "bridge" package that filters the available methods
    33  // and adds a few more.  The package used relies on a builder/compiler hack;
    34  // the compiler's type checker enforces export naming conventions, but the
    35  // build system limits visibility to unrelated "internal" packages and can be
    36  // modified to allow access in special cases (like this one).  This allows the
    37  // rewritten code to reference types, functions, and methods that are not
    38  // accessible otherwise.
    39  
    40  type Rewriter struct {
    41  	pkg      *types2.Package
    42  	analyzer *Analyzer
    43  	info     *types2.Info
    44  	sizes    []int
    45  }
    46  
    47  func NewRewriter(pkg *types2.Package, info *types2.Info, analyzer *Analyzer, sizes []int) *Rewriter {
    48  	return &Rewriter{
    49  		pkg:      pkg,
    50  		info:     info,
    51  		analyzer: analyzer,
    52  		sizes:    sizes,
    53  	}
    54  }
    55  
    56  func (r *Rewriter) Rewrite(files []*syntax.File) {
    57  
    58  	// First duplicate and specialize all dependent functions and variables.
    59  	for _, fileAST := range files {
    60  
    61  		var newDecls []syntax.Decl
    62  		for _, k := range r.sizes {
    63  			newDecls = r.generateForSize(fileAST, k, newDecls)
    64  		}
    65  
    66  		// Then replace original functions with dispatchers.
    67  		r.generateDispatchers(fileAST)
    68  
    69  		fileAST.DeclList = append(fileAST.DeclList, newDecls...)
    70  	}
    71  }
    72  
    73  func (r *Rewriter) generateDispatchers(fileAST *syntax.File) {
    74  	var newDecls []syntax.Decl
    75  
    76  	for _, decl := range fileAST.DeclList {
    77  		switch d := decl.(type) {
    78  		case *syntax.FuncDecl:
    79  			if d.Name == nil {
    80  				newDecls = append(newDecls, d)
    81  				continue
    82  			}
    83  			obj := r.info.Defs[d.Name]
    84  			if !r.analyzer.dependentObj[obj] || r.analyzer.inSimd {
    85  				newDecls = append(newDecls, d)
    86  				continue
    87  			}
    88  
    89  			sig, ok := obj.Type().(*types2.Signature)
    90  			if !ok {
    91  				newDecls = append(newDecls, d)
    92  				continue
    93  			}
    94  
    95  			if r.analyzer.HasDependentSignature(sig) {
    96  				// Drop dependent signatures entirely
    97  				continue
    98  			}
    99  
   100  			// Clean signature -> Replace body with dispatcher
   101  			d.Body = r.createDispatcherBody(d, sig)
   102  			newDecls = append(newDecls, d)
   103  
   104  		case *syntax.VarDecl:
   105  			// Filter specs conceptually based on dependents
   106  			keep := false
   107  			for _, name := range d.NameList {
   108  				if !r.analyzer.dependentObj[r.info.Defs[name]] {
   109  					keep = true
   110  					break // Keep entire var decl if any name is clean, else drop
   111  				}
   112  			}
   113  			if keep {
   114  				newDecls = append(newDecls, d)
   115  			}
   116  		case *syntax.TypeDecl:
   117  			if !r.analyzer.dependentObj[r.info.Defs[d.Name]] || r.analyzer.inSimd {
   118  				newDecls = append(newDecls, d)
   119  			}
   120  		default:
   121  			newDecls = append(newDecls, decl)
   122  		}
   123  	}
   124  
   125  	fileAST.DeclList = newDecls
   126  
   127  	if !r.analyzer.inSimd {
   128  		// Inject an import to the bridge package (if not exists)
   129  		hasArchSimd := false
   130  		var simdImport *syntax.ImportDecl
   131  		for _, decl := range fileAST.DeclList {
   132  			if imp, ok := decl.(*syntax.ImportDecl); ok {
   133  				if imp.Path.Value == `"`+archFullPkg+`"` {
   134  					hasArchSimd = true
   135  				}
   136  				if imp.Path.Value == `"`+simdPkg+`"` {
   137  					simdImport = imp
   138  				}
   139  
   140  			}
   141  		}
   142  		p := simdImport.Pos()
   143  		if !hasArchSimd {
   144  			r.injectImport(fileAST, archFullPkg, p)
   145  		}
   146  
   147  		// Ensure at least one use of "simd"
   148  		// var _ = simd.VectorBitLen()
   149  		fun := &syntax.SelectorExpr{
   150  			X:   syntax.NewName(p, simdPkg), // Assume this is resolvable
   151  			Sel: syntax.NewName(p, vectorSizeFn),
   152  		}
   153  		fun.SetPos(p)
   154  		call := &syntax.CallExpr{Fun: fun}
   155  		call.SetPos(p)
   156  
   157  		name := syntax.NewName(p, "_")
   158  
   159  		varDecl := &syntax.VarDecl{NameList: []*syntax.Name{name}, Values: call}
   160  		varDecl.SetPos(p)
   161  		fileAST.DeclList = append(fileAST.DeclList, varDecl)
   162  	}
   163  }
   164  
   165  func (r *Rewriter) injectImport(fileAST *syntax.File, toImport string, simdImportPos syntax.Pos) {
   166  	importDecl := &syntax.ImportDecl{
   167  		Path: &syntax.BasicLit{Value: `"` + toImport + `"`, Kind: syntax.StringLit},
   168  	}
   169  	importDecl.Path.SetPos(simdImportPos)
   170  	importDecl.SetPos(simdImportPos)
   171  	fileAST.DeclList = append([]syntax.Decl{importDecl}, fileAST.DeclList...)
   172  }
   173  
   174  func (r *Rewriter) createDispatcherBody(d *syntax.FuncDecl, sig *types2.Signature) *syntax.BlockStmt {
   175  
   176  	// Build call arguments from the function parameters
   177  	args := func() []syntax.Expr {
   178  		var args []syntax.Expr
   179  		if d.Type.ParamList != nil {
   180  			for _, field := range d.Type.ParamList {
   181  				if field.Name != nil {
   182  					paramName := syntax.NewName(field.Pos(), field.Name.Value)
   183  					args = append(args, paramName)
   184  				}
   185  			}
   186  		}
   187  		return args
   188  	}
   189  
   190  	// Slap a pos on an expression
   191  	pe := func(e syntax.Expr) syntax.Expr {
   192  		e.SetPos(d.Pos())
   193  		return e
   194  	}
   195  	// Slap a pos on a statement
   196  	ps := func(e syntax.Stmt) syntax.Stmt {
   197  		e.SetPos(d.Pos())
   198  		return e
   199  	}
   200  
   201  	// switch ast node.
   202  	// the goal is something like (for now, till there are finer-grained choices)
   203  	// switch simd.VectorSize() {
   204  	//   case 128: call the specialize-for-128-code(args)
   205  	//   case 256: call the specialize-for-256-code(args)
   206  	//   etc
   207  	// }
   208  	switchStmt := &syntax.SwitchStmt{
   209  		Tag: pe(&syntax.CallExpr{
   210  			Fun: pe(&syntax.SelectorExpr{
   211  				X:   syntax.NewName(d.Pos(), simdPkg), // Assume this is resolvable
   212  				Sel: syntax.NewName(d.Pos(), vectorSizeFn),
   213  			}),
   214  		}),
   215  		Body: []*syntax.CaseClause{},
   216  	}
   217  
   218  	for _, k := range r.sizes {
   219  		fnName := fmt.Sprintf("%s@simd%d", d.Name.Value, k)
   220  		fnIdent := syntax.NewName(d.Pos(), fnName)
   221  
   222  		callExpr := pe(&syntax.CallExpr{
   223  			Fun:     pe(fnIdent),
   224  			ArgList: args(),
   225  		})
   226  
   227  		var branchStmt syntax.Stmt
   228  		if d.Type.ResultList != nil && len(d.Type.ResultList) > 0 {
   229  			branchStmt = &syntax.ReturnStmt{Results: callExpr}
   230  		} else {
   231  			branchStmt = &syntax.BlockStmt{
   232  				List: []syntax.Stmt{
   233  					ps(&syntax.ExprStmt{X: callExpr}),
   234  					ps(&syntax.ReturnStmt{}),
   235  				},
   236  			}
   237  		}
   238  		branchStmt.SetPos(d.Pos())
   239  
   240  		caseClause := &syntax.CaseClause{
   241  			Cases: pe(&syntax.BasicLit{Kind: syntax.IntLit, Value: fmt.Sprintf("%d", k)}),
   242  			Body:  []syntax.Stmt{branchStmt},
   243  		}
   244  		caseClause.SetPos(d.Pos())
   245  		switchStmt.Body = append(switchStmt.Body, caseClause)
   246  	}
   247  
   248  	fnName := "panic"
   249  	fnIdent := pe(syntax.NewName(d.Pos(), fnName))
   250  
   251  	callExpr := pe(&syntax.CallExpr{
   252  		Fun:     fnIdent,
   253  		ArgList: []syntax.Expr{pe(&syntax.BasicLit{Value: "\"unsupported vector size in simd-rewritten code\"", Kind: syntax.StringLit})},
   254  	})
   255  
   256  	panicStmt := &syntax.ExprStmt{X: callExpr}
   257  	blockStmt := &syntax.BlockStmt{List: []syntax.Stmt{ps(switchStmt), ps(panicStmt)}}
   258  
   259  	blockStmt.SetPos(d.Pos())
   260  
   261  	return blockStmt
   262  }
   263  
   264  func (r *Rewriter) generateForSize(fileAST *syntax.File, k int, newDecls []syntax.Decl) []syntax.Decl {
   265  	copier := NewDeepCopier(r.pkg, r.info, k, r.analyzer, fmt.Sprintf("@simd%d", k))
   266  	for _, decl := range fileAST.DeclList {
   267  		if r.shouldIncludeDecl(decl) {
   268  			newDecl := copier.CopyDecl(decl)
   269  			newDecls = append(newDecls, newDecl)
   270  		}
   271  	}
   272  	return newDecls
   273  }
   274  
   275  func nameToElemBitWidth(name string) int {
   276  	var width int
   277  	switch name {
   278  	case "Int8s", "Uint8s", "Mask8s":
   279  		width = 8
   280  	case "Int16s", "Uint16s", "Mask16s":
   281  		width = 16
   282  	case "Int32s", "Uint32s", "Float32s", "Mask32s":
   283  		width = 32
   284  	case "Int64s", "Uint64s", "Float64s", "Mask64s":
   285  		width = 64
   286  	}
   287  	return width
   288  }
   289  
   290  func (r *Rewriter) shouldIncludeDecl(decl syntax.Decl) bool {
   291  	// Files (and declarations) in the simd package are excluded
   292  	// from processing, except for those that whose name begins
   293  	// with "tofrom_".
   294  	if r.analyzer.inSimd {
   295  		theFile := decl.Pos().Base().Filename()
   296  		// within the compiler paths use "/" as a separator.
   297  		if simdSlash := strings.LastIndex(theFile, simdPkg+"/"); simdSlash == -1 || !strings.HasPrefix(theFile[simdSlash:], simdPkg+"/tofrom_") {
   298  			return false
   299  		}
   300  	}
   301  
   302  	switch d := decl.(type) {
   303  	case *syntax.FuncDecl:
   304  		if d.Name != nil {
   305  			return r.analyzer.dependentObj[r.info.Defs[d.Name]]
   306  		}
   307  	case *syntax.TypeDecl:
   308  		return r.analyzer.dependentObj[r.info.Defs[d.Name]]
   309  	case *syntax.VarDecl:
   310  		for _, name := range d.NameList {
   311  			if r.analyzer.dependentObj[r.info.Defs[name]] {
   312  				return true
   313  			}
   314  		}
   315  	}
   316  	return false
   317  }
   318  
   319  // Generate an API matching the standalone compilation call
   320  func RewriteWrapper(pkg *types2.Package, info *types2.Info, files []*syntax.File) bool {
   321  	if !buildcfg.Experiment.SIMD {
   322  		return false
   323  	}
   324  
   325  	switch buildcfg.GOARCH {
   326  	case "wasm", "amd64", "arm64":
   327  	default:
   328  		return false
   329  	}
   330  
   331  	sizes := rewriteSizes()
   332  	if len(sizes) == 0 {
   333  		return false
   334  	}
   335  	analyzer := NewAnalyzer(pkg, info)
   336  	if !analyzer.Analyze(files) {
   337  		return false
   338  	}
   339  
   340  	CheckPositions(files, "before midway")
   341  
   342  	rewriter := NewRewriter(pkg, info, analyzer, sizes)
   343  	rewriter.Rewrite(files)
   344  
   345  	CheckPositions(files, "after midway")
   346  
   347  	return true
   348  }
   349  

View as plain text