1
2
3
4
5 package main
6
7 import (
8 "bufio"
9 "bytes"
10 "flag"
11 "fmt"
12 "go/ast"
13 "go/format"
14 "go/parser"
15 "go/token"
16 "io"
17 "log"
18 "os"
19 "path/filepath"
20 "slices"
21 "sort"
22 "strings"
23 "unicode"
24 "unicode/utf8"
25
26 "gopkg.in/yaml.v3"
27 )
28
29 type MethodSet map[string]*ast.FuncDecl
30 type TypeMethods map[string]MethodSet
31
32 type Comments struct {
33 Types map[string]string `yaml:"types"`
34 Functions map[string]string `yaml:"functions"`
35 Methods map[string]map[string]string `yaml:"methods"`
36 }
37
38 var goRoot = flag.String("goroot", "../../../../..", "Go root")
39 var verbose = flag.Bool("v", false, "Be much chattier about processing")
40
41 type ArchAndFiles struct {
42 arch string
43 files []string
44 }
45
46 type TypeMethod struct {
47 t, m string
48 }
49
50 type whyMissing struct {
51 wasm128, arm128, amd128, amd256, amd512 bool
52 }
53
54 func (w whyMissing) String() string {
55 why := ""
56 if w.wasm128 {
57 why += " wasm"
58 }
59 if w.arm128 {
60 why += " neon"
61 }
62 if w.amd128 {
63 why += " avx"
64 }
65 if w.amd256 {
66 why += " avx2"
67 }
68 if w.amd512 {
69 why += " avx512"
70 }
71 return why[1:]
72 }
73
74 func combine(arch, typ string) string {
75 return arch + "-" + typ
76 }
77
78 func main() {
79 minorProblem := false
80
81 flag.Parse()
82
83 var comments Comments
84 commentsData, err := os.ReadFile("comments.yaml")
85 if err != nil {
86 log.Fatalf("Failed to read comments.yaml: %v", err)
87 }
88 if err := yaml.Unmarshal(commentsData, &comments); err != nil {
89 log.Fatalf("Failed to parse comments.yaml: %v", err)
90 }
91
92 pv := func(f string, s ...any) {
93 if *verbose {
94 fmt.Fprintf(os.Stderr, f, s...)
95 }
96 }
97 pw := func(f string, s ...any) {
98 minorProblem = true
99 fmt.Fprintf(os.Stderr, f, s...)
100 }
101
102
103 archSimdPath := *goRoot + "/src/simd/archsimd"
104
105
106 amd64Files := []string{"ops_amd64.go", "compare_gen_amd64.go", "types_amd64.go",
107 "other_gen_amd64.go", "extra_amd64.go", "maskmerge_gen_amd64.go",
108 "shuffles_amd64.go", "slice_gen_amd64.go", "slicepart_amd64.go",
109 "slicepart_128.go", "string.go", "ops_emulated_amd64.go"}
110 wasmFiles := []string{"ops_wasm.go", "types_wasm.go", "slicepart_wasm.go",
111 "string.go", "slicepart_128.go", "ops_emulated_wasm.go"}
112 neonFiles := []string{"clmul_arm64.go", "compare_gen_arm64.go",
113 "maskmerge_gen_arm64.go", "ops_arm64.go", "slicepart_128.go",
114 "ops_internal_arm64.go", "other_gen_arm64.go", "slice_gen_arm64.go",
115 "slicepart_arm64.go", "types_arm64.go"}
116
117 emulatedFile := *goRoot + "/src/simd/simd_emulated.go"
118
119 archAndFiles := []ArchAndFiles{
120 ArchAndFiles{"wasm", wasmFiles},
121 ArchAndFiles{"amd64", amd64Files},
122 ArchAndFiles{"arm64", neonFiles},
123 }
124
125
126
127 map128 := map[string]string{
128 "Int8": "Int8x16",
129 "Int16": "Int16x8",
130 "Int32": "Int32x4",
131 "Int64": "Int64x2",
132 "Uint8": "Uint8x16",
133 "Uint16": "Uint16x8",
134 "Uint32": "Uint32x4",
135 "Uint64": "Uint64x2",
136 "Float32": "Float32x4",
137 "Float64": "Float64x2",
138 "Mask8": "Mask8x16",
139 "Mask16": "Mask16x8",
140 "Mask32": "Mask32x4",
141 "Mask64": "Mask64x2",
142 }
143
144
145 map256 := map[string]string{
146 "Int8": "Int8x32",
147 "Int16": "Int16x16",
148 "Int32": "Int32x8",
149 "Int64": "Int64x4",
150 "Uint8": "Uint8x32",
151 "Uint16": "Uint16x16",
152 "Uint32": "Uint32x8",
153 "Uint64": "Uint64x4",
154 "Float32": "Float32x8",
155 "Float64": "Float64x4",
156 "Mask8": "Mask8x32",
157 "Mask16": "Mask16x16",
158 "Mask32": "Mask32x8",
159 "Mask64": "Mask64x4",
160 }
161
162 map512 := map[string]string{
163 "Int8": "Int8x64",
164 "Int16": "Int16x32",
165 "Int32": "Int32x16",
166 "Int64": "Int64x8",
167 "Uint8": "Uint8x64",
168 "Uint16": "Uint16x32",
169 "Uint32": "Uint32x16",
170 "Uint64": "Uint64x8",
171 "Float32": "Float32x16",
172 "Float64": "Float64x8",
173 "Mask8": "Mask8x64",
174 "Mask16": "Mask16x32",
175 "Mask32": "Mask32x16",
176 "Mask64": "Mask64x8",
177 }
178
179 sizeForType := make(map[string]int)
180
181 methodsByType := make(TypeMethods)
182
183 allMethodNames := make(map[string]bool)
184
185 missing := make(map[string]whyMissing)
186
187 fset := token.NewFileSet()
188
189 knownReceivers := make(map[string]string)
190 for k, v := range map128 {
191 knownReceivers[v] = k + "s"
192 sizeForType[v] = 128
193 }
194 for k, v := range map256 {
195 knownReceivers[v] = k + "s"
196 sizeForType[v] = 256
197 }
198 for k, v := range map512 {
199 knownReceivers[v] = k + "s"
200 sizeForType[v] = 512
201 }
202
203 receiver := func(funcDecl *ast.FuncDecl) string {
204 if funcDecl.Recv == nil {
205 return ""
206 }
207 recvType := ""
208 for _, field := range funcDecl.Recv.List {
209
210 if ident, ok := field.Type.(*ast.Ident); ok {
211 recvType = ident.Name
212 } else if star, ok := field.Type.(*ast.StarExpr); ok {
213 if ident, ok := star.X.(*ast.Ident); ok {
214 recvType = ident.Name
215 }
216 }
217 }
218 return recvType
219 }
220
221
222 emulated := make(map[TypeMethod]bool)
223 f, err := parser.ParseFile(fset, emulatedFile, nil, parser.ParseComments)
224 if err != nil {
225 log.Fatalf("Failed to parse %s: %v", emulatedFile, err)
226 }
227
228 for _, decl := range f.Decls {
229 if funcDecl, ok := decl.(*ast.FuncDecl); ok {
230 if receiver := receiver(funcDecl); receiver != "" {
231 method := funcDecl.Name.Name
232
233 if m, _ := utf8.DecodeRuneInString(method); unicode.IsUpper(m) {
234 emulated[TypeMethod{receiver, method}] = true
235 }
236 }
237 }
238 }
239
240 for _, aaf := range archAndFiles {
241 for _, fname := range aaf.files {
242 path := filepath.Join(archSimdPath, fname)
243 f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
244 if err != nil {
245 log.Fatalf("Failed to parse %s: %v", path, err)
246 }
247
248 lci := 0
249 fComments := f.Comments
250
251 for _, decl := range f.Decls {
252 if funcDecl, ok := decl.(*ast.FuncDecl); ok {
253
254 lastComment := ""
255 for ; lci < len(fComments) && fComments[lci].Pos() > funcDecl.Pos(); lci++ {
256 lastComment = fComments[lci].Text()
257 }
258
259 recvType := receiver(funcDecl)
260
261 if recvType == "" || knownReceivers[recvType] == "" {
262 continue
263 }
264
265 methodName := funcDecl.Name.Name
266
267 if strings.Contains(funcDecl.Doc.Text(), "Deprecated:") {
268 pv("Skipping deprecated %s.%s\n", recvType, methodName)
269 continue
270 }
271
272 if strings.Contains(lastComment, "Deprecated:") {
273 pv("Skipping MAYBE deprecated %s.%s (check comment)\n", recvType, methodName)
274 continue
275 }
276
277 if sizeForType[recvType] == 128 {
278 if s := funcDecl.Doc.Text(); strings.Contains(s, "AVX512") || strings.Contains(s, "AVX2") {
279 pv("Skipping 128-bit %s.%s because AVX2/AVX512\n", recvType, methodName)
280 continue
281 }
282 }
283 if sizeForType[recvType] == 256 {
284 if s := funcDecl.Doc.Text(); strings.Contains(s, "AVX512") {
285 pv("Skipping 256-bit %s.%s because AVX512\n", recvType, methodName)
286 continue
287 }
288 }
289
290 eltType := recvType[:strings.Index(recvType, "x")]
291
292
293 if xAt := strings.Index(methodName, "x"); xAt != -1 && (strings.HasPrefix(methodName, "As") || strings.HasPrefix(methodName, "ToInt") && strings.HasPrefix(eltType, "Mask")) {
294
295
296 methodName = methodName[:xAt] + "s"
297 } else if strings.HasPrefix(methodName, "Broadcast") {
298
299 } else {
300
301 if strings.Contains(methodName, "Group") {
302 pv("Skipping grouped method %s.%s\n", recvType, methodName)
303 continue
304 }
305 if methodName == "StoreArray" || methodName == "StoreMasked" {
306 pv("Skipping fixed-size Store method method %s.%s\n", recvType, methodName)
307 continue
308 }
309 if methodName == "ToBits" && recvType[0] == 'M' {
310 pv("Skipping Mask ToBits method (has varying return type) %s.%s\n", recvType, methodName)
311 continue
312 }
313 if lastChar := methodName[len(methodName)-1]; unicode.IsDigit(rune(lastChar)) && lastChar != eltType[len(eltType)-1] {
314 pv("Skipping size-changing method %s.%s\n", recvType, methodName)
315 continue
316 }
317 }
318
319 archReceiver := combine(aaf.arch, recvType)
320
321 if methodsByType[archReceiver] == nil {
322 methodsByType[archReceiver] = make(MethodSet)
323 }
324 methodsByType[archReceiver][methodName] = funcDecl
325 allMethodNames[methodName] = true
326 }
327 }
328 }
329 }
330
331 type ElemMethod struct {
332 e, m string
333 }
334
335 intersectionByElem := make(map[string][]string)
336 signatureByElemMethod := make(map[ElemMethod]*ast.FuncDecl)
337
338
339 elems := []string{"Int8", "Int16", "Int32", "Int64", "Uint8", "Uint16", "Uint32", "Uint64", "Float32", "Float64", "Mask8", "Mask16", "Mask32", "Mask64"}
340
341 for _, elem := range elems {
342 type128 := map128[elem]
343 type256 := map256[elem]
344 type512 := map512[elem]
345
346 methods128w := methodsByType[combine("wasm", type128)]
347 methods128n := methodsByType[combine("arm64", type128)]
348 methods128 := methodsByType[combine("amd64", type128)]
349 methods256 := methodsByType[combine("amd64", type256)]
350 methods512 := methodsByType[combine("amd64", type512)]
351
352 var intersection []string
353 var missingNames []string
354 for m := range allMethodNames {
355 if wasm128, arm128, amd128, amd256, amd512 :=
356 methods128w[m] == nil, methods128n[m] == nil, methods128[m] == nil, methods256[m] == nil, methods512[m] == nil; !wasm128 && !arm128 && !amd128 && !amd256 && !amd512 {
357 intersection = append(intersection, m)
358 signatureByElemMethod[ElemMethod{elem, m}] = methods512[m]
359 } else if !(wasm128 && arm128 && amd128 && amd256 && amd512) {
360 missing[m] = whyMissing{wasm128, arm128, amd128, amd256, amd512}
361 missingNames = append(missingNames, m)
362 }
363 }
364 sort.Strings(missingNames)
365
366 for _, m := range missingNames {
367 pv("Missing implementation for %ss.%s on %s\n", elem, m, missing[m].String())
368 }
369
370 sort.Strings(intersection)
371
372 intersectionByElem[elem] = intersection
373 }
374
375
376
377
378 var xlateType func(ast.Expr) string
379 xlateType = func(e ast.Expr) string {
380 switch t := e.(type) {
381 case *ast.Ident:
382 if mapped, ok := knownReceivers[t.Name]; ok {
383 return mapped
384 }
385 return t.Name
386 case *ast.StarExpr:
387 return "*" + xlateType(t.X)
388 case *ast.ArrayType:
389 lenStr := ""
390 if t.Len != nil {
391 var buf strings.Builder
392 format.Node(&buf, token.NewFileSet(), t.Len)
393 lenStr = buf.String()
394 }
395 return "[" + lenStr + "]" + xlateType(t.Elt)
396 case *ast.SelectorExpr:
397 return xlateType(t.X) + "." + t.Sel.Name
398 case *ast.Ellipsis:
399 return "..." + xlateType(t.Elt)
400 default:
401 var buf strings.Builder
402 format.Node(&buf, token.NewFileSet(), t)
403 return buf.String()
404 }
405 }
406
407 toScalar := func(s string) string {
408 if strings.HasPrefix(s, "Mask") {
409 return "int" + s[4:]
410 }
411 return strings.ToLower(s)
412 }
413
414 doTypes := func(w io.Writer) {
415
416 pf := func(f string, s ...any) { fmt.Fprintf(w, f, s...) }
417
418 fmt.Fprintln(w,
419 `// Code generated by 'go run -C $GOROOT/src/simd/archsimd/_gen/midway'; DO NOT EDIT.
420
421 //go:build goexperiment.simd
422
423 // Scalable vector types for rewriting and emulation
424
425 package simd
426
427 import "simd/internal/bridge"
428
429 // internal SIMD marker, and hard dependence on simd/internal/bridge
430 type _simd bridge.ZeroSized
431 `)
432
433 for _, elem := range elems {
434 if c := comments.Types[elem+"s"]; c != "" {
435 pf("// %s\n", c)
436 }
437 pf("type %ss struct {\n\t_ _simd\n\ta, b uint64 // the actual vector size may be larger.\n}\n", elem)
438 }
439 }
440
441 doMethods := func(w io.Writer) {
442
443 p := func(s ...any) { fmt.Fprint(w, s...) }
444 pf := func(f string, s ...any) { fmt.Fprintf(w, f, s...) }
445 nl := func() { fmt.Fprintln(w) }
446
447 fmt.Fprintln(w,
448 `// Code generated by 'go run -C $GOROOT/src/simd/archsimd/_gen/midway'; DO NOT EDIT.
449
450 //go:build goexperiment.simd && (amd64 || wasm || arm64)
451
452 // Computed intersection of methods for supported SIMD architectures and vector widths
453
454 package simd
455
456 `)
457
458 for _, elem := range elems {
459 intersection := intersectionByElem[elem]
460
461 if elem[0] != 'M' {
462
463
464 loadComment := comments.Functions["Load"+elem]
465 if loadComment == "" && comments.Functions["default_LoadSlice"] != "" {
466 loadComment = fmt.Sprintf(comments.Functions["default_LoadSlice"], elem, toScalar(elem), elem)
467 }
468 if loadComment != "" {
469 pf("// %s\n", loadComment)
470 }
471 pf("func Load%ss([]%s) %ss\n", elem, toScalar(elem), elem)
472
473 loadPartComment := comments.Functions["Load"+elem+"Part"]
474 if loadPartComment == "" && comments.Functions["default_LoadPart"] != "" {
475 loadPartComment = fmt.Sprintf(comments.Functions["default_LoadPart"], elem, toScalar(elem), elem)
476 }
477 if loadPartComment != "" {
478 pf("// %s\n", loadPartComment)
479 }
480 pf("func Load%ssPart([]%s) (%ss, int)\n", elem, toScalar(elem), elem)
481
482 broadcastComment := comments.Functions["Broadcast"+elem]
483 if broadcastComment == "" && comments.Functions["default_Broadcast"] != "" {
484 broadcastComment = fmt.Sprintf(comments.Functions["default_Broadcast"], elem)
485 }
486 if broadcastComment != "" {
487 pf("// %s\n", broadcastComment)
488 }
489 pf("func Broadcast%ss(%s) %ss\n", elem, toScalar(elem), elem)
490 }
491
492 for _, m := range intersection {
493 fd := signatureByElemMethod[ElemMethod{elem, m}]
494 elems := elem + "s"
495 methodComment := ""
496 if typeMethods, ok := comments.Methods[elem+"s"]; ok {
497 methodComment = typeMethods[m]
498 }
499 if methodComment != "" {
500 pf("// %s\n", methodComment)
501 } else {
502 pw("Missing doc comment (in midway/comments.yaml) for %s.%s\n", elems, m)
503 }
504 pf("func (x %s) %s(", elems, m)
505
506 if !emulated[TypeMethod{elems, m}] {
507 pw("Missing emulated method for %s.%s\n", elems, m)
508 } else {
509 delete(emulated, TypeMethod{elems, m})
510 }
511
512 if fd.Type.Params != nil {
513 for i, field := range fd.Type.Params.List {
514 if i > 0 {
515 p(", ")
516 }
517 if len(field.Names) > 0 {
518 for j, name := range field.Names {
519 if j > 0 {
520 p(", ")
521 }
522 p(name.Name)
523 }
524 p(" ")
525 }
526 p(xlateType(field.Type))
527 }
528 }
529 p(")")
530
531 if fd.Type.Results != nil && len(fd.Type.Results.List) > 0 {
532 p(" ")
533 needsParens := len(fd.Type.Results.List) > 1 || (len(fd.Type.Results.List) == 1 && len(fd.Type.Results.List[0].Names) > 0)
534 if needsParens {
535 p("(")
536 }
537 for i, field := range fd.Type.Results.List {
538 if i > 0 {
539 p(", ")
540 }
541 if len(field.Names) > 0 {
542 for j, name := range field.Names {
543 if j > 0 {
544 p(", ")
545 }
546 p(name.Name)
547 }
548 p(" ")
549 }
550 p(xlateType(field.Type))
551 }
552 if needsParens {
553 p(")")
554 }
555 }
556 nl()
557 }
558 }
559 }
560
561 formatAndWrite(*goRoot+"/src/simd/simd_types.go", doTypes)
562 formatAndWrite(*goRoot+"/src/simd/simd_stubs.go", doMethods)
563
564 var extraMocks []TypeMethod
565 for x := range emulated {
566 extraMocks = append(extraMocks, x)
567 }
568 slices.SortFunc(extraMocks, func(a, b TypeMethod) int {
569 if c := strings.Compare(a.t, b.t); c != 0 {
570 return c
571 }
572 return strings.Compare(a.m, b.m)
573 })
574
575 for _, x := range extraMocks {
576 pw("%s contains %s.%s missing from intersected methods\n", emulatedFile, x.t, x.m)
577 }
578
579 for _, aaf := range archAndFiles {
580 arch := aaf.arch
581 doArchWrites := func(w io.Writer) {
582 p := func(s ...any) { fmt.Fprint(w, s...) }
583 pf := func(f string, s ...any) { fmt.Fprintf(w, f, s...) }
584 nl := func() { fmt.Fprintln(w) }
585
586 pf("// Code generated by 'go run -C $GOROOT/src/simd/archsimd/_gen/midway'; DO NOT EDIT.\n\n")
587 pf("//go:build goexperiment.simd && %s\n\n", arch)
588 pf("package bridge\n\n")
589 pf("import \"simd/archsimd\"\n\n")
590 pf("\n")
591 pf("// These types/methods/functions forward calls to their counterparts in simd/archsimd.\n")
592 pf("// Interposing this package allows a clean separation of \"simd\" from \"archsimd\" and\n")
593 pf("// also allows additional useful exported declarations that would weirdly pollute archsimd.\n")
594 pf("\n")
595
596 var typesForArch []string
597 for t := range knownReceivers {
598 if methodsByType[combine(arch, t)] != nil {
599 typesForArch = append(typesForArch, t)
600 }
601 }
602 sort.Strings(typesForArch)
603
604 toScalar := func(s string) string {
605 if strings.HasPrefix(s, "Mask") {
606 return "int" + s[4:]
607 }
608 return strings.ToLower(s)
609 }
610
611 for _, t := range typesForArch {
612 pf("type %s archsimd.%s\n", t, t)
613 if xAt := strings.Index(t, "x"); xAt != -1 && !strings.HasPrefix(t, "Mask") {
614 elem := t[:xAt]
615 scalar := toScalar(elem)
616 pf("func Load%s(s []%s) %s {\n\treturn %s(archsimd.Load%s(s))\n}\n", t, scalar, t, t, t)
617 pf("func Load%sPart(s []%s) (%s, int) {\n\tv, n := archsimd.Load%sPart(s)\n\treturn %s(v), n\n}\n", t, scalar, t, t, t)
618 pf("func Broadcast%s(x %s) %s {\n\treturn %s(archsimd.Broadcast%s(x))\n}\n", t, scalar, t, t, t)
619 }
620 }
621 nl()
622
623 typeStr := func(e ast.Expr) string {
624 var buf strings.Builder
625 format.Node(&buf, token.NewFileSet(), e)
626 return buf.String()
627 }
628
629 convertArg := func(name string, e ast.Expr) string {
630 switch t := e.(type) {
631 case *ast.Ident:
632 if _, ok := knownReceivers[t.Name]; ok {
633 return fmt.Sprintf("archsimd.%s(%s)", t.Name, name)
634 }
635 case *ast.StarExpr:
636 if ident, ok := t.X.(*ast.Ident); ok {
637 if _, ok := knownReceivers[ident.Name]; ok {
638 return fmt.Sprintf("(*archsimd.%s)(%s)", ident.Name, name)
639 }
640 }
641 }
642 return name
643 }
644
645 wrapResult := func(call string, e ast.Expr) string {
646 switch t := e.(type) {
647 case *ast.Ident:
648 if _, ok := knownReceivers[t.Name]; ok {
649 return fmt.Sprintf("%s(%s)", t.Name, call)
650 }
651 case *ast.StarExpr:
652 if ident, ok := t.X.(*ast.Ident); ok {
653 if _, ok := knownReceivers[ident.Name]; ok {
654 return fmt.Sprintf("(*%s)(%s)", ident.Name, call)
655 }
656 }
657 }
658 return call
659 }
660
661 for _, elem := range elems {
662 intersection := intersectionByElem[elem]
663 for _, m := range intersection {
664 for _, t := range typesForArch {
665 if map128[elem] != t && map256[elem] != t && map512[elem] != t {
666 continue
667 }
668 fd := methodsByType[combine(arch, t)][m]
669 if fd == nil {
670 continue
671 }
672 pf("func (x %s) %s(", t, fd.Name.Name)
673 var args []string
674 if fd.Type.Params != nil {
675 paramCount := 0
676 for _, field := range fd.Type.Params.List {
677 if len(field.Names) > 0 {
678 for _, name := range field.Names {
679 if paramCount > 0 {
680 p(", ")
681 }
682 pf("%s %s", name.Name, typeStr(field.Type))
683 args = append(args, convertArg(name.Name, field.Type))
684 paramCount++
685 }
686 } else {
687 if paramCount > 0 {
688 p(", ")
689 }
690 paramName := fmt.Sprintf("p%d", paramCount)
691 pf("%s %s", paramName, typeStr(field.Type))
692 args = append(args, convertArg(paramName, field.Type))
693 paramCount++
694 }
695 }
696 }
697 p(")")
698
699 var results []ast.Expr
700 if fd.Type.Results != nil {
701 p(" ")
702 needsParens := len(fd.Type.Results.List) > 1 || (len(fd.Type.Results.List) == 1 && len(fd.Type.Results.List[0].Names) > 0)
703 if needsParens {
704 p("(")
705 }
706 for i, field := range fd.Type.Results.List {
707 if i > 0 {
708 p(", ")
709 }
710 results = append(results, field.Type)
711 p(typeStr(field.Type))
712 }
713 if needsParens {
714 p(")")
715 }
716 }
717
718 p(" {\n\t")
719 if len(results) > 0 {
720 p("return ")
721 }
722
723 callStr := fmt.Sprintf("(archsimd.%s(x)).%s(%s)", t, fd.Name.Name, strings.Join(args, ", "))
724 if len(results) == 1 {
725 p(wrapResult(callStr, results[0]))
726 } else {
727 p(callStr)
728 }
729 p("\n}\n\n")
730 }
731 }
732 }
733 }
734 archDir := filepath.Join(*goRoot, "src", "simd", "internal", "bridge")
735 os.MkdirAll(archDir, 0755)
736 filename := filepath.Join(archDir, "decls_"+arch+".go")
737 formatAndWrite(filename, doArchWrites)
738
739 doToFromWrites := func(w io.Writer) {
740 pf := func(f string, s ...any) { fmt.Fprintf(w, f, s...) }
741
742 pf("// Code generated by 'go run -C $GOROOT/src/simd/archsimd/_gen/midway'; DO NOT EDIT.\n\n")
743 pf("//go:build goexperiment.simd && %s\n\n", arch)
744 pf("package simd\n\n")
745 pf("import (\n\t\"simd/archsimd\"\n\t\"simd/internal/bridge\"\n)\n\n")
746
747 for _, elem := range elems {
748 var archTypes []string
749 if methodsByType[combine(arch, map128[elem])] != nil {
750 archTypes = append(archTypes, map128[elem])
751 }
752 if methodsByType[combine(arch, map256[elem])] != nil {
753 archTypes = append(archTypes, map256[elem])
754 }
755 if methodsByType[combine(arch, map512[elem])] != nil {
756 archTypes = append(archTypes, map512[elem])
757 }
758
759 if len(archTypes) == 0 {
760 continue
761 }
762
763 pf("func (x %ss) ToArch() any\n\n", elem)
764
765 var intfOpts []string
766 for _, t := range archTypes {
767 intfOpts = append(intfOpts, "archsimd."+t)
768 }
769 pf("type archSimd%ss interface {\n\t%s\n}\n\n", elem, strings.Join(intfOpts, " | "))
770
771 pf("func %ssFromArch[T archSimd%ss](x T) %ss {\n", elem, elem, elem)
772 pf("\tswitch a := any(x).(type) {\n")
773 pf("\t// The return expression is written this way because the code will be rewritten\n")
774 pf("\t// with %ss replaced by one of the arch types, and without the any-assert\n", elem)
775 pf("\t// hack the rewritten code would not pass type checking.\n")
776 pf("\t// The backend of the compiler will eat this and turn it into no code at all,\n")
777 pf("\t// assuming it inlines.\n")
778
779 for _, t := range archTypes {
780 pf("\tcase archsimd.%s:\n", t)
781 pf("\t\tvar t bridge.%s = bridge.%s(a)\n", t, t)
782 pf("\t\treturn (any(t)).(%ss)\n", elem)
783 }
784 pf("\t}\n\tpanic(\"wrong type\")\n}\n\n")
785 }
786 }
787 toFromFilename := filepath.Join(*goRoot, "src", "simd", "tofrom_"+arch+".go")
788 formatAndWrite(toFromFilename, doToFromWrites)
789 }
790
791 if minorProblem {
792 pw("The logged warnings did not prevent generation of the midway API files, but the API is flawed (lacks emulations, documentation, etc).\n")
793 }
794 }
795
796
797
798 func numberLines(data []byte) string {
799 var buf bytes.Buffer
800 r := bytes.NewReader(data)
801 s := bufio.NewScanner(r)
802 for i := 1; s.Scan(); i++ {
803 fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
804 }
805 return buf.String()
806 }
807
808 func formatAndWrite(filename string, doWrites func(w io.Writer)) {
809 if filename == "" {
810 return
811 }
812 f, err := os.Create(filename)
813 if err != nil {
814 log.Fatal(err)
815 }
816 defer f.Close()
817
818 out := new(bytes.Buffer)
819 doWrites(out)
820
821 b, err := format.Source(out.Bytes())
822 if err != nil {
823 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
824 fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
825 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
826 os.Exit(1)
827 } else {
828 f.Write(b)
829 f.Close()
830 }
831 }
832
View as plain text