1
2
3
4
5 package ssa
6
7 import (
8 "fmt"
9 "os"
10 )
11
12
13 var debugPoset = false
14
15 const uintSize = 32 << (^uint(0) >> 63)
16
17
18 type bitset []uint
19
20 func newBitset(n int) bitset {
21 return make(bitset, (n+uintSize-1)/uintSize)
22 }
23
24 func (bs bitset) Reset() {
25 for i := range bs {
26 bs[i] = 0
27 }
28 }
29
30 func (bs bitset) Set(idx uint32) {
31 bs[idx/uintSize] |= 1 << (idx % uintSize)
32 }
33
34 func (bs bitset) Clear(idx uint32) {
35 bs[idx/uintSize] &^= 1 << (idx % uintSize)
36 }
37
38 func (bs bitset) Test(idx uint32) bool {
39 return bs[idx/uintSize]&(1<<(idx%uintSize)) != 0
40 }
41
42 type undoType uint8
43
44 const (
45 undoInvalid undoType = iota
46 undoCheckpoint
47 undoSetChl
48 undoSetChr
49 undoNonEqual
50 undoNewNode
51 undoNewConstant
52 undoAliasNode
53 undoNewRoot
54 undoChangeRoot
55 undoMergeRoot
56 )
57
58
59
60
61
62 type posetUndo struct {
63 typ undoType
64 idx uint32
65 ID ID
66 edge posetEdge
67 }
68
69 const (
70
71 posetFlagUnsigned = 1 << iota
72 )
73
74
75
76 type posetEdge uint32
77
78 func newedge(t uint32, strict bool) posetEdge {
79 s := uint32(0)
80 if strict {
81 s = 1
82 }
83 return posetEdge(t<<1 | s)
84 }
85 func (e posetEdge) Target() uint32 { return uint32(e) >> 1 }
86 func (e posetEdge) Strict() bool { return uint32(e)&1 != 0 }
87 func (e posetEdge) String() string {
88 s := fmt.Sprint(e.Target())
89 if e.Strict() {
90 s += "*"
91 }
92 return s
93 }
94
95
96 type posetNode struct {
97 l, r posetEdge
98 }
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148 type poset struct {
149 lastidx uint32
150 flags uint8
151 values map[ID]uint32
152 constants map[int64]uint32
153 nodes []posetNode
154 roots []uint32
155 noneq map[uint32]bitset
156 undo []posetUndo
157 }
158
159 func newPoset() *poset {
160 return &poset{
161 values: make(map[ID]uint32),
162 constants: make(map[int64]uint32, 8),
163 nodes: make([]posetNode, 1, 16),
164 roots: make([]uint32, 0, 4),
165 noneq: make(map[uint32]bitset),
166 undo: make([]posetUndo, 0, 4),
167 }
168 }
169
170 func (po *poset) SetUnsigned(uns bool) {
171 if uns {
172 po.flags |= posetFlagUnsigned
173 } else {
174 po.flags &^= posetFlagUnsigned
175 }
176 }
177
178
179 func (po *poset) setchl(i uint32, l posetEdge) { po.nodes[i].l = l }
180 func (po *poset) setchr(i uint32, r posetEdge) { po.nodes[i].r = r }
181 func (po *poset) chl(i uint32) uint32 { return po.nodes[i].l.Target() }
182 func (po *poset) chr(i uint32) uint32 { return po.nodes[i].r.Target() }
183 func (po *poset) children(i uint32) (posetEdge, posetEdge) {
184 return po.nodes[i].l, po.nodes[i].r
185 }
186
187
188
189 func (po *poset) upush(typ undoType, p uint32, e posetEdge) {
190 po.undo = append(po.undo, posetUndo{typ: typ, idx: p, edge: e})
191 }
192
193
194 func (po *poset) upushnew(id ID, idx uint32) {
195 po.undo = append(po.undo, posetUndo{typ: undoNewNode, ID: id, idx: idx})
196 }
197
198
199 func (po *poset) upushneq(idx1 uint32, idx2 uint32) {
200 po.undo = append(po.undo, posetUndo{typ: undoNonEqual, ID: ID(idx1), idx: idx2})
201 }
202
203
204 func (po *poset) upushalias(id ID, i2 uint32) {
205 po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2})
206 }
207
208
209 func (po *poset) upushconst(idx uint32, old uint32) {
210 po.undo = append(po.undo, posetUndo{typ: undoNewConstant, idx: idx, ID: ID(old)})
211 }
212
213
214 func (po *poset) addchild(i1, i2 uint32, strict bool) {
215 i1l, i1r := po.children(i1)
216 e2 := newedge(i2, strict)
217
218 if i1l == 0 {
219 po.setchl(i1, e2)
220 po.upush(undoSetChl, i1, 0)
221 } else if i1r == 0 {
222 po.setchr(i1, e2)
223 po.upush(undoSetChr, i1, 0)
224 } else {
225
226
227
228
229
230
231
232
233
234
235
236
237 extra := po.newnode(nil)
238 if (i1^i2)&1 != 0 {
239 po.setchl(extra, i1r)
240 po.setchr(extra, e2)
241 po.setchr(i1, newedge(extra, false))
242 po.upush(undoSetChr, i1, i1r)
243 } else {
244 po.setchl(extra, i1l)
245 po.setchr(extra, e2)
246 po.setchl(i1, newedge(extra, false))
247 po.upush(undoSetChl, i1, i1l)
248 }
249 }
250 }
251
252
253
254 func (po *poset) newnode(n *Value) uint32 {
255 i := po.lastidx + 1
256 po.lastidx++
257 po.nodes = append(po.nodes, posetNode{})
258 if n != nil {
259 if po.values[n.ID] != 0 {
260 panic("newnode for Value already inserted")
261 }
262 po.values[n.ID] = i
263 po.upushnew(n.ID, i)
264 } else {
265 po.upushnew(0, i)
266 }
267 return i
268 }
269
270
271
272 func (po *poset) lookup(n *Value) (uint32, bool) {
273 i, f := po.values[n.ID]
274 if !f && n.isGenericIntConst() {
275 po.newconst(n)
276 i, f = po.values[n.ID]
277 }
278 return i, f
279 }
280
281
282
283
284 func (po *poset) newconst(n *Value) {
285 if !n.isGenericIntConst() {
286 panic("newconst on non-constant")
287 }
288
289
290
291 val := n.AuxInt
292 if po.flags&posetFlagUnsigned != 0 {
293 val = int64(n.AuxUnsigned())
294 }
295 if c, found := po.constants[val]; found {
296 po.values[n.ID] = c
297 po.upushalias(n.ID, 0)
298 return
299 }
300
301
302 i := po.newnode(n)
303
304
305
306
307
308
309 if len(po.constants) == 0 {
310 idx := len(po.roots)
311 po.roots = append(po.roots, i)
312 po.roots[0], po.roots[idx] = po.roots[idx], po.roots[0]
313 po.upush(undoNewRoot, i, 0)
314 po.constants[val] = i
315 po.upushconst(i, 0)
316 return
317 }
318
319
320
321
322
323
324 var lowerptr, higherptr uint32
325
326 if po.flags&posetFlagUnsigned != 0 {
327 var lower, higher uint64
328 val1 := n.AuxUnsigned()
329 for val2, ptr := range po.constants {
330 val2 := uint64(val2)
331 if val1 == val2 {
332 panic("unreachable")
333 }
334 if val2 < val1 && (lowerptr == 0 || val2 > lower) {
335 lower = val2
336 lowerptr = ptr
337 } else if val2 > val1 && (higherptr == 0 || val2 < higher) {
338 higher = val2
339 higherptr = ptr
340 }
341 }
342 } else {
343 var lower, higher int64
344 val1 := n.AuxInt
345 for val2, ptr := range po.constants {
346 if val1 == val2 {
347 panic("unreachable")
348 }
349 if val2 < val1 && (lowerptr == 0 || val2 > lower) {
350 lower = val2
351 lowerptr = ptr
352 } else if val2 > val1 && (higherptr == 0 || val2 < higher) {
353 higher = val2
354 higherptr = ptr
355 }
356 }
357 }
358
359 if lowerptr == 0 && higherptr == 0 {
360
361
362 panic("no constant found")
363 }
364
365
366
367
368
369
370 switch {
371 case lowerptr != 0 && higherptr != 0:
372
373 po.addchild(lowerptr, i, true)
374 po.addchild(i, higherptr, true)
375
376 case lowerptr != 0:
377
378 po.addchild(lowerptr, i, true)
379
380 case higherptr != 0:
381
382
383
384
385
386
387
388
389
390
391
392 i2 := higherptr
393 r2 := po.findroot(i2)
394 if r2 != po.roots[0] {
395 panic("constant not in root #0")
396 }
397 extra := po.newnode(nil)
398 po.changeroot(r2, extra)
399 po.upush(undoChangeRoot, extra, newedge(r2, false))
400 po.addchild(extra, r2, false)
401 po.addchild(extra, i, false)
402 po.addchild(i, i2, true)
403 }
404
405 po.constants[val] = i
406 po.upushconst(i, 0)
407 }
408
409
410
411 func (po *poset) aliasnewnode(n1, n2 *Value) {
412 i1, i2 := po.values[n1.ID], po.values[n2.ID]
413 if i1 == 0 || i2 != 0 {
414 panic("aliasnewnode invalid arguments")
415 }
416
417 po.values[n2.ID] = i1
418 po.upushalias(n2.ID, 0)
419 }
420
421
422
423
424
425
426 func (po *poset) aliasnodes(n1 *Value, i2s bitset) {
427 i1 := po.values[n1.ID]
428 if i1 == 0 {
429 panic("aliasnode for non-existing node")
430 }
431 if i2s.Test(i1) {
432 panic("aliasnode i2s contains n1 node")
433 }
434
435
436 for idx, n := range po.nodes {
437
438 if uint32(idx) == i1 {
439 continue
440 }
441 l, r := n.l, n.r
442
443
444 if i2s.Test(l.Target()) {
445 po.setchl(uint32(idx), newedge(i1, l.Strict()))
446 po.upush(undoSetChl, uint32(idx), l)
447 }
448 if i2s.Test(r.Target()) {
449 po.setchr(uint32(idx), newedge(i1, r.Strict()))
450 po.upush(undoSetChr, uint32(idx), r)
451 }
452
453
454
455 if i2s.Test(uint32(idx)) {
456 if l != 0 && !i2s.Test(l.Target()) {
457 po.addchild(i1, l.Target(), l.Strict())
458 }
459 if r != 0 && !i2s.Test(r.Target()) {
460 po.addchild(i1, r.Target(), r.Strict())
461 }
462 po.setchl(uint32(idx), 0)
463 po.setchr(uint32(idx), 0)
464 po.upush(undoSetChl, uint32(idx), l)
465 po.upush(undoSetChr, uint32(idx), r)
466 }
467 }
468
469
470
471 for k, v := range po.values {
472 if i2s.Test(v) {
473 po.values[k] = i1
474 po.upushalias(k, v)
475 }
476 }
477
478
479
480 for val, idx := range po.constants {
481 if i2s.Test(idx) {
482 po.constants[val] = i1
483 po.upushconst(i1, idx)
484 }
485 }
486 }
487
488 func (po *poset) isroot(r uint32) bool {
489 for i := range po.roots {
490 if po.roots[i] == r {
491 return true
492 }
493 }
494 return false
495 }
496
497 func (po *poset) changeroot(oldr, newr uint32) {
498 for i := range po.roots {
499 if po.roots[i] == oldr {
500 po.roots[i] = newr
501 return
502 }
503 }
504 panic("changeroot on non-root")
505 }
506
507 func (po *poset) removeroot(r uint32) {
508 for i := range po.roots {
509 if po.roots[i] == r {
510 po.roots = append(po.roots[:i], po.roots[i+1:]...)
511 return
512 }
513 }
514 panic("removeroot on non-root")
515 }
516
517
518
519
520
521
522
523
524
525 func (po *poset) dfs(r uint32, strict bool, f func(i uint32) bool) bool {
526 closed := newBitset(int(po.lastidx + 1))
527 open := make([]uint32, 1, 64)
528 open[0] = r
529
530 if strict {
531
532
533
534 next := make([]uint32, 0, 64)
535
536 for len(open) > 0 {
537 i := open[len(open)-1]
538 open = open[:len(open)-1]
539
540
541
542
543
544 if !closed.Test(i) {
545 closed.Set(i)
546
547 l, r := po.children(i)
548 if l != 0 {
549 if l.Strict() {
550 next = append(next, l.Target())
551 } else {
552 open = append(open, l.Target())
553 }
554 }
555 if r != 0 {
556 if r.Strict() {
557 next = append(next, r.Target())
558 } else {
559 open = append(open, r.Target())
560 }
561 }
562 }
563 }
564 open = next
565 closed.Reset()
566 }
567
568 for len(open) > 0 {
569 i := open[len(open)-1]
570 open = open[:len(open)-1]
571
572 if !closed.Test(i) {
573 if f(i) {
574 return true
575 }
576 closed.Set(i)
577 l, r := po.children(i)
578 if l != 0 {
579 open = append(open, l.Target())
580 }
581 if r != 0 {
582 open = append(open, r.Target())
583 }
584 }
585 }
586 return false
587 }
588
589
590
591
592
593 func (po *poset) reaches(i1, i2 uint32, strict bool) bool {
594 return po.dfs(i1, strict, func(n uint32) bool {
595 return n == i2
596 })
597 }
598
599
600
601
602 func (po *poset) findroot(i uint32) uint32 {
603
604
605
606 for _, r := range po.roots {
607 if po.reaches(r, i, false) {
608 return r
609 }
610 }
611 panic("findroot didn't find any root")
612 }
613
614
615 func (po *poset) mergeroot(r1, r2 uint32) uint32 {
616
617
618
619 if r2 == po.roots[0] {
620 r1, r2 = r2, r1
621 }
622 r := po.newnode(nil)
623 po.setchl(r, newedge(r1, false))
624 po.setchr(r, newedge(r2, false))
625 po.changeroot(r1, r)
626 po.removeroot(r2)
627 po.upush(undoMergeRoot, r, 0)
628 return r
629 }
630
631
632
633
634
635 func (po *poset) collapsepath(n1, n2 *Value) bool {
636 i1, i2 := po.values[n1.ID], po.values[n2.ID]
637 if po.reaches(i1, i2, true) {
638 return false
639 }
640
641
642 paths := po.findpaths(i1, i2)
643
644
645 paths.Clear(i1)
646 po.aliasnodes(n1, paths)
647 return true
648 }
649
650
651
652
653
654
655
656 func (po *poset) findpaths(cur, dst uint32) bitset {
657 seen := newBitset(int(po.lastidx + 1))
658 path := newBitset(int(po.lastidx + 1))
659 path.Set(dst)
660 po.findpaths1(cur, dst, seen, path)
661 return path
662 }
663
664 func (po *poset) findpaths1(cur, dst uint32, seen bitset, path bitset) {
665 if cur == dst {
666 return
667 }
668 seen.Set(cur)
669 l, r := po.chl(cur), po.chr(cur)
670 if !seen.Test(l) {
671 po.findpaths1(l, dst, seen, path)
672 }
673 if !seen.Test(r) {
674 po.findpaths1(r, dst, seen, path)
675 }
676 if path.Test(l) || path.Test(r) {
677 path.Set(cur)
678 }
679 }
680
681
682 func (po *poset) isnoneq(i1, i2 uint32) bool {
683 if i1 == i2 {
684 return false
685 }
686 if i1 < i2 {
687 i1, i2 = i2, i1
688 }
689
690
691 if bs, ok := po.noneq[i1]; ok && bs.Test(i2) {
692 return true
693 }
694 return false
695 }
696
697
698 func (po *poset) setnoneq(n1, n2 *Value) {
699 i1, f1 := po.lookup(n1)
700 i2, f2 := po.lookup(n2)
701
702
703
704
705 if !f1 {
706 i1 = po.newnode(n1)
707 po.roots = append(po.roots, i1)
708 po.upush(undoNewRoot, i1, 0)
709 }
710 if !f2 {
711 i2 = po.newnode(n2)
712 po.roots = append(po.roots, i2)
713 po.upush(undoNewRoot, i2, 0)
714 }
715
716 if i1 == i2 {
717 panic("setnoneq on same node")
718 }
719 if i1 < i2 {
720 i1, i2 = i2, i1
721 }
722 bs := po.noneq[i1]
723 if bs == nil {
724
725
726
727
728 bs = newBitset(int(i1))
729 po.noneq[i1] = bs
730 } else if bs.Test(i2) {
731
732 return
733 }
734 bs.Set(i2)
735 po.upushneq(i1, i2)
736 }
737
738
739
740 func (po *poset) CheckIntegrity() {
741
742 constants := newBitset(int(po.lastidx + 1))
743 for _, c := range po.constants {
744 constants.Set(c)
745 }
746
747
748
749 seen := newBitset(int(po.lastidx + 1))
750 for ridx, r := range po.roots {
751 if r == 0 {
752 panic("empty root")
753 }
754
755 po.dfs(r, false, func(i uint32) bool {
756 if seen.Test(i) {
757 panic("duplicate node")
758 }
759 seen.Set(i)
760 if constants.Test(i) {
761 if ridx != 0 {
762 panic("constants not in the first DAG")
763 }
764 }
765 return false
766 })
767 }
768
769
770 for id, idx := range po.values {
771 if !seen.Test(idx) {
772 panic(fmt.Errorf("spurious value [%d]=%d", id, idx))
773 }
774 }
775
776
777 for i, n := range po.nodes {
778 if n.l|n.r != 0 {
779 if !seen.Test(uint32(i)) {
780 panic(fmt.Errorf("children of unknown node %d->%v", i, n))
781 }
782 if n.l.Target() == uint32(i) || n.r.Target() == uint32(i) {
783 panic(fmt.Errorf("self-loop on node %d", i))
784 }
785 }
786 }
787 }
788
789
790
791
792 func (po *poset) CheckEmpty() error {
793 if len(po.nodes) != 1 {
794 return fmt.Errorf("non-empty nodes list: %v", po.nodes)
795 }
796 if len(po.values) != 0 {
797 return fmt.Errorf("non-empty value map: %v", po.values)
798 }
799 if len(po.roots) != 0 {
800 return fmt.Errorf("non-empty root list: %v", po.roots)
801 }
802 if len(po.constants) != 0 {
803 return fmt.Errorf("non-empty constants: %v", po.constants)
804 }
805 if len(po.undo) != 0 {
806 return fmt.Errorf("non-empty undo list: %v", po.undo)
807 }
808 if po.lastidx != 0 {
809 return fmt.Errorf("lastidx index is not zero: %v", po.lastidx)
810 }
811 for _, bs := range po.noneq {
812 for _, x := range bs {
813 if x != 0 {
814 return fmt.Errorf("non-empty noneq map")
815 }
816 }
817 }
818 return nil
819 }
820
821
822 func (po *poset) DotDump(fn string, title string) error {
823 f, err := os.Create(fn)
824 if err != nil {
825 return err
826 }
827 defer f.Close()
828
829
830 names := make(map[uint32]string)
831 for id, i := range po.values {
832 s := names[i]
833 if s == "" {
834 s = fmt.Sprintf("v%d", id)
835 } else {
836 s += fmt.Sprintf(", v%d", id)
837 }
838 names[i] = s
839 }
840
841
842 consts := make(map[uint32]int64)
843 for val, idx := range po.constants {
844 consts[idx] = val
845 }
846
847 fmt.Fprintf(f, "digraph poset {\n")
848 fmt.Fprintf(f, "\tedge [ fontsize=10 ]\n")
849 for ridx, r := range po.roots {
850 fmt.Fprintf(f, "\tsubgraph root%d {\n", ridx)
851 po.dfs(r, false, func(i uint32) bool {
852 if val, ok := consts[i]; ok {
853
854 var vals string
855 if po.flags&posetFlagUnsigned != 0 {
856 vals = fmt.Sprint(uint64(val))
857 } else {
858 vals = fmt.Sprint(int64(val))
859 }
860 fmt.Fprintf(f, "\t\tnode%d [shape=box style=filled fillcolor=cadetblue1 label=<%s <font point-size=\"6\">%s [%d]</font>>]\n",
861 i, vals, names[i], i)
862 } else {
863
864 fmt.Fprintf(f, "\t\tnode%d [label=<%s <font point-size=\"6\">[%d]</font>>]\n", i, names[i], i)
865 }
866 chl, chr := po.children(i)
867 for _, ch := range []posetEdge{chl, chr} {
868 if ch != 0 {
869 if ch.Strict() {
870 fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <\" color=\"red\"]\n", i, ch.Target())
871 } else {
872 fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <=\" color=\"green\"]\n", i, ch.Target())
873 }
874 }
875 }
876 return false
877 })
878 fmt.Fprintf(f, "\t}\n")
879 }
880 fmt.Fprintf(f, "\tlabelloc=\"t\"\n")
881 fmt.Fprintf(f, "\tlabeldistance=\"3.0\"\n")
882 fmt.Fprintf(f, "\tlabel=%q\n", title)
883 fmt.Fprintf(f, "}\n")
884 return nil
885 }
886
887
888
889
890
891 func (po *poset) Ordered(n1, n2 *Value) bool {
892 if debugPoset {
893 defer po.CheckIntegrity()
894 }
895 if n1.ID == n2.ID {
896 panic("should not call Ordered with n1==n2")
897 }
898
899 i1, f1 := po.lookup(n1)
900 i2, f2 := po.lookup(n2)
901 if !f1 || !f2 {
902 return false
903 }
904
905 return i1 != i2 && po.reaches(i1, i2, true)
906 }
907
908
909
910
911
912 func (po *poset) OrderedOrEqual(n1, n2 *Value) bool {
913 if debugPoset {
914 defer po.CheckIntegrity()
915 }
916 if n1.ID == n2.ID {
917 panic("should not call Ordered with n1==n2")
918 }
919
920 i1, f1 := po.lookup(n1)
921 i2, f2 := po.lookup(n2)
922 if !f1 || !f2 {
923 return false
924 }
925
926 return i1 == i2 || po.reaches(i1, i2, false)
927 }
928
929
930
931
932
933 func (po *poset) Equal(n1, n2 *Value) bool {
934 if debugPoset {
935 defer po.CheckIntegrity()
936 }
937 if n1.ID == n2.ID {
938 panic("should not call Equal with n1==n2")
939 }
940
941 i1, f1 := po.lookup(n1)
942 i2, f2 := po.lookup(n2)
943 return f1 && f2 && i1 == i2
944 }
945
946
947
948
949
950
951 func (po *poset) NonEqual(n1, n2 *Value) bool {
952 if debugPoset {
953 defer po.CheckIntegrity()
954 }
955 if n1.ID == n2.ID {
956 panic("should not call NonEqual with n1==n2")
957 }
958
959
960
961 i1, f1 := po.lookup(n1)
962 i2, f2 := po.lookup(n2)
963 if !f1 || !f2 {
964 return false
965 }
966
967
968 if po.isnoneq(i1, i2) {
969 return true
970 }
971
972
973 if po.Ordered(n1, n2) || po.Ordered(n2, n1) {
974 return true
975 }
976
977 return false
978 }
979
980
981
982
983 func (po *poset) setOrder(n1, n2 *Value, strict bool) bool {
984 i1, f1 := po.lookup(n1)
985 i2, f2 := po.lookup(n2)
986
987 switch {
988 case !f1 && !f2:
989
990
991
992 i1, i2 = po.newnode(n1), po.newnode(n2)
993 po.roots = append(po.roots, i1)
994 po.upush(undoNewRoot, i1, 0)
995 po.addchild(i1, i2, strict)
996
997 case f1 && !f2:
998
999
1000 i2 = po.newnode(n2)
1001 po.addchild(i1, i2, strict)
1002
1003 case !f1 && f2:
1004
1005
1006
1007 i1 = po.newnode(n1)
1008
1009 if po.isroot(i2) {
1010 po.changeroot(i2, i1)
1011 po.upush(undoChangeRoot, i1, newedge(i2, strict))
1012 po.addchild(i1, i2, strict)
1013 return true
1014 }
1015
1016
1017
1018 r := po.findroot(i2)
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028 extra := po.newnode(nil)
1029 po.changeroot(r, extra)
1030 po.upush(undoChangeRoot, extra, newedge(r, false))
1031 po.addchild(extra, r, false)
1032 po.addchild(extra, i1, false)
1033 po.addchild(i1, i2, strict)
1034
1035 case f1 && f2:
1036
1037
1038 if i1 == i2 {
1039 return !strict
1040 }
1041
1042
1043
1044 if !strict && po.isnoneq(i1, i2) {
1045 strict = true
1046 }
1047
1048
1049
1050
1051
1052 if po.reaches(i1, i2, false) {
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063 if strict && !po.reaches(i1, i2, true) {
1064 po.addchild(i1, i2, true)
1065 return true
1066 }
1067
1068
1069 return true
1070 }
1071
1072
1073 if po.reaches(i2, i1, false) {
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083 if strict {
1084
1085 return false
1086 }
1087
1088
1089
1090 return po.collapsepath(n2, n1)
1091 }
1092
1093
1094
1095
1096 r1, r2 := po.findroot(i1), po.findroot(i2)
1097 if r1 != r2 {
1098
1099 po.mergeroot(r1, r2)
1100 }
1101
1102
1103 po.addchild(i1, i2, strict)
1104 }
1105
1106 return true
1107 }
1108
1109
1110
1111 func (po *poset) SetOrder(n1, n2 *Value) bool {
1112 if debugPoset {
1113 defer po.CheckIntegrity()
1114 }
1115 if n1.ID == n2.ID {
1116 panic("should not call SetOrder with n1==n2")
1117 }
1118 return po.setOrder(n1, n2, true)
1119 }
1120
1121
1122
1123 func (po *poset) SetOrderOrEqual(n1, n2 *Value) bool {
1124 if debugPoset {
1125 defer po.CheckIntegrity()
1126 }
1127 if n1.ID == n2.ID {
1128 panic("should not call SetOrder with n1==n2")
1129 }
1130 return po.setOrder(n1, n2, false)
1131 }
1132
1133
1134
1135
1136 func (po *poset) SetEqual(n1, n2 *Value) bool {
1137 if debugPoset {
1138 defer po.CheckIntegrity()
1139 }
1140 if n1.ID == n2.ID {
1141 panic("should not call Add with n1==n2")
1142 }
1143
1144 i1, f1 := po.lookup(n1)
1145 i2, f2 := po.lookup(n2)
1146
1147 switch {
1148 case !f1 && !f2:
1149 i1 = po.newnode(n1)
1150 po.roots = append(po.roots, i1)
1151 po.upush(undoNewRoot, i1, 0)
1152 po.aliasnewnode(n1, n2)
1153 case f1 && !f2:
1154 po.aliasnewnode(n1, n2)
1155 case !f1 && f2:
1156 po.aliasnewnode(n2, n1)
1157 case f1 && f2:
1158 if i1 == i2 {
1159
1160 return true
1161 }
1162
1163
1164 if po.isnoneq(i1, i2) {
1165 return false
1166 }
1167
1168
1169
1170 if po.reaches(i1, i2, false) {
1171 return po.collapsepath(n1, n2)
1172 }
1173 if po.reaches(i2, i1, false) {
1174 return po.collapsepath(n2, n1)
1175 }
1176
1177 r1 := po.findroot(i1)
1178 r2 := po.findroot(i2)
1179 if r1 != r2 {
1180
1181 po.mergeroot(r1, r2)
1182 }
1183
1184
1185
1186 i2s := newBitset(int(po.lastidx) + 1)
1187 i2s.Set(i2)
1188 po.aliasnodes(n1, i2s)
1189 }
1190 return true
1191 }
1192
1193
1194
1195
1196 func (po *poset) SetNonEqual(n1, n2 *Value) bool {
1197 if debugPoset {
1198 defer po.CheckIntegrity()
1199 }
1200 if n1.ID == n2.ID {
1201 panic("should not call SetNonEqual with n1==n2")
1202 }
1203
1204
1205 i1, f1 := po.lookup(n1)
1206 i2, f2 := po.lookup(n2)
1207
1208
1209
1210 if !f1 || !f2 {
1211 po.setnoneq(n1, n2)
1212 return true
1213 }
1214
1215
1216 if po.isnoneq(i1, i2) {
1217 return true
1218 }
1219
1220
1221 if po.Equal(n1, n2) {
1222 return false
1223 }
1224
1225
1226 po.setnoneq(n1, n2)
1227
1228
1229
1230
1231
1232 if po.reaches(i1, i2, false) && !po.reaches(i1, i2, true) {
1233 po.addchild(i1, i2, true)
1234 }
1235 if po.reaches(i2, i1, false) && !po.reaches(i2, i1, true) {
1236 po.addchild(i2, i1, true)
1237 }
1238
1239 return true
1240 }
1241
1242
1243
1244
1245 func (po *poset) Checkpoint() {
1246 po.undo = append(po.undo, posetUndo{typ: undoCheckpoint})
1247 }
1248
1249
1250
1251
1252
1253 func (po *poset) Undo() {
1254 if len(po.undo) == 0 {
1255 panic("empty undo stack")
1256 }
1257 if debugPoset {
1258 defer po.CheckIntegrity()
1259 }
1260
1261 for len(po.undo) > 0 {
1262 pass := po.undo[len(po.undo)-1]
1263 po.undo = po.undo[:len(po.undo)-1]
1264
1265 switch pass.typ {
1266 case undoCheckpoint:
1267 return
1268
1269 case undoSetChl:
1270 po.setchl(pass.idx, pass.edge)
1271
1272 case undoSetChr:
1273 po.setchr(pass.idx, pass.edge)
1274
1275 case undoNonEqual:
1276 po.noneq[uint32(pass.ID)].Clear(pass.idx)
1277
1278 case undoNewNode:
1279 if pass.idx != po.lastidx {
1280 panic("invalid newnode index")
1281 }
1282 if pass.ID != 0 {
1283 if po.values[pass.ID] != pass.idx {
1284 panic("invalid newnode undo pass")
1285 }
1286 delete(po.values, pass.ID)
1287 }
1288 po.setchl(pass.idx, 0)
1289 po.setchr(pass.idx, 0)
1290 po.nodes = po.nodes[:pass.idx]
1291 po.lastidx--
1292
1293 case undoNewConstant:
1294
1295 var val int64
1296 var i uint32
1297 for val, i = range po.constants {
1298 if i == pass.idx {
1299 break
1300 }
1301 }
1302 if i != pass.idx {
1303 panic("constant not found in undo pass")
1304 }
1305 if pass.ID == 0 {
1306 delete(po.constants, val)
1307 } else {
1308
1309
1310 oldidx := uint32(pass.ID)
1311 po.constants[val] = oldidx
1312 }
1313
1314 case undoAliasNode:
1315 ID, prev := pass.ID, pass.idx
1316 cur := po.values[ID]
1317 if prev == 0 {
1318
1319 delete(po.values, ID)
1320 } else {
1321 if cur == prev {
1322 panic("invalid aliasnode undo pass")
1323 }
1324
1325 po.values[ID] = prev
1326 }
1327
1328 case undoNewRoot:
1329 i := pass.idx
1330 l, r := po.children(i)
1331 if l|r != 0 {
1332 panic("non-empty root in undo newroot")
1333 }
1334 po.removeroot(i)
1335
1336 case undoChangeRoot:
1337 i := pass.idx
1338 l, r := po.children(i)
1339 if l|r != 0 {
1340 panic("non-empty root in undo changeroot")
1341 }
1342 po.changeroot(i, pass.edge.Target())
1343
1344 case undoMergeRoot:
1345 i := pass.idx
1346 l, r := po.children(i)
1347 po.changeroot(i, l.Target())
1348 po.roots = append(po.roots, r.Target())
1349
1350 default:
1351 panic(pass.typ)
1352 }
1353 }
1354
1355 if debugPoset && po.CheckEmpty() != nil {
1356 panic("poset not empty at the end of undo")
1357 }
1358 }
1359
View as plain text