1
2
3
4
5 package singleflight
6
7 import (
8 "errors"
9 "fmt"
10 "sync"
11 "sync/atomic"
12 "testing"
13 "time"
14 )
15
16 func TestDo(t *testing.T) {
17 var g Group
18 v, err, _ := g.Do("key", func() (any, error) {
19 return "bar", nil
20 })
21 if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
22 t.Errorf("Do = %v; want %v", got, want)
23 }
24 if err != nil {
25 t.Errorf("Do error = %v", err)
26 }
27 }
28
29 func TestDoErr(t *testing.T) {
30 var g Group
31 someErr := errors.New("some error")
32 v, err, _ := g.Do("key", func() (any, error) {
33 return nil, someErr
34 })
35 if err != someErr {
36 t.Errorf("Do error = %v; want someErr %v", err, someErr)
37 }
38 if v != nil {
39 t.Errorf("unexpected non-nil value %#v", v)
40 }
41 }
42
43 func TestDoDupSuppress(t *testing.T) {
44 var g Group
45 var wg1, wg2 sync.WaitGroup
46 c := make(chan string, 1)
47 var calls atomic.Int32
48 fn := func() (any, error) {
49 if calls.Add(1) == 1 {
50
51 wg1.Done()
52 }
53 v := <-c
54 c <- v
55
56 time.Sleep(10 * time.Millisecond)
57
58 return v, nil
59 }
60
61 const n = 10
62 wg1.Add(1)
63 for i := 0; i < n; i++ {
64 wg1.Add(1)
65 wg2.Add(1)
66 go func() {
67 defer wg2.Done()
68 wg1.Done()
69 v, err, _ := g.Do("key", fn)
70 if err != nil {
71 t.Errorf("Do error: %v", err)
72 return
73 }
74 if s, _ := v.(string); s != "bar" {
75 t.Errorf("Do = %T %v; want %q", v, v, "bar")
76 }
77 }()
78 }
79 wg1.Wait()
80
81
82 c <- "bar"
83 wg2.Wait()
84 if got := calls.Load(); got <= 0 || got >= n {
85 t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
86 }
87 }
88
89 func TestForgetUnshared(t *testing.T) {
90 var g Group
91
92 var firstStarted, firstFinished sync.WaitGroup
93
94 firstStarted.Add(1)
95 firstFinished.Add(1)
96
97 key := "key"
98 firstCh := make(chan struct{})
99 go func() {
100 g.Do(key, func() (i interface{}, e error) {
101 firstStarted.Done()
102 <-firstCh
103 return
104 })
105 firstFinished.Done()
106 }()
107
108 firstStarted.Wait()
109 g.ForgetUnshared(key)
110
111 secondCh := make(chan struct{})
112 go func() {
113 g.Do(key, func() (i interface{}, e error) {
114
115 secondCh <- struct{}{}
116 <-secondCh
117 return 2, nil
118 })
119 }()
120
121 <-secondCh
122
123 resultCh := g.DoChan(key, func() (i interface{}, e error) {
124 panic("third must not be started")
125 })
126
127 if g.ForgetUnshared(key) {
128 t.Errorf("Before first goroutine finished, key %q is shared, should return false", key)
129 }
130
131 close(firstCh)
132 firstFinished.Wait()
133
134 if g.ForgetUnshared(key) {
135 t.Errorf("After first goroutine finished, key %q is still shared, should return false", key)
136 }
137
138 secondCh <- struct{}{}
139
140 if result := <-resultCh; result.Val != 2 {
141 t.Errorf("We should receive result produced by second call, expected: 2, got %d", result.Val)
142 }
143 }
144
145 func TestDoAndForgetUnsharedRace(t *testing.T) {
146 t.Parallel()
147
148 var g Group
149 key := "key"
150 d := time.Millisecond
151 for {
152 var calls, shared atomic.Int64
153 const n = 1000
154 var wg sync.WaitGroup
155 wg.Add(n)
156 for i := 0; i < n; i++ {
157 go func() {
158 g.Do(key, func() (interface{}, error) {
159 time.Sleep(d)
160 return calls.Add(1), nil
161 })
162 if !g.ForgetUnshared(key) {
163 shared.Add(1)
164 }
165 wg.Done()
166 }()
167 }
168 wg.Wait()
169
170 if calls.Load() != 1 {
171
172
173
174 d *= 2
175 continue
176 }
177
178
179
180
181 if shared.Load() > 0 {
182 t.Errorf("after a single shared Do, ForgetUnshared returned false %d times", shared.Load())
183 }
184 break
185 }
186 }
187
View as plain text