1
2
3
4
5 package ssa
6
7 import (
8 "cmd/compile/internal/types"
9 "cmd/internal/src"
10 "cmp"
11 "fmt"
12 "slices"
13 )
14
15
16
17
18 func cse(f *Func) {
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 a := f.Cache.allocValueSlice(f.NumValues())
36 defer func() { f.Cache.freeValueSlice(a) }()
37 a = a[:0]
38 o := f.Cache.allocInt32Slice(f.NumValues())
39 defer func() { f.Cache.freeInt32Slice(o) }()
40 if f.auxmap == nil {
41 f.auxmap = auxmap{}
42 }
43 for _, b := range f.Blocks {
44 for _, v := range b.Values {
45 if v.Type.IsMemory() {
46 continue
47 }
48 if f.auxmap[v.Aux] == 0 {
49 f.auxmap[v.Aux] = int32(len(f.auxmap)) + 1
50 }
51 a = append(a, v)
52 }
53 }
54 partition := partitionValues(a, f.auxmap)
55
56
57 valueEqClass := f.Cache.allocIDSlice(f.NumValues())
58 defer f.Cache.freeIDSlice(valueEqClass)
59 for _, b := range f.Blocks {
60 for _, v := range b.Values {
61
62 valueEqClass[v.ID] = -v.ID
63 }
64 }
65 var pNum ID = 1
66 for _, e := range partition {
67 if f.pass.debug > 1 && len(e) > 500 {
68 fmt.Printf("CSE.large partition (%d): ", len(e))
69 for j := 0; j < 3; j++ {
70 fmt.Printf("%s ", e[j].LongString())
71 }
72 fmt.Println()
73 }
74
75 for _, v := range e {
76 valueEqClass[v.ID] = pNum
77 }
78 if f.pass.debug > 2 && len(e) > 1 {
79 fmt.Printf("CSE.partition #%d:", pNum)
80 for _, v := range e {
81 fmt.Printf(" %s", v.String())
82 }
83 fmt.Printf("\n")
84 }
85 pNum++
86 }
87
88
89
90 memTable := f.Cache.allocInt32Slice(f.NumValues())
91 defer f.Cache.freeInt32Slice(memTable)
92
93
94
95
96 var splitPoints []int
97 for {
98 changed := false
99
100
101
102 for i := 0; i < len(partition); i++ {
103 e := partition[i]
104
105 if opcodeTable[e[0].Op].commutative {
106
107 for _, v := range e {
108 if valueEqClass[v.Args[0].ID] > valueEqClass[v.Args[1].ID] {
109 v.Args[0], v.Args[1] = v.Args[1], v.Args[0]
110 }
111 }
112 }
113
114
115 slices.SortFunc(e, func(v, w *Value) int {
116 _, idxMem, _, _ := isMemUser(v)
117 for i, a := range v.Args {
118 var aId, bId ID
119 if i != idxMem {
120 b := w.Args[i]
121 aId = a.ID
122 bId = b.ID
123 } else {
124
125
126 aId, _ = getEffectiveMemoryArg(memTable, v)
127 bId, _ = getEffectiveMemoryArg(memTable, w)
128 }
129 if valueEqClass[aId] < valueEqClass[bId] {
130 return -1
131 }
132 if valueEqClass[aId] > valueEqClass[bId] {
133 return +1
134 }
135 }
136 return 0
137 })
138
139
140 splitPoints = append(splitPoints[:0], 0)
141 for j := 1; j < len(e); j++ {
142 v, w := e[j-1], e[j]
143
144 eqArgs := true
145 _, idxMem, _, _ := isMemUser(v)
146 for k, a := range v.Args {
147 if v.Op == OpLocalAddr && k == 1 {
148 continue
149 }
150 var aId, bId ID
151 if k != idxMem {
152 b := w.Args[k]
153 aId = a.ID
154 bId = b.ID
155 } else {
156
157
158 aId, _ = getEffectiveMemoryArg(memTable, v)
159 bId, _ = getEffectiveMemoryArg(memTable, w)
160 }
161 if valueEqClass[aId] != valueEqClass[bId] {
162 eqArgs = false
163 break
164 }
165 }
166 if !eqArgs {
167 splitPoints = append(splitPoints, j)
168 }
169 }
170 if len(splitPoints) == 1 {
171 continue
172 }
173
174
175 partition[i] = partition[len(partition)-1]
176 partition = partition[:len(partition)-1]
177 i--
178
179
180 splitPoints = append(splitPoints, len(e))
181 for j := 0; j < len(splitPoints)-1; j++ {
182 f := e[splitPoints[j]:splitPoints[j+1]]
183 if len(f) == 1 {
184
185 valueEqClass[f[0].ID] = -f[0].ID
186 continue
187 }
188 for _, v := range f {
189 valueEqClass[v.ID] = pNum
190 }
191 pNum++
192 partition = append(partition, f)
193 }
194 changed = true
195 }
196
197 if !changed {
198 break
199 }
200 }
201
202 sdom := f.Sdom()
203
204
205
206 rewrite := f.Cache.allocValueSlice(f.NumValues())
207 defer f.Cache.freeValueSlice(rewrite)
208 for _, e := range partition {
209 slices.SortFunc(e, func(v, w *Value) int {
210 if c := cmp.Compare(sdom.domorder(v.Block), sdom.domorder(w.Block)); c != 0 {
211 return c
212 }
213 if _, _, _, ok := isMemUser(v); ok {
214
215
216
217 _, vSkips := getEffectiveMemoryArg(memTable, v)
218 _, wSkips := getEffectiveMemoryArg(memTable, w)
219 if c := cmp.Compare(vSkips, wSkips); c != 0 {
220 return c
221 }
222 }
223 if v.Op == OpLocalAddr {
224
225 vm := v.Args[1]
226 wm := w.Args[1]
227 if vm == wm {
228 return 0
229 }
230
231
232
233 if vm.Block != v.Block {
234 return -1
235 }
236 if wm.Block != w.Block {
237 return +1
238 }
239
240 vs := storeOrdering(vm, o)
241 ws := storeOrdering(wm, o)
242 if vs <= 0 {
243 f.Fatalf("unable to determine the order of %s", vm.LongString())
244 }
245 if ws <= 0 {
246 f.Fatalf("unable to determine the order of %s", wm.LongString())
247 }
248 return cmp.Compare(vs, ws)
249 }
250 vStmt := v.Pos.IsStmt() == src.PosIsStmt
251 wStmt := w.Pos.IsStmt() == src.PosIsStmt
252 if vStmt != wStmt {
253 if vStmt {
254 return -1
255 }
256 return +1
257 }
258 return 0
259 })
260
261 for i := 0; i < len(e)-1; i++ {
262
263 v := e[i]
264 if v == nil {
265 continue
266 }
267
268 e[i] = nil
269
270 for j := i + 1; j < len(e); j++ {
271 w := e[j]
272 if w == nil {
273 continue
274 }
275 if sdom.IsAncestorEq(v.Block, w.Block) {
276 rewrite[w.ID] = v
277 e[j] = nil
278 } else {
279
280 break
281 }
282 }
283 }
284 }
285
286 rewrites := int64(0)
287
288
289 for _, b := range f.Blocks {
290 for _, v := range b.Values {
291 for i, w := range v.Args {
292 if x := rewrite[w.ID]; x != nil {
293 if w.Pos.IsStmt() == src.PosIsStmt && w.Op != OpNilCheck {
294
295
296
297 if w.Block == v.Block && w.Pos.Line() == v.Pos.Line() {
298 v.Pos = v.Pos.WithIsStmt()
299 w.Pos = w.Pos.WithNotStmt()
300 }
301 }
302 v.SetArg(i, x)
303 rewrites++
304 }
305 }
306 }
307 for i, v := range b.ControlValues() {
308 if x := rewrite[v.ID]; x != nil {
309 if v.Op == OpNilCheck {
310
311
312 continue
313 }
314 b.ReplaceControl(i, x)
315 }
316 }
317 }
318
319 if f.pass.stats > 0 {
320 f.LogStat("CSE REWRITES", rewrites)
321 }
322 }
323
324
325
326
327
328 func storeOrdering(v *Value, cache []int32) int32 {
329 const minScore int32 = 1
330 score := minScore
331 w := v
332 for {
333 if s := cache[w.ID]; s >= minScore {
334 score += s
335 break
336 }
337 if w.Op == OpPhi || w.Op == OpInitMem {
338 break
339 }
340 a := w.MemoryArg()
341 if a.Block != w.Block {
342 break
343 }
344 w = a
345 score++
346 }
347 w = v
348 for cache[w.ID] == 0 {
349 cache[w.ID] = score
350 if score == minScore {
351 break
352 }
353 w = w.MemoryArg()
354 score--
355 }
356 return cache[v.ID]
357 }
358
359
360
361
362 type eqclass []*Value
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379 func partitionValues(a []*Value, auxIDs auxmap) []eqclass {
380 slices.SortFunc(a, func(v, w *Value) int {
381 switch cmpVal(v, w, auxIDs) {
382 case types.CMPlt:
383 return -1
384 case types.CMPgt:
385 return +1
386 default:
387
388 return cmp.Compare(v.ID, w.ID)
389 }
390 })
391
392 var partition []eqclass
393 for len(a) > 0 {
394 v := a[0]
395 j := 1
396 for ; j < len(a); j++ {
397 w := a[j]
398 if cmpVal(v, w, auxIDs) != types.CMPeq {
399 break
400 }
401 }
402 if j > 1 {
403 partition = append(partition, a[:j])
404 }
405 a = a[j:]
406 }
407
408 return partition
409 }
410 func lt2Cmp(isLt bool) types.Cmp {
411 if isLt {
412 return types.CMPlt
413 }
414 return types.CMPgt
415 }
416
417 type auxmap map[Aux]int32
418
419 func cmpVal(v, w *Value, auxIDs auxmap) types.Cmp {
420
421 if v.Op != w.Op {
422 return lt2Cmp(v.Op < w.Op)
423 }
424 if v.AuxInt != w.AuxInt {
425 return lt2Cmp(v.AuxInt < w.AuxInt)
426 }
427 if len(v.Args) != len(w.Args) {
428 return lt2Cmp(len(v.Args) < len(w.Args))
429 }
430 if v.Op == OpPhi && v.Block != w.Block {
431 return lt2Cmp(v.Block.ID < w.Block.ID)
432 }
433 if v.Type.IsMemory() {
434
435
436 return lt2Cmp(v.ID < w.ID)
437 }
438
439
440
441 if v.Op != OpSelect0 && v.Op != OpSelect1 && v.Op != OpSelectN {
442 if tc := v.Type.Compare(w.Type); tc != types.CMPeq {
443 return tc
444 }
445 }
446
447 if v.Aux != w.Aux {
448 if v.Aux == nil {
449 return types.CMPlt
450 }
451 if w.Aux == nil {
452 return types.CMPgt
453 }
454 return lt2Cmp(auxIDs[v.Aux] < auxIDs[w.Aux])
455 }
456
457 return types.CMPeq
458 }
459
460
461
462 func isMemUser(v *Value) (int, int, int64, bool) {
463 switch v.Op {
464 case OpLoad:
465 return 0, 1, v.Type.Size(), true
466 case OpNilCheck:
467 return 0, 1, 0, true
468 default:
469 return -1, -1, 0, false
470 }
471 }
472
473
474
475 func isMemDef(v *Value) (int, int, int64, bool) {
476 switch v.Op {
477 case OpStore:
478 return 0, 2, auxToType(v.Aux).Size(), true
479 default:
480 return -1, -1, 0, false
481 }
482 }
483
484
485
486 const memTableSkipBits = 8
487
488
489 const maxId = ID(1<<(31-memTableSkipBits)) - 1
490
491
492 func getEffectiveMemoryArg(memTable []int32, v *Value) (ID, uint32) {
493 if code := uint32(memTable[v.ID]); code != 0 {
494 return ID(code >> memTableSkipBits), code & ((1 << memTableSkipBits) - 1)
495 }
496 if idxPtr, idxMem, width, ok := isMemUser(v); ok {
497
498 memId := v.Args[idxMem].ID
499 if memId > maxId {
500 return memId, 0
501 }
502 mem, skips := skipDisjointMemDefs(v, idxPtr, idxMem, width)
503 if mem.ID <= maxId {
504 memId = mem.ID
505 } else {
506 skips = 0
507 }
508 memTable[v.ID] = int32(memId<<memTableSkipBits) | int32(skips)
509 return memId, skips
510 } else {
511 v.Block.Func.Fatalf("expected memory user instruction: %v", v.LongString())
512 }
513 return 0, 0
514 }
515
516
517
518 func skipDisjointMemDefs(user *Value, idxUserPtr, idxUserMem int, useWidth int64) (*Value, uint32) {
519 usePtr, mem := user.Args[idxUserPtr], user.Args[idxUserMem]
520 const maxSkips = (1 << memTableSkipBits) - 1
521 var skips uint32
522 for skips = 0; skips < maxSkips; skips++ {
523 if idxPtr, idxMem, width, ok := isMemDef(mem); ok {
524 if mem.Args[idxMem].Uses > 50 {
525
526 break
527 }
528 defPtr := mem.Args[idxPtr]
529 if disjoint(defPtr, width, usePtr, useWidth) {
530 mem = mem.Args[idxMem]
531 continue
532 }
533 }
534 break
535 }
536 return mem, skips
537 }
538
View as plain text