1
2
3
4
5
6
7
8 package counter
9
10 import (
11 "fmt"
12 "os"
13 "runtime"
14 "strings"
15 "sync/atomic"
16 )
17
18 var (
19
20
21 debugCounter = strings.Contains(os.Getenv("GODEBUG"), "countertrace=1")
22 CrashOnBugs = false
23 )
24
25
26 func debugPrintf(format string, args ...any) {
27 if debugCounter {
28 if len(format) == 0 || format[len(format)-1] != '\n' {
29 format += "\n"
30 }
31 fmt.Fprintf(os.Stderr, "counter: "+format, args...)
32 }
33 }
34
35
36 func debugFatalf(format string, args ...any) {
37 if debugCounter || CrashOnBugs {
38 if len(format) == 0 || format[len(format)-1] != '\n' {
39 format += "\n"
40 }
41 fmt.Fprintf(os.Stderr, "counter bug: "+format, args...)
42 os.Exit(1)
43 }
44 }
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66 type Counter struct {
67 name string
68 file *file
69
70 next atomic.Pointer[Counter]
71 state counterState
72 ptr counterPtr
73 }
74
75 func (c *Counter) Name() string {
76 return c.name
77 }
78
79 type counterPtr struct {
80 m *mappedFile
81 count *atomic.Uint64
82 }
83
84 type counterState struct {
85 bits atomic.Uint64
86 }
87
88 func (s *counterState) load() counterStateBits {
89 return counterStateBits(s.bits.Load())
90 }
91
92 func (s *counterState) update(old *counterStateBits, new counterStateBits) bool {
93 if s.bits.CompareAndSwap(uint64(*old), uint64(new)) {
94 *old = new
95 return true
96 }
97 return false
98 }
99
100 type counterStateBits uint64
101
102 const (
103 stateReaders counterStateBits = 1<<30 - 1
104 stateLocked counterStateBits = stateReaders
105 stateHavePtr counterStateBits = 1 << 30
106 stateExtraShift = 31
107 stateExtra counterStateBits = 1<<64 - 1<<stateExtraShift
108 )
109
110 func (b counterStateBits) readers() int { return int(b & stateReaders) }
111 func (b counterStateBits) locked() bool { return b&stateReaders == stateLocked }
112 func (b counterStateBits) havePtr() bool { return b&stateHavePtr != 0 }
113 func (b counterStateBits) extra() uint64 { return uint64(b&stateExtra) >> stateExtraShift }
114
115 func (b counterStateBits) incReader() counterStateBits { return b + 1 }
116 func (b counterStateBits) decReader() counterStateBits { return b - 1 }
117 func (b counterStateBits) setLocked() counterStateBits { return b | stateLocked }
118 func (b counterStateBits) clearLocked() counterStateBits { return b &^ stateLocked }
119 func (b counterStateBits) setHavePtr() counterStateBits { return b | stateHavePtr }
120 func (b counterStateBits) clearHavePtr() counterStateBits { return b &^ stateHavePtr }
121 func (b counterStateBits) clearExtra() counterStateBits { return b &^ stateExtra }
122 func (b counterStateBits) addExtra(n uint64) counterStateBits {
123 const maxExtra = uint64(stateExtra) >> stateExtraShift
124 x := b.extra()
125 if x+n < x || x+n > maxExtra {
126 x = maxExtra
127 } else {
128 x += n
129 }
130 return b.clearExtra() | counterStateBits(x)<<stateExtraShift
131 }
132
133
134
135
136
137 func New(name string) *Counter {
138
139
140
141 return &Counter{name: name, file: &defaultFile}
142 }
143
144
145 func (c *Counter) Inc() {
146 c.Add(1)
147 }
148
149
150 func (c *Counter) Add(n int64) {
151 debugPrintf("Add %q += %d", c.name, n)
152
153 if n < 0 {
154 panic("Counter.Add negative")
155 }
156 if n == 0 {
157 return
158 }
159 c.file.register(c)
160
161 state := c.state.load()
162 for ; ; state = c.state.load() {
163 switch {
164 case !state.locked() && state.havePtr():
165 if !c.state.update(&state, state.incReader()) {
166 continue
167 }
168
169 if c.ptr.count == nil {
170 for !c.state.update(&state, state.addExtra(uint64(n))) {
171
172 state = c.state.load()
173 }
174 debugPrintf("Add %q += %d: nil extra=%d\n", c.name, n, state.extra())
175 } else {
176 sum := c.add(uint64(n))
177 debugPrintf("Add %q += %d: count=%d\n", c.name, n, sum)
178 }
179 c.releaseReader(state)
180 return
181
182 case state.locked():
183 if !c.state.update(&state, state.addExtra(uint64(n))) {
184 continue
185 }
186 debugPrintf("Add %q += %d: locked extra=%d\n", c.name, n, state.extra())
187 return
188
189 case !state.havePtr():
190 if !c.state.update(&state, state.addExtra(uint64(n)).setLocked()) {
191 continue
192 }
193 debugPrintf("Add %q += %d: noptr extra=%d\n", c.name, n, state.extra())
194 c.releaseLock(state)
195 return
196 }
197 }
198 }
199
200 func (c *Counter) releaseReader(state counterStateBits) {
201 for ; ; state = c.state.load() {
202
203
204
205
206
207
208 if state.readers() == 1 && !state.havePtr() {
209 if !c.state.update(&state, state.setLocked()) {
210 continue
211 }
212 debugPrintf("releaseReader %s: last reader, need ptr\n", c.name)
213 c.releaseLock(state)
214 return
215 }
216
217
218 if !c.state.update(&state, state.decReader()) {
219 continue
220 }
221 debugPrintf("releaseReader %s: released (%d readers now)\n", c.name, state.readers())
222 return
223 }
224 }
225
226 func (c *Counter) releaseLock(state counterStateBits) {
227 for ; ; state = c.state.load() {
228 if !state.havePtr() {
229
230
231 if !c.state.update(&state, state.setHavePtr()) {
232 continue
233 }
234 debugPrintf("releaseLock %s: reset havePtr (extra=%d)\n", c.name, state.extra())
235
236
237
238 c.ptr = counterPtr{nil, nil}
239 if state.extra() != 0 {
240 c.ptr = c.file.lookup(c.name)
241 debugPrintf("releaseLock %s: ptr=%v\n", c.name, c.ptr)
242 }
243 }
244
245 if extra := state.extra(); extra != 0 && c.ptr.count != nil {
246 if !c.state.update(&state, state.clearExtra()) {
247 continue
248 }
249 sum := c.add(extra)
250 debugPrintf("releaseLock %s: flush extra=%d -> count=%d\n", c.name, extra, sum)
251 }
252
253
254
255
256
257 if !c.state.update(&state, state.clearLocked()) {
258 continue
259 }
260 debugPrintf("releaseLock %s: unlocked\n", c.name)
261 return
262 }
263 }
264
265
266 func (c *Counter) add(n uint64) uint64 {
267 count := c.ptr.count
268 for {
269 old := count.Load()
270 sum := old + n
271 if sum < old {
272 sum = ^uint64(0)
273 }
274 if count.CompareAndSwap(old, sum) {
275 runtime.KeepAlive(c.ptr.m)
276 return sum
277 }
278 }
279 }
280
281 func (c *Counter) invalidate() {
282 for {
283 state := c.state.load()
284 if !state.havePtr() {
285 debugPrintf("invalidate %s: no ptr\n", c.name)
286 return
287 }
288 if c.state.update(&state, state.clearHavePtr()) {
289 debugPrintf("invalidate %s: cleared havePtr\n", c.name)
290 return
291 }
292 }
293 }
294
295 func (c *Counter) refresh() {
296 for {
297 state := c.state.load()
298 if state.havePtr() || state.readers() > 0 || state.extra() == 0 {
299 debugPrintf("refresh %s: havePtr=%v readers=%d extra=%d\n", c.name, state.havePtr(), state.readers(), state.extra())
300 return
301 }
302 if c.state.update(&state, state.setLocked()) {
303 debugPrintf("refresh %s: locked havePtr=%v readers=%d extra=%d\n", c.name, state.havePtr(), state.readers(), state.extra())
304 c.releaseLock(state)
305 return
306 }
307 }
308 }
309
310
311
312 func Read(c *Counter) (uint64, error) {
313 if c.file.current.Load() == nil {
314 return c.state.load().extra(), nil
315 }
316 pf, err := readFile(c.file)
317 if err != nil {
318 return 0, err
319 }
320 v, ok := pf.Count[DecodeStack(c.Name())]
321 if !ok {
322 return v, fmt.Errorf("not found:%q", DecodeStack(c.Name()))
323 }
324 return v, nil
325 }
326
327 func readFile(f *file) (*File, error) {
328 if f == nil {
329 debugPrintf("No file")
330 return nil, fmt.Errorf("counter is not initialized - was Open called?")
331 }
332
333
334 f.rotate1()
335
336 if f.err != nil {
337 return nil, fmt.Errorf("failed to rotate mapped file - %v", f.err)
338 }
339 current := f.current.Load()
340 if current == nil {
341 return nil, fmt.Errorf("counter has no mapped file")
342 }
343 name := current.f.Name()
344 data, err := os.ReadFile(name)
345 if err != nil {
346 return nil, fmt.Errorf("failed to read from file: %v", err)
347 }
348 pf, err := Parse(name, data)
349 if err != nil {
350 return nil, fmt.Errorf("failed to parse: %v", err)
351 }
352 return pf, nil
353 }
354
355
356
357 func ReadFile(name string) (counters, stackCounters map[string]uint64, _ error) {
358
359
360 data, err := ReadMapped(name)
361 if err != nil {
362 return nil, nil, fmt.Errorf("failed to read from file: %v", err)
363 }
364 pf, err := Parse(name, data)
365 if err != nil {
366 return nil, nil, fmt.Errorf("failed to parse: %v", err)
367 }
368 counters = make(map[string]uint64)
369 stackCounters = make(map[string]uint64)
370 for k, v := range pf.Count {
371 if IsStackCounter(k) {
372 stackCounters[DecodeStack(k)] = v
373 } else {
374 counters[k] = v
375 }
376 }
377 return counters, stackCounters, nil
378 }
379
380
381
382
383 func ReadMapped(name string) ([]byte, error) {
384 f, err := os.OpenFile(name, os.O_RDWR, 0666)
385 if err != nil {
386 return nil, err
387 }
388 defer f.Close()
389 fi, err := f.Stat()
390 if err != nil {
391 return nil, err
392 }
393 mapping, err := memmap(f)
394 if err != nil {
395 return nil, err
396 }
397 data := make([]byte, fi.Size())
398 copy(data, mapping.Data)
399 munmap(mapping)
400 return data, nil
401 }
402
View as plain text