Source file src/internal/concurrent/hashtriemap.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package concurrent
     6  
     7  import (
     8  	"internal/abi"
     9  	"internal/goarch"
    10  	"math/rand/v2"
    11  	"sync"
    12  	"sync/atomic"
    13  	"unsafe"
    14  )
    15  
    16  // HashTrieMap is an implementation of a concurrent hash-trie. The implementation
    17  // is designed around frequent loads, but offers decent performance for stores
    18  // and deletes as well, especially if the map is larger. Its primary use-case is
    19  // the unique package, but can be used elsewhere as well.
    20  type HashTrieMap[K, V comparable] struct {
    21  	root     *indirect[K, V]
    22  	keyHash  hashFunc
    23  	keyEqual equalFunc
    24  	valEqual equalFunc
    25  	seed     uintptr
    26  }
    27  
    28  // NewHashTrieMap creates a new HashTrieMap for the provided key and value.
    29  func NewHashTrieMap[K, V comparable]() *HashTrieMap[K, V] {
    30  	var m map[K]V
    31  	mapType := abi.TypeOf(m).MapType()
    32  	ht := &HashTrieMap[K, V]{
    33  		root:     newIndirectNode[K, V](nil),
    34  		keyHash:  mapType.Hasher,
    35  		keyEqual: mapType.Key.Equal,
    36  		valEqual: mapType.Elem.Equal,
    37  		seed:     uintptr(rand.Uint64()),
    38  	}
    39  	return ht
    40  }
    41  
    42  type hashFunc func(unsafe.Pointer, uintptr) uintptr
    43  type equalFunc func(unsafe.Pointer, unsafe.Pointer) bool
    44  
    45  // Load returns the value stored in the map for a key, or nil if no
    46  // value is present.
    47  // The ok result indicates whether value was found in the map.
    48  func (ht *HashTrieMap[K, V]) Load(key K) (value V, ok bool) {
    49  	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
    50  
    51  	i := ht.root
    52  	hashShift := 8 * goarch.PtrSize
    53  	for hashShift != 0 {
    54  		hashShift -= nChildrenLog2
    55  
    56  		n := i.children[(hash>>hashShift)&nChildrenMask].Load()
    57  		if n == nil {
    58  			return *new(V), false
    59  		}
    60  		if n.isEntry {
    61  			return n.entry().lookup(key, ht.keyEqual)
    62  		}
    63  		i = n.indirect()
    64  	}
    65  	panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
    66  }
    67  
    68  // LoadOrStore returns the existing value for the key if present.
    69  // Otherwise, it stores and returns the given value.
    70  // The loaded result is true if the value was loaded, false if stored.
    71  func (ht *HashTrieMap[K, V]) LoadOrStore(key K, value V) (result V, loaded bool) {
    72  	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
    73  	var i *indirect[K, V]
    74  	var hashShift uint
    75  	var slot *atomic.Pointer[node[K, V]]
    76  	var n *node[K, V]
    77  	for {
    78  		// Find the key or a candidate location for insertion.
    79  		i = ht.root
    80  		hashShift = 8 * goarch.PtrSize
    81  		haveInsertPoint := false
    82  		for hashShift != 0 {
    83  			hashShift -= nChildrenLog2
    84  
    85  			slot = &i.children[(hash>>hashShift)&nChildrenMask]
    86  			n = slot.Load()
    87  			if n == nil {
    88  				// We found a nil slot which is a candidate for insertion.
    89  				haveInsertPoint = true
    90  				break
    91  			}
    92  			if n.isEntry {
    93  				// We found an existing entry, which is as far as we can go.
    94  				// If it stays this way, we'll have to replace it with an
    95  				// indirect node.
    96  				if v, ok := n.entry().lookup(key, ht.keyEqual); ok {
    97  					return v, true
    98  				}
    99  				haveInsertPoint = true
   100  				break
   101  			}
   102  			i = n.indirect()
   103  		}
   104  		if !haveInsertPoint {
   105  			panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
   106  		}
   107  
   108  		// Grab the lock and double-check what we saw.
   109  		i.mu.Lock()
   110  		n = slot.Load()
   111  		if (n == nil || n.isEntry) && !i.dead.Load() {
   112  			// What we saw is still true, so we can continue with the insert.
   113  			break
   114  		}
   115  		// We have to start over.
   116  		i.mu.Unlock()
   117  	}
   118  	// N.B. This lock is held from when we broke out of the outer loop above.
   119  	// We specifically break this out so that we can use defer here safely.
   120  	// One option is to break this out into a new function instead, but
   121  	// there's so much local iteration state used below that this turns out
   122  	// to be cleaner.
   123  	defer i.mu.Unlock()
   124  
   125  	var oldEntry *entry[K, V]
   126  	if n != nil {
   127  		oldEntry = n.entry()
   128  		if v, ok := oldEntry.lookup(key, ht.keyEqual); ok {
   129  			// Easy case: by loading again, it turns out exactly what we wanted is here!
   130  			return v, true
   131  		}
   132  	}
   133  	newEntry := newEntryNode(key, value)
   134  	if oldEntry == nil {
   135  		// Easy case: create a new entry and store it.
   136  		slot.Store(&newEntry.node)
   137  	} else {
   138  		// We possibly need to expand the entry already there into one or more new nodes.
   139  		//
   140  		// Publish the node last, which will make both oldEntry and newEntry visible. We
   141  		// don't want readers to be able to observe that oldEntry isn't in the tree.
   142  		slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
   143  	}
   144  	return value, false
   145  }
   146  
   147  // expand takes oldEntry and newEntry whose hashes conflict from bit 64 down to hashShift and
   148  // produces a subtree of indirect nodes to hold the two new entries.
   149  func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uintptr, hashShift uint, parent *indirect[K, V]) *node[K, V] {
   150  	// Check for a hash collision.
   151  	oldHash := ht.keyHash(unsafe.Pointer(&oldEntry.key), ht.seed)
   152  	if oldHash == newHash {
   153  		// Store the old entry in the new entry's overflow list, then store
   154  		// the new entry.
   155  		newEntry.overflow.Store(oldEntry)
   156  		return &newEntry.node
   157  	}
   158  	// We have to add an indirect node. Worse still, we may need to add more than one.
   159  	newIndirect := newIndirectNode(parent)
   160  	top := newIndirect
   161  	for {
   162  		if hashShift == 0 {
   163  			panic("internal/concurrent.HashMapTrie: ran out of hash bits while inserting")
   164  		}
   165  		hashShift -= nChildrenLog2 // hashShift is for the level parent is at. We need to go deeper.
   166  		oi := (oldHash >> hashShift) & nChildrenMask
   167  		ni := (newHash >> hashShift) & nChildrenMask
   168  		if oi != ni {
   169  			newIndirect.children[oi].Store(&oldEntry.node)
   170  			newIndirect.children[ni].Store(&newEntry.node)
   171  			break
   172  		}
   173  		nextIndirect := newIndirectNode(newIndirect)
   174  		newIndirect.children[oi].Store(&nextIndirect.node)
   175  		newIndirect = nextIndirect
   176  	}
   177  	return &top.node
   178  }
   179  
   180  // CompareAndDelete deletes the entry for key if its value is equal to old.
   181  //
   182  // If there is no current value for key in the map, CompareAndDelete returns false
   183  // (even if the old value is the nil interface value).
   184  func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
   185  	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
   186  	var i *indirect[K, V]
   187  	var hashShift uint
   188  	var slot *atomic.Pointer[node[K, V]]
   189  	var n *node[K, V]
   190  	for {
   191  		// Find the key or return when there's nothing to delete.
   192  		i = ht.root
   193  		hashShift = 8 * goarch.PtrSize
   194  		found := false
   195  		for hashShift != 0 {
   196  			hashShift -= nChildrenLog2
   197  
   198  			slot = &i.children[(hash>>hashShift)&nChildrenMask]
   199  			n = slot.Load()
   200  			if n == nil {
   201  				// Nothing to delete. Give up.
   202  				return
   203  			}
   204  			if n.isEntry {
   205  				// We found an entry. Check if it matches.
   206  				if _, ok := n.entry().lookup(key, ht.keyEqual); !ok {
   207  					// No match, nothing to delete.
   208  					return
   209  				}
   210  				// We've got something to delete.
   211  				found = true
   212  				break
   213  			}
   214  			i = n.indirect()
   215  		}
   216  		if !found {
   217  			panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
   218  		}
   219  
   220  		// Grab the lock and double-check what we saw.
   221  		i.mu.Lock()
   222  		n = slot.Load()
   223  		if !i.dead.Load() {
   224  			if n == nil {
   225  				// Valid node that doesn't contain what we need. Nothing to delete.
   226  				i.mu.Unlock()
   227  				return
   228  			}
   229  			if n.isEntry {
   230  				// What we saw is still true, so we can continue with the delete.
   231  				break
   232  			}
   233  		}
   234  		// We have to start over.
   235  		i.mu.Unlock()
   236  	}
   237  	// Try to delete the entry.
   238  	e, deleted := n.entry().compareAndDelete(key, old, ht.keyEqual, ht.valEqual)
   239  	if !deleted {
   240  		// Nothing was actually deleted, which means the node is no longer there.
   241  		i.mu.Unlock()
   242  		return false
   243  	}
   244  	if e != nil {
   245  		// We didn't actually delete the whole entry, just one entry in the chain.
   246  		// Nothing else to do, since the parent is definitely not empty.
   247  		slot.Store(&e.node)
   248  		i.mu.Unlock()
   249  		return true
   250  	}
   251  	// Delete the entry.
   252  	slot.Store(nil)
   253  
   254  	// Check if the node is now empty (and isn't the root), and delete it if able.
   255  	for i.parent != nil && i.empty() {
   256  		if hashShift == 8*goarch.PtrSize {
   257  			panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
   258  		}
   259  		hashShift += nChildrenLog2
   260  
   261  		// Delete the current node in the parent.
   262  		parent := i.parent
   263  		parent.mu.Lock()
   264  		i.dead.Store(true)
   265  		parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
   266  		i.mu.Unlock()
   267  		i = parent
   268  	}
   269  	i.mu.Unlock()
   270  	return true
   271  }
   272  
   273  // All returns an iter.Seq2 that produces all key-value pairs in the map.
   274  // The enumeration does not represent any consistent snapshot of the map,
   275  // but is guaranteed to visit each unique key-value pair only once. It is
   276  // safe to operate on the tree during iteration. No particular enumeration
   277  // order is guaranteed.
   278  func (ht *HashTrieMap[K, V]) All() func(yield func(K, V) bool) {
   279  	return func(yield func(key K, value V) bool) {
   280  		ht.iter(ht.root, yield)
   281  	}
   282  }
   283  
   284  func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) bool) bool {
   285  	for j := range i.children {
   286  		n := i.children[j].Load()
   287  		if n == nil {
   288  			continue
   289  		}
   290  		if !n.isEntry {
   291  			if !ht.iter(n.indirect(), yield) {
   292  				return false
   293  			}
   294  			continue
   295  		}
   296  		e := n.entry()
   297  		for e != nil {
   298  			if !yield(e.key, e.value) {
   299  				return false
   300  			}
   301  			e = e.overflow.Load()
   302  		}
   303  	}
   304  	return true
   305  }
   306  
   307  const (
   308  	// 16 children. This seems to be the sweet spot for
   309  	// load performance: any smaller and we lose out on
   310  	// 50% or more in CPU performance. Any larger and the
   311  	// returns are minuscule (~1% improvement for 32 children).
   312  	nChildrenLog2 = 4
   313  	nChildren     = 1 << nChildrenLog2
   314  	nChildrenMask = nChildren - 1
   315  )
   316  
   317  // indirect is an internal node in the hash-trie.
   318  type indirect[K, V comparable] struct {
   319  	node[K, V]
   320  	dead     atomic.Bool
   321  	mu       sync.Mutex // Protects mutation to children and any children that are entry nodes.
   322  	parent   *indirect[K, V]
   323  	children [nChildren]atomic.Pointer[node[K, V]]
   324  }
   325  
   326  func newIndirectNode[K, V comparable](parent *indirect[K, V]) *indirect[K, V] {
   327  	return &indirect[K, V]{node: node[K, V]{isEntry: false}, parent: parent}
   328  }
   329  
   330  func (i *indirect[K, V]) empty() bool {
   331  	nc := 0
   332  	for j := range i.children {
   333  		if i.children[j].Load() != nil {
   334  			nc++
   335  		}
   336  	}
   337  	return nc == 0
   338  }
   339  
   340  // entry is a leaf node in the hash-trie.
   341  type entry[K, V comparable] struct {
   342  	node[K, V]
   343  	overflow atomic.Pointer[entry[K, V]] // Overflow for hash collisions.
   344  	key      K
   345  	value    V
   346  }
   347  
   348  func newEntryNode[K, V comparable](key K, value V) *entry[K, V] {
   349  	return &entry[K, V]{
   350  		node:  node[K, V]{isEntry: true},
   351  		key:   key,
   352  		value: value,
   353  	}
   354  }
   355  
   356  func (e *entry[K, V]) lookup(key K, equal equalFunc) (V, bool) {
   357  	for e != nil {
   358  		if equal(unsafe.Pointer(&e.key), abi.NoEscape(unsafe.Pointer(&key))) {
   359  			return e.value, true
   360  		}
   361  		e = e.overflow.Load()
   362  	}
   363  	return *new(V), false
   364  }
   365  
   366  // compareAndDelete deletes an entry in the overflow chain if both the key and value compare
   367  // equal. Returns the new entry chain and whether or not anything was deleted.
   368  //
   369  // compareAndDelete must be called under the mutex of the indirect node which e is a child of.
   370  func (head *entry[K, V]) compareAndDelete(key K, value V, keyEqual, valEqual equalFunc) (*entry[K, V], bool) {
   371  	if keyEqual(unsafe.Pointer(&head.key), abi.NoEscape(unsafe.Pointer(&key))) &&
   372  		valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) {
   373  		// Drop the head of the list.
   374  		return head.overflow.Load(), true
   375  	}
   376  	i := &head.overflow
   377  	e := i.Load()
   378  	for e != nil {
   379  		if keyEqual(unsafe.Pointer(&e.key), abi.NoEscape(unsafe.Pointer(&key))) &&
   380  			valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) {
   381  			i.Store(e.overflow.Load())
   382  			return head, true
   383  		}
   384  		i = &e.overflow
   385  		e = e.overflow.Load()
   386  	}
   387  	return head, false
   388  }
   389  
   390  // node is the header for a node. It's polymorphic and
   391  // is actually either an entry or an indirect.
   392  type node[K, V comparable] struct {
   393  	isEntry bool
   394  }
   395  
   396  func (n *node[K, V]) entry() *entry[K, V] {
   397  	if !n.isEntry {
   398  		panic("called entry on non-entry node")
   399  	}
   400  	return (*entry[K, V])(unsafe.Pointer(n))
   401  }
   402  
   403  func (n *node[K, V]) indirect() *indirect[K, V] {
   404  	if n.isEntry {
   405  		panic("called indirect on entry node")
   406  	}
   407  	return (*indirect[K, V])(unsafe.Pointer(n))
   408  }
   409  

View as plain text