1
2
3
4
5 package abt
6
7 import (
8 "fmt"
9 "strconv"
10 "strings"
11 )
12
13 const (
14 LEAF_HEIGHT = 1
15 ZERO_HEIGHT = 0
16 NOT_KEY32 = int32(-0x80000000)
17 )
18
19
20
21
22 type T struct {
23 root *node32
24 size int
25 }
26
27
28 type node32 struct {
29
30 left, right *node32
31 data interface{}
32 key int32
33 height_ int8
34 }
35
36 func makeNode(key int32) *node32 {
37 return &node32{key: key, height_: LEAF_HEIGHT}
38 }
39
40
41 func (t *T) IsEmpty() bool {
42 return t.root == nil
43 }
44
45
46 func (t *T) IsSingle() bool {
47 return t.root != nil && t.root.isLeaf()
48 }
49
50
51
52 func (t *T) VisitInOrder(f func(int32, interface{})) {
53 if t.root == nil {
54 return
55 }
56 t.root.visitInOrder(f)
57 }
58
59 func (n *node32) nilOrData() interface{} {
60 if n == nil {
61 return nil
62 }
63 return n.data
64 }
65
66 func (n *node32) nilOrKeyAndData() (k int32, d interface{}) {
67 if n == nil {
68 k = NOT_KEY32
69 d = nil
70 } else {
71 k = n.key
72 d = n.data
73 }
74 return
75 }
76
77 func (n *node32) height() int8 {
78 if n == nil {
79 return 0
80 }
81 return n.height_
82 }
83
84
85
86 func (t *T) Find(x int32) interface{} {
87 return t.root.find(x).nilOrData()
88 }
89
90
91
92
93
94
95 func (t *T) Insert(x int32, data interface{}) interface{} {
96 if x == NOT_KEY32 {
97 panic("Cannot use sentinel value -0x80000000 as key")
98 }
99 n := t.root
100 var newroot *node32
101 var o *node32
102 if n == nil {
103 n = makeNode(x)
104 newroot = n
105 } else {
106 newroot, n, o = n.aInsert(x)
107 }
108 var r interface{}
109 if o != nil {
110 r = o.data
111 } else {
112 t.size++
113 }
114 n.data = data
115 t.root = newroot
116 return r
117 }
118
119 func (t *T) Copy() *T {
120 u := *t
121 return &u
122 }
123
124 func (t *T) Delete(x int32) interface{} {
125 n := t.root
126 if n == nil {
127 return nil
128 }
129 d, s := n.aDelete(x)
130 if d == nil {
131 return nil
132 }
133 t.root = s
134 t.size--
135 return d.data
136 }
137
138 func (t *T) DeleteMin() (int32, interface{}) {
139 n := t.root
140 if n == nil {
141 return NOT_KEY32, nil
142 }
143 d, s := n.aDeleteMin()
144 if d == nil {
145 return NOT_KEY32, nil
146 }
147 t.root = s
148 t.size--
149 return d.key, d.data
150 }
151
152 func (t *T) DeleteMax() (int32, interface{}) {
153 n := t.root
154 if n == nil {
155 return NOT_KEY32, nil
156 }
157 d, s := n.aDeleteMax()
158 if d == nil {
159 return NOT_KEY32, nil
160 }
161 t.root = s
162 t.size--
163 return d.key, d.data
164 }
165
166 func (t *T) Size() int {
167 return t.size
168 }
169
170
171
172
173
174
175 func (t *T) Intersection(u *T, f func(x, y interface{}) interface{}) *T {
176 if t.Size() == 0 || u.Size() == 0 {
177 return &T{}
178 }
179
180
181 if t.Size() <= u.Size() {
182 v := t.Copy()
183 for it := t.Iterator(); !it.Done(); {
184 k, d := it.Next()
185 e := u.Find(k)
186 if e == nil {
187 v.Delete(k)
188 continue
189 }
190 if f == nil {
191 continue
192 }
193 if c := f(d, e); c != d {
194 if c == nil {
195 v.Delete(k)
196 } else {
197 v.Insert(k, c)
198 }
199 }
200 }
201 return v
202 }
203 v := u.Copy()
204 for it := u.Iterator(); !it.Done(); {
205 k, e := it.Next()
206 d := t.Find(k)
207 if d == nil {
208 v.Delete(k)
209 continue
210 }
211 if f == nil {
212 continue
213 }
214 if c := f(d, e); c != d {
215 if c == nil {
216 v.Delete(k)
217 } else {
218 v.Insert(k, c)
219 }
220 }
221 }
222
223 return v
224 }
225
226
227
228
229
230 func (t *T) Union(u *T, f func(x, y interface{}) interface{}) *T {
231 if t.Size() == 0 {
232 return u
233 }
234 if u.Size() == 0 {
235 return t
236 }
237
238 if t.Size() >= u.Size() {
239 v := t.Copy()
240 for it := u.Iterator(); !it.Done(); {
241 k, e := it.Next()
242 d := t.Find(k)
243 if d == nil {
244 v.Insert(k, e)
245 continue
246 }
247 if f == nil {
248 continue
249 }
250 if c := f(d, e); c != d {
251 if c == nil {
252 v.Delete(k)
253 } else {
254 v.Insert(k, c)
255 }
256 }
257 }
258 return v
259 }
260
261 v := u.Copy()
262 for it := t.Iterator(); !it.Done(); {
263 k, d := it.Next()
264 e := u.Find(k)
265 if e == nil {
266 v.Insert(k, d)
267 continue
268 }
269 if f == nil {
270 continue
271 }
272 if c := f(d, e); c != d {
273 if c == nil {
274 v.Delete(k)
275 } else {
276 v.Insert(k, c)
277 }
278 }
279 }
280 return v
281 }
282
283
284
285
286
287 func (t *T) Difference(u *T, f func(x, y interface{}) interface{}) *T {
288 if t.Size() == 0 {
289 return &T{}
290 }
291 if u.Size() == 0 {
292 return t
293 }
294 v := t.Copy()
295 for it := t.Iterator(); !it.Done(); {
296 k, d := it.Next()
297 e := u.Find(k)
298 if e != nil {
299 if f == nil {
300 v.Delete(k)
301 continue
302 }
303 c := f(d, e)
304 if c == nil {
305 v.Delete(k)
306 continue
307 }
308 if c != d {
309 v.Insert(k, c)
310 }
311 }
312 }
313 return v
314 }
315
316 func (t *T) Iterator() Iterator {
317 return Iterator{it: t.root.iterator()}
318 }
319
320 func (t *T) Equals(u *T) bool {
321 if t == u {
322 return true
323 }
324 if t.Size() != u.Size() {
325 return false
326 }
327 return t.root.equals(u.root)
328 }
329
330 func (t *T) String() string {
331 var b strings.Builder
332 first := true
333 for it := t.Iterator(); !it.Done(); {
334 k, v := it.Next()
335 if first {
336 first = false
337 } else {
338 b.WriteString("; ")
339 }
340 b.WriteString(strconv.FormatInt(int64(k), 10))
341 b.WriteString(":")
342 fmt.Fprint(&b, v)
343 }
344 return b.String()
345 }
346
347 func (t *node32) equals(u *node32) bool {
348 if t == u {
349 return true
350 }
351 it, iu := t.iterator(), u.iterator()
352 for !it.done() && !iu.done() {
353 nt := it.next()
354 nu := iu.next()
355 if nt == nu {
356 continue
357 }
358 if nt.key != nu.key {
359 return false
360 }
361 if nt.data != nu.data {
362 return false
363 }
364 }
365 return it.done() == iu.done()
366 }
367
368 func (t *T) Equiv(u *T, eqv func(x, y interface{}) bool) bool {
369 if t == u {
370 return true
371 }
372 if t.Size() != u.Size() {
373 return false
374 }
375 return t.root.equiv(u.root, eqv)
376 }
377
378 func (t *node32) equiv(u *node32, eqv func(x, y interface{}) bool) bool {
379 if t == u {
380 return true
381 }
382 it, iu := t.iterator(), u.iterator()
383 for !it.done() && !iu.done() {
384 nt := it.next()
385 nu := iu.next()
386 if nt == nu {
387 continue
388 }
389 if nt.key != nu.key {
390 return false
391 }
392 if !eqv(nt.data, nu.data) {
393 return false
394 }
395 }
396 return it.done() == iu.done()
397 }
398
399 type iterator struct {
400 parents []*node32
401 }
402
403 type Iterator struct {
404 it iterator
405 }
406
407 func (it *Iterator) Next() (int32, interface{}) {
408 x := it.it.next()
409 if x == nil {
410 return NOT_KEY32, nil
411 }
412 return x.key, x.data
413 }
414
415 func (it *Iterator) Done() bool {
416 return len(it.it.parents) == 0
417 }
418
419 func (t *node32) iterator() iterator {
420 if t == nil {
421 return iterator{}
422 }
423 it := iterator{parents: make([]*node32, 0, int(t.height()))}
424 it.leftmost(t)
425 return it
426 }
427
428 func (it *iterator) leftmost(t *node32) {
429 for t != nil {
430 it.parents = append(it.parents, t)
431 t = t.left
432 }
433 }
434
435 func (it *iterator) done() bool {
436 return len(it.parents) == 0
437 }
438
439 func (it *iterator) next() *node32 {
440 l := len(it.parents)
441 if l == 0 {
442 return nil
443 }
444 x := it.parents[l-1]
445 if x.right != nil {
446 it.leftmost(x.right)
447 return x
448 }
449
450 l--
451 it.parents = it.parents[:l]
452 y := x
453 for l > 0 && y == it.parents[l-1].right {
454 y = it.parents[l-1]
455 l--
456 it.parents = it.parents[:l]
457 }
458
459 return x
460 }
461
462
463
464 func (t *T) Min() (k int32, d interface{}) {
465 return t.root.min().nilOrKeyAndData()
466 }
467
468
469
470 func (t *T) Max() (k int32, d interface{}) {
471 return t.root.max().nilOrKeyAndData()
472 }
473
474
475
476 func (t *T) Glb(x int32) (k int32, d interface{}) {
477 return t.root.glb(x, false).nilOrKeyAndData()
478 }
479
480
481
482 func (t *T) GlbEq(x int32) (k int32, d interface{}) {
483 return t.root.glb(x, true).nilOrKeyAndData()
484 }
485
486
487
488 func (t *T) Lub(x int32) (k int32, d interface{}) {
489 return t.root.lub(x, false).nilOrKeyAndData()
490 }
491
492
493
494 func (t *T) LubEq(x int32) (k int32, d interface{}) {
495 return t.root.lub(x, true).nilOrKeyAndData()
496 }
497
498 func (t *node32) isLeaf() bool {
499 return t.left == nil && t.right == nil && t.height_ == LEAF_HEIGHT
500 }
501
502 func (t *node32) visitInOrder(f func(int32, interface{})) {
503 if t.left != nil {
504 t.left.visitInOrder(f)
505 }
506 f(t.key, t.data)
507 if t.right != nil {
508 t.right.visitInOrder(f)
509 }
510 }
511
512 func (t *node32) find(key int32) *node32 {
513 for t != nil {
514 if key < t.key {
515 t = t.left
516 } else if key > t.key {
517 t = t.right
518 } else {
519 return t
520 }
521 }
522 return nil
523 }
524
525 func (t *node32) min() *node32 {
526 if t == nil {
527 return t
528 }
529 for t.left != nil {
530 t = t.left
531 }
532 return t
533 }
534
535 func (t *node32) max() *node32 {
536 if t == nil {
537 return t
538 }
539 for t.right != nil {
540 t = t.right
541 }
542 return t
543 }
544
545 func (t *node32) glb(key int32, allow_eq bool) *node32 {
546 var best *node32 = nil
547 for t != nil {
548 if key <= t.key {
549 if allow_eq && key == t.key {
550 return t
551 }
552
553 t = t.left
554 } else {
555
556 best = t
557 t = t.right
558 }
559 }
560 return best
561 }
562
563 func (t *node32) lub(key int32, allow_eq bool) *node32 {
564 var best *node32 = nil
565 for t != nil {
566 if key >= t.key {
567 if allow_eq && key == t.key {
568 return t
569 }
570
571 t = t.right
572 } else {
573
574 best = t
575 t = t.left
576 }
577 }
578 return best
579 }
580
581 func (t *node32) aInsert(x int32) (newroot, newnode, oldnode *node32) {
582
583 if x == t.key {
584 oldnode = t
585 newt := *t
586 newnode = &newt
587 newroot = newnode
588 return
589 }
590 if x < t.key {
591 if t.left == nil {
592 t = t.copy()
593 n := makeNode(x)
594 t.left = n
595 newnode = n
596 newroot = t
597 t.height_ = 2
598 return
599 }
600 var new_l *node32
601 new_l, newnode, oldnode = t.left.aInsert(x)
602 t = t.copy()
603 t.left = new_l
604 if new_l.height() > 1+t.right.height() {
605 newroot = t.aLeftIsHigh(newnode)
606 } else {
607 t.height_ = 1 + max(t.left.height(), t.right.height())
608 newroot = t
609 }
610 } else {
611 if t.right == nil {
612 t = t.copy()
613 n := makeNode(x)
614 t.right = n
615 newnode = n
616 newroot = t
617 t.height_ = 2
618 return
619 }
620 var new_r *node32
621 new_r, newnode, oldnode = t.right.aInsert(x)
622 t = t.copy()
623 t.right = new_r
624 if new_r.height() > 1+t.left.height() {
625 newroot = t.aRightIsHigh(newnode)
626 } else {
627 t.height_ = 1 + max(t.left.height(), t.right.height())
628 newroot = t
629 }
630 }
631 return
632 }
633
634 func (t *node32) aDelete(key int32) (deleted, newSubTree *node32) {
635 if t == nil {
636 return nil, nil
637 }
638
639 if key < t.key {
640 oh := t.left.height()
641 d, tleft := t.left.aDelete(key)
642 if tleft == t.left {
643 return d, t
644 }
645 return d, t.copy().aRebalanceAfterLeftDeletion(oh, tleft)
646 } else if key > t.key {
647 oh := t.right.height()
648 d, tright := t.right.aDelete(key)
649 if tright == t.right {
650 return d, t
651 }
652 return d, t.copy().aRebalanceAfterRightDeletion(oh, tright)
653 }
654
655 if t.height() == LEAF_HEIGHT {
656 return t, nil
657 }
658
659
660
661 if t.left.height() > t.right.height() {
662 oh := t.left.height()
663 d, tleft := t.left.aDeleteMax()
664 r := t
665 t = t.copy()
666 t.data, t.key = d.data, d.key
667 return r, t.aRebalanceAfterLeftDeletion(oh, tleft)
668 }
669
670 oh := t.right.height()
671 d, tright := t.right.aDeleteMin()
672 r := t
673 t = t.copy()
674 t.data, t.key = d.data, d.key
675 return r, t.aRebalanceAfterRightDeletion(oh, tright)
676 }
677
678 func (t *node32) aDeleteMin() (deleted, newSubTree *node32) {
679 if t == nil {
680 return nil, nil
681 }
682 if t.left == nil {
683 return t, t.right
684 }
685 oh := t.left.height()
686 d, tleft := t.left.aDeleteMin()
687 if tleft == t.left {
688 return d, t
689 }
690 return d, t.copy().aRebalanceAfterLeftDeletion(oh, tleft)
691 }
692
693 func (t *node32) aDeleteMax() (deleted, newSubTree *node32) {
694 if t == nil {
695 return nil, nil
696 }
697
698 if t.right == nil {
699 return t, t.left
700 }
701
702 oh := t.right.height()
703 d, tright := t.right.aDeleteMax()
704 if tright == t.right {
705 return d, t
706 }
707 return d, t.copy().aRebalanceAfterRightDeletion(oh, tright)
708 }
709
710 func (t *node32) aRebalanceAfterLeftDeletion(oldLeftHeight int8, tleft *node32) *node32 {
711 t.left = tleft
712
713 if oldLeftHeight == tleft.height() || oldLeftHeight == t.right.height() {
714
715 return t
716 }
717
718 if oldLeftHeight > t.right.height() {
719
720 t.height_--
721 return t
722 }
723
724
725 t.right = t.right.copy()
726 return t.aRightIsHigh(nil)
727 }
728
729 func (t *node32) aRebalanceAfterRightDeletion(oldRightHeight int8, tright *node32) *node32 {
730 t.right = tright
731
732 if oldRightHeight == tright.height() || oldRightHeight == t.left.height() {
733
734 return t
735 }
736
737 if oldRightHeight > t.left.height() {
738
739 t.height_--
740 return t
741 }
742
743
744 t.left = t.left.copy()
745 return t.aLeftIsHigh(nil)
746 }
747
748
749
750 func (t *node32) aRightIsHigh(newnode *node32) *node32 {
751 right := t.right
752 if right.right.height() < right.left.height() {
753
754 if newnode != right.left {
755 right.left = right.left.copy()
756 }
757 t.right = right.leftToRoot()
758 }
759 t = t.rightToRoot()
760 return t
761 }
762
763
764
765 func (t *node32) aLeftIsHigh(newnode *node32) *node32 {
766 left := t.left
767 if left.left.height() < left.right.height() {
768
769 if newnode != left.right {
770 left.right = left.right.copy()
771 }
772 t.left = left.rightToRoot()
773 }
774 t = t.leftToRoot()
775 return t
776 }
777
778
779 func (t *node32) rightToRoot() *node32 {
780
781
782
783
784
785
786
787
788
789
790 right := t.right
791 rl := right.left
792 right.left = t
793
794 t.right = rl
795 t.height_ = 1 + max(rl.height(), t.left.height())
796 right.height_ = 1 + max(t.height(), right.right.height())
797 return right
798 }
799
800
801 func (t *node32) leftToRoot() *node32 {
802
803
804
805
806
807
808
809
810
811
812 left := t.left
813 lr := left.right
814 left.right = t
815
816 t.left = lr
817 t.height_ = 1 + max(lr.height(), t.right.height())
818 left.height_ = 1 + max(t.height(), left.left.height())
819 return left
820 }
821
822 func max(a, b int8) int8 {
823 if a > b {
824 return a
825 }
826 return b
827 }
828
829 func (t *node32) copy() *node32 {
830 u := *t
831 return &u
832 }
833
View as plain text