1
2
3
4
5 package sync
6
7 import (
8 "internal/abi"
9 "internal/goarch"
10 "sync/atomic"
11 "unsafe"
12 )
13
14
15
16
17
18
19
20
21 type HashTrieMap[K comparable, V any] struct {
22 inited atomic.Uint32
23 initMu Mutex
24 root atomic.Pointer[indirect[K, V]]
25 keyHash hashFunc
26 valEqual equalFunc
27 seed uintptr
28 }
29
30 func (ht *HashTrieMap[K, V]) init() {
31 if ht.inited.Load() == 0 {
32 ht.initSlow()
33 }
34 }
35
36
37 func (ht *HashTrieMap[K, V]) initSlow() {
38 ht.initMu.Lock()
39 defer ht.initMu.Unlock()
40
41 if ht.inited.Load() != 0 {
42
43 return
44 }
45
46
47
48 var m map[K]V
49 mapType := abi.TypeOf(m).MapType()
50 ht.root.Store(newIndirectNode[K, V](nil))
51 ht.keyHash = mapType.Hasher
52 ht.valEqual = mapType.Elem.Equal
53 ht.seed = uintptr(runtime_rand())
54
55 ht.inited.Store(1)
56 }
57
58 type hashFunc func(unsafe.Pointer, uintptr) uintptr
59 type equalFunc func(unsafe.Pointer, unsafe.Pointer) bool
60
61
62
63
64 func (ht *HashTrieMap[K, V]) Load(key K) (value V, ok bool) {
65 ht.init()
66 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
67
68 i := ht.root.Load()
69 hashShift := 8 * goarch.PtrSize
70 for hashShift != 0 {
71 hashShift -= nChildrenLog2
72
73 n := i.children[(hash>>hashShift)&nChildrenMask].Load()
74 if n == nil {
75 return *new(V), false
76 }
77 if n.isEntry {
78 return n.entry().lookup(key)
79 }
80 i = n.indirect()
81 }
82 panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
83 }
84
85
86
87
88 func (ht *HashTrieMap[K, V]) LoadOrStore(key K, value V) (result V, loaded bool) {
89 ht.init()
90 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
91 var i *indirect[K, V]
92 var hashShift uint
93 var slot *atomic.Pointer[node[K, V]]
94 var n *node[K, V]
95 for {
96
97 i = ht.root.Load()
98 hashShift = 8 * goarch.PtrSize
99 haveInsertPoint := false
100 for hashShift != 0 {
101 hashShift -= nChildrenLog2
102
103 slot = &i.children[(hash>>hashShift)&nChildrenMask]
104 n = slot.Load()
105 if n == nil {
106
107 haveInsertPoint = true
108 break
109 }
110 if n.isEntry {
111
112
113
114 if v, ok := n.entry().lookup(key); ok {
115 return v, true
116 }
117 haveInsertPoint = true
118 break
119 }
120 i = n.indirect()
121 }
122 if !haveInsertPoint {
123 panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
124 }
125
126
127 i.mu.Lock()
128 n = slot.Load()
129 if (n == nil || n.isEntry) && !i.dead.Load() {
130
131 break
132 }
133
134 i.mu.Unlock()
135 }
136
137
138
139
140
141 defer i.mu.Unlock()
142
143 var oldEntry *entry[K, V]
144 if n != nil {
145 oldEntry = n.entry()
146 if v, ok := oldEntry.lookup(key); ok {
147
148 return v, true
149 }
150 }
151 newEntry := newEntryNode(key, value)
152 if oldEntry == nil {
153
154 slot.Store(&newEntry.node)
155 } else {
156
157
158
159
160 slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
161 }
162 return value, false
163 }
164
165
166
167 func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uintptr, hashShift uint, parent *indirect[K, V]) *node[K, V] {
168
169 oldHash := ht.keyHash(unsafe.Pointer(&oldEntry.key), ht.seed)
170 if oldHash == newHash {
171
172
173 newEntry.overflow.Store(oldEntry)
174 return &newEntry.node
175 }
176
177 newIndirect := newIndirectNode(parent)
178 top := newIndirect
179 for {
180 if hashShift == 0 {
181 panic("internal/sync.HashTrieMap: ran out of hash bits while inserting")
182 }
183 hashShift -= nChildrenLog2
184 oi := (oldHash >> hashShift) & nChildrenMask
185 ni := (newHash >> hashShift) & nChildrenMask
186 if oi != ni {
187 newIndirect.children[oi].Store(&oldEntry.node)
188 newIndirect.children[ni].Store(&newEntry.node)
189 break
190 }
191 nextIndirect := newIndirectNode(newIndirect)
192 newIndirect.children[oi].Store(&nextIndirect.node)
193 newIndirect = nextIndirect
194 }
195 return &top.node
196 }
197
198
199 func (ht *HashTrieMap[K, V]) Store(key K, old V) {
200 _, _ = ht.Swap(key, old)
201 }
202
203
204
205 func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
206 ht.init()
207 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
208 var i *indirect[K, V]
209 var hashShift uint
210 var slot *atomic.Pointer[node[K, V]]
211 var n *node[K, V]
212 for {
213
214 i = ht.root.Load()
215 hashShift = 8 * goarch.PtrSize
216 haveInsertPoint := false
217 for hashShift != 0 {
218 hashShift -= nChildrenLog2
219
220 slot = &i.children[(hash>>hashShift)&nChildrenMask]
221 n = slot.Load()
222 if n == nil || n.isEntry {
223
224
225 haveInsertPoint = true
226 break
227 }
228 i = n.indirect()
229 }
230 if !haveInsertPoint {
231 panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
232 }
233
234
235 i.mu.Lock()
236 n = slot.Load()
237 if (n == nil || n.isEntry) && !i.dead.Load() {
238
239 break
240 }
241
242 i.mu.Unlock()
243 }
244
245
246
247
248
249 defer i.mu.Unlock()
250
251 var zero V
252 var oldEntry *entry[K, V]
253 if n != nil {
254
255 oldEntry = n.entry()
256 newEntry, old, swapped := oldEntry.swap(key, new)
257 if swapped {
258 slot.Store(&newEntry.node)
259 return old, true
260 }
261 }
262
263 newEntry := newEntryNode(key, new)
264 if oldEntry == nil {
265
266 slot.Store(&newEntry.node)
267 } else {
268
269
270
271
272 slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
273 }
274 return zero, false
275 }
276
277
278
279
280 func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
281 ht.init()
282 if ht.valEqual == nil {
283 panic("called CompareAndSwap when value is not of comparable type")
284 }
285 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
286
287
288 i, _, slot, n := ht.find(key, hash, ht.valEqual, old)
289 if i != nil {
290 defer i.mu.Unlock()
291 }
292 if n == nil {
293 return false
294 }
295
296
297 e, swapped := n.entry().compareAndSwap(key, old, new, ht.valEqual)
298 if !swapped {
299
300 return false
301 }
302
303 slot.Store(&e.node)
304 return true
305 }
306
307
308
309 func (ht *HashTrieMap[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
310 ht.init()
311 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
312
313
314 i, hashShift, slot, n := ht.find(key, hash, nil, *new(V))
315 if n == nil {
316 if i != nil {
317 i.mu.Unlock()
318 }
319 return *new(V), false
320 }
321
322
323 v, e, loaded := n.entry().loadAndDelete(key)
324 if !loaded {
325
326 i.mu.Unlock()
327 return *new(V), false
328 }
329 if e != nil {
330
331
332 slot.Store(&e.node)
333 i.mu.Unlock()
334 return v, true
335 }
336
337 slot.Store(nil)
338
339
340 for i.parent != nil && i.empty() {
341 if hashShift == 8*goarch.PtrSize {
342 panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
343 }
344 hashShift += nChildrenLog2
345
346
347 parent := i.parent
348 parent.mu.Lock()
349 i.dead.Store(true)
350 parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
351 i.mu.Unlock()
352 i = parent
353 }
354 i.mu.Unlock()
355 return v, true
356 }
357
358
359 func (ht *HashTrieMap[K, V]) Delete(key K) {
360 _, _ = ht.LoadAndDelete(key)
361 }
362
363
364
365
366
367
368 func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
369 ht.init()
370 if ht.valEqual == nil {
371 panic("called CompareAndDelete when value is not of comparable type")
372 }
373 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
374
375
376 i, hashShift, slot, n := ht.find(key, hash, nil, *new(V))
377 if n == nil {
378 if i != nil {
379 i.mu.Unlock()
380 }
381 return false
382 }
383
384
385 e, deleted := n.entry().compareAndDelete(key, old, ht.valEqual)
386 if !deleted {
387
388 i.mu.Unlock()
389 return false
390 }
391 if e != nil {
392
393
394 slot.Store(&e.node)
395 i.mu.Unlock()
396 return true
397 }
398
399 slot.Store(nil)
400
401
402 for i.parent != nil && i.empty() {
403 if hashShift == 8*goarch.PtrSize {
404 panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
405 }
406 hashShift += nChildrenLog2
407
408
409 parent := i.parent
410 parent.mu.Lock()
411 i.dead.Store(true)
412 parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
413 i.mu.Unlock()
414 i = parent
415 }
416 i.mu.Unlock()
417 return true
418 }
419
420
421
422
423
424
425
426 func (ht *HashTrieMap[K, V]) find(key K, hash uintptr, valEqual equalFunc, value V) (i *indirect[K, V], hashShift uint, slot *atomic.Pointer[node[K, V]], n *node[K, V]) {
427 for {
428
429 i = ht.root.Load()
430 hashShift = 8 * goarch.PtrSize
431 found := false
432 for hashShift != 0 {
433 hashShift -= nChildrenLog2
434
435 slot = &i.children[(hash>>hashShift)&nChildrenMask]
436 n = slot.Load()
437 if n == nil {
438
439 i = nil
440 return
441 }
442 if n.isEntry {
443
444 if _, ok := n.entry().lookupWithValue(key, value, valEqual); !ok {
445
446 i = nil
447 n = nil
448 return
449 }
450
451 found = true
452 break
453 }
454 i = n.indirect()
455 }
456 if !found {
457 panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
458 }
459
460
461 i.mu.Lock()
462 n = slot.Load()
463 if !i.dead.Load() && (n == nil || n.isEntry) {
464
465
466 return
467 }
468
469 i.mu.Unlock()
470 }
471 }
472
473
474
475
476
477
478
479
480
481 func (ht *HashTrieMap[K, V]) All() func(yield func(K, V) bool) {
482 ht.init()
483 return func(yield func(key K, value V) bool) {
484 ht.iter(ht.root.Load(), yield)
485 }
486 }
487
488
489
490
491
492
493 func (ht *HashTrieMap[K, V]) Range(yield func(K, V) bool) {
494 ht.init()
495 ht.iter(ht.root.Load(), yield)
496 }
497
498 func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) bool) bool {
499 for j := range i.children {
500 n := i.children[j].Load()
501 if n == nil {
502 continue
503 }
504 if !n.isEntry {
505 if !ht.iter(n.indirect(), yield) {
506 return false
507 }
508 continue
509 }
510 e := n.entry()
511 for e != nil {
512 if !yield(e.key, e.value) {
513 return false
514 }
515 e = e.overflow.Load()
516 }
517 }
518 return true
519 }
520
521
522 func (ht *HashTrieMap[K, V]) Clear() {
523 ht.init()
524
525
526
527 ht.root.Store(newIndirectNode[K, V](nil))
528 }
529
530 const (
531
532
533
534
535 nChildrenLog2 = 4
536 nChildren = 1 << nChildrenLog2
537 nChildrenMask = nChildren - 1
538 )
539
540
541 type indirect[K comparable, V any] struct {
542 node[K, V]
543 dead atomic.Bool
544 mu Mutex
545 parent *indirect[K, V]
546 children [nChildren]atomic.Pointer[node[K, V]]
547 }
548
549 func newIndirectNode[K comparable, V any](parent *indirect[K, V]) *indirect[K, V] {
550 return &indirect[K, V]{node: node[K, V]{isEntry: false}, parent: parent}
551 }
552
553 func (i *indirect[K, V]) empty() bool {
554 nc := 0
555 for j := range i.children {
556 if i.children[j].Load() != nil {
557 nc++
558 }
559 }
560 return nc == 0
561 }
562
563
564 type entry[K comparable, V any] struct {
565 node[K, V]
566 overflow atomic.Pointer[entry[K, V]]
567 key K
568 value V
569 }
570
571 func newEntryNode[K comparable, V any](key K, value V) *entry[K, V] {
572 return &entry[K, V]{
573 node: node[K, V]{isEntry: true},
574 key: key,
575 value: value,
576 }
577 }
578
579 func (e *entry[K, V]) lookup(key K) (V, bool) {
580 for e != nil {
581 if e.key == key {
582 return e.value, true
583 }
584 e = e.overflow.Load()
585 }
586 return *new(V), false
587 }
588
589 func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) {
590 for e != nil {
591 if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value)))) {
592 return e.value, true
593 }
594 e = e.overflow.Load()
595 }
596 return *new(V), false
597 }
598
599
600
601
602
603 func (head *entry[K, V]) swap(key K, new V) (*entry[K, V], V, bool) {
604 if head.key == key {
605
606 e := newEntryNode(key, new)
607 if chain := head.overflow.Load(); chain != nil {
608 e.overflow.Store(chain)
609 }
610 return e, head.value, true
611 }
612 i := &head.overflow
613 e := i.Load()
614 for e != nil {
615 if e.key == key {
616 eNew := newEntryNode(key, new)
617 eNew.overflow.Store(e.overflow.Load())
618 i.Store(eNew)
619 return head, e.value, true
620 }
621 i = &e.overflow
622 e = e.overflow.Load()
623 }
624 var zero V
625 return head, zero, false
626 }
627
628
629
630
631
632 func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (*entry[K, V], bool) {
633 if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&old))) {
634
635 e := newEntryNode(key, new)
636 if chain := head.overflow.Load(); chain != nil {
637 e.overflow.Store(chain)
638 }
639 return e, true
640 }
641 i := &head.overflow
642 e := i.Load()
643 for e != nil {
644 if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&old))) {
645 eNew := newEntryNode(key, new)
646 eNew.overflow.Store(e.overflow.Load())
647 i.Store(eNew)
648 return head, true
649 }
650 i = &e.overflow
651 e = e.overflow.Load()
652 }
653 return head, false
654 }
655
656
657
658
659
660 func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
661 if head.key == key {
662
663 return head.value, head.overflow.Load(), true
664 }
665 i := &head.overflow
666 e := i.Load()
667 for e != nil {
668 if e.key == key {
669 i.Store(e.overflow.Load())
670 return e.value, head, true
671 }
672 i = &e.overflow
673 e = e.overflow.Load()
674 }
675 return *new(V), head, false
676 }
677
678
679
680
681
682 func (head *entry[K, V]) compareAndDelete(key K, value V, valEqual equalFunc) (*entry[K, V], bool) {
683 if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) {
684
685 return head.overflow.Load(), true
686 }
687 i := &head.overflow
688 e := i.Load()
689 for e != nil {
690 if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) {
691 i.Store(e.overflow.Load())
692 return head, true
693 }
694 i = &e.overflow
695 e = e.overflow.Load()
696 }
697 return head, false
698 }
699
700
701
702 type node[K comparable, V any] struct {
703 isEntry bool
704 }
705
706 func (n *node[K, V]) entry() *entry[K, V] {
707 if !n.isEntry {
708 panic("called entry on non-entry node")
709 }
710 return (*entry[K, V])(unsafe.Pointer(n))
711 }
712
713 func (n *node[K, V]) indirect() *indirect[K, V] {
714 if n.isEntry {
715 panic("called indirect on entry node")
716 }
717 return (*indirect[K, V])(unsafe.Pointer(n))
718 }
719
720
721
722
723
724 func runtime_rand() uint64
725
View as plain text