1
2
3
4
5 package concurrent
6
7 import (
8 "internal/abi"
9 "internal/goarch"
10 "math/rand/v2"
11 "sync"
12 "sync/atomic"
13 "unsafe"
14 )
15
16
17
18
19
20 type HashTrieMap[K, V comparable] struct {
21 root *indirect[K, V]
22 keyHash hashFunc
23 keyEqual equalFunc
24 valEqual equalFunc
25 seed uintptr
26 }
27
28
29 func NewHashTrieMap[K, V comparable]() *HashTrieMap[K, V] {
30 var m map[K]V
31 mapType := abi.TypeOf(m).MapType()
32 ht := &HashTrieMap[K, V]{
33 root: newIndirectNode[K, V](nil),
34 keyHash: mapType.Hasher,
35 keyEqual: mapType.Key.Equal,
36 valEqual: mapType.Elem.Equal,
37 seed: uintptr(rand.Uint64()),
38 }
39 return ht
40 }
41
42 type hashFunc func(unsafe.Pointer, uintptr) uintptr
43 type equalFunc func(unsafe.Pointer, unsafe.Pointer) bool
44
45
46
47
48 func (ht *HashTrieMap[K, V]) Load(key K) (value V, ok bool) {
49 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
50
51 i := ht.root
52 hashShift := 8 * goarch.PtrSize
53 for hashShift != 0 {
54 hashShift -= nChildrenLog2
55
56 n := i.children[(hash>>hashShift)&nChildrenMask].Load()
57 if n == nil {
58 return *new(V), false
59 }
60 if n.isEntry {
61 return n.entry().lookup(key, ht.keyEqual)
62 }
63 i = n.indirect()
64 }
65 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
66 }
67
68
69
70
71 func (ht *HashTrieMap[K, V]) LoadOrStore(key K, value V) (result V, loaded bool) {
72 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
73 var i *indirect[K, V]
74 var hashShift uint
75 var slot *atomic.Pointer[node[K, V]]
76 var n *node[K, V]
77 for {
78
79 i = ht.root
80 hashShift = 8 * goarch.PtrSize
81 haveInsertPoint := false
82 for hashShift != 0 {
83 hashShift -= nChildrenLog2
84
85 slot = &i.children[(hash>>hashShift)&nChildrenMask]
86 n = slot.Load()
87 if n == nil {
88
89 haveInsertPoint = true
90 break
91 }
92 if n.isEntry {
93
94
95
96 if v, ok := n.entry().lookup(key, ht.keyEqual); ok {
97 return v, true
98 }
99 haveInsertPoint = true
100 break
101 }
102 i = n.indirect()
103 }
104 if !haveInsertPoint {
105 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
106 }
107
108
109 i.mu.Lock()
110 n = slot.Load()
111 if (n == nil || n.isEntry) && !i.dead.Load() {
112
113 break
114 }
115
116 i.mu.Unlock()
117 }
118
119
120
121
122
123 defer i.mu.Unlock()
124
125 var oldEntry *entry[K, V]
126 if n != nil {
127 oldEntry = n.entry()
128 if v, ok := oldEntry.lookup(key, ht.keyEqual); ok {
129
130 return v, true
131 }
132 }
133 newEntry := newEntryNode(key, value)
134 if oldEntry == nil {
135
136 slot.Store(&newEntry.node)
137 } else {
138
139
140
141
142 slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
143 }
144 return value, false
145 }
146
147
148
149 func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uintptr, hashShift uint, parent *indirect[K, V]) *node[K, V] {
150
151 oldHash := ht.keyHash(unsafe.Pointer(&oldEntry.key), ht.seed)
152 if oldHash == newHash {
153
154
155 newEntry.overflow.Store(oldEntry)
156 return &newEntry.node
157 }
158
159 newIndirect := newIndirectNode(parent)
160 top := newIndirect
161 for {
162 if hashShift == 0 {
163 panic("internal/concurrent.HashMapTrie: ran out of hash bits while inserting")
164 }
165 hashShift -= nChildrenLog2
166 oi := (oldHash >> hashShift) & nChildrenMask
167 ni := (newHash >> hashShift) & nChildrenMask
168 if oi != ni {
169 newIndirect.children[oi].Store(&oldEntry.node)
170 newIndirect.children[ni].Store(&newEntry.node)
171 break
172 }
173 nextIndirect := newIndirectNode(newIndirect)
174 newIndirect.children[oi].Store(&nextIndirect.node)
175 newIndirect = nextIndirect
176 }
177 return &top.node
178 }
179
180
181
182
183
184 func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
185 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
186 var i *indirect[K, V]
187 var hashShift uint
188 var slot *atomic.Pointer[node[K, V]]
189 var n *node[K, V]
190 for {
191
192 i = ht.root
193 hashShift = 8 * goarch.PtrSize
194 found := false
195 for hashShift != 0 {
196 hashShift -= nChildrenLog2
197
198 slot = &i.children[(hash>>hashShift)&nChildrenMask]
199 n = slot.Load()
200 if n == nil {
201
202 return
203 }
204 if n.isEntry {
205
206 if _, ok := n.entry().lookup(key, ht.keyEqual); !ok {
207
208 return
209 }
210
211 found = true
212 break
213 }
214 i = n.indirect()
215 }
216 if !found {
217 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
218 }
219
220
221 i.mu.Lock()
222 n = slot.Load()
223 if !i.dead.Load() {
224 if n == nil {
225
226 i.mu.Unlock()
227 return
228 }
229 if n.isEntry {
230
231 break
232 }
233 }
234
235 i.mu.Unlock()
236 }
237
238 e, deleted := n.entry().compareAndDelete(key, old, ht.keyEqual, ht.valEqual)
239 if !deleted {
240
241 i.mu.Unlock()
242 return false
243 }
244 if e != nil {
245
246
247 slot.Store(&e.node)
248 i.mu.Unlock()
249 return true
250 }
251
252 slot.Store(nil)
253
254
255 for i.parent != nil && i.empty() {
256 if hashShift == 8*goarch.PtrSize {
257 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
258 }
259 hashShift += nChildrenLog2
260
261
262 parent := i.parent
263 parent.mu.Lock()
264 i.dead.Store(true)
265 parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
266 i.mu.Unlock()
267 i = parent
268 }
269 i.mu.Unlock()
270 return true
271 }
272
273
274
275
276
277
278 func (ht *HashTrieMap[K, V]) All() func(yield func(K, V) bool) {
279 return func(yield func(key K, value V) bool) {
280 ht.iter(ht.root, yield)
281 }
282 }
283
284 func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) bool) bool {
285 for j := range i.children {
286 n := i.children[j].Load()
287 if n == nil {
288 continue
289 }
290 if !n.isEntry {
291 if !ht.iter(n.indirect(), yield) {
292 return false
293 }
294 continue
295 }
296 e := n.entry()
297 for e != nil {
298 if !yield(e.key, e.value) {
299 return false
300 }
301 e = e.overflow.Load()
302 }
303 }
304 return true
305 }
306
307 const (
308
309
310
311
312 nChildrenLog2 = 4
313 nChildren = 1 << nChildrenLog2
314 nChildrenMask = nChildren - 1
315 )
316
317
318 type indirect[K, V comparable] struct {
319 node[K, V]
320 dead atomic.Bool
321 mu sync.Mutex
322 parent *indirect[K, V]
323 children [nChildren]atomic.Pointer[node[K, V]]
324 }
325
326 func newIndirectNode[K, V comparable](parent *indirect[K, V]) *indirect[K, V] {
327 return &indirect[K, V]{node: node[K, V]{isEntry: false}, parent: parent}
328 }
329
330 func (i *indirect[K, V]) empty() bool {
331 nc := 0
332 for j := range i.children {
333 if i.children[j].Load() != nil {
334 nc++
335 }
336 }
337 return nc == 0
338 }
339
340
341 type entry[K, V comparable] struct {
342 node[K, V]
343 overflow atomic.Pointer[entry[K, V]]
344 key K
345 value V
346 }
347
348 func newEntryNode[K, V comparable](key K, value V) *entry[K, V] {
349 return &entry[K, V]{
350 node: node[K, V]{isEntry: true},
351 key: key,
352 value: value,
353 }
354 }
355
356 func (e *entry[K, V]) lookup(key K, equal equalFunc) (V, bool) {
357 for e != nil {
358 if equal(unsafe.Pointer(&e.key), abi.NoEscape(unsafe.Pointer(&key))) {
359 return e.value, true
360 }
361 e = e.overflow.Load()
362 }
363 return *new(V), false
364 }
365
366
367
368
369
370 func (head *entry[K, V]) compareAndDelete(key K, value V, keyEqual, valEqual equalFunc) (*entry[K, V], bool) {
371 if keyEqual(unsafe.Pointer(&head.key), abi.NoEscape(unsafe.Pointer(&key))) &&
372 valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) {
373
374 return head.overflow.Load(), true
375 }
376 i := &head.overflow
377 e := i.Load()
378 for e != nil {
379 if keyEqual(unsafe.Pointer(&e.key), abi.NoEscape(unsafe.Pointer(&key))) &&
380 valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) {
381 i.Store(e.overflow.Load())
382 return head, true
383 }
384 i = &e.overflow
385 e = e.overflow.Load()
386 }
387 return head, false
388 }
389
390
391
392 type node[K, V comparable] struct {
393 isEntry bool
394 }
395
396 func (n *node[K, V]) entry() *entry[K, V] {
397 if !n.isEntry {
398 panic("called entry on non-entry node")
399 }
400 return (*entry[K, V])(unsafe.Pointer(n))
401 }
402
403 func (n *node[K, V]) indirect() *indirect[K, V] {
404 if n.isEntry {
405 panic("called indirect on entry node")
406 }
407 return (*indirect[K, V])(unsafe.Pointer(n))
408 }
409
View as plain text