Source file src/internal/concurrent/hashtriemap_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package concurrent
     6  
     7  import (
     8  	"fmt"
     9  	"internal/abi"
    10  	"math"
    11  	"runtime"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  	"unsafe"
    17  )
    18  
    19  func TestHashTrieMap(t *testing.T) {
    20  	testHashTrieMap(t, func() *HashTrieMap[string, int] {
    21  		return NewHashTrieMap[string, int]()
    22  	})
    23  }
    24  
    25  func TestHashTrieMapBadHash(t *testing.T) {
    26  	testHashTrieMap(t, func() *HashTrieMap[string, int] {
    27  		// Stub out the good hash function with a terrible one.
    28  		// Everything should still work as expected.
    29  		m := NewHashTrieMap[string, int]()
    30  		m.keyHash = func(_ unsafe.Pointer, _ uintptr) uintptr {
    31  			return 0
    32  		}
    33  		return m
    34  	})
    35  }
    36  
    37  func TestHashTrieMapTruncHash(t *testing.T) {
    38  	testHashTrieMap(t, func() *HashTrieMap[string, int] {
    39  		// Stub out the good hash function with a different terrible one
    40  		// (truncated hash). Everything should still work as expected.
    41  		// This is useful to test independently to catch issues with
    42  		// near collisions, where only the last few bits of the hash differ.
    43  		m := NewHashTrieMap[string, int]()
    44  		var mx map[string]int
    45  		mapType := abi.TypeOf(mx).MapType()
    46  		hasher := mapType.Hasher
    47  		m.keyHash = func(p unsafe.Pointer, n uintptr) uintptr {
    48  			return hasher(p, n) & ((uintptr(1) << 4) - 1)
    49  		}
    50  		return m
    51  	})
    52  }
    53  
    54  func testHashTrieMap(t *testing.T, newMap func() *HashTrieMap[string, int]) {
    55  	t.Run("LoadEmpty", func(t *testing.T) {
    56  		m := newMap()
    57  
    58  		for _, s := range testData {
    59  			expectMissing(t, s, 0)(m.Load(s))
    60  		}
    61  	})
    62  	t.Run("LoadOrStore", func(t *testing.T) {
    63  		m := newMap()
    64  
    65  		for i, s := range testData {
    66  			expectMissing(t, s, 0)(m.Load(s))
    67  			expectStored(t, s, i)(m.LoadOrStore(s, i))
    68  			expectPresent(t, s, i)(m.Load(s))
    69  			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
    70  		}
    71  		for i, s := range testData {
    72  			expectPresent(t, s, i)(m.Load(s))
    73  			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
    74  		}
    75  	})
    76  	t.Run("CompareAndDeleteAll", func(t *testing.T) {
    77  		m := newMap()
    78  
    79  		for range 3 {
    80  			for i, s := range testData {
    81  				expectMissing(t, s, 0)(m.Load(s))
    82  				expectStored(t, s, i)(m.LoadOrStore(s, i))
    83  				expectPresent(t, s, i)(m.Load(s))
    84  				expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
    85  			}
    86  			for i, s := range testData {
    87  				expectPresent(t, s, i)(m.Load(s))
    88  				expectNotDeleted(t, s, math.MaxInt)(m.CompareAndDelete(s, math.MaxInt))
    89  				expectDeleted(t, s, i)(m.CompareAndDelete(s, i))
    90  				expectNotDeleted(t, s, i)(m.CompareAndDelete(s, i))
    91  				expectMissing(t, s, 0)(m.Load(s))
    92  			}
    93  			for _, s := range testData {
    94  				expectMissing(t, s, 0)(m.Load(s))
    95  			}
    96  		}
    97  	})
    98  	t.Run("CompareAndDeleteOne", func(t *testing.T) {
    99  		m := newMap()
   100  
   101  		for i, s := range testData {
   102  			expectMissing(t, s, 0)(m.Load(s))
   103  			expectStored(t, s, i)(m.LoadOrStore(s, i))
   104  			expectPresent(t, s, i)(m.Load(s))
   105  			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
   106  		}
   107  		expectNotDeleted(t, testData[15], math.MaxInt)(m.CompareAndDelete(testData[15], math.MaxInt))
   108  		expectDeleted(t, testData[15], 15)(m.CompareAndDelete(testData[15], 15))
   109  		expectNotDeleted(t, testData[15], 15)(m.CompareAndDelete(testData[15], 15))
   110  		for i, s := range testData {
   111  			if i == 15 {
   112  				expectMissing(t, s, 0)(m.Load(s))
   113  			} else {
   114  				expectPresent(t, s, i)(m.Load(s))
   115  			}
   116  		}
   117  	})
   118  	t.Run("DeleteMultiple", func(t *testing.T) {
   119  		m := newMap()
   120  
   121  		for i, s := range testData {
   122  			expectMissing(t, s, 0)(m.Load(s))
   123  			expectStored(t, s, i)(m.LoadOrStore(s, i))
   124  			expectPresent(t, s, i)(m.Load(s))
   125  			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
   126  		}
   127  		for _, i := range []int{1, 105, 6, 85} {
   128  			expectNotDeleted(t, testData[i], math.MaxInt)(m.CompareAndDelete(testData[i], math.MaxInt))
   129  			expectDeleted(t, testData[i], i)(m.CompareAndDelete(testData[i], i))
   130  			expectNotDeleted(t, testData[i], i)(m.CompareAndDelete(testData[i], i))
   131  		}
   132  		for i, s := range testData {
   133  			if i == 1 || i == 105 || i == 6 || i == 85 {
   134  				expectMissing(t, s, 0)(m.Load(s))
   135  			} else {
   136  				expectPresent(t, s, i)(m.Load(s))
   137  			}
   138  		}
   139  	})
   140  	t.Run("All", func(t *testing.T) {
   141  		m := newMap()
   142  
   143  		testAll(t, m, testDataMap(testData[:]), func(_ string, _ int) bool {
   144  			return true
   145  		})
   146  	})
   147  	t.Run("AllDelete", func(t *testing.T) {
   148  		m := newMap()
   149  
   150  		testAll(t, m, testDataMap(testData[:]), func(s string, i int) bool {
   151  			expectDeleted(t, s, i)(m.CompareAndDelete(s, i))
   152  			return true
   153  		})
   154  		for _, s := range testData {
   155  			expectMissing(t, s, 0)(m.Load(s))
   156  		}
   157  	})
   158  	t.Run("ConcurrentLifecycleUnsharedKeys", func(t *testing.T) {
   159  		m := newMap()
   160  
   161  		gmp := runtime.GOMAXPROCS(-1)
   162  		var wg sync.WaitGroup
   163  		for i := range gmp {
   164  			wg.Add(1)
   165  			go func(id int) {
   166  				defer wg.Done()
   167  
   168  				makeKey := func(s string) string {
   169  					return s + "-" + strconv.Itoa(id)
   170  				}
   171  				for _, s := range testData {
   172  					key := makeKey(s)
   173  					expectMissing(t, key, 0)(m.Load(key))
   174  					expectStored(t, key, id)(m.LoadOrStore(key, id))
   175  					expectPresent(t, key, id)(m.Load(key))
   176  					expectLoaded(t, key, id)(m.LoadOrStore(key, 0))
   177  				}
   178  				for _, s := range testData {
   179  					key := makeKey(s)
   180  					expectPresent(t, key, id)(m.Load(key))
   181  					expectDeleted(t, key, id)(m.CompareAndDelete(key, id))
   182  					expectMissing(t, key, 0)(m.Load(key))
   183  				}
   184  				for _, s := range testData {
   185  					key := makeKey(s)
   186  					expectMissing(t, key, 0)(m.Load(key))
   187  				}
   188  			}(i)
   189  		}
   190  		wg.Wait()
   191  	})
   192  	t.Run("ConcurrentDeleteSharedKeys", func(t *testing.T) {
   193  		m := newMap()
   194  
   195  		// Load up the map.
   196  		for i, s := range testData {
   197  			expectMissing(t, s, 0)(m.Load(s))
   198  			expectStored(t, s, i)(m.LoadOrStore(s, i))
   199  		}
   200  		gmp := runtime.GOMAXPROCS(-1)
   201  		var wg sync.WaitGroup
   202  		for i := range gmp {
   203  			wg.Add(1)
   204  			go func(id int) {
   205  				defer wg.Done()
   206  
   207  				for i, s := range testData {
   208  					expectNotDeleted(t, s, math.MaxInt)(m.CompareAndDelete(s, math.MaxInt))
   209  					m.CompareAndDelete(s, i)
   210  					expectMissing(t, s, 0)(m.Load(s))
   211  				}
   212  				for _, s := range testData {
   213  					expectMissing(t, s, 0)(m.Load(s))
   214  				}
   215  			}(i)
   216  		}
   217  		wg.Wait()
   218  	})
   219  }
   220  
   221  func testAll[K, V comparable](t *testing.T, m *HashTrieMap[K, V], testData map[K]V, yield func(K, V) bool) {
   222  	for k, v := range testData {
   223  		expectStored(t, k, v)(m.LoadOrStore(k, v))
   224  	}
   225  	visited := make(map[K]int)
   226  	m.All()(func(key K, got V) bool {
   227  		want, ok := testData[key]
   228  		if !ok {
   229  			t.Errorf("unexpected key %v in map", key)
   230  			return false
   231  		}
   232  		if got != want {
   233  			t.Errorf("expected key %v to have value %v, got %v", key, want, got)
   234  			return false
   235  		}
   236  		visited[key]++
   237  		return yield(key, got)
   238  	})
   239  	for key, n := range visited {
   240  		if n > 1 {
   241  			t.Errorf("visited key %v more than once", key)
   242  		}
   243  	}
   244  }
   245  
   246  func expectPresent[K, V comparable](t *testing.T, key K, want V) func(got V, ok bool) {
   247  	t.Helper()
   248  	return func(got V, ok bool) {
   249  		t.Helper()
   250  
   251  		if !ok {
   252  			t.Errorf("expected key %v to be present in map", key)
   253  		}
   254  		if ok && got != want {
   255  			t.Errorf("expected key %v to have value %v, got %v", key, want, got)
   256  		}
   257  	}
   258  }
   259  
   260  func expectMissing[K, V comparable](t *testing.T, key K, want V) func(got V, ok bool) {
   261  	t.Helper()
   262  	if want != *new(V) {
   263  		// This is awkward, but the want argument is necessary to smooth over type inference.
   264  		// Just make sure the want argument always looks the same.
   265  		panic("expectMissing must always have a zero value variable")
   266  	}
   267  	return func(got V, ok bool) {
   268  		t.Helper()
   269  
   270  		if ok {
   271  			t.Errorf("expected key %v to be missing from map, got value %v", key, got)
   272  		}
   273  		if !ok && got != want {
   274  			t.Errorf("expected missing key %v to be paired with the zero value; got %v", key, got)
   275  		}
   276  	}
   277  }
   278  
   279  func expectLoaded[K, V comparable](t *testing.T, key K, want V) func(got V, loaded bool) {
   280  	t.Helper()
   281  	return func(got V, loaded bool) {
   282  		t.Helper()
   283  
   284  		if !loaded {
   285  			t.Errorf("expected key %v to have been loaded, not stored", key)
   286  		}
   287  		if got != want {
   288  			t.Errorf("expected key %v to have value %v, got %v", key, want, got)
   289  		}
   290  	}
   291  }
   292  
   293  func expectStored[K, V comparable](t *testing.T, key K, want V) func(got V, loaded bool) {
   294  	t.Helper()
   295  	return func(got V, loaded bool) {
   296  		t.Helper()
   297  
   298  		if loaded {
   299  			t.Errorf("expected inserted key %v to have been stored, not loaded", key)
   300  		}
   301  		if got != want {
   302  			t.Errorf("expected inserted key %v to have value %v, got %v", key, want, got)
   303  		}
   304  	}
   305  }
   306  
   307  func expectDeleted[K, V comparable](t *testing.T, key K, old V) func(deleted bool) {
   308  	t.Helper()
   309  	return func(deleted bool) {
   310  		t.Helper()
   311  
   312  		if !deleted {
   313  			t.Errorf("expected key %v with value %v to be in map and deleted", key, old)
   314  		}
   315  	}
   316  }
   317  
   318  func expectNotDeleted[K, V comparable](t *testing.T, key K, old V) func(deleted bool) {
   319  	t.Helper()
   320  	return func(deleted bool) {
   321  		t.Helper()
   322  
   323  		if deleted {
   324  			t.Errorf("expected key %v with value %v to not be in map and thus not deleted", key, old)
   325  		}
   326  	}
   327  }
   328  
   329  func testDataMap(data []string) map[string]int {
   330  	m := make(map[string]int)
   331  	for i, s := range data {
   332  		m[s] = i
   333  	}
   334  	return m
   335  }
   336  
   337  var (
   338  	testDataSmall [8]string
   339  	testData      [128]string
   340  	testDataLarge [128 << 10]string
   341  )
   342  
   343  func init() {
   344  	for i := range testDataSmall {
   345  		testDataSmall[i] = fmt.Sprintf("%b", i)
   346  	}
   347  	for i := range testData {
   348  		testData[i] = fmt.Sprintf("%b", i)
   349  	}
   350  	for i := range testDataLarge {
   351  		testDataLarge[i] = fmt.Sprintf("%b", i)
   352  	}
   353  }
   354  
   355  func dumpMap[K, V comparable](ht *HashTrieMap[K, V]) {
   356  	dumpNode(ht, &ht.root.node, 0)
   357  }
   358  
   359  func dumpNode[K, V comparable](ht *HashTrieMap[K, V], n *node[K, V], depth int) {
   360  	var sb strings.Builder
   361  	for range depth {
   362  		fmt.Fprintf(&sb, "\t")
   363  	}
   364  	prefix := sb.String()
   365  	if n.isEntry {
   366  		e := n.entry()
   367  		for e != nil {
   368  			fmt.Printf("%s%p [Entry Key=%v Value=%v Overflow=%p, Hash=%016x]\n", prefix, e, e.key, e.value, e.overflow.Load(), ht.keyHash(unsafe.Pointer(&e.key), ht.seed))
   369  			e = e.overflow.Load()
   370  		}
   371  		return
   372  	}
   373  	i := n.indirect()
   374  	fmt.Printf("%s%p [Indirect Parent=%p Dead=%t Children=[", prefix, i, i.parent, i.dead.Load())
   375  	for j := range i.children {
   376  		c := i.children[j].Load()
   377  		fmt.Printf("%p", c)
   378  		if j != len(i.children)-1 {
   379  			fmt.Printf(", ")
   380  		}
   381  	}
   382  	fmt.Printf("]]\n")
   383  	for j := range i.children {
   384  		c := i.children[j].Load()
   385  		if c != nil {
   386  			dumpNode(ht, c, depth+1)
   387  		}
   388  	}
   389  }
   390  

View as plain text