1
2
3
4
5 package midway
6
7 import (
8 "cmd/compile/internal/syntax"
9 "cmd/compile/internal/types2"
10 "fmt"
11 "internal/buildcfg"
12 "strings"
13 )
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40 type Rewriter struct {
41 pkg *types2.Package
42 analyzer *Analyzer
43 info *types2.Info
44 sizes []int
45 }
46
47 func NewRewriter(pkg *types2.Package, info *types2.Info, analyzer *Analyzer, sizes []int) *Rewriter {
48 return &Rewriter{
49 pkg: pkg,
50 info: info,
51 analyzer: analyzer,
52 sizes: sizes,
53 }
54 }
55
56 func (r *Rewriter) Rewrite(files []*syntax.File) {
57
58
59 for _, fileAST := range files {
60
61 var newDecls []syntax.Decl
62 for _, k := range r.sizes {
63 newDecls = r.generateForSize(fileAST, k, newDecls)
64 }
65
66
67 r.generateDispatchers(fileAST)
68
69 fileAST.DeclList = append(fileAST.DeclList, newDecls...)
70 }
71 }
72
73 func (r *Rewriter) generateDispatchers(fileAST *syntax.File) {
74 var newDecls []syntax.Decl
75
76 for _, decl := range fileAST.DeclList {
77 switch d := decl.(type) {
78 case *syntax.FuncDecl:
79 if d.Name == nil {
80 newDecls = append(newDecls, d)
81 continue
82 }
83 obj := r.info.Defs[d.Name]
84 if !r.analyzer.dependentObj[obj] || r.analyzer.inSimd {
85 newDecls = append(newDecls, d)
86 continue
87 }
88
89 sig, ok := obj.Type().(*types2.Signature)
90 if !ok {
91 newDecls = append(newDecls, d)
92 continue
93 }
94
95 if r.analyzer.HasDependentSignature(sig) {
96
97 continue
98 }
99
100
101 d.Body = r.createDispatcherBody(d, sig)
102 newDecls = append(newDecls, d)
103
104 case *syntax.VarDecl:
105
106 keep := false
107 for _, name := range d.NameList {
108 if !r.analyzer.dependentObj[r.info.Defs[name]] {
109 keep = true
110 break
111 }
112 }
113 if keep {
114 newDecls = append(newDecls, d)
115 }
116 case *syntax.TypeDecl:
117 if !r.analyzer.dependentObj[r.info.Defs[d.Name]] || r.analyzer.inSimd {
118 newDecls = append(newDecls, d)
119 }
120 default:
121 newDecls = append(newDecls, decl)
122 }
123 }
124
125 fileAST.DeclList = newDecls
126
127 if !r.analyzer.inSimd {
128
129 hasArchSimd := false
130 var simdImport *syntax.ImportDecl
131 for _, decl := range fileAST.DeclList {
132 if imp, ok := decl.(*syntax.ImportDecl); ok {
133 if imp.Path.Value == `"`+archFullPkg+`"` {
134 hasArchSimd = true
135 }
136 if imp.Path.Value == `"`+simdPkg+`"` {
137 simdImport = imp
138 }
139
140 }
141 }
142 p := simdImport.Pos()
143 if !hasArchSimd {
144 r.injectImport(fileAST, archFullPkg, p)
145 }
146
147
148
149 fun := &syntax.SelectorExpr{
150 X: syntax.NewName(p, simdPkg),
151 Sel: syntax.NewName(p, vectorSizeFn),
152 }
153 fun.SetPos(p)
154 call := &syntax.CallExpr{Fun: fun}
155 call.SetPos(p)
156
157 name := syntax.NewName(p, "_")
158
159 varDecl := &syntax.VarDecl{NameList: []*syntax.Name{name}, Values: call}
160 varDecl.SetPos(p)
161 fileAST.DeclList = append(fileAST.DeclList, varDecl)
162 }
163 }
164
165 func (r *Rewriter) injectImport(fileAST *syntax.File, toImport string, simdImportPos syntax.Pos) {
166 importDecl := &syntax.ImportDecl{
167 Path: &syntax.BasicLit{Value: `"` + toImport + `"`, Kind: syntax.StringLit},
168 }
169 importDecl.Path.SetPos(simdImportPos)
170 importDecl.SetPos(simdImportPos)
171 fileAST.DeclList = append([]syntax.Decl{importDecl}, fileAST.DeclList...)
172 }
173
174 func (r *Rewriter) createDispatcherBody(d *syntax.FuncDecl, sig *types2.Signature) *syntax.BlockStmt {
175
176
177 args := func() []syntax.Expr {
178 var args []syntax.Expr
179 if d.Type.ParamList != nil {
180 for _, field := range d.Type.ParamList {
181 if field.Name != nil {
182 paramName := syntax.NewName(field.Pos(), field.Name.Value)
183 args = append(args, paramName)
184 }
185 }
186 }
187 return args
188 }
189
190
191 pe := func(e syntax.Expr) syntax.Expr {
192 e.SetPos(d.Pos())
193 return e
194 }
195
196 ps := func(e syntax.Stmt) syntax.Stmt {
197 e.SetPos(d.Pos())
198 return e
199 }
200
201
202
203
204
205
206
207
208 switchStmt := &syntax.SwitchStmt{
209 Tag: pe(&syntax.CallExpr{
210 Fun: pe(&syntax.SelectorExpr{
211 X: syntax.NewName(d.Pos(), simdPkg),
212 Sel: syntax.NewName(d.Pos(), vectorSizeFn),
213 }),
214 }),
215 Body: []*syntax.CaseClause{},
216 }
217
218 for _, k := range r.sizes {
219 fnName := fmt.Sprintf("%s@simd%d", d.Name.Value, k)
220 fnIdent := syntax.NewName(d.Pos(), fnName)
221
222 callExpr := pe(&syntax.CallExpr{
223 Fun: pe(fnIdent),
224 ArgList: args(),
225 })
226
227 var branchStmt syntax.Stmt
228 if d.Type.ResultList != nil && len(d.Type.ResultList) > 0 {
229 branchStmt = &syntax.ReturnStmt{Results: callExpr}
230 } else {
231 branchStmt = &syntax.BlockStmt{
232 List: []syntax.Stmt{
233 ps(&syntax.ExprStmt{X: callExpr}),
234 ps(&syntax.ReturnStmt{}),
235 },
236 }
237 }
238 branchStmt.SetPos(d.Pos())
239
240 caseClause := &syntax.CaseClause{
241 Cases: pe(&syntax.BasicLit{Kind: syntax.IntLit, Value: fmt.Sprintf("%d", k)}),
242 Body: []syntax.Stmt{branchStmt},
243 }
244 caseClause.SetPos(d.Pos())
245 switchStmt.Body = append(switchStmt.Body, caseClause)
246 }
247
248 fnName := "panic"
249 fnIdent := pe(syntax.NewName(d.Pos(), fnName))
250
251 callExpr := pe(&syntax.CallExpr{
252 Fun: fnIdent,
253 ArgList: []syntax.Expr{pe(&syntax.BasicLit{Value: "\"unsupported vector size in simd-rewritten code\"", Kind: syntax.StringLit})},
254 })
255
256 panicStmt := &syntax.ExprStmt{X: callExpr}
257 blockStmt := &syntax.BlockStmt{List: []syntax.Stmt{ps(switchStmt), ps(panicStmt)}}
258
259 blockStmt.SetPos(d.Pos())
260
261 return blockStmt
262 }
263
264 func (r *Rewriter) generateForSize(fileAST *syntax.File, k int, newDecls []syntax.Decl) []syntax.Decl {
265 copier := NewDeepCopier(r.pkg, r.info, k, r.analyzer, fmt.Sprintf("@simd%d", k))
266 for _, decl := range fileAST.DeclList {
267 if r.shouldIncludeDecl(decl) {
268 newDecl := copier.CopyDecl(decl)
269 newDecls = append(newDecls, newDecl)
270 }
271 }
272 return newDecls
273 }
274
275 func nameToElemBitWidth(name string) int {
276 var width int
277 switch name {
278 case "Int8s", "Uint8s", "Mask8s":
279 width = 8
280 case "Int16s", "Uint16s", "Mask16s":
281 width = 16
282 case "Int32s", "Uint32s", "Float32s", "Mask32s":
283 width = 32
284 case "Int64s", "Uint64s", "Float64s", "Mask64s":
285 width = 64
286 }
287 return width
288 }
289
290 func (r *Rewriter) shouldIncludeDecl(decl syntax.Decl) bool {
291
292
293
294 if r.analyzer.inSimd {
295 theFile := decl.Pos().Base().Filename()
296
297 if simdSlash := strings.LastIndex(theFile, simdPkg+"/"); simdSlash == -1 || !strings.HasPrefix(theFile[simdSlash:], simdPkg+"/tofrom_") {
298 return false
299 }
300 }
301
302 switch d := decl.(type) {
303 case *syntax.FuncDecl:
304 if d.Name != nil {
305 return r.analyzer.dependentObj[r.info.Defs[d.Name]]
306 }
307 case *syntax.TypeDecl:
308 return r.analyzer.dependentObj[r.info.Defs[d.Name]]
309 case *syntax.VarDecl:
310 for _, name := range d.NameList {
311 if r.analyzer.dependentObj[r.info.Defs[name]] {
312 return true
313 }
314 }
315 }
316 return false
317 }
318
319
320 func RewriteWrapper(pkg *types2.Package, info *types2.Info, files []*syntax.File) bool {
321 if !buildcfg.Experiment.SIMD {
322 return false
323 }
324
325 switch buildcfg.GOARCH {
326 case "wasm", "amd64", "arm64":
327 default:
328 return false
329 }
330
331 sizes := rewriteSizes()
332 if len(sizes) == 0 {
333 return false
334 }
335 analyzer := NewAnalyzer(pkg, info)
336 if !analyzer.Analyze(files) {
337 return false
338 }
339
340 CheckPositions(files, "before midway")
341
342 rewriter := NewRewriter(pkg, info, analyzer, sizes)
343 rewriter.Rewrite(files)
344
345 CheckPositions(files, "after midway")
346
347 return true
348 }
349
View as plain text