├── README.md ├── counter.go ├── counter_atomic.go ├── counter_cas.go ├── counter_cas_float.go ├── counter_cas_float_test.go ├── counter_channel.go ├── counter_mutex.go ├── counter_notsafe.go ├── counter_test.go └── go.mod /README.md: -------------------------------------------------------------------------------- 1 | # Shared Counter Implementations 2 | 3 | - Not Safe Counter 4 | - Mutex Counter 5 | - Channel Counter 6 | - Atomic Counter 7 | - Compare and Swap Counter 8 | - Compare and Swap Float Counter 9 | 10 | This is a result out of my curiosity on trying to find out multiple implementations of a shared counter. Eventually, I came across the talk [Prometheus: Designing and Implementing a Modern Monitoring Solution in Go](https://www.youtube.com/watch?v=1V7eJ0jN8-E) that was very influential on some implementations. 11 | -------------------------------------------------------------------------------- /counter.go: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | type Counter interface { 4 | Add(uint64) 5 | Read() uint64 6 | } 7 | 8 | type FloatCounter interface { 9 | Add(float64) 10 | Read() float64 11 | } 12 | -------------------------------------------------------------------------------- /counter_atomic.go: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | import "sync/atomic" 4 | 5 | type AtomicCounter struct { 6 | number uint64 7 | } 8 | 9 | func NewAtomicCounter() Counter { 10 | return &AtomicCounter{0} 11 | } 12 | 13 | func (c *AtomicCounter) Add(num uint64) { 14 | atomic.AddUint64(&c.number, num) 15 | } 16 | 17 | func (c *AtomicCounter) Read() uint64 { 18 | return atomic.LoadUint64(&c.number) 19 | } 20 | -------------------------------------------------------------------------------- /counter_cas.go: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | import "sync/atomic" 4 | 5 | type CASCounter struct { 6 | number uint64 7 | } 8 | 9 | func NewCASCounter() Counter { 10 | return &CASCounter{0} 11 | } 12 | 13 | func (c *CASCounter) Add(num uint64) { 14 | for { 15 | v := atomic.LoadUint64(&c.number) 16 | if atomic.CompareAndSwapUint64(&c.number, v, v+num) { 17 | return 18 | } 19 | } 20 | } 21 | 22 | func (c *CASCounter) Read() uint64 { 23 | return atomic.LoadUint64(&c.number) 24 | } 25 | -------------------------------------------------------------------------------- /counter_cas_float.go: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | import ( 4 | "math" 5 | "sync/atomic" 6 | ) 7 | 8 | type CASFloatCounter struct { 9 | number uint64 10 | } 11 | 12 | func NewCASFloatCounter() *CASFloatCounter { 13 | return &CASFloatCounter{0} 14 | } 15 | 16 | func (c *CASFloatCounter) Add(num float64) { 17 | for { 18 | v := atomic.LoadUint64(&c.number) 19 | newValue := math.Float64bits(math.Float64frombits(v) + num) 20 | if atomic.CompareAndSwapUint64(&c.number, v, newValue) { 21 | return 22 | } 23 | } 24 | } 25 | 26 | func (c *CASFloatCounter) Read() float64 { 27 | return math.Float64frombits(atomic.LoadUint64(&c.number)) 28 | } 29 | -------------------------------------------------------------------------------- /counter_cas_float_test.go: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | import ( 4 | "math" 5 | "math/big" 6 | "sync" 7 | "testing" 8 | ) 9 | 10 | func testFloatCorrectness(t *testing.T, counter *CASFloatCounter) { 11 | wg := &sync.WaitGroup{} 12 | counter.Add(0.8) 13 | for i := 0; i < 100; i++ { 14 | wg.Add(1) 15 | if i%3 == 0 { 16 | go func(counter *CASFloatCounter) { 17 | counter.Read() 18 | wg.Done() 19 | }(counter) 20 | } else if i%3 == 1 { 21 | go func(counter *CASFloatCounter) { 22 | counter.Add(1.1) 23 | counter.Read() 24 | wg.Done() 25 | }(counter) 26 | } else { 27 | go func(counter *CASFloatCounter) { 28 | counter.Add(2.3) 29 | wg.Done() 30 | }(counter) 31 | } 32 | } 33 | 34 | wg.Wait() 35 | // 2.3 * 33 + 1.1 * 33 + 0.8= 113 36 | if big.NewFloat(math.Round(counter.Read())).Cmp(big.NewFloat(113)) != 0 { 37 | t.Errorf("counter should be %d and was %f", 113, counter.Read()) 38 | } 39 | } 40 | 41 | func TestCASFloatCounter(t *testing.T) { 42 | testFloatCorrectness(t, NewCASFloatCounter()) 43 | } 44 | -------------------------------------------------------------------------------- /counter_channel.go: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | type ChannelCounter struct { 4 | ch chan func() 5 | number uint64 6 | } 7 | 8 | func NewChannelCounter() Counter { 9 | counter := &ChannelCounter{make(chan func(), 100), 0} 10 | go func(counter *ChannelCounter) { 11 | for f := range counter.ch { 12 | f() 13 | } 14 | }(counter) 15 | return counter 16 | } 17 | 18 | func (c *ChannelCounter) Add(num uint64) { 19 | c.ch <- func() { 20 | c.number = c.number + num 21 | } 22 | } 23 | 24 | func (c *ChannelCounter) Read() uint64 { 25 | ret := make(chan uint64) 26 | c.ch <- func() { 27 | ret <- c.number 28 | close(ret) 29 | } 30 | return <-ret 31 | } 32 | -------------------------------------------------------------------------------- /counter_mutex.go: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | import "sync" 4 | 5 | type MutexCounter struct { 6 | mu sync.RWMutex 7 | number uint64 8 | } 9 | 10 | func NewMutexCounter() Counter { 11 | return &MutexCounter{} 12 | } 13 | 14 | func (c *MutexCounter) Add(num uint64) { 15 | c.mu.Lock() 16 | defer c.mu.Unlock() 17 | c.number = c.number + num 18 | } 19 | 20 | func (c *MutexCounter) Read() uint64 { 21 | c.mu.RLock() 22 | defer c.mu.RUnlock() 23 | return c.number 24 | } 25 | -------------------------------------------------------------------------------- /counter_notsafe.go: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | type NotSafeCounter struct { 4 | number uint64 5 | } 6 | 7 | func NewNotSafeCounter() Counter { 8 | return &NotSafeCounter{0} 9 | } 10 | 11 | func (c *NotSafeCounter) Add(num uint64) { 12 | c.number = c.number + num 13 | } 14 | 15 | func (c *NotSafeCounter) Read() uint64 { 16 | return c.number 17 | } 18 | -------------------------------------------------------------------------------- /counter_test.go: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | ) 7 | 8 | func testCorrectness(t *testing.T, counter Counter) { 9 | wg := &sync.WaitGroup{} 10 | for i := 0; i < 100; i++ { 11 | wg.Add(1) 12 | if i%3 == 0 { 13 | go func(counter Counter) { 14 | counter.Read() 15 | wg.Done() 16 | }(counter) 17 | } else if i%3 == 1 { 18 | go func(counter Counter) { 19 | counter.Add(1) 20 | counter.Read() 21 | wg.Done() 22 | }(counter) 23 | } else { 24 | go func(counter Counter) { 25 | counter.Add(1) 26 | wg.Done() 27 | }(counter) 28 | } 29 | } 30 | 31 | wg.Wait() 32 | 33 | if counter.Read() != 66 { 34 | t.Errorf("counter should be %d and was %d", 66, counter.Read()) 35 | } 36 | } 37 | 38 | func benchmark(b *testing.B, counter Counter, concurrency int) { 39 | b.StopTimer() 40 | start, end := &sync.WaitGroup{}, &sync.WaitGroup{} 41 | start.Add(1) 42 | for i := 0; i < concurrency; i++ { 43 | end.Add(1) 44 | go func(counter Counter) { 45 | start.Wait() 46 | counter.Add(1) 47 | counter.Read() 48 | end.Done() 49 | }(counter) 50 | } 51 | 52 | b.StartTimer() 53 | start.Done() 54 | end.Wait() 55 | } 56 | 57 | func TestNotSafeCounter(t *testing.T) { 58 | testCorrectness(t, NewNotSafeCounter()) 59 | } 60 | 61 | func TestMutexCounter(t *testing.T) { 62 | testCorrectness(t, NewMutexCounter()) 63 | } 64 | 65 | func TestChannelCounter(t *testing.T) { 66 | testCorrectness(t, NewChannelCounter()) 67 | } 68 | 69 | func TestCASCounter(t *testing.T) { 70 | testCorrectness(t, NewCASCounter()) 71 | } 72 | 73 | func TestAtomicCounter(t *testing.T) { 74 | testCorrectness(t, NewAtomicCounter()) 75 | } 76 | 77 | func BenchmarkNotSafeCounter1(b *testing.B) { 78 | benchmark(b, NewNotSafeCounter(), 1) 79 | } 80 | 81 | func BenchmarkMutexCounter1(b *testing.B) { 82 | benchmark(b, NewMutexCounter(), 1) 83 | } 84 | 85 | func BenchmarkChannelCounter1(b *testing.B) { 86 | benchmark(b, NewChannelCounter(), 1) 87 | } 88 | 89 | func BenchmarkCASCounter1(b *testing.B) { 90 | benchmark(b, NewCASCounter(), 1) 91 | } 92 | 93 | func BenchmarkAtomicCounter1(b *testing.B) { 94 | benchmark(b, NewAtomicCounter(), 1) 95 | } 96 | 97 | func BenchmarkNotSafeCounter10(b *testing.B) { 98 | benchmark(b, NewNotSafeCounter(), 10) 99 | } 100 | 101 | func BenchmarkMutexCounter10(b *testing.B) { 102 | benchmark(b, NewMutexCounter(), 10) 103 | } 104 | 105 | func BenchmarkChannelCounter10(b *testing.B) { 106 | benchmark(b, NewChannelCounter(), 10) 107 | } 108 | 109 | func BenchmarkCASCounter10(b *testing.B) { 110 | benchmark(b, NewCASCounter(), 10) 111 | } 112 | 113 | func BenchmarkAtomicCounter10(b *testing.B) { 114 | benchmark(b, NewAtomicCounter(), 10) 115 | } 116 | 117 | func BenchmarkNotSafeCounter100(b *testing.B) { 118 | benchmark(b, NewNotSafeCounter(), 100) 119 | } 120 | 121 | func BenchmarkMutexCounter100(b *testing.B) { 122 | benchmark(b, NewMutexCounter(), 100) 123 | } 124 | 125 | func BenchmarkChannelCounter100(b *testing.B) { 126 | benchmark(b, NewChannelCounter(), 100) 127 | } 128 | 129 | func BenchmarkCASCounter100(b *testing.B) { 130 | benchmark(b, NewCASCounter(), 100) 131 | } 132 | 133 | func BenchmarkAtomicCounter100(b *testing.B) { 134 | benchmark(b, NewAtomicCounter(), 100) 135 | } 136 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/brunocalza/sharedcounter 2 | 3 | go 1.14 4 | --------------------------------------------------------------------------------