Source file src/simd/archsimd/_gen/midway/intersect_simd_ops.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"flag"
    11  	"fmt"
    12  	"go/ast"
    13  	"go/format"
    14  	"go/parser"
    15  	"go/token"
    16  	"io"
    17  	"log"
    18  	"os"
    19  	"path/filepath"
    20  	"slices"
    21  	"sort"
    22  	"strings"
    23  	"unicode"
    24  	"unicode/utf8"
    25  
    26  	"gopkg.in/yaml.v3"
    27  )
    28  
    29  type MethodSet map[string]*ast.FuncDecl
    30  type TypeMethods map[string]MethodSet
    31  
    32  type Comments struct {
    33  	Types     map[string]string            `yaml:"types"`
    34  	Functions map[string]string            `yaml:"functions"`
    35  	Methods   map[string]map[string]string `yaml:"methods"`
    36  }
    37  
    38  var goRoot = flag.String("goroot", "../../../../..", "Go root")
    39  var verbose = flag.Bool("v", false, "Be much chattier about processing")
    40  
    41  type ArchAndFiles struct {
    42  	arch  string
    43  	files []string
    44  }
    45  
    46  type TypeMethod struct {
    47  	t, m string
    48  }
    49  
    50  type whyMissing struct {
    51  	wasm128, arm128, amd128, amd256, amd512 bool
    52  }
    53  
    54  func (w whyMissing) String() string {
    55  	why := ""
    56  	if w.wasm128 {
    57  		why += " wasm"
    58  	}
    59  	if w.arm128 {
    60  		why += " neon"
    61  	}
    62  	if w.amd128 {
    63  		why += " avx"
    64  	}
    65  	if w.amd256 {
    66  		why += " avx2"
    67  	}
    68  	if w.amd512 {
    69  		why += " avx512"
    70  	}
    71  	return why[1:]
    72  }
    73  
    74  func combine(arch, typ string) string {
    75  	return arch + "-" + typ
    76  }
    77  
    78  func main() {
    79  	minorProblem := false
    80  
    81  	flag.Parse()
    82  
    83  	var comments Comments
    84  	commentsData, err := os.ReadFile("comments.yaml")
    85  	if err != nil {
    86  		log.Fatalf("Failed to read comments.yaml: %v", err)
    87  	}
    88  	if err := yaml.Unmarshal(commentsData, &comments); err != nil {
    89  		log.Fatalf("Failed to parse comments.yaml: %v", err)
    90  	}
    91  
    92  	pv := func(f string, s ...any) {
    93  		if *verbose {
    94  			fmt.Fprintf(os.Stderr, f, s...)
    95  		}
    96  	}
    97  	pw := func(f string, s ...any) {
    98  		minorProblem = true
    99  		fmt.Fprintf(os.Stderr, f, s...)
   100  	}
   101  
   102  	// Hardcoded path to archsimd
   103  	archSimdPath := *goRoot + "/src/simd/archsimd"
   104  
   105  	// Hardcoded list of files
   106  	amd64Files := []string{"ops_amd64.go", "compare_gen_amd64.go", "types_amd64.go",
   107  		"other_gen_amd64.go", "extra_amd64.go", "maskmerge_gen_amd64.go",
   108  		"shuffles_amd64.go", "slice_gen_amd64.go", "slicepart_amd64.go",
   109  		"slicepart_128.go", "string.go", "ops_emulated_amd64.go"}
   110  	wasmFiles := []string{"ops_wasm.go", "types_wasm.go", "slicepart_wasm.go",
   111  		"string.go", "slicepart_128.go", "ops_emulated_wasm.go"}
   112  	neonFiles := []string{"clmul_arm64.go", "compare_gen_arm64.go",
   113  		"maskmerge_gen_arm64.go", "ops_arm64.go", "slicepart_128.go",
   114  		"ops_internal_arm64.go", "other_gen_arm64.go", "slice_gen_arm64.go",
   115  		"slicepart_arm64.go", "types_arm64.go"}
   116  
   117  	emulatedFile := *goRoot + "/src/simd/simd_emulated.go"
   118  
   119  	archAndFiles := []ArchAndFiles{
   120  		ArchAndFiles{"wasm", wasmFiles},
   121  		ArchAndFiles{"amd64", amd64Files},
   122  		ArchAndFiles{"arm64", neonFiles},
   123  	}
   124  
   125  	// Categories based on bit size
   126  	// 128-bit map: ElementType -> TypeName
   127  	map128 := map[string]string{
   128  		"Int8":    "Int8x16",
   129  		"Int16":   "Int16x8",
   130  		"Int32":   "Int32x4",
   131  		"Int64":   "Int64x2",
   132  		"Uint8":   "Uint8x16",
   133  		"Uint16":  "Uint16x8",
   134  		"Uint32":  "Uint32x4",
   135  		"Uint64":  "Uint64x2",
   136  		"Float32": "Float32x4",
   137  		"Float64": "Float64x2",
   138  		"Mask8":   "Mask8x16",
   139  		"Mask16":  "Mask16x8",
   140  		"Mask32":  "Mask32x4",
   141  		"Mask64":  "Mask64x2",
   142  	}
   143  
   144  	// 256-bit map: ElementType -> TypeName
   145  	map256 := map[string]string{
   146  		"Int8":    "Int8x32",
   147  		"Int16":   "Int16x16",
   148  		"Int32":   "Int32x8",
   149  		"Int64":   "Int64x4",
   150  		"Uint8":   "Uint8x32",
   151  		"Uint16":  "Uint16x16",
   152  		"Uint32":  "Uint32x8",
   153  		"Uint64":  "Uint64x4",
   154  		"Float32": "Float32x8",
   155  		"Float64": "Float64x4",
   156  		"Mask8":   "Mask8x32",
   157  		"Mask16":  "Mask16x16",
   158  		"Mask32":  "Mask32x8",
   159  		"Mask64":  "Mask64x4",
   160  	}
   161  
   162  	map512 := map[string]string{
   163  		"Int8":    "Int8x64",
   164  		"Int16":   "Int16x32",
   165  		"Int32":   "Int32x16",
   166  		"Int64":   "Int64x8",
   167  		"Uint8":   "Uint8x64",
   168  		"Uint16":  "Uint16x32",
   169  		"Uint32":  "Uint32x16",
   170  		"Uint64":  "Uint64x8",
   171  		"Float32": "Float32x16",
   172  		"Float64": "Float64x8",
   173  		"Mask8":   "Mask8x64",
   174  		"Mask16":  "Mask16x32",
   175  		"Mask32":  "Mask32x16",
   176  		"Mask64":  "Mask64x8",
   177  	}
   178  
   179  	sizeForType := make(map[string]int)
   180  
   181  	methodsByType := make(TypeMethods)
   182  
   183  	allMethodNames := make(map[string]bool)
   184  
   185  	missing := make(map[string]whyMissing)
   186  
   187  	fset := token.NewFileSet()
   188  
   189  	knownReceivers := make(map[string]string)
   190  	for k, v := range map128 {
   191  		knownReceivers[v] = k + "s"
   192  		sizeForType[v] = 128
   193  	}
   194  	for k, v := range map256 {
   195  		knownReceivers[v] = k + "s"
   196  		sizeForType[v] = 256
   197  	}
   198  	for k, v := range map512 {
   199  		knownReceivers[v] = k + "s"
   200  		sizeForType[v] = 512
   201  	}
   202  
   203  	receiver := func(funcDecl *ast.FuncDecl) string {
   204  		if funcDecl.Recv == nil {
   205  			return ""
   206  		}
   207  		recvType := ""
   208  		for _, field := range funcDecl.Recv.List {
   209  			// We assume single receiver
   210  			if ident, ok := field.Type.(*ast.Ident); ok {
   211  				recvType = ident.Name
   212  			} else if star, ok := field.Type.(*ast.StarExpr); ok {
   213  				if ident, ok := star.X.(*ast.Ident); ok {
   214  					recvType = ident.Name
   215  				}
   216  			}
   217  		}
   218  		return recvType
   219  	}
   220  
   221  	// Record existing emulated methods
   222  	emulated := make(map[TypeMethod]bool)
   223  	f, err := parser.ParseFile(fset, emulatedFile, nil, parser.ParseComments)
   224  	if err != nil {
   225  		log.Fatalf("Failed to parse %s: %v", emulatedFile, err)
   226  	}
   227  
   228  	for _, decl := range f.Decls {
   229  		if funcDecl, ok := decl.(*ast.FuncDecl); ok {
   230  			if receiver := receiver(funcDecl); receiver != "" {
   231  				method := funcDecl.Name.Name
   232  				// Exported methods only (must begin with uppercase)
   233  				if m, _ := utf8.DecodeRuneInString(method); unicode.IsUpper(m) {
   234  					emulated[TypeMethod{receiver, method}] = true
   235  				}
   236  			}
   237  		}
   238  	}
   239  
   240  	for _, aaf := range archAndFiles {
   241  		for _, fname := range aaf.files {
   242  			path := filepath.Join(archSimdPath, fname)
   243  			f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
   244  			if err != nil {
   245  				log.Fatalf("Failed to parse %s: %v", path, err)
   246  			}
   247  
   248  			lci := 0
   249  			fComments := f.Comments
   250  
   251  			for _, decl := range f.Decls {
   252  				if funcDecl, ok := decl.(*ast.FuncDecl); ok {
   253  
   254  					lastComment := ""
   255  					for ; lci < len(fComments) && fComments[lci].Pos() > funcDecl.Pos(); lci++ {
   256  						lastComment = fComments[lci].Text()
   257  					}
   258  
   259  					recvType := receiver(funcDecl)
   260  
   261  					if recvType == "" || knownReceivers[recvType] == "" {
   262  						continue
   263  					}
   264  
   265  					methodName := funcDecl.Name.Name
   266  
   267  					if strings.Contains(funcDecl.Doc.Text(), "Deprecated:") {
   268  						pv("Skipping deprecated %s.%s\n", recvType, methodName)
   269  						continue
   270  					}
   271  
   272  					if strings.Contains(lastComment, "Deprecated:") {
   273  						pv("Skipping MAYBE deprecated %s.%s (check comment)\n", recvType, methodName)
   274  						continue
   275  					}
   276  
   277  					if sizeForType[recvType] == 128 {
   278  						if s := funcDecl.Doc.Text(); strings.Contains(s, "AVX512") || strings.Contains(s, "AVX2") {
   279  							pv("Skipping 128-bit %s.%s because AVX2/AVX512\n", recvType, methodName)
   280  							continue
   281  						}
   282  					}
   283  					if sizeForType[recvType] == 256 {
   284  						if s := funcDecl.Doc.Text(); strings.Contains(s, "AVX512") {
   285  							pv("Skipping 256-bit %s.%s because AVX512\n", recvType, methodName)
   286  							continue
   287  						}
   288  					}
   289  
   290  					eltType := recvType[:strings.Index(recvType, "x")]
   291  
   292  					// Allow reinterpret vectors.
   293  					if xAt := strings.Index(methodName, "x"); xAt != -1 && (strings.HasPrefix(methodName, "As") || strings.HasPrefix(methodName, "ToInt") && strings.HasPrefix(eltType, "Mask")) {
   294  						// We think this is fine, even if it changes the number of elements in the vector.
   295  						// Tweak the method name so that they will line up properly.
   296  						methodName = methodName[:xAt] + "s"
   297  					} else if strings.HasPrefix(methodName, "Broadcast") {
   298  						// Broadcast is okay
   299  					} else {
   300  						// Exclude "grouped", "Store" (not slice), and vector-size-changing methods.
   301  						if strings.Contains(methodName, "Group") {
   302  							pv("Skipping grouped method %s.%s\n", recvType, methodName)
   303  							continue
   304  						}
   305  						if methodName == "StoreArray" || methodName == "StoreMasked" {
   306  							pv("Skipping fixed-size Store method method %s.%s\n", recvType, methodName)
   307  							continue
   308  						}
   309  						if methodName == "ToBits" && recvType[0] == 'M' {
   310  							pv("Skipping Mask ToBits method (has varying return type) %s.%s\n", recvType, methodName)
   311  							continue
   312  						}
   313  						if lastChar := methodName[len(methodName)-1]; unicode.IsDigit(rune(lastChar)) && lastChar != eltType[len(eltType)-1] {
   314  							pv("Skipping size-changing method %s.%s\n", recvType, methodName)
   315  							continue
   316  						}
   317  					}
   318  
   319  					archReceiver := combine(aaf.arch, recvType)
   320  
   321  					if methodsByType[archReceiver] == nil {
   322  						methodsByType[archReceiver] = make(MethodSet)
   323  					}
   324  					methodsByType[archReceiver][methodName] = funcDecl
   325  					allMethodNames[methodName] = true
   326  				}
   327  			}
   328  		}
   329  	}
   330  
   331  	type ElemMethod struct {
   332  		e, m string
   333  	}
   334  
   335  	intersectionByElem := make(map[string][]string)
   336  	signatureByElemMethod := make(map[ElemMethod]*ast.FuncDecl)
   337  
   338  	// elems is a slice of stems of vector types.
   339  	elems := []string{"Int8", "Int16", "Int32", "Int64", "Uint8", "Uint16", "Uint32", "Uint64", "Float32", "Float64", "Mask8", "Mask16", "Mask32", "Mask64"}
   340  
   341  	for _, elem := range elems {
   342  		type128 := map128[elem]
   343  		type256 := map256[elem]
   344  		type512 := map512[elem]
   345  
   346  		methods128w := methodsByType[combine("wasm", type128)]
   347  		methods128n := methodsByType[combine("arm64", type128)]
   348  		methods128 := methodsByType[combine("amd64", type128)]
   349  		methods256 := methodsByType[combine("amd64", type256)]
   350  		methods512 := methodsByType[combine("amd64", type512)]
   351  
   352  		var intersection []string
   353  		var missingNames []string
   354  		for m := range allMethodNames {
   355  			if wasm128, arm128, amd128, amd256, amd512 :=
   356  				methods128w[m] == nil, methods128n[m] == nil, methods128[m] == nil, methods256[m] == nil, methods512[m] == nil; !wasm128 && !arm128 && !amd128 && !amd256 && !amd512 {
   357  				intersection = append(intersection, m)
   358  				signatureByElemMethod[ElemMethod{elem, m}] = methods512[m] // Use 512-bit signature (arbitrary choice, they should match)
   359  			} else if !(wasm128 && arm128 && amd128 && amd256 && amd512) {
   360  				missing[m] = whyMissing{wasm128, arm128, amd128, amd256, amd512}
   361  				missingNames = append(missingNames, m)
   362  			}
   363  		}
   364  		sort.Strings(missingNames)
   365  
   366  		for _, m := range missingNames {
   367  			pv("Missing implementation for %ss.%s on %s\n", elem, m, missing[m].String())
   368  		}
   369  
   370  		sort.Strings(intersection)
   371  
   372  		intersectionByElem[elem] = intersection
   373  	}
   374  
   375  	// xlateType translates a type by replacing instances of types with keys in knownReceivers with their values,
   376  	// and generates the string representation of the resulting type.  E.g., []Int8x32 -> []Int8s
   377  	// (because Int8x32 -> Int8s in knownReceivers
   378  	var xlateType func(ast.Expr) string
   379  	xlateType = func(e ast.Expr) string {
   380  		switch t := e.(type) {
   381  		case *ast.Ident:
   382  			if mapped, ok := knownReceivers[t.Name]; ok {
   383  				return mapped
   384  			}
   385  			return t.Name
   386  		case *ast.StarExpr:
   387  			return "*" + xlateType(t.X)
   388  		case *ast.ArrayType:
   389  			lenStr := ""
   390  			if t.Len != nil {
   391  				var buf strings.Builder
   392  				format.Node(&buf, token.NewFileSet(), t.Len)
   393  				lenStr = buf.String()
   394  			}
   395  			return "[" + lenStr + "]" + xlateType(t.Elt)
   396  		case *ast.SelectorExpr:
   397  			return xlateType(t.X) + "." + t.Sel.Name
   398  		case *ast.Ellipsis:
   399  			return "..." + xlateType(t.Elt)
   400  		default:
   401  			var buf strings.Builder
   402  			format.Node(&buf, token.NewFileSet(), t)
   403  			return buf.String()
   404  		}
   405  	}
   406  
   407  	toScalar := func(s string) string {
   408  		if strings.HasPrefix(s, "Mask") {
   409  			return "int" + s[4:]
   410  		}
   411  		return strings.ToLower(s)
   412  	}
   413  
   414  	doTypes := func(w io.Writer) {
   415  
   416  		pf := func(f string, s ...any) { fmt.Fprintf(w, f, s...) }
   417  
   418  		fmt.Fprintln(w,
   419  			`// Code generated by 'go run -C $GOROOT/src/simd/archsimd/_gen/midway'; DO NOT EDIT.
   420  
   421  //go:build goexperiment.simd
   422  
   423  // Scalable vector types for rewriting and emulation
   424  
   425  package simd
   426  
   427  import "simd/internal/bridge"
   428  
   429  // internal SIMD marker, and hard dependence on simd/internal/bridge
   430  type _simd bridge.ZeroSized
   431  `)
   432  
   433  		for _, elem := range elems {
   434  			if c := comments.Types[elem+"s"]; c != "" {
   435  				pf("// %s\n", c)
   436  			}
   437  			pf("type %ss struct {\n\t_       _simd\n\ta, b uint64 // the actual vector size may be larger.\n}\n", elem)
   438  		}
   439  	}
   440  
   441  	doMethods := func(w io.Writer) {
   442  
   443  		p := func(s ...any) { fmt.Fprint(w, s...) }
   444  		pf := func(f string, s ...any) { fmt.Fprintf(w, f, s...) }
   445  		nl := func() { fmt.Fprintln(w) }
   446  
   447  		fmt.Fprintln(w,
   448  			`// Code generated by 'go run -C $GOROOT/src/simd/archsimd/_gen/midway'; DO NOT EDIT.
   449  
   450  //go:build goexperiment.simd && (amd64 || wasm || arm64)
   451  
   452  // Computed intersection of methods for supported SIMD architectures and vector widths
   453  
   454  package simd
   455  
   456  `)
   457  
   458  		for _, elem := range elems {
   459  			intersection := intersectionByElem[elem]
   460  
   461  			if elem[0] != 'M' {
   462  				// cannot load masks
   463  
   464  				loadComment := comments.Functions["Load"+elem]
   465  				if loadComment == "" && comments.Functions["default_LoadSlice"] != "" {
   466  					loadComment = fmt.Sprintf(comments.Functions["default_LoadSlice"], elem, toScalar(elem), elem)
   467  				}
   468  				if loadComment != "" {
   469  					pf("// %s\n", loadComment)
   470  				}
   471  				pf("func Load%ss([]%s) %ss\n", elem, toScalar(elem), elem)
   472  
   473  				loadPartComment := comments.Functions["Load"+elem+"Part"]
   474  				if loadPartComment == "" && comments.Functions["default_LoadPart"] != "" {
   475  					loadPartComment = fmt.Sprintf(comments.Functions["default_LoadPart"], elem, toScalar(elem), elem)
   476  				}
   477  				if loadPartComment != "" {
   478  					pf("// %s\n", loadPartComment)
   479  				}
   480  				pf("func Load%ssPart([]%s) (%ss, int)\n", elem, toScalar(elem), elem)
   481  
   482  				broadcastComment := comments.Functions["Broadcast"+elem]
   483  				if broadcastComment == "" && comments.Functions["default_Broadcast"] != "" {
   484  					broadcastComment = fmt.Sprintf(comments.Functions["default_Broadcast"], elem)
   485  				}
   486  				if broadcastComment != "" {
   487  					pf("// %s\n", broadcastComment)
   488  				}
   489  				pf("func Broadcast%ss(%s) %ss\n", elem, toScalar(elem), elem)
   490  			}
   491  
   492  			for _, m := range intersection {
   493  				fd := signatureByElemMethod[ElemMethod{elem, m}]
   494  				elems := elem + "s"
   495  				methodComment := ""
   496  				if typeMethods, ok := comments.Methods[elem+"s"]; ok {
   497  					methodComment = typeMethods[m]
   498  				}
   499  				if methodComment != "" {
   500  					pf("// %s\n", methodComment)
   501  				} else {
   502  					pw("Missing doc comment (in midway/comments.yaml) for %s.%s\n", elems, m)
   503  				}
   504  				pf("func (x %s) %s(", elems, m)
   505  
   506  				if !emulated[TypeMethod{elems, m}] {
   507  					pw("Missing emulated method for %s.%s\n", elems, m)
   508  				} else {
   509  					delete(emulated, TypeMethod{elems, m})
   510  				}
   511  
   512  				if fd.Type.Params != nil {
   513  					for i, field := range fd.Type.Params.List {
   514  						if i > 0 {
   515  							p(", ")
   516  						}
   517  						if len(field.Names) > 0 {
   518  							for j, name := range field.Names {
   519  								if j > 0 {
   520  									p(", ")
   521  								}
   522  								p(name.Name)
   523  							}
   524  							p(" ")
   525  						}
   526  						p(xlateType(field.Type))
   527  					}
   528  				}
   529  				p(")")
   530  
   531  				if fd.Type.Results != nil && len(fd.Type.Results.List) > 0 {
   532  					p(" ")
   533  					needsParens := len(fd.Type.Results.List) > 1 || (len(fd.Type.Results.List) == 1 && len(fd.Type.Results.List[0].Names) > 0)
   534  					if needsParens {
   535  						p("(")
   536  					}
   537  					for i, field := range fd.Type.Results.List {
   538  						if i > 0 {
   539  							p(", ")
   540  						}
   541  						if len(field.Names) > 0 {
   542  							for j, name := range field.Names {
   543  								if j > 0 {
   544  									p(", ")
   545  								}
   546  								p(name.Name)
   547  							}
   548  							p(" ")
   549  						}
   550  						p(xlateType(field.Type))
   551  					}
   552  					if needsParens {
   553  						p(")")
   554  					}
   555  				}
   556  				nl()
   557  			}
   558  		}
   559  	}
   560  
   561  	formatAndWrite(*goRoot+"/src/simd/simd_types.go", doTypes)
   562  	formatAndWrite(*goRoot+"/src/simd/simd_stubs.go", doMethods)
   563  
   564  	var extraMocks []TypeMethod
   565  	for x := range emulated {
   566  		extraMocks = append(extraMocks, x)
   567  	}
   568  	slices.SortFunc(extraMocks, func(a, b TypeMethod) int {
   569  		if c := strings.Compare(a.t, b.t); c != 0 {
   570  			return c
   571  		}
   572  		return strings.Compare(a.m, b.m)
   573  	})
   574  
   575  	for _, x := range extraMocks {
   576  		pw("%s contains %s.%s missing from intersected methods\n", emulatedFile, x.t, x.m)
   577  	}
   578  
   579  	for _, aaf := range archAndFiles {
   580  		arch := aaf.arch
   581  		doArchWrites := func(w io.Writer) {
   582  			p := func(s ...any) { fmt.Fprint(w, s...) }
   583  			pf := func(f string, s ...any) { fmt.Fprintf(w, f, s...) }
   584  			nl := func() { fmt.Fprintln(w) }
   585  
   586  			pf("// Code generated by 'go run -C $GOROOT/src/simd/archsimd/_gen/midway'; DO NOT EDIT.\n\n")
   587  			pf("//go:build goexperiment.simd && %s\n\n", arch)
   588  			pf("package bridge\n\n")
   589  			pf("import \"simd/archsimd\"\n\n")
   590  			pf("\n")
   591  			pf("// These types/methods/functions forward calls to their counterparts in simd/archsimd.\n")
   592  			pf("// Interposing this package allows a clean separation of \"simd\" from \"archsimd\" and\n")
   593  			pf("// also allows additional useful exported declarations that would weirdly pollute archsimd.\n")
   594  			pf("\n")
   595  
   596  			var typesForArch []string
   597  			for t := range knownReceivers {
   598  				if methodsByType[combine(arch, t)] != nil {
   599  					typesForArch = append(typesForArch, t)
   600  				}
   601  			}
   602  			sort.Strings(typesForArch)
   603  
   604  			toScalar := func(s string) string {
   605  				if strings.HasPrefix(s, "Mask") {
   606  					return "int" + s[4:]
   607  				}
   608  				return strings.ToLower(s)
   609  			}
   610  
   611  			for _, t := range typesForArch {
   612  				pf("type %s archsimd.%s\n", t, t)
   613  				if xAt := strings.Index(t, "x"); xAt != -1 && !strings.HasPrefix(t, "Mask") {
   614  					elem := t[:xAt]
   615  					scalar := toScalar(elem)
   616  					pf("func Load%s(s []%s) %s {\n\treturn %s(archsimd.Load%s(s))\n}\n", t, scalar, t, t, t)
   617  					pf("func Load%sPart(s []%s) (%s, int) {\n\tv, n := archsimd.Load%sPart(s)\n\treturn %s(v), n\n}\n", t, scalar, t, t, t)
   618  					pf("func Broadcast%s(x %s) %s {\n\treturn %s(archsimd.Broadcast%s(x))\n}\n", t, scalar, t, t, t)
   619  				}
   620  			}
   621  			nl()
   622  
   623  			typeStr := func(e ast.Expr) string {
   624  				var buf strings.Builder
   625  				format.Node(&buf, token.NewFileSet(), e)
   626  				return buf.String()
   627  			}
   628  
   629  			convertArg := func(name string, e ast.Expr) string {
   630  				switch t := e.(type) {
   631  				case *ast.Ident:
   632  					if _, ok := knownReceivers[t.Name]; ok {
   633  						return fmt.Sprintf("archsimd.%s(%s)", t.Name, name)
   634  					}
   635  				case *ast.StarExpr:
   636  					if ident, ok := t.X.(*ast.Ident); ok {
   637  						if _, ok := knownReceivers[ident.Name]; ok {
   638  							return fmt.Sprintf("(*archsimd.%s)(%s)", ident.Name, name)
   639  						}
   640  					}
   641  				}
   642  				return name
   643  			}
   644  
   645  			wrapResult := func(call string, e ast.Expr) string {
   646  				switch t := e.(type) {
   647  				case *ast.Ident:
   648  					if _, ok := knownReceivers[t.Name]; ok {
   649  						return fmt.Sprintf("%s(%s)", t.Name, call)
   650  					}
   651  				case *ast.StarExpr:
   652  					if ident, ok := t.X.(*ast.Ident); ok {
   653  						if _, ok := knownReceivers[ident.Name]; ok {
   654  							return fmt.Sprintf("(*%s)(%s)", ident.Name, call)
   655  						}
   656  					}
   657  				}
   658  				return call
   659  			}
   660  
   661  			for _, elem := range elems {
   662  				intersection := intersectionByElem[elem]
   663  				for _, m := range intersection {
   664  					for _, t := range typesForArch {
   665  						if map128[elem] != t && map256[elem] != t && map512[elem] != t {
   666  							continue
   667  						}
   668  						fd := methodsByType[combine(arch, t)][m]
   669  						if fd == nil {
   670  							continue
   671  						}
   672  						pf("func (x %s) %s(", t, fd.Name.Name)
   673  						var args []string
   674  						if fd.Type.Params != nil {
   675  							paramCount := 0
   676  							for _, field := range fd.Type.Params.List {
   677  								if len(field.Names) > 0 {
   678  									for _, name := range field.Names {
   679  										if paramCount > 0 {
   680  											p(", ")
   681  										}
   682  										pf("%s %s", name.Name, typeStr(field.Type))
   683  										args = append(args, convertArg(name.Name, field.Type))
   684  										paramCount++
   685  									}
   686  								} else {
   687  									if paramCount > 0 {
   688  										p(", ")
   689  									}
   690  									paramName := fmt.Sprintf("p%d", paramCount)
   691  									pf("%s %s", paramName, typeStr(field.Type))
   692  									args = append(args, convertArg(paramName, field.Type))
   693  									paramCount++
   694  								}
   695  							}
   696  						}
   697  						p(")")
   698  
   699  						var results []ast.Expr
   700  						if fd.Type.Results != nil {
   701  							p(" ")
   702  							needsParens := len(fd.Type.Results.List) > 1 || (len(fd.Type.Results.List) == 1 && len(fd.Type.Results.List[0].Names) > 0)
   703  							if needsParens {
   704  								p("(")
   705  							}
   706  							for i, field := range fd.Type.Results.List {
   707  								if i > 0 {
   708  									p(", ")
   709  								}
   710  								results = append(results, field.Type)
   711  								p(typeStr(field.Type))
   712  							}
   713  							if needsParens {
   714  								p(")")
   715  							}
   716  						}
   717  
   718  						p(" {\n\t")
   719  						if len(results) > 0 {
   720  							p("return ")
   721  						}
   722  
   723  						callStr := fmt.Sprintf("(archsimd.%s(x)).%s(%s)", t, fd.Name.Name, strings.Join(args, ", "))
   724  						if len(results) == 1 {
   725  							p(wrapResult(callStr, results[0]))
   726  						} else {
   727  							p(callStr)
   728  						}
   729  						p("\n}\n\n")
   730  					}
   731  				}
   732  			}
   733  		}
   734  		archDir := filepath.Join(*goRoot, "src", "simd", "internal", "bridge")
   735  		os.MkdirAll(archDir, 0755)
   736  		filename := filepath.Join(archDir, "decls_"+arch+".go")
   737  		formatAndWrite(filename, doArchWrites)
   738  
   739  		doToFromWrites := func(w io.Writer) {
   740  			pf := func(f string, s ...any) { fmt.Fprintf(w, f, s...) }
   741  
   742  			pf("// Code generated by 'go run -C $GOROOT/src/simd/archsimd/_gen/midway'; DO NOT EDIT.\n\n")
   743  			pf("//go:build goexperiment.simd && %s\n\n", arch)
   744  			pf("package simd\n\n")
   745  			pf("import (\n\t\"simd/archsimd\"\n\t\"simd/internal/bridge\"\n)\n\n")
   746  
   747  			for _, elem := range elems {
   748  				var archTypes []string
   749  				if methodsByType[combine(arch, map128[elem])] != nil {
   750  					archTypes = append(archTypes, map128[elem])
   751  				}
   752  				if methodsByType[combine(arch, map256[elem])] != nil {
   753  					archTypes = append(archTypes, map256[elem])
   754  				}
   755  				if methodsByType[combine(arch, map512[elem])] != nil {
   756  					archTypes = append(archTypes, map512[elem])
   757  				}
   758  
   759  				if len(archTypes) == 0 {
   760  					continue
   761  				}
   762  
   763  				pf("func (x %ss) ToArch() any\n\n", elem)
   764  
   765  				var intfOpts []string
   766  				for _, t := range archTypes {
   767  					intfOpts = append(intfOpts, "archsimd."+t)
   768  				}
   769  				pf("type archSimd%ss interface {\n\t%s\n}\n\n", elem, strings.Join(intfOpts, " | "))
   770  
   771  				pf("func %ssFromArch[T archSimd%ss](x T) %ss {\n", elem, elem, elem)
   772  				pf("\tswitch a := any(x).(type) {\n")
   773  				pf("\t// The return expression is written this way because the code will be rewritten\n")
   774  				pf("\t// with %ss replaced by one of the arch types, and without the any-assert\n", elem)
   775  				pf("\t// hack the rewritten code would not pass type checking.\n")
   776  				pf("\t// The backend of the compiler will eat this and turn it into no code at all,\n")
   777  				pf("\t// assuming it inlines.\n")
   778  
   779  				for _, t := range archTypes {
   780  					pf("\tcase archsimd.%s:\n", t)
   781  					pf("\t\tvar t bridge.%s = bridge.%s(a)\n", t, t)
   782  					pf("\t\treturn (any(t)).(%ss)\n", elem)
   783  				}
   784  				pf("\t}\n\tpanic(\"wrong type\")\n}\n\n")
   785  			}
   786  		}
   787  		toFromFilename := filepath.Join(*goRoot, "src", "simd", "tofrom_"+arch+".go")
   788  		formatAndWrite(toFromFilename, doToFromWrites)
   789  	}
   790  
   791  	if minorProblem {
   792  		pw("The logged warnings did not prevent generation of the midway API files, but the API is flawed (lacks emulations, documentation, etc).\n")
   793  	}
   794  }
   795  
   796  // numberLines takes a slice of bytes, and returns a string where each line
   797  // is numbered, starting from 1.
   798  func numberLines(data []byte) string {
   799  	var buf bytes.Buffer
   800  	r := bytes.NewReader(data)
   801  	s := bufio.NewScanner(r)
   802  	for i := 1; s.Scan(); i++ {
   803  		fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
   804  	}
   805  	return buf.String()
   806  }
   807  
   808  func formatAndWrite(filename string, doWrites func(w io.Writer)) {
   809  	if filename == "" {
   810  		return
   811  	}
   812  	f, err := os.Create(filename)
   813  	if err != nil {
   814  		log.Fatal(err)
   815  	}
   816  	defer f.Close()
   817  
   818  	out := new(bytes.Buffer)
   819  	doWrites(out)
   820  
   821  	b, err := format.Source(out.Bytes())
   822  	if err != nil {
   823  		fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
   824  		fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
   825  		fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
   826  		os.Exit(1)
   827  	} else {
   828  		f.Write(b)
   829  		f.Close()
   830  	}
   831  }
   832  

View as plain text