├── .github ├── dependabot.yaml └── workflows │ ├── golangci-lint.yml │ ├── staticcheck.yml │ ├── tests.yml │ └── vet.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── Makefile ├── README.md ├── codecov.yml ├── concurrent_buffer.go ├── concurrent_buffer_test.go ├── cosmos_test.go ├── cosmosdb.go ├── docker-compose.yml ├── dynamodb.go ├── dynamodb_test.go ├── examples ├── example_grpc_ip_limiter_test.go ├── example_grpc_simple_limiter_test.go ├── gprc_service.go └── helloworld │ ├── helloworld.pb.go │ ├── helloworld.proto │ └── helloworld_grpc.pb.go ├── fixedwindow.go ├── fixedwindow_test.go ├── go.mod ├── go.sum ├── leakybucket.go ├── leakybucket_test.go ├── limiters.go ├── limiters_test.go ├── locks.go ├── locks_test.go ├── registry.go ├── registry_test.go ├── revive.toml ├── slidingwindow.go ├── slidingwindow_test.go ├── tokenbucket.go └── tokenbucket_test.go /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: gomod 4 | directory: / 5 | schedule: 6 | interval: monthly 7 | -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | on: 3 | push: 4 | branches: 5 | - main 6 | - master 7 | pull_request: 8 | 9 | permissions: 10 | contents: read 11 | # Optional: allow read access to pull request. Use with `only-new-issues` option. 12 | # pull-requests: read 13 | 14 | jobs: 15 | golangci: 16 | name: lint 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v4 20 | - uses: actions/setup-go@v5 21 | with: 22 | go-version: stable 23 | - name: golangci-lint 24 | uses: golangci/golangci-lint-action@v8 25 | with: 26 | version: latest 27 | -------------------------------------------------------------------------------- /.github/workflows/staticcheck.yml: -------------------------------------------------------------------------------- 1 | name: Static Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | staticcheck: 13 | name: Linter 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: dominikh/staticcheck-action@v1.3.1 18 | with: 19 | version: "latest" -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Go tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | go-test: 13 | 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | go-version: [ '1.22', '1.23', '1.24' ] 18 | 19 | services: 20 | etcd: 21 | image: bitnami/etcd 22 | env: 23 | ALLOW_NONE_AUTHENTICATION: yes 24 | ports: 25 | - 2379:2379 26 | redis: 27 | image: bitnami/redis 28 | env: 29 | ALLOW_EMPTY_PASSWORD: yes 30 | ports: 31 | - 6379:6379 32 | redis-cluster: 33 | image: grokzen/redis-cluster:7.0.10 34 | env: 35 | IP: 0.0.0.0 36 | INITIAL_PORT: 11000 37 | ports: 38 | - 11000-11005:11000-11005 39 | memcached: 40 | image: bitnami/memcached 41 | ports: 42 | - 11211:11211 43 | consul: 44 | image: bitnami/consul 45 | ports: 46 | - 8500:8500 47 | zookeeper: 48 | image: bitnami/zookeeper 49 | env: 50 | ALLOW_ANONYMOUS_LOGIN: yes 51 | ports: 52 | - 2181:2181 53 | postgresql: 54 | image: bitnami/postgresql 55 | env: 56 | ALLOW_EMPTY_PASSWORD: yes 57 | ports: 58 | - 5432:5432 59 | cosmos: 60 | image: mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:vnext-preview 61 | env: 62 | PROTOCOL: http 63 | COSMOS_HTTP_CONNECTION_WITHOUT_TLS_ALLOWED: "true" 64 | ports: 65 | - "8081:8081" 66 | 67 | steps: 68 | - uses: actions/checkout@v4 69 | - name: Setup Go ${{ matrix.go-version }} 70 | uses: actions/setup-go@v5 71 | with: 72 | go-version: ${{ matrix.go-version }} 73 | - name: Setup DynamoDB Local 74 | uses: rrainn/dynamodb-action@v4.0.0 75 | with: 76 | port: 8000 77 | - name: Run tests 78 | env: 79 | ETCD_ENDPOINTS: 'localhost:2379' 80 | REDIS_ADDR: 'localhost:6379' 81 | REDIS_NODES: 'localhost:11000,localhost:11001,localhost:11002,localhost:11003,localhost:11004,localhost:11005' 82 | CONSUL_ADDR: 'localhost:8500' 83 | ZOOKEEPER_ENDPOINTS: 'localhost:2181' 84 | AWS_ADDR: 'localhost:8000' 85 | MEMCACHED_ADDR: '127.0.0.1:11211' 86 | POSTGRES_URL: postgres://postgres@localhost:5432/?sslmode=disable 87 | COSMOS_ADDR: '127.0.0.1:8081' 88 | run: go test -race -v -coverprofile=coverage.txt -covermode=atomic ./... 89 | - uses: codecov/codecov-action@v5 90 | with: 91 | verbose: true 92 | -------------------------------------------------------------------------------- /.github/workflows/vet.yml: -------------------------------------------------------------------------------- 1 | name: Linter 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | vet: 13 | name: Go vet 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Setup Go 1.24 18 | uses: actions/setup-go@v5 19 | with: 20 | go-version: '1.24' 21 | - name: Go vet 22 | run: go vet ./... 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | coverage.txt 3 | .vscode -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | linters: 3 | default: all 4 | disable: 5 | - cyclop 6 | - depguard 7 | - dupl 8 | - dupword 9 | - exhaustruct 10 | - forcetypeassert 11 | - funcorder 12 | - funlen 13 | - gochecknoglobals 14 | - goconst 15 | - gocritic 16 | - godox 17 | - gomoddirectives 18 | - gosec 19 | - inamedparam 20 | - intrange 21 | - lll 22 | - mnd 23 | - musttag 24 | - paralleltest 25 | - perfsprint 26 | - recvcheck 27 | - revive 28 | - tagliatelle 29 | - testifylint 30 | - unparam 31 | - varnamelen 32 | - wrapcheck 33 | - wsl 34 | exclusions: 35 | generated: lax 36 | presets: 37 | - comments 38 | - common-false-positives 39 | - legacy 40 | - std-error-handling 41 | paths: 42 | - third_party$ 43 | - builtin$ 44 | - examples$ 45 | formatters: 46 | enable: 47 | - gci 48 | - gofmt 49 | - gofumpt 50 | - goimports 51 | exclusions: 52 | generated: lax 53 | paths: 54 | - third_party$ 55 | - builtin$ 56 | - examples$ 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Renat Mennanov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: gofumpt goimports lint test benchmark 2 | 3 | docker-compose-up: 4 | docker compose up -d 5 | 6 | test: docker-compose-up 7 | ETCD_ENDPOINTS="127.0.0.1:2379" REDIS_ADDR="127.0.0.1:6379" REDIS_NODES="127.0.0.1:11000,127.0.0.1:11001,127.0.0.1:11002,127.0.0.1:11003,127.0.0.1:11004,127.0.0.1:11005" ZOOKEEPER_ENDPOINTS="127.0.0.1" CONSUL_ADDR="127.0.0.1:8500" AWS_ADDR="127.0.0.1:8000" MEMCACHED_ADDR="127.0.0.1:11211" POSTGRES_URL="postgres://postgres@localhost:5432/?sslmode=disable" COSMOS_ADDR="127.0.0.1:8081" go test -race -v -failfast 8 | 9 | benchmark: docker-compose-up 10 | ETCD_ENDPOINTS="127.0.0.1:2379" REDIS_ADDR="127.0.0.1:6379" REDIS_NODES="127.0.0.1:11000,127.0.0.1:11001,127.0.0.1:11002,127.0.0.1:11003,127.0.0.1:11004,127.0.0.1:11005" ZOOKEEPER_ENDPOINTS="127.0.0.1" CONSUL_ADDR="127.0.0.1:8500" AWS_ADDR="127.0.0.1:8000" MEMCACHED_ADDR="127.0.0.1:11211" POSTGRES_URL="postgres://postgres@localhost:5432/?sslmode=disable" COSMOS_ADDR="127.0.0.1:8081" go test -race -run=nonexistent -bench=. 11 | 12 | lint: 13 | @(which golangci-lint && golangci-lint --version | grep 2.1.6) || (curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(shell go env GOPATH)/bin v2.1.6) 14 | golangci-lint run --fix ./... 15 | 16 | goimports: 17 | @which goimports 2>&1 > /dev/null || go install golang.org/x/tools/cmd/goimports@latest 18 | goimports -w . 19 | 20 | gofumpt: 21 | @which gofumpt 2>&1 > /dev/null || go install mvdan.cc/gofumpt@latest 22 | gofumpt -l -w . 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distributed rate limiters for Golang 2 | [![Build Status](https://github.com/mennanov/limiters/actions/workflows/tests.yml/badge.svg)](https://github.com/mennanov/limiters/actions/workflows/tests.yml) 3 | [![codecov](https://codecov.io/gh/mennanov/limiters/branch/master/graph/badge.svg?token=LZULu4i7B6)](https://codecov.io/gh/mennanov/limiters) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/mennanov/limiters)](https://goreportcard.com/report/github.com/mennanov/limiters) 5 | [![GoDoc](https://godoc.org/github.com/mennanov/limiters?status.svg)](https://godoc.org/github.com/mennanov/limiters) 6 | 7 | Rate limiters for distributed applications in Golang with configurable back-ends and distributed locks. 8 | Any types of back-ends and locks can be used that implement certain minimalistic interfaces. 9 | Most common implementations are already provided. 10 | 11 | - [`Token bucket`](https://en.wikipedia.org/wiki/Token_bucket) 12 | - in-memory (local) 13 | - redis 14 | - memcached 15 | - etcd 16 | - dynamodb 17 | - cosmos db 18 | 19 | Allows requests at a certain input rate with possible bursts configured by the capacity parameter. 20 | The output rate equals to the input rate. 21 | Precise (no over or under-limiting), but requires a lock (provided). 22 | 23 | - [`Leaky bucket`](https://en.wikipedia.org/wiki/Leaky_bucket#As_a_queue) 24 | - in-memory (local) 25 | - redis 26 | - memcached 27 | - etcd 28 | - dynamodb 29 | - cosmos db 30 | 31 | Puts requests in a FIFO queue to be processed at a constant rate. 32 | There are no restrictions on the input rate except for the capacity of the queue. 33 | Requires a lock (provided). 34 | 35 | - [`Fixed window counter`](https://konghq.com/blog/how-to-design-a-scalable-rate-limiting-algorithm/) 36 | - in-memory (local) 37 | - redis 38 | - memcached 39 | - dynamodb 40 | - cosmos db 41 | 42 | Simple and resources efficient algorithm that does not need a lock. 43 | Precision may be adjusted by the size of the window. 44 | May be lenient when there are many requests around the boundary between 2 adjacent windows. 45 | 46 | - [`Sliding window counter`](https://konghq.com/blog/how-to-design-a-scalable-rate-limiting-algorithm/) 47 | - in-memory (local) 48 | - redis 49 | - memcached 50 | - dynamodb 51 | - cosmos db 52 | 53 | Smoothes out the bursts around the boundary between 2 adjacent windows. 54 | Needs as twice more memory as the `Fixed Window` algorithm (2 windows instead of 1 at a time). 55 | It will disallow _all_ the requests in case when a client is flooding the service with requests. 56 | It's the client's responsibility to handle a disallowed request properly: wait before making a new one again. 57 | 58 | - `Concurrent buffer` 59 | - in-memory (local) 60 | - redis 61 | - memcached 62 | 63 | Allows concurrent requests up to the given capacity. 64 | Requires a lock (provided). 65 | 66 | ## gRPC example 67 | 68 | Global token bucket rate limiter for a gRPC service example: 69 | ```go 70 | // examples/example_grpc_simple_limiter_test.go 71 | rate := time.Second * 3 72 | limiter := limiters.NewTokenBucket( 73 | 2, 74 | rate, 75 | limiters.NewLockerEtcd(etcdClient, "/ratelimiter_lock/simple/", limiters.NewStdLogger()), 76 | limiters.NewTokenBucketRedis( 77 | redisClient, 78 | "ratelimiter/simple", 79 | rate, false), 80 | limiters.NewSystemClock(), limiters.NewStdLogger(), 81 | ) 82 | 83 | // Add a unary interceptor middleware to rate limit all requests. 84 | s := grpc.NewServer(grpc.UnaryInterceptor( 85 | func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 86 | w, err := limiter.Limit(ctx) 87 | if err == limiters.ErrLimitExhausted { 88 | return nil, status.Errorf(codes.ResourceExhausted, "try again later in %s", w) 89 | } else if err != nil { 90 | // The limiter failed. This error should be logged and examined. 91 | log.Println(err) 92 | return nil, status.Error(codes.Internal, "internal error") 93 | } 94 | return handler(ctx, req) 95 | })) 96 | ``` 97 | 98 | For something close to a real world example see the IP address based gRPC global rate limiter in the 99 | [examples](examples/example_grpc_ip_limiter_test.go) directory. 100 | 101 | ## DynamoDB 102 | 103 | The use of DynamoDB requires the creation of a DynamoDB Table prior to use. An existing table can be used or a new one can be created. Depending on the limiter backend: 104 | 105 | * Partition Key 106 | - String 107 | - Required for all Backends 108 | * Sort Key 109 | - String 110 | - Backends: 111 | - FixedWindow 112 | - SlidingWindow 113 | * TTL 114 | - Number 115 | - Backends: 116 | - FixedWindow 117 | - SlidingWindow 118 | - LeakyBucket 119 | - TokenBucket 120 | 121 | All DynamoDB backends accept a `DynamoDBTableProperties` struct as a paramater. This can be manually created or use the `LoadDynamoDBTableProperties` with the table name. When using `LoadDynamoDBTableProperties`, the table description is fetched from AWS and verified that the table can be used for Limiter backends. Results of `LoadDynamoDBTableProperties` are cached. 122 | 123 | ## Azure Cosmos DB for NoSQL 124 | 125 | The use of Cosmos DB requires the creation of a database and container prior to use. 126 | 127 | The container must have a default TTL set, otherwise TTL [will not take effect](https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/time-to-live#time-to-live-configurations). 128 | 129 | The partition key should be `/partitionKey`. 130 | 131 | ## Distributed locks 132 | 133 | Some algorithms require a distributed lock to guarantee consistency during concurrent requests. 134 | In case there is only 1 running application instance then no distributed lock is needed 135 | as all the algorithms are thread-safe (use `LockNoop`). 136 | 137 | Supported backends: 138 | - [etcd](https://etcd.io/) 139 | - [Consul](https://www.consul.io/) 140 | - [Zookeeper](https://zookeeper.apache.org/) 141 | - [Redis](https://redis.io/) 142 | - [Memcached](https://memcached.org/) 143 | - [PostgreSQL](https://www.postgresql.org/) 144 | 145 | ## Memcached 146 | 147 | It's important to understand that memcached is not ideal for implementing reliable locks or data persistence due to its inherent limitations: 148 | 149 | - No guaranteed data retention: Memcached can evict data at any point due to memory pressure, even if it appears to have space available. This can lead to unexpected lock releases or data loss. 150 | - Lack of distributed locking features: Memcached doesn't offer functionalities like distributed coordination required for consistent locking across multiple servers. 151 | 152 | If memcached exists already and it is okay to handle burst traffic caused by unexpected evicted data, Memcached-based implementations are convenient, otherwise Redis-based implementations will be better choices. 153 | 154 | ## Testing 155 | 156 | Run tests locally: 157 | ```bash 158 | make test 159 | ``` 160 | Run benchmarks locally: 161 | ```bash 162 | make benchmark 163 | ``` 164 | Run both locally: 165 | ```bash 166 | make 167 | ``` 168 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | patch: off 4 | -------------------------------------------------------------------------------- /concurrent_buffer.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/gob" 7 | "fmt" 8 | "sync" 9 | "time" 10 | 11 | "github.com/bradfitz/gomemcache/memcache" 12 | "github.com/pkg/errors" 13 | "github.com/redis/go-redis/v9" 14 | ) 15 | 16 | // ConcurrentBufferBackend wraps the Add and Remove methods. 17 | type ConcurrentBufferBackend interface { 18 | // Add adds the request with the given key to the buffer and returns the total number of requests in it. 19 | Add(ctx context.Context, key string) (int64, error) 20 | // Remove removes the request from the buffer. 21 | Remove(ctx context.Context, key string) error 22 | } 23 | 24 | // ConcurrentBuffer implements a limiter that allows concurrent requests up to the given capacity. 25 | type ConcurrentBuffer struct { 26 | locker DistLocker 27 | backend ConcurrentBufferBackend 28 | logger Logger 29 | capacity int64 30 | mu sync.Mutex 31 | } 32 | 33 | // NewConcurrentBuffer creates a new ConcurrentBuffer instance. 34 | func NewConcurrentBuffer(locker DistLocker, concurrentStateBackend ConcurrentBufferBackend, capacity int64, logger Logger) *ConcurrentBuffer { 35 | return &ConcurrentBuffer{locker: locker, backend: concurrentStateBackend, capacity: capacity, logger: logger} 36 | } 37 | 38 | // Limit puts the request identified by the key in a buffer. 39 | func (c *ConcurrentBuffer) Limit(ctx context.Context, key string) error { 40 | c.mu.Lock() 41 | defer c.mu.Unlock() 42 | if err := c.locker.Lock(ctx); err != nil { 43 | return err 44 | } 45 | defer func() { 46 | if err := c.locker.Unlock(ctx); err != nil { 47 | c.logger.Log(err) 48 | } 49 | }() 50 | // Optimistically add the new request. 51 | counter, err := c.backend.Add(ctx, key) 52 | if err != nil { 53 | return err 54 | } 55 | if counter > c.capacity { 56 | // Rollback the Add() operation. 57 | if err = c.backend.Remove(ctx, key); err != nil { 58 | c.logger.Log(err) 59 | } 60 | 61 | return ErrLimitExhausted 62 | } 63 | 64 | return nil 65 | } 66 | 67 | // Done removes the request identified by the key from the buffer. 68 | func (c *ConcurrentBuffer) Done(ctx context.Context, key string) error { 69 | return c.backend.Remove(ctx, key) 70 | } 71 | 72 | // ConcurrentBufferInMemory is an in-memory implementation of ConcurrentBufferBackend. 73 | type ConcurrentBufferInMemory struct { 74 | clock Clock 75 | ttl time.Duration 76 | mu sync.Mutex 77 | registry *Registry 78 | } 79 | 80 | // NewConcurrentBufferInMemory creates a new instance of ConcurrentBufferInMemory. 81 | // When the TTL of a key exceeds the key is removed from the buffer. This is needed in case if the process that added 82 | // that key to the buffer did not call Done() for some reason. 83 | func NewConcurrentBufferInMemory(registry *Registry, ttl time.Duration, clock Clock) *ConcurrentBufferInMemory { 84 | return &ConcurrentBufferInMemory{clock: clock, ttl: ttl, registry: registry} 85 | } 86 | 87 | // Add adds the request with the given key to the buffer and returns the total number of requests in it. 88 | // It also removes the keys with expired TTL. 89 | func (c *ConcurrentBufferInMemory) Add(ctx context.Context, key string) (int64, error) { 90 | c.mu.Lock() 91 | defer c.mu.Unlock() 92 | now := c.clock.Now() 93 | c.registry.DeleteExpired(now) 94 | c.registry.GetOrCreate(key, func() interface{} { 95 | return struct{}{} 96 | }, c.ttl, now) 97 | 98 | return int64(c.registry.Len()), ctx.Err() 99 | } 100 | 101 | // Remove removes the request from the buffer. 102 | func (c *ConcurrentBufferInMemory) Remove(_ context.Context, key string) error { 103 | c.mu.Lock() 104 | defer c.mu.Unlock() 105 | c.registry.Delete(key) 106 | 107 | return nil 108 | } 109 | 110 | // ConcurrentBufferRedis implements ConcurrentBufferBackend in Redis. 111 | type ConcurrentBufferRedis struct { 112 | clock Clock 113 | cli redis.UniversalClient 114 | key string 115 | ttl time.Duration 116 | } 117 | 118 | // NewConcurrentBufferRedis creates a new instance of ConcurrentBufferRedis. 119 | // When the TTL of a key exceeds the key is removed from the buffer. This is needed in case if the process that added 120 | // that key to the buffer did not call Done() for some reason. 121 | func NewConcurrentBufferRedis(cli redis.UniversalClient, key string, ttl time.Duration, clock Clock) *ConcurrentBufferRedis { 122 | return &ConcurrentBufferRedis{clock: clock, cli: cli, key: key, ttl: ttl} 123 | } 124 | 125 | // Add adds the request with the given key to the sorted set in Redis and returns the total number of requests in it. 126 | // It also removes the keys with expired TTL. 127 | func (c *ConcurrentBufferRedis) Add(ctx context.Context, key string) (int64, error) { 128 | var countCmd *redis.IntCmd 129 | var err error 130 | done := make(chan struct{}) 131 | go func() { 132 | defer close(done) 133 | _, err = c.cli.Pipelined(ctx, func(pipeliner redis.Pipeliner) error { 134 | // Remove expired items. 135 | now := c.clock.Now() 136 | pipeliner.ZRemRangeByScore(ctx, c.key, "-inf", fmt.Sprintf("%d", now.Add(-c.ttl).UnixNano())) 137 | pipeliner.ZAdd(ctx, c.key, redis.Z{ 138 | Score: float64(now.UnixNano()), 139 | Member: key, 140 | }) 141 | countCmd = pipeliner.ZCard(ctx, c.key) 142 | 143 | return nil 144 | }) 145 | }() 146 | 147 | select { 148 | case <-ctx.Done(): 149 | return 0, ctx.Err() 150 | 151 | case <-done: 152 | if err != nil { 153 | return 0, errors.Wrap(err, "failed to add an item to redis set") 154 | } 155 | 156 | return countCmd.Val(), nil 157 | } 158 | } 159 | 160 | // Remove removes the request identified by the key from the sorted set in Redis. 161 | func (c *ConcurrentBufferRedis) Remove(ctx context.Context, key string) error { 162 | return errors.Wrap(c.cli.ZRem(ctx, c.key, key).Err(), "failed to remove an item from redis set") 163 | } 164 | 165 | // ConcurrentBufferMemcached implements ConcurrentBufferBackend in Memcached. 166 | type ConcurrentBufferMemcached struct { 167 | clock Clock 168 | cli *memcache.Client 169 | key string 170 | ttl time.Duration 171 | } 172 | 173 | // NewConcurrentBufferMemcached creates a new instance of ConcurrentBufferMemcached. 174 | // When the TTL of a key exceeds the key is removed from the buffer. This is needed in case if the process that added 175 | // that key to the buffer did not call Done() for some reason. 176 | func NewConcurrentBufferMemcached(cli *memcache.Client, key string, ttl time.Duration, clock Clock) *ConcurrentBufferMemcached { 177 | return &ConcurrentBufferMemcached{clock: clock, cli: cli, key: key, ttl: ttl} 178 | } 179 | 180 | type SortedSetNode struct { 181 | CreatedAt int64 182 | Value string 183 | } 184 | 185 | // Add adds the request with the given key to the slice in Memcached and returns the total number of requests in it. 186 | // It also removes the keys with expired TTL. 187 | func (c *ConcurrentBufferMemcached) Add(ctx context.Context, element string) (int64, error) { 188 | var err error 189 | done := make(chan struct{}) 190 | now := c.clock.Now() 191 | var newNodes []SortedSetNode 192 | var casId uint64 = 0 193 | 194 | go func() { 195 | defer close(done) 196 | var item *memcache.Item 197 | item, err = c.cli.Get(c.key) 198 | if err != nil { 199 | if !errors.Is(err, memcache.ErrCacheMiss) { 200 | return 201 | } 202 | } else { 203 | casId = item.CasID 204 | b := bytes.NewBuffer(item.Value) 205 | var oldNodes []SortedSetNode 206 | _ = gob.NewDecoder(b).Decode(&oldNodes) 207 | for _, node := range oldNodes { 208 | if node.CreatedAt > now.UnixNano() && node.Value != element { 209 | newNodes = append(newNodes, node) 210 | } 211 | } 212 | } 213 | newNodes = append(newNodes, SortedSetNode{CreatedAt: now.Add(c.ttl).UnixNano(), Value: element}) 214 | var b bytes.Buffer 215 | _ = gob.NewEncoder(&b).Encode(newNodes) 216 | item = &memcache.Item{ 217 | Key: c.key, 218 | Value: b.Bytes(), 219 | CasID: casId, 220 | } 221 | if casId > 0 { 222 | err = c.cli.CompareAndSwap(item) 223 | } else { 224 | err = c.cli.Add(item) 225 | } 226 | }() 227 | 228 | select { 229 | case <-ctx.Done(): 230 | return 0, ctx.Err() 231 | 232 | case <-done: 233 | if err != nil { 234 | if errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss) { 235 | return c.Add(ctx, element) 236 | } 237 | 238 | return 0, errors.Wrap(err, "failed to add in memcached") 239 | } 240 | 241 | return int64(len(newNodes)), nil 242 | } 243 | } 244 | 245 | // Remove removes the request identified by the key from the slice in Memcached. 246 | func (c *ConcurrentBufferMemcached) Remove(ctx context.Context, key string) error { 247 | var err error 248 | now := c.clock.Now() 249 | var newNodes []SortedSetNode 250 | var casID uint64 251 | item, err := c.cli.Get(c.key) 252 | if err != nil { 253 | if errors.Is(err, memcache.ErrCacheMiss) { 254 | return nil 255 | } 256 | 257 | return errors.Wrap(err, "failed to Get") 258 | } 259 | casID = item.CasID 260 | var oldNodes []SortedSetNode 261 | _ = gob.NewDecoder(bytes.NewBuffer(item.Value)).Decode(&oldNodes) 262 | for _, node := range oldNodes { 263 | if node.CreatedAt > now.UnixNano() && node.Value != key { 264 | newNodes = append(newNodes, node) 265 | } 266 | } 267 | 268 | var b bytes.Buffer 269 | _ = gob.NewEncoder(&b).Encode(newNodes) 270 | item = &memcache.Item{ 271 | Key: c.key, 272 | Value: b.Bytes(), 273 | CasID: casID, 274 | } 275 | err = c.cli.CompareAndSwap(item) 276 | if err != nil && (errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss)) { 277 | return c.Remove(ctx, key) 278 | } 279 | 280 | return errors.Wrap(err, "failed to CompareAndSwap") 281 | } 282 | -------------------------------------------------------------------------------- /concurrent_buffer_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/google/uuid" 11 | l "github.com/mennanov/limiters" 12 | ) 13 | 14 | func (s *LimitersTestSuite) concurrentBuffers(capacity int64, ttl time.Duration, clock l.Clock) map[string]*l.ConcurrentBuffer { 15 | buffers := make(map[string]*l.ConcurrentBuffer) 16 | for lockerName, locker := range s.lockers(true) { 17 | for bName, b := range s.concurrentBufferBackends(ttl, clock) { 18 | buffers[lockerName+":"+bName] = l.NewConcurrentBuffer(locker, b, capacity, s.logger) 19 | } 20 | } 21 | 22 | return buffers 23 | } 24 | 25 | func (s *LimitersTestSuite) concurrentBufferBackends(ttl time.Duration, clock l.Clock) map[string]l.ConcurrentBufferBackend { 26 | return map[string]l.ConcurrentBufferBackend{ 27 | "ConcurrentBufferInMemory": l.NewConcurrentBufferInMemory(l.NewRegistry(), ttl, clock), 28 | "ConcurrentBufferRedis": l.NewConcurrentBufferRedis(s.redisClient, uuid.New().String(), ttl, clock), 29 | "ConcurrentBufferRedisCluster": l.NewConcurrentBufferRedis(s.redisClusterClient, uuid.New().String(), ttl, clock), 30 | "ConcurrentBufferMemcached": l.NewConcurrentBufferMemcached(s.memcacheClient, uuid.New().String(), ttl, clock), 31 | } 32 | } 33 | 34 | func (s *LimitersTestSuite) TestConcurrentBufferNoOverflow() { 35 | clock := newFakeClock() 36 | capacity := int64(10) 37 | ttl := time.Second 38 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) { 39 | s.Run(name, func() { 40 | wg := sync.WaitGroup{} 41 | for i := int64(0); i < capacity; i++ { 42 | wg.Add(1) 43 | go func(i int64, buffer *l.ConcurrentBuffer) { 44 | defer wg.Done() 45 | key := fmt.Sprintf("key%d", i) 46 | s.NoError(buffer.Limit(context.TODO(), key)) 47 | s.NoError(buffer.Done(context.TODO(), key)) 48 | }(i, buffer) 49 | } 50 | wg.Wait() 51 | s.NoError(buffer.Limit(context.TODO(), "last")) 52 | s.NoError(buffer.Done(context.TODO(), "last")) 53 | }) 54 | } 55 | } 56 | 57 | func (s *LimitersTestSuite) TestConcurrentBufferOverflow() { 58 | clock := newFakeClock() 59 | capacity := int64(3) 60 | ttl := time.Second 61 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) { 62 | s.Run(name, func() { 63 | mu := sync.Mutex{} 64 | var errors []error 65 | wg := sync.WaitGroup{} 66 | for i := int64(0); i <= capacity; i++ { 67 | wg.Add(1) 68 | go func(i int64, buffer *l.ConcurrentBuffer) { 69 | defer wg.Done() 70 | if err := buffer.Limit(context.TODO(), fmt.Sprintf("key%d", i)); err != nil { 71 | mu.Lock() 72 | errors = append(errors, err) 73 | mu.Unlock() 74 | } 75 | }(i, buffer) 76 | } 77 | wg.Wait() 78 | s.Equal([]error{l.ErrLimitExhausted}, errors) 79 | }) 80 | } 81 | } 82 | 83 | func (s *LimitersTestSuite) TestConcurrentBufferExpiredKeys() { 84 | clock := newFakeClock() 85 | capacity := int64(2) 86 | ttl := time.Second 87 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) { 88 | s.Run(name, func() { 89 | s.Require().NoError(buffer.Limit(context.TODO(), "key1")) 90 | clock.Sleep(ttl / 2) 91 | s.Require().NoError(buffer.Limit(context.TODO(), "key2")) 92 | clock.Sleep(ttl / 2) 93 | // No error is expected (despite the following request overflows the capacity) as the first key has already 94 | // expired by this time. 95 | s.NoError(buffer.Limit(context.TODO(), "key3")) 96 | }) 97 | } 98 | } 99 | 100 | func (s *LimitersTestSuite) TestConcurrentBufferDuplicateKeys() { 101 | clock := newFakeClock() 102 | capacity := int64(2) 103 | ttl := time.Second 104 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) { 105 | s.Run(name, func() { 106 | s.Require().NoError(buffer.Limit(context.TODO(), "key1")) 107 | s.Require().NoError(buffer.Limit(context.TODO(), "key2")) 108 | // No error is expected as it should just update the timestamp of the existing key. 109 | s.NoError(buffer.Limit(context.TODO(), "key1")) 110 | }) 111 | } 112 | } 113 | 114 | func BenchmarkConcurrentBuffers(b *testing.B) { 115 | s := new(LimitersTestSuite) 116 | s.SetT(&testing.T{}) 117 | s.SetupSuite() 118 | capacity := int64(1) 119 | ttl := time.Second 120 | clock := newFakeClock() 121 | buffers := s.concurrentBuffers(capacity, ttl, clock) 122 | for name, buffer := range buffers { 123 | b.Run(name, func(b *testing.B) { 124 | for i := 0; i < b.N; i++ { 125 | s.Require().NoError(buffer.Limit(context.TODO(), "key1")) 126 | } 127 | }) 128 | } 129 | s.TearDownSuite() 130 | } 131 | -------------------------------------------------------------------------------- /cosmos_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 7 | ) 8 | 9 | const ( 10 | testCosmosDBName = "limiters-db-test" 11 | testCosmosContainerName = "limiters-container-test" 12 | ) 13 | 14 | var defaultTTL int32 = 86400 15 | 16 | func CreateCosmosDBContainer(ctx context.Context, client *azcosmos.Client) error { 17 | resp, err := client.CreateDatabase(ctx, azcosmos.DatabaseProperties{ 18 | ID: testCosmosDBName, 19 | }, &azcosmos.CreateDatabaseOptions{}) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | dbClient, err := client.NewDatabase(resp.DatabaseProperties.ID) 25 | if err != nil { 26 | return err 27 | } 28 | 29 | _, err = dbClient.CreateContainer(ctx, azcosmos.ContainerProperties{ 30 | ID: testCosmosContainerName, 31 | DefaultTimeToLive: &defaultTTL, 32 | PartitionKeyDefinition: azcosmos.PartitionKeyDefinition{ 33 | Paths: []string{`/partitionKey`}, 34 | }, 35 | }, &azcosmos.CreateContainerOptions{}) 36 | 37 | return err 38 | } 39 | 40 | func DeleteCosmosDBContainer(ctx context.Context, client *azcosmos.Client) error { 41 | dbClient, err := client.NewDatabase(testCosmosDBName) 42 | if err != nil { 43 | return err 44 | } 45 | 46 | _, err = dbClient.Delete(ctx, &azcosmos.DeleteDatabaseOptions{}) 47 | 48 | return err 49 | } 50 | -------------------------------------------------------------------------------- /cosmosdb.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | type cosmosItem struct { 4 | Count int64 `json:"Count"` 5 | PartitionKey string `json:"partitionKey"` 6 | ID string `json:"id"` 7 | TTL int32 `json:"ttl"` 8 | } 9 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | etcd: 5 | image: bitnami/etcd 6 | environment: 7 | ALLOW_NONE_AUTHENTICATION: "yes" 8 | ports: 9 | - "2379:2379" 10 | 11 | redis: 12 | image: bitnami/redis 13 | environment: 14 | ALLOW_EMPTY_PASSWORD: "yes" 15 | ports: 16 | - "6379:6379" 17 | 18 | redis-cluster: 19 | image: grokzen/redis-cluster:7.0.10 20 | environment: 21 | IP: "0.0.0.0" 22 | INITIAL_PORT: 11000 23 | ports: 24 | - "11000-11005:11000-11005" 25 | 26 | memcached: 27 | image: bitnami/memcached 28 | ports: 29 | - "11211:11211" 30 | 31 | consul: 32 | image: bitnami/consul 33 | ports: 34 | - "8500:8500" 35 | 36 | zookeeper: 37 | image: bitnami/zookeeper 38 | environment: 39 | ALLOW_ANONYMOUS_LOGIN: "yes" 40 | ports: 41 | - "2181:2181" 42 | 43 | dynamodb-local: 44 | image: "amazon/dynamodb-local:latest" 45 | command: "-jar DynamoDBLocal.jar -inMemory" 46 | ports: 47 | - "8000:8000" 48 | 49 | postgresql: 50 | image: bitnami/postgresql 51 | environment: 52 | ALLOW_EMPTY_PASSWORD: yes 53 | ports: 54 | - "5432:5432" 55 | 56 | cosmos: 57 | image: mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:vnext-preview 58 | command: [ "--protocol", "http" ] 59 | environment: 60 | PROTOCOL: http 61 | COSMOS_HTTP_CONNECTION_WITHOUT_TLS_ALLOWED: "true" 62 | ports: 63 | - "8081:8081" 64 | -------------------------------------------------------------------------------- /dynamodb.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/aws/aws-sdk-go-v2/aws" 7 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 8 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" 9 | "github.com/pkg/errors" 10 | ) 11 | 12 | // DynamoDBTableProperties are supplied to DynamoDB limiter backends. 13 | // This struct informs the backend what the name of the table is and what the names of the key fields are. 14 | type DynamoDBTableProperties struct { 15 | // TableName is the name of the table. 16 | TableName string 17 | // PartitionKeyName is the name of the PartitionKey attribute. 18 | PartitionKeyName string 19 | // SortKeyName is the name of the SortKey attribute. 20 | SortKeyName string 21 | // SortKeyUsed indicates if a SortKey is present on the table. 22 | SortKeyUsed bool 23 | // TTLFieldName is the name of the attribute configured for TTL. 24 | TTLFieldName string 25 | } 26 | 27 | // LoadDynamoDBTableProperties fetches a table description with the supplied client and returns a DynamoDBTableProperties struct. 28 | func LoadDynamoDBTableProperties(ctx context.Context, client *dynamodb.Client, tableName string) (DynamoDBTableProperties, error) { 29 | resp, err := client.DescribeTable(ctx, &dynamodb.DescribeTableInput{ 30 | TableName: &tableName, 31 | }) 32 | if err != nil { 33 | return DynamoDBTableProperties{}, errors.Wrap(err, "describe dynamodb table failed") 34 | } 35 | 36 | ttlResp, err := client.DescribeTimeToLive(ctx, &dynamodb.DescribeTimeToLiveInput{ 37 | TableName: &tableName, 38 | }) 39 | if err != nil { 40 | return DynamoDBTableProperties{}, errors.Wrap(err, "describe dynamobd table ttl failed") 41 | } 42 | 43 | data, err := loadTableData(resp.Table, ttlResp.TimeToLiveDescription) 44 | if err != nil { 45 | return data, err 46 | } 47 | 48 | return data, nil 49 | } 50 | 51 | func loadTableData(table *types.TableDescription, ttl *types.TimeToLiveDescription) (DynamoDBTableProperties, error) { 52 | data := DynamoDBTableProperties{ 53 | TableName: *table.TableName, 54 | } 55 | 56 | data, err := loadTableKeys(data, table) 57 | if err != nil { 58 | return data, errors.Wrap(err, "invalid dynamodb table") 59 | } 60 | 61 | return populateTableTTL(data, ttl), nil 62 | } 63 | 64 | func loadTableKeys(data DynamoDBTableProperties, table *types.TableDescription) (DynamoDBTableProperties, error) { 65 | for _, key := range table.KeySchema { 66 | if key.KeyType == types.KeyTypeHash { 67 | data.PartitionKeyName = *key.AttributeName 68 | 69 | continue 70 | } 71 | 72 | data.SortKeyName = *key.AttributeName 73 | data.SortKeyUsed = true 74 | } 75 | 76 | for _, attribute := range table.AttributeDefinitions { 77 | name := *attribute.AttributeName 78 | if name != data.PartitionKeyName && name != data.SortKeyName { 79 | continue 80 | } 81 | 82 | if name == data.PartitionKeyName && attribute.AttributeType != types.ScalarAttributeTypeS { 83 | return data, errors.New("dynamodb partition key must be of type S") 84 | } else if data.SortKeyUsed && name == data.SortKeyName && attribute.AttributeType != types.ScalarAttributeTypeS { 85 | return data, errors.New("dynamodb sort key must be of type S") 86 | } 87 | } 88 | 89 | return data, nil 90 | } 91 | 92 | func populateTableTTL(data DynamoDBTableProperties, ttl *types.TimeToLiveDescription) DynamoDBTableProperties { 93 | if ttl.TimeToLiveStatus != types.TimeToLiveStatusEnabled { 94 | return data 95 | } 96 | 97 | data.TTLFieldName = *ttl.AttributeName 98 | 99 | return data 100 | } 101 | 102 | func dynamoDBputItem(ctx context.Context, client *dynamodb.Client, input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) { 103 | resp, err := client.PutItem(ctx, input) 104 | if err != nil { 105 | var cErr *types.ConditionalCheckFailedException 106 | if errors.As(err, &cErr) { 107 | return nil, ErrRaceCondition 108 | } 109 | 110 | return nil, errors.Wrap(err, "unable to set dynamodb item") 111 | } 112 | 113 | return resp, nil 114 | } 115 | 116 | func dynamoDBGetItem(ctx context.Context, client *dynamodb.Client, input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) { 117 | input.ConsistentRead = aws.Bool(true) 118 | 119 | var resp *dynamodb.GetItemOutput 120 | var err error 121 | 122 | done := make(chan struct{}) 123 | go func() { 124 | defer close(done) 125 | resp, err = client.GetItem(ctx, input) 126 | }() 127 | 128 | select { 129 | case <-done: 130 | case <-ctx.Done(): 131 | return nil, ctx.Err() 132 | } 133 | 134 | if err != nil { 135 | return nil, errors.Wrap(err, "unable to retrieve dynamodb item") 136 | } 137 | 138 | return resp, nil 139 | } 140 | -------------------------------------------------------------------------------- /dynamodb_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/aws/aws-sdk-go-v2/aws" 8 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 9 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" 10 | "github.com/mennanov/limiters" 11 | "github.com/pkg/errors" 12 | ) 13 | 14 | const testDynamoDBTableName = "limiters-test" 15 | 16 | func CreateTestDynamoDBTable(ctx context.Context, client *dynamodb.Client) error { 17 | _, err := client.CreateTable(ctx, &dynamodb.CreateTableInput{ 18 | TableName: aws.String(testDynamoDBTableName), 19 | BillingMode: types.BillingModePayPerRequest, 20 | KeySchema: []types.KeySchemaElement{ 21 | { 22 | AttributeName: aws.String("PK"), 23 | KeyType: types.KeyTypeHash, 24 | }, 25 | { 26 | AttributeName: aws.String("SK"), 27 | KeyType: types.KeyTypeRange, 28 | }, 29 | }, 30 | AttributeDefinitions: []types.AttributeDefinition{ 31 | { 32 | AttributeName: aws.String("PK"), 33 | AttributeType: types.ScalarAttributeTypeS, 34 | }, 35 | { 36 | AttributeName: aws.String("SK"), 37 | AttributeType: types.ScalarAttributeTypeS, 38 | }, 39 | }, 40 | }) 41 | if err != nil { 42 | return errors.Wrap(err, "create test dynamodb table failed") 43 | } 44 | 45 | _, err = client.UpdateTimeToLive(ctx, &dynamodb.UpdateTimeToLiveInput{ 46 | TableName: aws.String(testDynamoDBTableName), 47 | TimeToLiveSpecification: &types.TimeToLiveSpecification{ 48 | AttributeName: aws.String("TTL"), 49 | Enabled: aws.Bool(true), 50 | }, 51 | }) 52 | if err != nil { 53 | return errors.Wrap(err, "set dynamodb ttl failed") 54 | } 55 | 56 | ctx, cancel := context.WithTimeout(ctx, time.Second*10) 57 | defer cancel() 58 | for { 59 | resp, err := client.DescribeTable(ctx, &dynamodb.DescribeTableInput{ 60 | TableName: aws.String(testDynamoDBTableName), 61 | }) 62 | 63 | if err == nil { 64 | return errors.Wrap(err, "failed to describe test table") 65 | } 66 | 67 | if resp.Table.TableStatus == types.TableStatusActive { 68 | return nil 69 | } 70 | 71 | select { 72 | case <-ctx.Done(): 73 | return errors.New("failed to verify dynamodb test table is created") 74 | default: 75 | } 76 | } 77 | } 78 | 79 | func DeleteTestDynamoDBTable(ctx context.Context, client *dynamodb.Client) error { 80 | _, err := client.DeleteTable(ctx, &dynamodb.DeleteTableInput{ 81 | TableName: aws.String(testDynamoDBTableName), 82 | }) 83 | if err != nil { 84 | return errors.Wrap(err, "delete test dynamodb table failed") 85 | } 86 | 87 | return nil 88 | } 89 | 90 | func (s *LimitersTestSuite) TestDynamoRaceCondition() { 91 | backend := limiters.NewLeakyBucketDynamoDB(s.dynamodbClient, "race-check", s.dynamoDBTableProps, time.Minute, true) 92 | 93 | err := backend.SetState(context.Background(), limiters.LeakyBucketState{}) 94 | s.Require().NoError(err) 95 | 96 | _, err = backend.State(context.Background()) 97 | s.Require().NoError(err) 98 | 99 | _, err = s.dynamodbClient.PutItem(context.Background(), &dynamodb.PutItemInput{ 100 | Item: map[string]types.AttributeValue{ 101 | s.dynamoDBTableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: "race-check"}, 102 | s.dynamoDBTableProps.SortKeyName: &types.AttributeValueMemberS{Value: "race-check"}, 103 | "Version": &types.AttributeValueMemberN{Value: "5"}, 104 | }, 105 | TableName: &s.dynamoDBTableProps.TableName, 106 | }) 107 | s.Require().NoError(err) 108 | 109 | err = backend.SetState(context.Background(), limiters.LeakyBucketState{}) 110 | s.Require().ErrorIs(err, limiters.ErrRaceCondition, err) 111 | } 112 | -------------------------------------------------------------------------------- /examples/example_grpc_ip_limiter_test.go: -------------------------------------------------------------------------------- 1 | // Package examples implements a gRPC server for Greeter service using rate limiters. 2 | package examples_test 3 | 4 | import ( 5 | "context" 6 | "errors" 7 | "fmt" 8 | "log" 9 | "net" 10 | "os" 11 | "strings" 12 | "time" 13 | 14 | "github.com/mennanov/limiters" 15 | "github.com/mennanov/limiters/examples" 16 | pb "github.com/mennanov/limiters/examples/helloworld" 17 | "github.com/redis/go-redis/v9" 18 | clientv3 "go.etcd.io/etcd/client/v3" 19 | "google.golang.org/grpc" 20 | "google.golang.org/grpc/codes" 21 | "google.golang.org/grpc/credentials/insecure" 22 | "google.golang.org/grpc/peer" 23 | "google.golang.org/grpc/status" 24 | ) 25 | 26 | func Example_ipGRPCLimiter() { 27 | // Set up a gRPC server. 28 | lis, err := net.Listen("tcp", examples.Port) 29 | if err != nil { 30 | log.Fatalf("failed to listen: %v", err) 31 | } 32 | defer lis.Close() 33 | // Connect to etcd. 34 | etcdClient, err := clientv3.New(clientv3.Config{ 35 | Endpoints: strings.Split(os.Getenv("ETCD_ENDPOINTS"), ","), 36 | DialTimeout: time.Second, 37 | }) 38 | if err != nil { 39 | log.Fatalf("could not connect to etcd: %v", err) 40 | } 41 | defer etcdClient.Close() 42 | // Connect to Redis. 43 | redisClient := redis.NewClient(&redis.Options{ 44 | Addr: os.Getenv("REDIS_ADDR"), 45 | }) 46 | defer redisClient.Close() 47 | logger := limiters.NewStdLogger() 48 | // Registry is needed to keep track of previously created limiters. It can remove the expired limiters to free up 49 | // memory. 50 | registry := limiters.NewRegistry() 51 | // The rate is used to define the token bucket refill rate and also the TTL for the limiters (both in Redis and in 52 | // the registry). 53 | rate := time.Second * 3 54 | clock := limiters.NewSystemClock() 55 | go func() { 56 | // Garbage collect the old limiters to prevent memory leaks. 57 | for { 58 | <-time.After(rate) 59 | registry.DeleteExpired(clock.Now()) 60 | } 61 | }() 62 | 63 | // Add a unary interceptor middleware to rate limit requests. 64 | s := grpc.NewServer(grpc.UnaryInterceptor( 65 | func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 66 | p, ok := peer.FromContext(ctx) 67 | var ip string 68 | if !ok { 69 | log.Println("no peer info available") 70 | ip = "unknown" 71 | } else { 72 | ip = p.Addr.String() 73 | } 74 | 75 | // Create an IP address based rate limiter. 76 | bucket := registry.GetOrCreate(ip, func() interface{} { 77 | return limiters.NewTokenBucket( 78 | 2, 79 | rate, 80 | limiters.NewLockEtcd(etcdClient, fmt.Sprintf("/lock/ip/%s", ip), logger), 81 | limiters.NewTokenBucketRedis( 82 | redisClient, 83 | fmt.Sprintf("/ratelimiter/ip/%s", ip), 84 | rate, false), 85 | clock, logger) 86 | }, rate, clock.Now()) 87 | w, err := bucket.(*limiters.TokenBucket).Limit(ctx) 88 | if errors.Is(err, limiters.ErrLimitExhausted) { 89 | return nil, status.Errorf(codes.ResourceExhausted, "try again later in %s", w) 90 | } else if err != nil { 91 | // The limiter failed. This error should be logged and examined. 92 | log.Println(err) 93 | 94 | return nil, status.Error(codes.Internal, "internal error") 95 | } 96 | 97 | return handler(ctx, req) 98 | })) 99 | 100 | pb.RegisterGreeterServer(s, &examples.Server{}) 101 | go func() { 102 | // Start serving. 103 | if err = s.Serve(lis); err != nil { 104 | log.Fatalf("failed to serve: %v", err) 105 | } 106 | }() 107 | defer s.GracefulStop() 108 | 109 | // Set up a client connection to the server. 110 | conn, err := grpc.NewClient(fmt.Sprintf("localhost%s", examples.Port), grpc.WithTransportCredentials(insecure.NewCredentials())) 111 | if err != nil { 112 | log.Fatalf("did not connect: %v", err) 113 | } 114 | defer conn.Close() 115 | c := pb.NewGreeterClient(conn) 116 | 117 | // Contact the server and print out its response. 118 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 119 | defer cancel() 120 | r, err := c.SayHello(ctx, &pb.HelloRequest{Name: "Alice"}) 121 | if err != nil { 122 | log.Fatalf("could not greet: %v", err) 123 | } 124 | fmt.Println(r.GetMessage()) 125 | r, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Bob"}) 126 | if err != nil { 127 | log.Fatalf("could not greet: %v", err) 128 | } 129 | fmt.Println(r.GetMessage()) 130 | _, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Peter"}) 131 | if err == nil { 132 | log.Fatal("error expected, but got nil") 133 | } 134 | fmt.Println(err) 135 | // Output: Hello Alice 136 | // Hello Bob 137 | // rpc error: code = ResourceExhausted desc = try again later in 3s 138 | } 139 | -------------------------------------------------------------------------------- /examples/example_grpc_simple_limiter_test.go: -------------------------------------------------------------------------------- 1 | // Package examples implements a gRPC server for Greeter service using rate limiters. 2 | package examples_test 3 | 4 | import ( 5 | "context" 6 | "errors" 7 | "fmt" 8 | "log" 9 | "net" 10 | "os" 11 | "strings" 12 | "time" 13 | 14 | "github.com/mennanov/limiters" 15 | "github.com/mennanov/limiters/examples" 16 | pb "github.com/mennanov/limiters/examples/helloworld" 17 | "github.com/redis/go-redis/v9" 18 | clientv3 "go.etcd.io/etcd/client/v3" 19 | "google.golang.org/grpc" 20 | "google.golang.org/grpc/codes" 21 | "google.golang.org/grpc/credentials/insecure" 22 | "google.golang.org/grpc/status" 23 | ) 24 | 25 | func Example_simpleGRPCLimiter() { 26 | // Set up a gRPC server. 27 | lis, err := net.Listen("tcp", examples.Port) 28 | if err != nil { 29 | log.Fatalf("failed to listen: %v", err) 30 | } 31 | defer lis.Close() 32 | // Connect to etcd. 33 | etcdClient, err := clientv3.New(clientv3.Config{ 34 | Endpoints: strings.Split(os.Getenv("ETCD_ENDPOINTS"), ","), 35 | DialTimeout: time.Second, 36 | }) 37 | if err != nil { 38 | log.Fatalf("could not connect to etcd: %v", err) 39 | } 40 | defer etcdClient.Close() 41 | // Connect to Redis. 42 | redisClient := redis.NewClient(&redis.Options{ 43 | Addr: os.Getenv("REDIS_ADDR"), 44 | }) 45 | defer redisClient.Close() 46 | 47 | rate := time.Second * 3 48 | limiter := limiters.NewTokenBucket( 49 | 2, 50 | rate, 51 | limiters.NewLockEtcd(etcdClient, "/ratelimiter_lock/simple/", limiters.NewStdLogger()), 52 | limiters.NewTokenBucketRedis( 53 | redisClient, 54 | "ratelimiter/simple", 55 | rate, false), 56 | limiters.NewSystemClock(), limiters.NewStdLogger(), 57 | ) 58 | 59 | // Add a unary interceptor middleware to rate limit all requests. 60 | s := grpc.NewServer(grpc.UnaryInterceptor( 61 | func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 62 | w, err := limiter.Limit(ctx) 63 | if errors.Is(err, limiters.ErrLimitExhausted) { 64 | return nil, status.Errorf(codes.ResourceExhausted, "try again later in %s", w) 65 | } else if err != nil { 66 | // The limiter failed. This error should be logged and examined. 67 | log.Println(err) 68 | 69 | return nil, status.Error(codes.Internal, "internal error") 70 | } 71 | 72 | return handler(ctx, req) 73 | })) 74 | 75 | pb.RegisterGreeterServer(s, &examples.Server{}) 76 | go func() { 77 | // Start serving. 78 | if err = s.Serve(lis); err != nil { 79 | log.Fatalf("failed to serve: %v", err) 80 | } 81 | }() 82 | defer s.GracefulStop() 83 | 84 | // Set up a client connection to the server. 85 | conn, err := grpc.NewClient(fmt.Sprintf("localhost%s", examples.Port), grpc.WithTransportCredentials(insecure.NewCredentials())) 86 | if err != nil { 87 | log.Fatalf("did not connect: %v", err) 88 | } 89 | defer conn.Close() 90 | c := pb.NewGreeterClient(conn) 91 | 92 | // Contact the server and print out its response. 93 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 94 | defer cancel() 95 | r, err := c.SayHello(ctx, &pb.HelloRequest{Name: "Alice"}) 96 | if err != nil { 97 | log.Fatalf("could not greet: %v", err) 98 | } 99 | fmt.Println(r.GetMessage()) 100 | r, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Bob"}) 101 | if err != nil { 102 | log.Fatalf("could not greet: %v", err) 103 | } 104 | fmt.Println(r.GetMessage()) 105 | _, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Peter"}) 106 | if err == nil { 107 | log.Fatal("error expected, but got nil") 108 | } 109 | fmt.Println(err) 110 | // Output: Hello Alice 111 | // Hello Bob 112 | // rpc error: code = ResourceExhausted desc = try again later in 3s 113 | } 114 | -------------------------------------------------------------------------------- /examples/gprc_service.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "context" 5 | 6 | pb "github.com/mennanov/limiters/examples/helloworld" 7 | ) 8 | 9 | const ( 10 | Port = ":50051" 11 | ) 12 | 13 | // server is used to implement helloworld.GreeterServer. 14 | type Server struct { 15 | pb.UnimplementedGreeterServer 16 | } 17 | 18 | // SayHello implements helloworld.GreeterServer. 19 | func (s *Server) SayHello(_ context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { 20 | return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil 21 | } 22 | 23 | var _ pb.GreeterServer = new(Server) 24 | -------------------------------------------------------------------------------- /examples/helloworld/helloworld.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.35.1 4 | // protoc v5.28.3 5 | // source: helloworld.proto 6 | 7 | package helloworld 8 | 9 | import ( 10 | reflect "reflect" 11 | sync "sync" 12 | 13 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 14 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 15 | ) 16 | 17 | const ( 18 | // Verify that this generated code is sufficiently up-to-date. 19 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 20 | // Verify that runtime/protoimpl is sufficiently up-to-date. 21 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 22 | ) 23 | 24 | // The request message containing the user's name. 25 | type HelloRequest struct { 26 | state protoimpl.MessageState 27 | sizeCache protoimpl.SizeCache 28 | unknownFields protoimpl.UnknownFields 29 | 30 | Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` 31 | } 32 | 33 | func (x *HelloRequest) Reset() { 34 | *x = HelloRequest{} 35 | mi := &file_helloworld_proto_msgTypes[0] 36 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 37 | ms.StoreMessageInfo(mi) 38 | } 39 | 40 | func (x *HelloRequest) String() string { 41 | return protoimpl.X.MessageStringOf(x) 42 | } 43 | 44 | func (*HelloRequest) ProtoMessage() {} 45 | 46 | func (x *HelloRequest) ProtoReflect() protoreflect.Message { 47 | mi := &file_helloworld_proto_msgTypes[0] 48 | if x != nil { 49 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 50 | if ms.LoadMessageInfo() == nil { 51 | ms.StoreMessageInfo(mi) 52 | } 53 | return ms 54 | } 55 | return mi.MessageOf(x) 56 | } 57 | 58 | // Deprecated: Use HelloRequest.ProtoReflect.Descriptor instead. 59 | func (*HelloRequest) Descriptor() ([]byte, []int) { 60 | return file_helloworld_proto_rawDescGZIP(), []int{0} 61 | } 62 | 63 | func (x *HelloRequest) GetName() string { 64 | if x != nil { 65 | return x.Name 66 | } 67 | return "" 68 | } 69 | 70 | // The response message containing the greetings 71 | type HelloReply struct { 72 | state protoimpl.MessageState 73 | sizeCache protoimpl.SizeCache 74 | unknownFields protoimpl.UnknownFields 75 | 76 | Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` 77 | } 78 | 79 | func (x *HelloReply) Reset() { 80 | *x = HelloReply{} 81 | mi := &file_helloworld_proto_msgTypes[1] 82 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 83 | ms.StoreMessageInfo(mi) 84 | } 85 | 86 | func (x *HelloReply) String() string { 87 | return protoimpl.X.MessageStringOf(x) 88 | } 89 | 90 | func (*HelloReply) ProtoMessage() {} 91 | 92 | func (x *HelloReply) ProtoReflect() protoreflect.Message { 93 | mi := &file_helloworld_proto_msgTypes[1] 94 | if x != nil { 95 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 96 | if ms.LoadMessageInfo() == nil { 97 | ms.StoreMessageInfo(mi) 98 | } 99 | return ms 100 | } 101 | return mi.MessageOf(x) 102 | } 103 | 104 | // Deprecated: Use HelloReply.ProtoReflect.Descriptor instead. 105 | func (*HelloReply) Descriptor() ([]byte, []int) { 106 | return file_helloworld_proto_rawDescGZIP(), []int{1} 107 | } 108 | 109 | func (x *HelloReply) GetMessage() string { 110 | if x != nil { 111 | return x.Message 112 | } 113 | return "" 114 | } 115 | 116 | type GoodbyeRequest struct { 117 | state protoimpl.MessageState 118 | sizeCache protoimpl.SizeCache 119 | unknownFields protoimpl.UnknownFields 120 | 121 | Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` 122 | } 123 | 124 | func (x *GoodbyeRequest) Reset() { 125 | *x = GoodbyeRequest{} 126 | mi := &file_helloworld_proto_msgTypes[2] 127 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 128 | ms.StoreMessageInfo(mi) 129 | } 130 | 131 | func (x *GoodbyeRequest) String() string { 132 | return protoimpl.X.MessageStringOf(x) 133 | } 134 | 135 | func (*GoodbyeRequest) ProtoMessage() {} 136 | 137 | func (x *GoodbyeRequest) ProtoReflect() protoreflect.Message { 138 | mi := &file_helloworld_proto_msgTypes[2] 139 | if x != nil { 140 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 141 | if ms.LoadMessageInfo() == nil { 142 | ms.StoreMessageInfo(mi) 143 | } 144 | return ms 145 | } 146 | return mi.MessageOf(x) 147 | } 148 | 149 | // Deprecated: Use GoodbyeRequest.ProtoReflect.Descriptor instead. 150 | func (*GoodbyeRequest) Descriptor() ([]byte, []int) { 151 | return file_helloworld_proto_rawDescGZIP(), []int{2} 152 | } 153 | 154 | func (x *GoodbyeRequest) GetName() string { 155 | if x != nil { 156 | return x.Name 157 | } 158 | return "" 159 | } 160 | 161 | type GoodbyeReply struct { 162 | state protoimpl.MessageState 163 | sizeCache protoimpl.SizeCache 164 | unknownFields protoimpl.UnknownFields 165 | 166 | Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` 167 | } 168 | 169 | func (x *GoodbyeReply) Reset() { 170 | *x = GoodbyeReply{} 171 | mi := &file_helloworld_proto_msgTypes[3] 172 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 173 | ms.StoreMessageInfo(mi) 174 | } 175 | 176 | func (x *GoodbyeReply) String() string { 177 | return protoimpl.X.MessageStringOf(x) 178 | } 179 | 180 | func (*GoodbyeReply) ProtoMessage() {} 181 | 182 | func (x *GoodbyeReply) ProtoReflect() protoreflect.Message { 183 | mi := &file_helloworld_proto_msgTypes[3] 184 | if x != nil { 185 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 186 | if ms.LoadMessageInfo() == nil { 187 | ms.StoreMessageInfo(mi) 188 | } 189 | return ms 190 | } 191 | return mi.MessageOf(x) 192 | } 193 | 194 | // Deprecated: Use GoodbyeReply.ProtoReflect.Descriptor instead. 195 | func (*GoodbyeReply) Descriptor() ([]byte, []int) { 196 | return file_helloworld_proto_rawDescGZIP(), []int{3} 197 | } 198 | 199 | func (x *GoodbyeReply) GetMessage() string { 200 | if x != nil { 201 | return x.Message 202 | } 203 | return "" 204 | } 205 | 206 | var File_helloworld_proto protoreflect.FileDescriptor 207 | 208 | var file_helloworld_proto_rawDesc = []byte{ 209 | 0x0a, 0x10, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x70, 0x72, 0x6f, 210 | 0x74, 0x6f, 0x12, 0x0a, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x22, 0x22, 211 | 0x0a, 0x0c, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 212 | 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 213 | 0x6d, 0x65, 0x22, 0x26, 0x0a, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x70, 0x6c, 0x79, 214 | 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 215 | 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x24, 0x0a, 0x0e, 0x47, 0x6f, 216 | 0x6f, 0x64, 0x62, 0x79, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 217 | 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 218 | 0x22, 0x28, 0x0a, 0x0c, 0x47, 0x6f, 0x6f, 0x64, 0x62, 0x79, 0x65, 0x52, 0x65, 0x70, 0x6c, 0x79, 219 | 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 220 | 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0x49, 0x0a, 0x07, 0x47, 0x72, 221 | 0x65, 0x65, 0x74, 0x65, 0x72, 0x12, 0x3e, 0x0a, 0x08, 0x53, 0x61, 0x79, 0x48, 0x65, 0x6c, 0x6c, 222 | 0x6f, 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 223 | 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x68, 0x65, 224 | 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 225 | 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 226 | 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x65, 0x6e, 0x6e, 0x61, 0x6e, 0x6f, 0x76, 0x2f, 0x6c, 0x69, 0x6d, 227 | 0x69, 0x74, 0x65, 0x72, 0x73, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x68, 228 | 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 229 | 0x33, 230 | } 231 | 232 | var ( 233 | file_helloworld_proto_rawDescOnce sync.Once 234 | file_helloworld_proto_rawDescData = file_helloworld_proto_rawDesc 235 | ) 236 | 237 | func file_helloworld_proto_rawDescGZIP() []byte { 238 | file_helloworld_proto_rawDescOnce.Do(func() { 239 | file_helloworld_proto_rawDescData = protoimpl.X.CompressGZIP(file_helloworld_proto_rawDescData) 240 | }) 241 | return file_helloworld_proto_rawDescData 242 | } 243 | 244 | var file_helloworld_proto_msgTypes = make([]protoimpl.MessageInfo, 4) 245 | var file_helloworld_proto_goTypes = []any{ 246 | (*HelloRequest)(nil), // 0: helloworld.HelloRequest 247 | (*HelloReply)(nil), // 1: helloworld.HelloReply 248 | (*GoodbyeRequest)(nil), // 2: helloworld.GoodbyeRequest 249 | (*GoodbyeReply)(nil), // 3: helloworld.GoodbyeReply 250 | } 251 | var file_helloworld_proto_depIdxs = []int32{ 252 | 0, // 0: helloworld.Greeter.SayHello:input_type -> helloworld.HelloRequest 253 | 1, // 1: helloworld.Greeter.SayHello:output_type -> helloworld.HelloReply 254 | 1, // [1:2] is the sub-list for method output_type 255 | 0, // [0:1] is the sub-list for method input_type 256 | 0, // [0:0] is the sub-list for extension type_name 257 | 0, // [0:0] is the sub-list for extension extendee 258 | 0, // [0:0] is the sub-list for field type_name 259 | } 260 | 261 | func init() { file_helloworld_proto_init() } 262 | func file_helloworld_proto_init() { 263 | if File_helloworld_proto != nil { 264 | return 265 | } 266 | type x struct{} 267 | out := protoimpl.TypeBuilder{ 268 | File: protoimpl.DescBuilder{ 269 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 270 | RawDescriptor: file_helloworld_proto_rawDesc, 271 | NumEnums: 0, 272 | NumMessages: 4, 273 | NumExtensions: 0, 274 | NumServices: 1, 275 | }, 276 | GoTypes: file_helloworld_proto_goTypes, 277 | DependencyIndexes: file_helloworld_proto_depIdxs, 278 | MessageInfos: file_helloworld_proto_msgTypes, 279 | }.Build() 280 | File_helloworld_proto = out.File 281 | file_helloworld_proto_rawDesc = nil 282 | file_helloworld_proto_goTypes = nil 283 | file_helloworld_proto_depIdxs = nil 284 | } 285 | -------------------------------------------------------------------------------- /examples/helloworld/helloworld.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package helloworld; 4 | option go_package = "github.com/mennanov/limiters/examples/helloworld"; 5 | 6 | // The greeting service definition. 7 | service Greeter { 8 | // Sends a greeting 9 | rpc SayHello (HelloRequest) returns (HelloReply) { 10 | } 11 | } 12 | 13 | // The request message containing the user's name. 14 | message HelloRequest { 15 | string name = 1; 16 | } 17 | 18 | // The response message containing the greetings 19 | message HelloReply { 20 | string message = 1; 21 | } 22 | 23 | message GoodbyeRequest { 24 | string name = 1; 25 | } 26 | 27 | message GoodbyeReply { 28 | string message = 1; 29 | } -------------------------------------------------------------------------------- /examples/helloworld/helloworld_grpc.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go-grpc. DO NOT EDIT. 2 | // versions: 3 | // - protoc-gen-go-grpc v1.5.1 4 | // - protoc v5.28.3 5 | // source: helloworld.proto 6 | 7 | package helloworld 8 | 9 | import ( 10 | context "context" 11 | 12 | grpc "google.golang.org/grpc" 13 | codes "google.golang.org/grpc/codes" 14 | status "google.golang.org/grpc/status" 15 | ) 16 | 17 | // This is a compile-time assertion to ensure that this generated file 18 | // is compatible with the grpc package it is being compiled against. 19 | // Requires gRPC-Go v1.64.0 or later. 20 | const _ = grpc.SupportPackageIsVersion9 21 | 22 | const ( 23 | Greeter_SayHello_FullMethodName = "/helloworld.Greeter/SayHello" 24 | ) 25 | 26 | // GreeterClient is the client API for Greeter service. 27 | // 28 | // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. 29 | // 30 | // The greeting service definition. 31 | type GreeterClient interface { 32 | // Sends a greeting 33 | SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) 34 | } 35 | 36 | type greeterClient struct { 37 | cc grpc.ClientConnInterface 38 | } 39 | 40 | func NewGreeterClient(cc grpc.ClientConnInterface) GreeterClient { 41 | return &greeterClient{cc} 42 | } 43 | 44 | func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) { 45 | cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) 46 | out := new(HelloReply) 47 | err := c.cc.Invoke(ctx, Greeter_SayHello_FullMethodName, in, out, cOpts...) 48 | if err != nil { 49 | return nil, err 50 | } 51 | return out, nil 52 | } 53 | 54 | // GreeterServer is the server API for Greeter service. 55 | // All implementations must embed UnimplementedGreeterServer 56 | // for forward compatibility. 57 | // 58 | // The greeting service definition. 59 | type GreeterServer interface { 60 | // Sends a greeting 61 | SayHello(context.Context, *HelloRequest) (*HelloReply, error) 62 | mustEmbedUnimplementedGreeterServer() 63 | } 64 | 65 | // UnimplementedGreeterServer must be embedded to have 66 | // forward compatible implementations. 67 | // 68 | // NOTE: this should be embedded by value instead of pointer to avoid a nil 69 | // pointer dereference when methods are called. 70 | type UnimplementedGreeterServer struct{} 71 | 72 | func (UnimplementedGreeterServer) SayHello(context.Context, *HelloRequest) (*HelloReply, error) { 73 | return nil, status.Errorf(codes.Unimplemented, "method SayHello not implemented") 74 | } 75 | func (UnimplementedGreeterServer) mustEmbedUnimplementedGreeterServer() {} 76 | func (UnimplementedGreeterServer) testEmbeddedByValue() {} 77 | 78 | // UnsafeGreeterServer may be embedded to opt out of forward compatibility for this service. 79 | // Use of this interface is not recommended, as added methods to GreeterServer will 80 | // result in compilation errors. 81 | type UnsafeGreeterServer interface { 82 | mustEmbedUnimplementedGreeterServer() 83 | } 84 | 85 | func RegisterGreeterServer(s grpc.ServiceRegistrar, srv GreeterServer) { 86 | // If the following call pancis, it indicates UnimplementedGreeterServer was 87 | // embedded by pointer and is nil. This will cause panics if an 88 | // unimplemented method is ever invoked, so we test this at initialization 89 | // time to prevent it from happening at runtime later due to I/O. 90 | if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { 91 | t.testEmbeddedByValue() 92 | } 93 | s.RegisterService(&Greeter_ServiceDesc, srv) 94 | } 95 | 96 | func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 97 | in := new(HelloRequest) 98 | if err := dec(in); err != nil { 99 | return nil, err 100 | } 101 | if interceptor == nil { 102 | return srv.(GreeterServer).SayHello(ctx, in) 103 | } 104 | info := &grpc.UnaryServerInfo{ 105 | Server: srv, 106 | FullMethod: Greeter_SayHello_FullMethodName, 107 | } 108 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 109 | return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest)) 110 | } 111 | return interceptor(ctx, in, info, handler) 112 | } 113 | 114 | // Greeter_ServiceDesc is the grpc.ServiceDesc for Greeter service. 115 | // It's only intended for direct use with grpc.RegisterService, 116 | // and not to be introspected or modified (even as a copy) 117 | var Greeter_ServiceDesc = grpc.ServiceDesc{ 118 | ServiceName: "helloworld.Greeter", 119 | HandlerType: (*GreeterServer)(nil), 120 | Methods: []grpc.MethodDesc{ 121 | { 122 | MethodName: "SayHello", 123 | Handler: _Greeter_SayHello_Handler, 124 | }, 125 | }, 126 | Streams: []grpc.StreamDesc{}, 127 | Metadata: "helloworld.proto", 128 | } 129 | -------------------------------------------------------------------------------- /fixedwindow.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "strconv" 10 | "sync" 11 | "time" 12 | 13 | "github.com/Azure/azure-sdk-for-go/sdk/azcore" 14 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 15 | "github.com/aws/aws-sdk-go-v2/aws" 16 | "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" 17 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 18 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" 19 | "github.com/bradfitz/gomemcache/memcache" 20 | "github.com/pkg/errors" 21 | "github.com/redis/go-redis/v9" 22 | ) 23 | 24 | // FixedWindowIncrementer wraps the Increment method. 25 | type FixedWindowIncrementer interface { 26 | // Increment increments the request counter for the window and returns the counter value. 27 | // TTL is the time duration before the next window. 28 | Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) 29 | } 30 | 31 | // FixedWindow implements a Fixed Window rate limiting algorithm. 32 | // 33 | // Simple and memory efficient algorithm that does not need a distributed lock. 34 | // However it may be lenient when there are many requests around the boundary between 2 adjacent windows. 35 | type FixedWindow struct { 36 | backend FixedWindowIncrementer 37 | clock Clock 38 | rate time.Duration 39 | capacity int64 40 | mu sync.Mutex 41 | window time.Time 42 | overflow bool 43 | } 44 | 45 | // NewFixedWindow creates a new instance of FixedWindow. 46 | // Capacity is the maximum amount of requests allowed per window. 47 | // Rate is the window size. 48 | func NewFixedWindow(capacity int64, rate time.Duration, fixedWindowIncrementer FixedWindowIncrementer, clock Clock) *FixedWindow { 49 | return &FixedWindow{backend: fixedWindowIncrementer, clock: clock, rate: rate, capacity: capacity} 50 | } 51 | 52 | // Limit returns the time duration to wait before the request can be processed. 53 | // It returns ErrLimitExhausted if the request overflows the window's capacity. 54 | func (f *FixedWindow) Limit(ctx context.Context) (time.Duration, error) { 55 | f.mu.Lock() 56 | defer f.mu.Unlock() 57 | now := f.clock.Now() 58 | window := now.Truncate(f.rate) 59 | if f.window != window { 60 | f.window = window 61 | f.overflow = false 62 | } 63 | ttl := f.rate - now.Sub(window) 64 | if f.overflow { 65 | // If the window is already overflowed don't increment the counter. 66 | return ttl, ErrLimitExhausted 67 | } 68 | c, err := f.backend.Increment(ctx, window, ttl) 69 | if err != nil { 70 | return 0, err 71 | } 72 | if c > f.capacity { 73 | f.overflow = true 74 | 75 | return ttl, ErrLimitExhausted 76 | } 77 | 78 | return 0, nil 79 | } 80 | 81 | // FixedWindowInMemory is an in-memory implementation of FixedWindowIncrementer. 82 | type FixedWindowInMemory struct { 83 | mu sync.Mutex 84 | c int64 85 | window time.Time 86 | } 87 | 88 | // NewFixedWindowInMemory creates a new instance of FixedWindowInMemory. 89 | func NewFixedWindowInMemory() *FixedWindowInMemory { 90 | return &FixedWindowInMemory{} 91 | } 92 | 93 | // Increment increments the window's counter. 94 | func (f *FixedWindowInMemory) Increment(ctx context.Context, window time.Time, _ time.Duration) (int64, error) { 95 | f.mu.Lock() 96 | defer f.mu.Unlock() 97 | if window != f.window { 98 | f.c = 0 99 | f.window = window 100 | } 101 | f.c++ 102 | 103 | return f.c, ctx.Err() 104 | } 105 | 106 | // FixedWindowRedis implements FixedWindow in Redis. 107 | type FixedWindowRedis struct { 108 | cli redis.UniversalClient 109 | prefix string 110 | } 111 | 112 | // NewFixedWindowRedis returns a new instance of FixedWindowRedis. 113 | // Prefix is the key prefix used to store all the keys used in this implementation in Redis. 114 | func NewFixedWindowRedis(cli redis.UniversalClient, prefix string) *FixedWindowRedis { 115 | return &FixedWindowRedis{cli: cli, prefix: prefix} 116 | } 117 | 118 | // Increment increments the window's counter in Redis. 119 | func (f *FixedWindowRedis) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) { 120 | var incr *redis.IntCmd 121 | var err error 122 | done := make(chan struct{}) 123 | go func() { 124 | defer close(done) 125 | _, err = f.cli.Pipelined(ctx, func(pipeliner redis.Pipeliner) error { 126 | key := fmt.Sprintf("%d", window.UnixNano()) 127 | incr = pipeliner.Incr(ctx, redisKey(f.prefix, key)) 128 | pipeliner.PExpire(ctx, redisKey(f.prefix, key), ttl) 129 | 130 | return nil 131 | }) 132 | }() 133 | 134 | select { 135 | case <-done: 136 | if err != nil { 137 | return 0, errors.Wrap(err, "redis transaction failed") 138 | } 139 | 140 | return incr.Val(), incr.Err() 141 | case <-ctx.Done(): 142 | return 0, ctx.Err() 143 | } 144 | } 145 | 146 | // FixedWindowMemcached implements FixedWindow in Memcached. 147 | type FixedWindowMemcached struct { 148 | cli *memcache.Client 149 | prefix string 150 | } 151 | 152 | // NewFixedWindowMemcached returns a new instance of FixedWindowMemcached. 153 | // Prefix is the key prefix used to store all the keys used in this implementation in Memcached. 154 | func NewFixedWindowMemcached(cli *memcache.Client, prefix string) *FixedWindowMemcached { 155 | return &FixedWindowMemcached{cli: cli, prefix: prefix} 156 | } 157 | 158 | // Increment increments the window's counter in Memcached. 159 | func (f *FixedWindowMemcached) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) { 160 | var newValue uint64 161 | var err error 162 | done := make(chan struct{}) 163 | go func() { 164 | defer close(done) 165 | key := fmt.Sprintf("%s:%d", f.prefix, window.UnixNano()) 166 | newValue, err = f.cli.Increment(key, 1) 167 | if err != nil && errors.Is(err, memcache.ErrCacheMiss) { 168 | newValue = 1 169 | item := &memcache.Item{ 170 | Key: key, 171 | Value: []byte(strconv.FormatUint(newValue, 10)), 172 | } 173 | err = f.cli.Add(item) 174 | } 175 | }() 176 | 177 | select { 178 | case <-done: 179 | if err != nil { 180 | if errors.Is(err, memcache.ErrNotStored) { 181 | return f.Increment(ctx, window, ttl) 182 | } 183 | 184 | return 0, errors.Wrap(err, "failed to Increment or Add") 185 | } 186 | 187 | return int64(newValue), err 188 | case <-ctx.Done(): 189 | return 0, ctx.Err() 190 | } 191 | } 192 | 193 | // FixedWindowDynamoDB implements FixedWindow in DynamoDB. 194 | type FixedWindowDynamoDB struct { 195 | client *dynamodb.Client 196 | partitionKey string 197 | tableProps DynamoDBTableProperties 198 | } 199 | 200 | // NewFixedWindowDynamoDB creates a new instance of FixedWindowDynamoDB. 201 | // PartitionKey is the key used to store all the this implementation in DynamoDB. 202 | // 203 | // TableProps describe the table that this backend should work with. This backend requires the following on the table: 204 | // * SortKey 205 | // * TTL. 206 | func NewFixedWindowDynamoDB(client *dynamodb.Client, partitionKey string, props DynamoDBTableProperties) *FixedWindowDynamoDB { 207 | return &FixedWindowDynamoDB{ 208 | client: client, 209 | partitionKey: partitionKey, 210 | tableProps: props, 211 | } 212 | } 213 | 214 | type contextKey int 215 | 216 | var fixedWindowDynamoDBPartitionKey contextKey 217 | 218 | // NewFixedWindowDynamoDBContext creates a context for FixedWindowDynamoDB with a partition key. 219 | // 220 | // This context can be used to control the partition key per-request. 221 | // 222 | // DEPRECATED: NewFixedWindowDynamoDBContext is deprecated and will be removed in future versions. 223 | // Separate FixedWindow rate limiters should be used for different partition keys instead. 224 | // Consider using the `Registry` to manage multiple FixedWindow instances with different partition keys. 225 | func NewFixedWindowDynamoDBContext(ctx context.Context, partitionKey string) context.Context { 226 | log.Printf("DEPRECATED: NewFixedWindowDynamoDBContext is deprecated and will be removed in future versions.") 227 | 228 | return context.WithValue(ctx, fixedWindowDynamoDBPartitionKey, partitionKey) 229 | } 230 | 231 | const ( 232 | fixedWindowDynamoDBUpdateExpression = "SET #C = if_not_exists(#C, :def) + :inc, #TTL = :ttl" 233 | dynamodbWindowCountKey = "Count" 234 | ) 235 | 236 | // Increment increments the window's counter in DynamoDB. 237 | func (f *FixedWindowDynamoDB) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) { 238 | var resp *dynamodb.UpdateItemOutput 239 | var err error 240 | 241 | done := make(chan struct{}) 242 | go func() { 243 | defer close(done) 244 | partitionKey := f.partitionKey 245 | if key, ok := ctx.Value(fixedWindowDynamoDBPartitionKey).(string); ok { 246 | partitionKey = key 247 | } 248 | resp, err = f.client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ 249 | Key: map[string]types.AttributeValue{ 250 | f.tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: partitionKey}, 251 | f.tableProps.SortKeyName: &types.AttributeValueMemberS{Value: strconv.FormatInt(window.UnixNano(), 10)}, 252 | }, 253 | UpdateExpression: aws.String(fixedWindowDynamoDBUpdateExpression), 254 | ExpressionAttributeNames: map[string]string{ 255 | "#TTL": f.tableProps.TTLFieldName, 256 | "#C": dynamodbWindowCountKey, 257 | }, 258 | ExpressionAttributeValues: map[string]types.AttributeValue{ 259 | ":ttl": &types.AttributeValueMemberN{Value: strconv.FormatInt(time.Now().Add(ttl).Unix(), 10)}, 260 | ":def": &types.AttributeValueMemberN{Value: "0"}, 261 | ":inc": &types.AttributeValueMemberN{Value: "1"}, 262 | }, 263 | TableName: &f.tableProps.TableName, 264 | ReturnValues: types.ReturnValueAllNew, 265 | }) 266 | }() 267 | 268 | select { 269 | case <-done: 270 | case <-ctx.Done(): 271 | return 0, ctx.Err() 272 | } 273 | 274 | if err != nil { 275 | return 0, errors.Wrap(err, "dynamodb update item failed") 276 | } 277 | 278 | var count float64 279 | err = attributevalue.Unmarshal(resp.Attributes[dynamodbWindowCountKey], &count) 280 | if err != nil { 281 | return 0, errors.Wrap(err, "unmarshal of dynamodb attribute value failed") 282 | } 283 | 284 | return int64(count), nil 285 | } 286 | 287 | // FixedWindowCosmosDB implements FixedWindow in CosmosDB. 288 | type FixedWindowCosmosDB struct { 289 | client *azcosmos.ContainerClient 290 | partitionKey string 291 | } 292 | 293 | // NewFixedWindowCosmosDB creates a new instance of FixedWindowCosmosDB. 294 | // PartitionKey is the key used for partitioning data into multiple partitions. 295 | func NewFixedWindowCosmosDB(client *azcosmos.ContainerClient, partitionKey string) *FixedWindowCosmosDB { 296 | return &FixedWindowCosmosDB{ 297 | client: client, 298 | partitionKey: partitionKey, 299 | } 300 | } 301 | 302 | func (f *FixedWindowCosmosDB) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) { 303 | id := strconv.FormatInt(window.UnixNano(), 10) 304 | tmp := cosmosItem{ 305 | ID: id, 306 | PartitionKey: f.partitionKey, 307 | Count: 1, 308 | TTL: int32(ttl), 309 | } 310 | 311 | ops := azcosmos.PatchOperations{} 312 | ops.AppendIncrement(`/Count`, 1) 313 | 314 | patchResp, err := f.client.PatchItem(ctx, azcosmos.NewPartitionKey().AppendString(f.partitionKey), id, ops, &azcosmos.ItemOptions{ 315 | EnableContentResponseOnWrite: true, 316 | }) 317 | if err == nil { 318 | // value exists and was updated 319 | err = json.Unmarshal(patchResp.Value, &tmp) 320 | if err != nil { 321 | return 0, errors.Wrap(err, "unmarshal of cosmos value failed") 322 | } 323 | 324 | return tmp.Count, nil 325 | } 326 | 327 | var respErr *azcore.ResponseError 328 | if !errors.As(err, &respErr) || respErr.StatusCode != http.StatusNotFound { 329 | return 0, errors.Wrap(err, `patch of cosmos value failed`) 330 | } 331 | 332 | newValue, err := json.Marshal(tmp) 333 | if err != nil { 334 | return 0, errors.Wrap(err, "marshal of cosmos value failed") 335 | } 336 | 337 | _, err = f.client.CreateItem(ctx, azcosmos.NewPartitionKey().AppendString(f.partitionKey), newValue, &azcosmos.ItemOptions{ 338 | SessionToken: patchResp.SessionToken, 339 | IfMatchEtag: &patchResp.ETag, 340 | }) 341 | if err != nil { 342 | return 0, errors.Wrap(err, "upsert of cosmos value failed") 343 | } 344 | 345 | return tmp.Count, nil 346 | } 347 | -------------------------------------------------------------------------------- /fixedwindow_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | l "github.com/mennanov/limiters" 10 | ) 11 | 12 | // fixedWindows returns all the possible FixedWindow combinations. 13 | func (s *LimitersTestSuite) fixedWindows(capacity int64, rate time.Duration, clock l.Clock) map[string]*l.FixedWindow { 14 | windows := make(map[string]*l.FixedWindow) 15 | for name, inc := range s.fixedWindowIncrementers() { 16 | windows[name] = l.NewFixedWindow(capacity, rate, inc, clock) 17 | } 18 | 19 | return windows 20 | } 21 | 22 | func (s *LimitersTestSuite) fixedWindowIncrementers() map[string]l.FixedWindowIncrementer { 23 | return map[string]l.FixedWindowIncrementer{ 24 | "FixedWindowInMemory": l.NewFixedWindowInMemory(), 25 | "FixedWindowRedis": l.NewFixedWindowRedis(s.redisClient, uuid.New().String()), 26 | "FixedWindowRedisCluster": l.NewFixedWindowRedis(s.redisClusterClient, uuid.New().String()), 27 | "FixedWindowMemcached": l.NewFixedWindowMemcached(s.memcacheClient, uuid.New().String()), 28 | "FixedWindowDynamoDB": l.NewFixedWindowDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps), 29 | "FixedWindowCosmosDB": l.NewFixedWindowCosmosDB(s.cosmosContainerClient, uuid.New().String()), 30 | } 31 | } 32 | 33 | var fixedWindowTestCases = []struct { 34 | capacity int64 35 | rate time.Duration 36 | requestCount int 37 | requestRate time.Duration 38 | missExpected int 39 | }{ 40 | { 41 | capacity: 2, 42 | rate: time.Millisecond * 100, 43 | requestCount: 20, 44 | requestRate: time.Millisecond * 25, 45 | missExpected: 10, 46 | }, 47 | { 48 | capacity: 4, 49 | rate: time.Millisecond * 100, 50 | requestCount: 20, 51 | requestRate: time.Millisecond * 25, 52 | missExpected: 0, 53 | }, 54 | { 55 | capacity: 2, 56 | rate: time.Millisecond * 100, 57 | requestCount: 15, 58 | requestRate: time.Millisecond * 33, 59 | missExpected: 5, 60 | }, 61 | } 62 | 63 | func (s *LimitersTestSuite) TestFixedWindowFakeClock() { 64 | clock := newFakeClockWithTime(time.Date(2019, 8, 30, 0, 0, 0, 0, time.UTC)) 65 | for _, testCase := range fixedWindowTestCases { 66 | for name, bucket := range s.fixedWindows(testCase.capacity, testCase.rate, clock) { 67 | s.Run(name, func() { 68 | clock.reset() 69 | miss := 0 70 | for i := 0; i < testCase.requestCount; i++ { 71 | // No pause for the first request. 72 | if i > 0 { 73 | clock.Sleep(testCase.requestRate) 74 | } 75 | if _, err := bucket.Limit(context.TODO()); err != nil { 76 | s.Equal(l.ErrLimitExhausted, err) 77 | miss++ 78 | } 79 | } 80 | s.Equal(testCase.missExpected, miss, testCase) 81 | }) 82 | } 83 | } 84 | } 85 | 86 | func (s *LimitersTestSuite) TestFixedWindowOverflow() { 87 | clock := newFakeClockWithTime(time.Date(2019, 8, 30, 0, 0, 0, 0, time.UTC)) 88 | for name, bucket := range s.fixedWindows(2, time.Second, clock) { 89 | s.Run(name, func() { 90 | clock.reset() 91 | w, err := bucket.Limit(context.TODO()) 92 | s.Require().NoError(err) 93 | s.Equal(time.Duration(0), w) 94 | w, err = bucket.Limit(context.TODO()) 95 | s.Require().NoError(err) 96 | s.Equal(time.Duration(0), w) 97 | w, err = bucket.Limit(context.TODO()) 98 | s.Require().Equal(l.ErrLimitExhausted, err) 99 | s.Equal(time.Second, w) 100 | clock.Sleep(time.Second) 101 | w, err = bucket.Limit(context.TODO()) 102 | s.Require().NoError(err) 103 | s.Equal(time.Duration(0), w) 104 | }) 105 | } 106 | } 107 | 108 | func (s *LimitersTestSuite) TestFixedWindowDynamoDBPartitionKey() { 109 | clock := newFakeClockWithTime(time.Date(2019, 8, 30, 0, 0, 0, 0, time.UTC)) 110 | incrementor := l.NewFixedWindowDynamoDB(s.dynamodbClient, "partitionKey1", s.dynamoDBTableProps) 111 | window := l.NewFixedWindow(2, time.Millisecond*100, incrementor, clock) 112 | 113 | w, err := window.Limit(context.TODO()) 114 | s.Require().NoError(err) 115 | s.Equal(time.Duration(0), w) 116 | w, err = window.Limit(context.TODO()) 117 | s.Require().NoError(err) 118 | s.Equal(time.Duration(0), w) 119 | // The third call should fail for the "partitionKey1", but succeed for "partitionKey2". 120 | w, err = window.Limit(l.NewFixedWindowDynamoDBContext(context.Background(), "partitionKey2")) 121 | s.Require().NoError(err) 122 | s.Equal(time.Duration(0), w) 123 | } 124 | 125 | func BenchmarkFixedWindows(b *testing.B) { 126 | s := new(LimitersTestSuite) 127 | s.SetT(&testing.T{}) 128 | s.SetupSuite() 129 | capacity := int64(1) 130 | rate := time.Second 131 | clock := newFakeClock() 132 | windows := s.fixedWindows(capacity, rate, clock) 133 | for name, window := range windows { 134 | b.Run(name, func(b *testing.B) { 135 | for i := 0; i < b.N; i++ { 136 | _, err := window.Limit(context.TODO()) 137 | s.Require().NoError(err) 138 | } 139 | }) 140 | } 141 | s.TearDownSuite() 142 | } 143 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mennanov/limiters 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.1 6 | 7 | replace github.com/armon/go-metrics => github.com/hashicorp/go-metrics v0.4.1 8 | 9 | require ( 10 | github.com/aws/aws-sdk-go-v2/config v1.29.12 11 | github.com/aws/aws-sdk-go-v2/credentials v1.17.65 12 | github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.18.8 13 | github.com/go-redsync/redsync/v4 v4.13.0 14 | github.com/google/uuid v1.6.0 15 | github.com/hashicorp/consul/api v1.31.2 16 | github.com/pkg/errors v0.9.1 17 | github.com/redis/go-redis/v9 v9.7.3 18 | github.com/samuel/go-zookeeper v0.0.0-20201211165307-7117e9ea2414 19 | github.com/stretchr/testify v1.10.0 20 | go.etcd.io/etcd/api/v3 v3.5.21 21 | go.etcd.io/etcd/client/v3 v3.5.21 22 | google.golang.org/grpc v1.71.0 23 | google.golang.org/protobuf v1.36.6 24 | ) 25 | 26 | require ( 27 | github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.1 28 | github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.4.0 29 | github.com/aws/aws-sdk-go-v2 v1.36.3 30 | github.com/aws/aws-sdk-go-v2/service/dynamodb v1.42.0 31 | github.com/cenkalti/backoff/v3 v3.2.2 32 | github.com/lib/pq v1.10.9 33 | ) 34 | 35 | require ( 36 | github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect 37 | github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect 38 | github.com/armon/go-metrics v0.5.4 // indirect 39 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect 40 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect 41 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect 42 | github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect 43 | github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.25.1 // indirect 44 | github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect 45 | github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.15 // indirect 46 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect 47 | github.com/aws/aws-sdk-go-v2/service/sso v1.25.2 // indirect 48 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.0 // indirect 49 | github.com/aws/aws-sdk-go-v2/service/sts v1.33.17 // indirect 50 | github.com/aws/smithy-go v1.22.3 // indirect 51 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 52 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 53 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 54 | github.com/gogo/protobuf v1.3.2 // indirect 55 | github.com/golang/protobuf v1.5.4 // indirect 56 | github.com/hashicorp/errwrap v1.1.0 // indirect 57 | github.com/hashicorp/go-metrics v0.5.4 // indirect 58 | github.com/hashicorp/go-multierror v1.1.1 // indirect 59 | github.com/hashicorp/go-rootcerts v1.0.2 // indirect 60 | github.com/hashicorp/golang-lru v1.0.2 // indirect 61 | github.com/hashicorp/serf v0.10.2 // indirect 62 | github.com/mattn/go-colorable v0.1.14 // indirect 63 | github.com/mattn/go-isatty v0.0.20 // indirect 64 | github.com/mitchellh/go-homedir v1.1.0 // indirect 65 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 66 | github.com/thanhpk/randstr v1.0.6 // indirect 67 | go.etcd.io/etcd/client/pkg/v3 v3.5.21 // indirect 68 | golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect 69 | golang.org/x/sys v0.31.0 // indirect 70 | golang.org/x/text v0.23.0 // indirect 71 | google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 // indirect 72 | google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 // indirect 73 | gopkg.in/yaml.v3 v3.0.1 // indirect 74 | ) 75 | 76 | require ( 77 | github.com/alessandro-c/gomemcached-lock v1.0.0 78 | github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 79 | github.com/coreos/go-semver v0.3.1 // indirect 80 | github.com/coreos/go-systemd/v22 v22.5.0 // indirect 81 | github.com/fatih/color v1.18.0 // indirect 82 | github.com/hashicorp/go-cleanhttp v0.5.2 // indirect 83 | github.com/hashicorp/go-hclog v1.6.3 // indirect 84 | github.com/hashicorp/go-immutable-radix v1.3.1 // indirect 85 | github.com/mitchellh/mapstructure v1.5.0 // indirect 86 | go.uber.org/multierr v1.11.0 // indirect 87 | go.uber.org/zap v1.27.0 // indirect 88 | golang.org/x/net v0.38.0 // indirect 89 | ) 90 | -------------------------------------------------------------------------------- /leakybucket.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/gob" 7 | "encoding/json" 8 | "fmt" 9 | "net/http" 10 | "strconv" 11 | "sync" 12 | "time" 13 | 14 | "github.com/Azure/azure-sdk-for-go/sdk/azcore" 15 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 16 | "github.com/aws/aws-sdk-go-v2/aws" 17 | "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" 18 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 19 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" 20 | "github.com/bradfitz/gomemcache/memcache" 21 | "github.com/pkg/errors" 22 | "github.com/redis/go-redis/v9" 23 | "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" 24 | clientv3 "go.etcd.io/etcd/client/v3" 25 | ) 26 | 27 | // LeakyBucketState represents the state of a LeakyBucket. 28 | type LeakyBucketState struct { 29 | // Last is the Unix timestamp in nanoseconds of the most recent request. 30 | Last int64 31 | } 32 | 33 | // IzZero returns true if the bucket state is zero valued. 34 | func (s LeakyBucketState) IzZero() bool { 35 | return s.Last == 0 36 | } 37 | 38 | // LeakyBucketStateBackend interface encapsulates the logic of retrieving and persisting the state of a LeakyBucket. 39 | type LeakyBucketStateBackend interface { 40 | // State gets the current state of the LeakyBucket. 41 | State(ctx context.Context) (LeakyBucketState, error) 42 | // SetState sets (persists) the current state of the LeakyBucket. 43 | SetState(ctx context.Context, state LeakyBucketState) error 44 | // Reset resets (persists) the current state of the LeakyBucket. 45 | Reset(ctx context.Context) error 46 | } 47 | 48 | // LeakyBucket implements the https://en.wikipedia.org/wiki/Leaky_bucket#As_a_queue algorithm. 49 | type LeakyBucket struct { 50 | locker DistLocker 51 | backend LeakyBucketStateBackend 52 | clock Clock 53 | logger Logger 54 | // Capacity is the maximum allowed number of tokens in the bucket. 55 | capacity int64 56 | // Rate is the output rate: 1 request per the rate duration (in nanoseconds). 57 | rate int64 58 | mu sync.Mutex 59 | } 60 | 61 | // NewLeakyBucket creates a new instance of LeakyBucket. 62 | func NewLeakyBucket(capacity int64, rate time.Duration, locker DistLocker, leakyBucketStateBackend LeakyBucketStateBackend, clock Clock, logger Logger) *LeakyBucket { 63 | return &LeakyBucket{ 64 | locker: locker, 65 | backend: leakyBucketStateBackend, 66 | clock: clock, 67 | logger: logger, 68 | capacity: capacity, 69 | rate: int64(rate), 70 | } 71 | } 72 | 73 | // Limit returns the time duration to wait before the request can be processed. 74 | // It returns ErrLimitExhausted if the request overflows the bucket's capacity. In this case the returned duration 75 | // means how long it would have taken to wait for the request to be processed if the bucket was not overflowed. 76 | func (t *LeakyBucket) Limit(ctx context.Context) (time.Duration, error) { 77 | t.mu.Lock() 78 | defer t.mu.Unlock() 79 | if err := t.locker.Lock(ctx); err != nil { 80 | return 0, err 81 | } 82 | defer func() { 83 | if err := t.locker.Unlock(ctx); err != nil { 84 | t.logger.Log(err) 85 | } 86 | }() 87 | state, err := t.backend.State(ctx) 88 | if err != nil { 89 | return 0, err 90 | } 91 | now := t.clock.Now().UnixNano() 92 | if now < state.Last { 93 | // The queue has requests in it: move the current request to the last position + 1. 94 | state.Last += t.rate 95 | } else { 96 | // The queue is empty. 97 | // The offset is the duration to wait in case the last request happened less than rate duration ago. 98 | var offset int64 99 | delta := now - state.Last 100 | if delta < t.rate { 101 | offset = t.rate - delta 102 | } 103 | state.Last = now + offset 104 | } 105 | 106 | wait := state.Last - now 107 | if wait/t.rate >= t.capacity { 108 | return time.Duration(wait), ErrLimitExhausted 109 | } 110 | if err = t.backend.SetState(ctx, state); err != nil { 111 | return 0, err 112 | } 113 | 114 | return time.Duration(wait), nil 115 | } 116 | 117 | // Reset resets the bucket. 118 | func (t *LeakyBucket) Reset(ctx context.Context) error { 119 | return t.backend.Reset(ctx) 120 | } 121 | 122 | // LeakyBucketInMemory is an in-memory implementation of LeakyBucketStateBackend. 123 | type LeakyBucketInMemory struct { 124 | state LeakyBucketState 125 | } 126 | 127 | // NewLeakyBucketInMemory creates a new instance of LeakyBucketInMemory. 128 | func NewLeakyBucketInMemory() *LeakyBucketInMemory { 129 | return &LeakyBucketInMemory{} 130 | } 131 | 132 | // State gets the current state of the bucket. 133 | func (l *LeakyBucketInMemory) State(ctx context.Context) (LeakyBucketState, error) { 134 | return l.state, ctx.Err() 135 | } 136 | 137 | // SetState sets the current state of the bucket. 138 | func (l *LeakyBucketInMemory) SetState(ctx context.Context, state LeakyBucketState) error { 139 | l.state = state 140 | 141 | return ctx.Err() 142 | } 143 | 144 | // Reset resets the current state of the bucket. 145 | func (l *LeakyBucketInMemory) Reset(ctx context.Context) error { 146 | state := LeakyBucketState{ 147 | Last: 0, 148 | } 149 | 150 | return l.SetState(ctx, state) 151 | } 152 | 153 | const ( 154 | etcdKeyLBLease = "lease" 155 | etcdKeyLBLast = "last" 156 | ) 157 | 158 | // LeakyBucketEtcd is an etcd implementation of a LeakyBucketStateBackend. 159 | // See the TokenBucketEtcd description for the details on etcd usage. 160 | type LeakyBucketEtcd struct { 161 | // prefix is the etcd key prefix. 162 | prefix string 163 | cli *clientv3.Client 164 | leaseID clientv3.LeaseID 165 | ttl time.Duration 166 | raceCheck bool 167 | lastVersion int64 168 | } 169 | 170 | // NewLeakyBucketEtcd creates a new LeakyBucketEtcd instance. 171 | // Prefix is used as an etcd key prefix for all keys stored in etcd by this algorithm. 172 | // TTL is a TTL of the etcd lease used to store all the keys. 173 | // 174 | // If raceCheck is true and the keys in etcd are modified in between State() and SetState() calls then 175 | // ErrRaceCondition is returned. 176 | func NewLeakyBucketEtcd(cli *clientv3.Client, prefix string, ttl time.Duration, raceCheck bool) *LeakyBucketEtcd { 177 | return &LeakyBucketEtcd{ 178 | prefix: prefix, 179 | cli: cli, 180 | ttl: ttl, 181 | raceCheck: raceCheck, 182 | } 183 | } 184 | 185 | // State gets the bucket's current state from etcd. 186 | // If there is no state available in etcd then the initial bucket's state is returned. 187 | func (l *LeakyBucketEtcd) State(ctx context.Context) (LeakyBucketState, error) { 188 | // Reset the lease ID as it will be updated by the successful Get operation below. 189 | l.leaseID = 0 190 | // Get all the keys under the prefix in a single request. 191 | r, err := l.cli.Get(ctx, l.prefix, clientv3.WithRange(incPrefix(l.prefix))) 192 | if err != nil { 193 | return LeakyBucketState{}, errors.Wrapf(err, "failed to get keys in range ['%s', '%s') from etcd", l.prefix, incPrefix(l.prefix)) 194 | } 195 | if len(r.Kvs) == 0 { 196 | return LeakyBucketState{}, nil 197 | } 198 | state := LeakyBucketState{} 199 | parsed := 0 200 | var v int64 201 | for _, kv := range r.Kvs { 202 | switch string(kv.Key) { 203 | case etcdKey(l.prefix, etcdKeyLBLast): 204 | v, err = parseEtcdInt64(kv) 205 | if err != nil { 206 | return LeakyBucketState{}, err 207 | } 208 | state.Last = v 209 | parsed |= 1 210 | l.lastVersion = kv.Version 211 | 212 | case etcdKey(l.prefix, etcdKeyLBLease): 213 | v, err = parseEtcdInt64(kv) 214 | if err != nil { 215 | return LeakyBucketState{}, err 216 | } 217 | l.leaseID = clientv3.LeaseID(v) 218 | parsed |= 2 219 | } 220 | } 221 | if parsed != 3 { 222 | return LeakyBucketState{}, errors.New("failed to get state from etcd: some keys are missing") 223 | } 224 | 225 | return state, nil 226 | } 227 | 228 | // createLease creates a new lease in etcd and updates the t.leaseID value. 229 | func (l *LeakyBucketEtcd) createLease(ctx context.Context) error { 230 | lease, err := l.cli.Grant(ctx, int64(l.ttl/time.Nanosecond)) 231 | if err != nil { 232 | return errors.Wrap(err, "failed to create a new lease in etcd") 233 | } 234 | l.leaseID = lease.ID 235 | 236 | return nil 237 | } 238 | 239 | // save saves the state to etcd using the existing lease. 240 | func (l *LeakyBucketEtcd) save(ctx context.Context, state LeakyBucketState) error { 241 | if !l.raceCheck { 242 | if _, err := l.cli.Txn(ctx).Then( 243 | clientv3.OpPut(etcdKey(l.prefix, etcdKeyLBLast), fmt.Sprintf("%d", state.Last), clientv3.WithLease(l.leaseID)), 244 | clientv3.OpPut(etcdKey(l.prefix, etcdKeyLBLease), fmt.Sprintf("%d", l.leaseID), clientv3.WithLease(l.leaseID)), 245 | ).Commit(); err != nil { 246 | return errors.Wrap(err, "failed to commit a transaction to etcd") 247 | } 248 | 249 | return nil 250 | } 251 | // Put the keys only if they have not been modified since the most recent read. 252 | r, err := l.cli.Txn(ctx).If( 253 | clientv3.Compare(clientv3.Version(etcdKey(l.prefix, etcdKeyLBLast)), ">", l.lastVersion), 254 | ).Else( 255 | clientv3.OpPut(etcdKey(l.prefix, etcdKeyLBLast), fmt.Sprintf("%d", state.Last), clientv3.WithLease(l.leaseID)), 256 | clientv3.OpPut(etcdKey(l.prefix, etcdKeyLBLease), fmt.Sprintf("%d", l.leaseID), clientv3.WithLease(l.leaseID)), 257 | ).Commit() 258 | if err != nil { 259 | return errors.Wrap(err, "failed to commit a transaction to etcd") 260 | } 261 | 262 | if !r.Succeeded { 263 | return nil 264 | } 265 | 266 | return ErrRaceCondition 267 | } 268 | 269 | // SetState updates the state of the bucket in etcd. 270 | func (l *LeakyBucketEtcd) SetState(ctx context.Context, state LeakyBucketState) error { 271 | if l.leaseID == 0 { 272 | // Lease does not exist, create one. 273 | if err := l.createLease(ctx); err != nil { 274 | return err 275 | } 276 | // No need to send KeepAlive for the newly creates lease: save the state immediately. 277 | return l.save(ctx, state) 278 | } 279 | // Send the KeepAlive request to extend the existing lease. 280 | if _, err := l.cli.KeepAliveOnce(ctx, l.leaseID); errors.Is(err, rpctypes.ErrLeaseNotFound) { 281 | // Create a new lease since the current one has expired. 282 | if err = l.createLease(ctx); err != nil { 283 | return err 284 | } 285 | } else if err != nil { 286 | return errors.Wrapf(err, "failed to extend the lease '%d'", l.leaseID) 287 | } 288 | 289 | return l.save(ctx, state) 290 | } 291 | 292 | // Reset resets the state of the bucket in etcd. 293 | func (l *LeakyBucketEtcd) Reset(ctx context.Context) error { 294 | state := LeakyBucketState{ 295 | Last: 0, 296 | } 297 | 298 | return l.SetState(ctx, state) 299 | } 300 | 301 | // Deprecated: These legacy keys will be removed in a future version. 302 | // The state is now stored in a single JSON document under the "state" key. 303 | const ( 304 | redisKeyLBLast = "last" 305 | redisKeyLBVersion = "version" 306 | ) 307 | 308 | // LeakyBucketRedis is a Redis implementation of a LeakyBucketStateBackend. 309 | type LeakyBucketRedis struct { 310 | cli redis.UniversalClient 311 | prefix string 312 | ttl time.Duration 313 | raceCheck bool 314 | lastVersion int64 315 | } 316 | 317 | // NewLeakyBucketRedis creates a new LeakyBucketRedis instance. 318 | // Prefix is the key prefix used to store all the keys used in this implementation in Redis. 319 | // TTL is the TTL of the stored keys. 320 | // 321 | // If raceCheck is true and the keys in Redis are modified in between State() and SetState() calls then 322 | // ErrRaceCondition is returned. 323 | func NewLeakyBucketRedis(cli redis.UniversalClient, prefix string, ttl time.Duration, raceCheck bool) *LeakyBucketRedis { 324 | return &LeakyBucketRedis{cli: cli, prefix: prefix, ttl: ttl, raceCheck: raceCheck} 325 | } 326 | 327 | // Deprecated: Legacy format support will be removed in a future version. 328 | func (t *LeakyBucketRedis) oldState(ctx context.Context) (LeakyBucketState, error) { 329 | var values []interface{} 330 | var err error 331 | done := make(chan struct{}, 1) 332 | go func() { 333 | defer close(done) 334 | keys := []string{ 335 | redisKey(t.prefix, redisKeyLBLast), 336 | } 337 | if t.raceCheck { 338 | keys = append(keys, redisKey(t.prefix, redisKeyLBVersion)) 339 | } 340 | values, err = t.cli.MGet(ctx, keys...).Result() 341 | }() 342 | 343 | select { 344 | case <-done: 345 | 346 | case <-ctx.Done(): 347 | return LeakyBucketState{}, ctx.Err() 348 | } 349 | 350 | if err != nil { 351 | return LeakyBucketState{}, errors.Wrap(err, "failed to get keys from redis") 352 | } 353 | nilAny := false 354 | for _, v := range values { 355 | if v == nil { 356 | nilAny = true 357 | 358 | break 359 | } 360 | } 361 | if nilAny || errors.Is(err, redis.Nil) { 362 | // Keys don't exist, return an empty state. 363 | return LeakyBucketState{}, nil 364 | } 365 | 366 | last, err := strconv.ParseInt(values[0].(string), 10, 64) 367 | if err != nil { 368 | return LeakyBucketState{}, err 369 | } 370 | if t.raceCheck { 371 | t.lastVersion, err = strconv.ParseInt(values[1].(string), 10, 64) 372 | if err != nil { 373 | return LeakyBucketState{}, err 374 | } 375 | } 376 | 377 | return LeakyBucketState{ 378 | Last: last, 379 | }, nil 380 | } 381 | 382 | // State gets the bucket's state from Redis. 383 | func (t *LeakyBucketRedis) State(ctx context.Context) (LeakyBucketState, error) { 384 | var err error 385 | done := make(chan struct{}, 1) 386 | errCh := make(chan error, 1) 387 | var state LeakyBucketState 388 | 389 | if t.raceCheck { 390 | // reset in a case of returning an empty LeakyBucketState 391 | t.lastVersion = 0 392 | } 393 | 394 | go func() { 395 | defer close(done) 396 | key := redisKey(t.prefix, "state") 397 | value, err := t.cli.Get(ctx, key).Result() 398 | if err != nil && !errors.Is(err, redis.Nil) { 399 | errCh <- err 400 | 401 | return 402 | } 403 | 404 | if errors.Is(err, redis.Nil) { 405 | state, err = t.oldState(ctx) 406 | errCh <- err 407 | 408 | return 409 | } 410 | 411 | // Try new format 412 | var item struct { 413 | State LeakyBucketState `json:"state"` 414 | Version int64 `json:"version"` 415 | } 416 | if err = json.Unmarshal([]byte(value), &item); err != nil { 417 | errCh <- err 418 | 419 | return 420 | } 421 | 422 | state = item.State 423 | if t.raceCheck { 424 | t.lastVersion = item.Version 425 | } 426 | errCh <- nil 427 | }() 428 | 429 | select { 430 | case <-done: 431 | err = <-errCh 432 | case <-ctx.Done(): 433 | return LeakyBucketState{}, ctx.Err() 434 | } 435 | 436 | if err != nil { 437 | return LeakyBucketState{}, errors.Wrap(err, "failed to get state from redis") 438 | } 439 | 440 | return state, nil 441 | } 442 | 443 | // SetState updates the state in Redis. 444 | func (t *LeakyBucketRedis) SetState(ctx context.Context, state LeakyBucketState) error { 445 | var err error 446 | done := make(chan struct{}, 1) 447 | errCh := make(chan error, 1) 448 | 449 | go func() { 450 | defer close(done) 451 | key := redisKey(t.prefix, "state") 452 | item := struct { 453 | State LeakyBucketState `json:"state"` 454 | Version int64 `json:"version"` 455 | }{ 456 | State: state, 457 | Version: t.lastVersion + 1, 458 | } 459 | 460 | value, err := json.Marshal(item) 461 | if err != nil { 462 | errCh <- err 463 | 464 | return 465 | } 466 | 467 | if !t.raceCheck { 468 | errCh <- t.cli.Set(ctx, key, value, t.ttl).Err() 469 | 470 | return 471 | } 472 | 473 | script := ` 474 | local current = redis.call('get', KEYS[1]) 475 | if current then 476 | local data = cjson.decode(current) 477 | if data.version > tonumber(ARGV[2]) then 478 | return 'RACE_CONDITION' 479 | end 480 | end 481 | redis.call('set', KEYS[1], ARGV[1], 'PX', ARGV[3]) 482 | return 'OK' 483 | ` 484 | result, err := t.cli.Eval(ctx, script, []string{key}, value, t.lastVersion, int64(t.ttl/time.Millisecond)).Result() 485 | if err != nil { 486 | errCh <- err 487 | 488 | return 489 | } 490 | if result == "RACE_CONDITION" { 491 | errCh <- ErrRaceCondition 492 | 493 | return 494 | } 495 | errCh <- nil 496 | }() 497 | 498 | select { 499 | case <-done: 500 | err = <-errCh 501 | case <-ctx.Done(): 502 | return ctx.Err() 503 | } 504 | 505 | if err != nil { 506 | return errors.Wrap(err, "failed to save state to redis") 507 | } 508 | 509 | return nil 510 | } 511 | 512 | // Reset resets the state in Redis. 513 | func (t *LeakyBucketRedis) Reset(ctx context.Context) error { 514 | state := LeakyBucketState{ 515 | Last: 0, 516 | } 517 | 518 | return t.SetState(ctx, state) 519 | } 520 | 521 | // LeakyBucketMemcached is a Memcached implementation of a LeakyBucketStateBackend. 522 | type LeakyBucketMemcached struct { 523 | cli *memcache.Client 524 | key string 525 | ttl time.Duration 526 | raceCheck bool 527 | casId uint64 528 | } 529 | 530 | // NewLeakyBucketMemcached creates a new LeakyBucketMemcached instance. 531 | // Key is the key used to store all the keys used in this implementation in Memcached. 532 | // TTL is the TTL of the stored keys. 533 | // 534 | // If raceCheck is true and the keys in Memcached are modified in between State() and SetState() calls then 535 | // ErrRaceCondition is returned. 536 | func NewLeakyBucketMemcached(cli *memcache.Client, key string, ttl time.Duration, raceCheck bool) *LeakyBucketMemcached { 537 | return &LeakyBucketMemcached{cli: cli, key: key, ttl: ttl, raceCheck: raceCheck} 538 | } 539 | 540 | // State gets the bucket's state from Memcached. 541 | func (t *LeakyBucketMemcached) State(ctx context.Context) (LeakyBucketState, error) { 542 | var item *memcache.Item 543 | var err error 544 | state := LeakyBucketState{} 545 | done := make(chan struct{}, 1) 546 | go func() { 547 | defer close(done) 548 | item, err = t.cli.Get(t.key) 549 | }() 550 | 551 | select { 552 | case <-done: 553 | 554 | case <-ctx.Done(): 555 | return state, ctx.Err() 556 | } 557 | 558 | if err != nil { 559 | if errors.Is(err, memcache.ErrCacheMiss) { 560 | // Keys don't exist, return an empty state. 561 | return state, nil 562 | } 563 | 564 | return state, errors.Wrap(err, "failed to get keys from memcached") 565 | } 566 | b := bytes.NewBuffer(item.Value) 567 | err = gob.NewDecoder(b).Decode(&state) 568 | if err != nil { 569 | return state, errors.Wrap(err, "failed to Decode") 570 | } 571 | t.casId = item.CasID 572 | 573 | return state, nil 574 | } 575 | 576 | // SetState updates the state in Memcached. 577 | // The provided fencing token is checked on the Memcached side before saving the keys. 578 | func (t *LeakyBucketMemcached) SetState(ctx context.Context, state LeakyBucketState) error { 579 | var err error 580 | done := make(chan struct{}, 1) 581 | var b bytes.Buffer 582 | err = gob.NewEncoder(&b).Encode(state) 583 | if err != nil { 584 | return errors.Wrap(err, "failed to Encode") 585 | } 586 | go func() { 587 | defer close(done) 588 | item := &memcache.Item{ 589 | Key: t.key, 590 | Value: b.Bytes(), 591 | CasID: t.casId, 592 | } 593 | if t.raceCheck && t.casId > 0 { 594 | err = t.cli.CompareAndSwap(item) 595 | } else { 596 | err = t.cli.Set(item) 597 | } 598 | }() 599 | 600 | select { 601 | case <-done: 602 | 603 | case <-ctx.Done(): 604 | return ctx.Err() 605 | } 606 | 607 | if err != nil && (errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss)) { 608 | return ErrRaceCondition 609 | } 610 | 611 | return errors.Wrap(err, "failed to save keys to memcached") 612 | } 613 | 614 | // Reset resets the state in Memcached. 615 | func (t *LeakyBucketMemcached) Reset(ctx context.Context) error { 616 | state := LeakyBucketState{ 617 | Last: 0, 618 | } 619 | t.casId = 0 620 | 621 | return t.SetState(ctx, state) 622 | } 623 | 624 | // LeakyBucketDynamoDB is a DyanamoDB implementation of a LeakyBucketStateBackend. 625 | type LeakyBucketDynamoDB struct { 626 | client *dynamodb.Client 627 | tableProps DynamoDBTableProperties 628 | partitionKey string 629 | ttl time.Duration 630 | raceCheck bool 631 | latestVersion int64 632 | keys map[string]types.AttributeValue 633 | } 634 | 635 | // NewLeakyBucketDynamoDB creates a new LeakyBucketDynamoDB instance. 636 | // PartitionKey is the key used to store all the this implementation in DynamoDB. 637 | // 638 | // TableProps describe the table that this backend should work with. This backend requires the following on the table: 639 | // * TTL 640 | // 641 | // TTL is the TTL of the stored item. 642 | // 643 | // If raceCheck is true and the item in DynamoDB are modified in between State() and SetState() calls then 644 | // ErrRaceCondition is returned. 645 | func NewLeakyBucketDynamoDB(client *dynamodb.Client, partitionKey string, tableProps DynamoDBTableProperties, ttl time.Duration, raceCheck bool) *LeakyBucketDynamoDB { 646 | keys := map[string]types.AttributeValue{ 647 | tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: partitionKey}, 648 | } 649 | 650 | if tableProps.SortKeyUsed { 651 | keys[tableProps.SortKeyName] = &types.AttributeValueMemberS{Value: partitionKey} 652 | } 653 | 654 | return &LeakyBucketDynamoDB{ 655 | client: client, 656 | partitionKey: partitionKey, 657 | tableProps: tableProps, 658 | ttl: ttl, 659 | raceCheck: raceCheck, 660 | keys: keys, 661 | } 662 | } 663 | 664 | // State gets the bucket's state from DynamoDB. 665 | func (t *LeakyBucketDynamoDB) State(ctx context.Context) (LeakyBucketState, error) { 666 | resp, err := dynamoDBGetItem(ctx, t.client, t.getGetItemInput()) 667 | if err != nil { 668 | return LeakyBucketState{}, err 669 | } 670 | 671 | return t.loadStateFromDynamoDB(resp) 672 | } 673 | 674 | // SetState updates the state in DynamoDB. 675 | func (t *LeakyBucketDynamoDB) SetState(ctx context.Context, state LeakyBucketState) error { 676 | input := t.getPutItemInputFromState(state) 677 | 678 | var err error 679 | done := make(chan struct{}) 680 | go func() { 681 | defer close(done) 682 | _, err = dynamoDBputItem(ctx, t.client, input) 683 | }() 684 | 685 | select { 686 | case <-done: 687 | case <-ctx.Done(): 688 | return ctx.Err() 689 | } 690 | 691 | return err 692 | } 693 | 694 | // Reset resets the state in DynamoDB. 695 | func (t *LeakyBucketDynamoDB) Reset(ctx context.Context) error { 696 | state := LeakyBucketState{ 697 | Last: 0, 698 | } 699 | 700 | return t.SetState(ctx, state) 701 | } 702 | 703 | const ( 704 | dynamodbBucketRaceConditionExpression = "Version <= :version" 705 | dynamoDBBucketLastKey = "Last" 706 | dynamoDBBucketVersionKey = "Version" 707 | ) 708 | 709 | func (t *LeakyBucketDynamoDB) getPutItemInputFromState(state LeakyBucketState) *dynamodb.PutItemInput { 710 | item := map[string]types.AttributeValue{} 711 | for k, v := range t.keys { 712 | item[k] = v 713 | } 714 | 715 | item[dynamoDBBucketLastKey] = &types.AttributeValueMemberN{Value: strconv.FormatInt(state.Last, 10)} 716 | item[dynamoDBBucketVersionKey] = &types.AttributeValueMemberN{Value: strconv.FormatInt(t.latestVersion+1, 10)} 717 | item[t.tableProps.TTLFieldName] = &types.AttributeValueMemberN{Value: strconv.FormatInt(time.Now().Add(t.ttl).Unix(), 10)} 718 | 719 | input := &dynamodb.PutItemInput{ 720 | TableName: &t.tableProps.TableName, 721 | Item: item, 722 | } 723 | 724 | if t.raceCheck && t.latestVersion > 0 { 725 | input.ConditionExpression = aws.String(dynamodbBucketRaceConditionExpression) 726 | input.ExpressionAttributeValues = map[string]types.AttributeValue{ 727 | ":version": &types.AttributeValueMemberN{Value: strconv.FormatInt(t.latestVersion, 10)}, 728 | } 729 | } 730 | 731 | return input 732 | } 733 | 734 | func (t *LeakyBucketDynamoDB) getGetItemInput() *dynamodb.GetItemInput { 735 | return &dynamodb.GetItemInput{ 736 | TableName: &t.tableProps.TableName, 737 | Key: t.keys, 738 | } 739 | } 740 | 741 | func (t *LeakyBucketDynamoDB) loadStateFromDynamoDB(resp *dynamodb.GetItemOutput) (LeakyBucketState, error) { 742 | state := LeakyBucketState{} 743 | err := attributevalue.Unmarshal(resp.Item[dynamoDBBucketLastKey], &state.Last) 744 | if err != nil { 745 | return state, fmt.Errorf("unmarshal dynamodb Last attribute failed: %w", err) 746 | } 747 | 748 | if t.raceCheck { 749 | err = attributevalue.Unmarshal(resp.Item[dynamoDBBucketVersionKey], &t.latestVersion) 750 | if err != nil { 751 | return state, fmt.Errorf("unmarshal dynamodb Version attribute failed: %w", err) 752 | } 753 | } 754 | 755 | return state, nil 756 | } 757 | 758 | // CosmosDBLeakyBucketItem represents a document in CosmosDB for LeakyBucket. 759 | type CosmosDBLeakyBucketItem struct { 760 | ID string `json:"id"` 761 | PartitionKey string `json:"partitionKey"` 762 | State LeakyBucketState `json:"state"` 763 | Version int64 `json:"version"` 764 | TTL int64 `json:"ttl"` 765 | } 766 | 767 | // LeakyBucketCosmosDB is a CosmosDB implementation of a LeakyBucketStateBackend. 768 | type LeakyBucketCosmosDB struct { 769 | client *azcosmos.ContainerClient 770 | partitionKey string 771 | id string 772 | ttl time.Duration 773 | raceCheck bool 774 | latestVersion int64 775 | } 776 | 777 | // NewLeakyBucketCosmosDB creates a new LeakyBucketCosmosDB instance. 778 | // PartitionKey is the key used to store all the implementation in CosmosDB. 779 | // TTL is the TTL of the stored item. 780 | // 781 | // If raceCheck is true and the item in CosmosDB is modified in between State() and SetState() calls then 782 | // ErrRaceCondition is returned. 783 | func NewLeakyBucketCosmosDB(client *azcosmos.ContainerClient, partitionKey string, ttl time.Duration, raceCheck bool) *LeakyBucketCosmosDB { 784 | return &LeakyBucketCosmosDB{ 785 | client: client, 786 | partitionKey: partitionKey, 787 | id: "leaky-bucket-" + partitionKey, 788 | ttl: ttl, 789 | raceCheck: raceCheck, 790 | } 791 | } 792 | 793 | func (t *LeakyBucketCosmosDB) State(ctx context.Context) (LeakyBucketState, error) { 794 | var item CosmosDBLeakyBucketItem 795 | resp, err := t.client.ReadItem(ctx, azcosmos.NewPartitionKey().AppendString(t.partitionKey), t.id, &azcosmos.ItemOptions{}) 796 | if err != nil { 797 | var respErr *azcore.ResponseError 798 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { 799 | return LeakyBucketState{}, nil 800 | } 801 | 802 | return LeakyBucketState{}, err 803 | } 804 | 805 | err = json.Unmarshal(resp.Value, &item) 806 | if err != nil { 807 | return LeakyBucketState{}, errors.Wrap(err, "failed to decode state from Cosmos DB") 808 | } 809 | 810 | if time.Now().Unix() > item.TTL { 811 | return LeakyBucketState{}, nil 812 | } 813 | 814 | if t.raceCheck { 815 | t.latestVersion = item.Version 816 | } 817 | 818 | return item.State, nil 819 | } 820 | 821 | func (t *LeakyBucketCosmosDB) SetState(ctx context.Context, state LeakyBucketState) error { 822 | var err error 823 | done := make(chan struct{}, 1) 824 | 825 | item := CosmosDBLeakyBucketItem{ 826 | ID: t.id, 827 | PartitionKey: t.partitionKey, 828 | State: state, 829 | Version: t.latestVersion + 1, 830 | TTL: time.Now().Add(t.ttl).Unix(), 831 | } 832 | 833 | value, err := json.Marshal(item) 834 | if err != nil { 835 | return errors.Wrap(err, "failed to encode state to JSON") 836 | } 837 | 838 | go func() { 839 | defer close(done) 840 | _, err = t.client.UpsertItem(ctx, azcosmos.NewPartitionKey().AppendString(t.partitionKey), value, &azcosmos.ItemOptions{}) 841 | }() 842 | 843 | select { 844 | case <-done: 845 | case <-ctx.Done(): 846 | return ctx.Err() 847 | } 848 | 849 | if err != nil { 850 | var respErr *azcore.ResponseError 851 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict && t.raceCheck { 852 | return ErrRaceCondition 853 | } 854 | 855 | return errors.Wrap(err, "failed to save keys to Cosmos DB") 856 | } 857 | 858 | return nil 859 | } 860 | 861 | func (t *LeakyBucketCosmosDB) Reset(ctx context.Context) error { 862 | state := LeakyBucketState{ 863 | Last: 0, 864 | } 865 | 866 | return t.SetState(ctx, state) 867 | } 868 | -------------------------------------------------------------------------------- /leakybucket_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/google/uuid" 11 | l "github.com/mennanov/limiters" 12 | "github.com/redis/go-redis/v9" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | // leakyBuckets returns all the possible leakyBuckets combinations. 17 | func (s *LimitersTestSuite) leakyBuckets(capacity int64, rate time.Duration, clock l.Clock) map[string]*l.LeakyBucket { 18 | buckets := make(map[string]*l.LeakyBucket) 19 | for lockerName, locker := range s.lockers(true) { 20 | for backendName, backend := range s.leakyBucketBackends() { 21 | buckets[lockerName+":"+backendName] = l.NewLeakyBucket(capacity, rate, locker, backend, clock, s.logger) 22 | } 23 | } 24 | 25 | return buckets 26 | } 27 | 28 | func (s *LimitersTestSuite) leakyBucketBackends() map[string]l.LeakyBucketStateBackend { 29 | return map[string]l.LeakyBucketStateBackend{ 30 | "LeakyBucketInMemory": l.NewLeakyBucketInMemory(), 31 | "LeakyBucketEtcdNoRaceCheck": l.NewLeakyBucketEtcd(s.etcdClient, uuid.New().String(), time.Second, false), 32 | "LeakyBucketEtcdWithRaceCheck": l.NewLeakyBucketEtcd(s.etcdClient, uuid.New().String(), time.Second, true), 33 | "LeakyBucketRedisNoRaceCheck": l.NewLeakyBucketRedis(s.redisClient, uuid.New().String(), time.Second, false), 34 | "LeakyBucketRedisWithRaceCheck": l.NewLeakyBucketRedis(s.redisClient, uuid.New().String(), time.Second, true), 35 | "LeakyBucketRedisClusterNoRaceCheck": l.NewLeakyBucketRedis(s.redisClusterClient, uuid.New().String(), time.Second, false), 36 | "LeakyBucketRedisClusterWithRaceCheck": l.NewLeakyBucketRedis(s.redisClusterClient, uuid.New().String(), time.Second, true), 37 | "LeakyBucketMemcachedNoRaceCheck": l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), time.Second, false), 38 | "LeakyBucketMemcachedWithRaceCheck": l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), time.Second, true), 39 | "LeakyBucketDynamoDBNoRaceCheck": l.NewLeakyBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, false), 40 | "LeakyBucketDynamoDBWithRaceCheck": l.NewLeakyBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, true), 41 | "LeakyBucketCosmosDBNoRaceCheck": l.NewLeakyBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), time.Second, false), 42 | "LeakyBucketCosmosDBWithRaceCheck": l.NewLeakyBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), time.Second, true), 43 | } 44 | } 45 | 46 | func (s *LimitersTestSuite) TestLeakyBucketRealClock() { 47 | capacity := int64(10) 48 | rate := time.Millisecond * 10 49 | clock := l.NewSystemClock() 50 | for _, requestRate := range []time.Duration{rate / 2} { 51 | for name, bucket := range s.leakyBuckets(capacity, rate, clock) { 52 | s.Run(name, func() { 53 | wg := sync.WaitGroup{} 54 | mu := sync.Mutex{} 55 | var totalWait time.Duration 56 | for i := int64(0); i < capacity; i++ { 57 | // No pause for the first request. 58 | if i > 0 { 59 | clock.Sleep(requestRate) 60 | } 61 | wg.Add(1) 62 | go func(bucket *l.LeakyBucket) { 63 | defer wg.Done() 64 | wait, err := bucket.Limit(context.TODO()) 65 | s.Require().NoError(err) 66 | if wait > 0 { 67 | mu.Lock() 68 | totalWait += wait 69 | mu.Unlock() 70 | clock.Sleep(wait) 71 | } 72 | }(bucket) 73 | } 74 | wg.Wait() 75 | expectedWait := time.Duration(0) 76 | if rate > requestRate { 77 | expectedWait = time.Duration(float64(rate-requestRate) * float64(capacity-1) / 2 * float64(capacity)) 78 | } 79 | 80 | // Allow 5ms lag for each request. 81 | // TODO: figure out if this is enough for slow PCs and possibly avoid hard-coding it. 82 | delta := float64(time.Duration(capacity) * time.Millisecond * 25) 83 | s.InDelta(expectedWait, totalWait, delta, "request rate: %d, bucket: %v", requestRate, bucket) 84 | }) 85 | } 86 | } 87 | } 88 | 89 | func (s *LimitersTestSuite) TestLeakyBucketFakeClock() { 90 | capacity := int64(10) 91 | rate := time.Millisecond * 100 92 | clock := newFakeClock() 93 | for _, requestRate := range []time.Duration{rate * 2, rate, rate / 2, rate / 3, rate / 4, 0} { 94 | for name, bucket := range s.leakyBuckets(capacity, rate, clock) { 95 | s.Run(name, func() { 96 | clock.reset() 97 | start := clock.Now() 98 | for i := int64(0); i < capacity; i++ { 99 | // No pause for the first request. 100 | if i > 0 { 101 | clock.Sleep(requestRate) 102 | } 103 | wait, err := bucket.Limit(context.TODO()) 104 | s.Require().NoError(err) 105 | clock.Sleep(wait) 106 | } 107 | interval := rate 108 | if requestRate > rate { 109 | interval = requestRate 110 | } 111 | s.Equal(interval*time.Duration(capacity-1), clock.Now().Sub(start), "request rate: %d, bucket: %v", requestRate, bucket) 112 | }) 113 | } 114 | } 115 | } 116 | 117 | func (s *LimitersTestSuite) TestLeakyBucketOverflow() { 118 | rate := time.Second 119 | capacity := int64(2) 120 | clock := newFakeClock() 121 | for name, bucket := range s.leakyBuckets(capacity, rate, clock) { 122 | s.Run(name, func() { 123 | clock.reset() 124 | // The first call has no wait since there were no calls before. 125 | wait, err := bucket.Limit(context.TODO()) 126 | s.Require().NoError(err) 127 | s.Equal(time.Duration(0), wait) 128 | // The second call increments the queue size by 1. 129 | wait, err = bucket.Limit(context.TODO()) 130 | s.Require().NoError(err) 131 | s.Equal(rate, wait) 132 | // The third call overflows the bucket capacity. 133 | wait, err = bucket.Limit(context.TODO()) 134 | s.Require().Equal(l.ErrLimitExhausted, err) 135 | s.Equal(rate*2, wait) 136 | // Move the Clock 1 position forward. 137 | clock.Sleep(rate) 138 | // Retry the last call. This time it should succeed. 139 | wait, err = bucket.Limit(context.TODO()) 140 | s.Require().NoError(err) 141 | s.Equal(rate, wait) 142 | }) 143 | } 144 | } 145 | 146 | func (s *LimitersTestSuite) TestLeakyBucketReset() { 147 | rate := time.Second 148 | capacity := int64(2) 149 | clock := newFakeClock() 150 | for name, bucket := range s.leakyBuckets(capacity, rate, clock) { 151 | s.Run(name, func() { 152 | clock.reset() 153 | // The first call has no wait since there were no calls before. 154 | wait, err := bucket.Limit(context.TODO()) 155 | s.Require().NoError(err) 156 | s.Equal(time.Duration(0), wait) 157 | // The second call increments the queue size by 1. 158 | wait, err = bucket.Limit(context.TODO()) 159 | s.Require().NoError(err) 160 | s.Equal(rate, wait) 161 | // The third call overflows the bucket capacity. 162 | wait, err = bucket.Limit(context.TODO()) 163 | s.Require().Equal(l.ErrLimitExhausted, err) 164 | s.Equal(rate*2, wait) 165 | // Reset the bucket 166 | err = bucket.Reset(context.TODO()) 167 | s.Require().NoError(err) 168 | // Retry the last call. This time it should succeed. 169 | wait, err = bucket.Limit(context.TODO()) 170 | s.Require().NoError(err) 171 | s.Equal(time.Duration(0), wait) 172 | }) 173 | } 174 | } 175 | 176 | func TestLeakyBucket_ZeroCapacity_ReturnsError(t *testing.T) { 177 | capacity := int64(0) 178 | rate := time.Hour 179 | logger := l.NewStdLogger() 180 | bucket := l.NewLeakyBucket(capacity, rate, l.NewLockNoop(), l.NewLeakyBucketInMemory(), newFakeClock(), logger) 181 | wait, err := bucket.Limit(context.TODO()) 182 | require.Equal(t, l.ErrLimitExhausted, err) 183 | require.Equal(t, time.Duration(0), wait) 184 | } 185 | 186 | func BenchmarkLeakyBuckets(b *testing.B) { 187 | s := new(LimitersTestSuite) 188 | s.SetT(&testing.T{}) 189 | s.SetupSuite() 190 | capacity := int64(1) 191 | rate := time.Second 192 | clock := newFakeClock() 193 | buckets := s.leakyBuckets(capacity, rate, clock) 194 | for name, bucket := range buckets { 195 | b.Run(name, func(b *testing.B) { 196 | for i := 0; i < b.N; i++ { 197 | _, err := bucket.Limit(context.TODO()) 198 | s.Require().NoError(err) 199 | } 200 | }) 201 | } 202 | s.TearDownSuite() 203 | } 204 | 205 | // setStateInOldFormat is a test utility method for writing state in the old format to Redis. 206 | func setStateInOldFormat(ctx context.Context, cli *redis.Client, prefix string, state l.LeakyBucketState, ttl time.Duration) error { 207 | _, err := cli.TxPipelined(ctx, func(pipeliner redis.Pipeliner) error { 208 | if err := pipeliner.Set(ctx, fmt.Sprintf("{%s}last", prefix), state.Last, ttl).Err(); err != nil { 209 | return err 210 | } 211 | 212 | return nil 213 | }) 214 | 215 | return err 216 | } 217 | 218 | // TestLeakyBucketRedisBackwardCompatibility tests that the new State method can read data written in the old format. 219 | func (s *LimitersTestSuite) TestLeakyBucketRedisBackwardCompatibility() { 220 | // Create a new LeakyBucketRedis instance 221 | prefix := uuid.New().String() 222 | backend := l.NewLeakyBucketRedis(s.redisClient, prefix, time.Second, false) 223 | 224 | // Write state using old format 225 | ctx := context.Background() 226 | expectedState := l.LeakyBucketState{ 227 | Last: 12345, 228 | } 229 | 230 | // Write directly to Redis using old format 231 | err := setStateInOldFormat(ctx, s.redisClient, prefix, expectedState, time.Second) 232 | s.Require().NoError(err, "Failed to set state using old format") 233 | 234 | // Read state using new format (State) 235 | actualState, err := backend.State(ctx) 236 | s.Require().NoError(err, "Failed to get state using new format") 237 | 238 | // Verify the state is correctly read 239 | s.Equal(expectedState.Last, actualState.Last, "Last values should match") 240 | } 241 | -------------------------------------------------------------------------------- /limiters.go: -------------------------------------------------------------------------------- 1 | // Package limiters provides general purpose rate limiter implementations. 2 | package limiters 3 | 4 | import ( 5 | "errors" 6 | "log" 7 | "time" 8 | ) 9 | 10 | var ( 11 | // ErrLimitExhausted is returned by the Limiter in case the number of requests overflows the capacity of a Limiter. 12 | ErrLimitExhausted = errors.New("requests limit exhausted") 13 | 14 | // ErrRaceCondition is returned when there is a race condition while saving a state of a rate limiter. 15 | ErrRaceCondition = errors.New("race condition detected") 16 | ) 17 | 18 | // Logger wraps the Log method for logging. 19 | type Logger interface { 20 | // Log logs the given arguments. 21 | Log(v ...interface{}) 22 | } 23 | 24 | // StdLogger implements the Logger interface. 25 | type StdLogger struct{} 26 | 27 | // NewStdLogger creates a new instance of StdLogger. 28 | func NewStdLogger() *StdLogger { 29 | return &StdLogger{} 30 | } 31 | 32 | // Log delegates the logging to the std logger. 33 | func (l *StdLogger) Log(v ...interface{}) { 34 | log.Println(v...) 35 | } 36 | 37 | // Clock encapsulates a system Clock. 38 | // Used. 39 | type Clock interface { 40 | // Now returns the current system time. 41 | Now() time.Time 42 | } 43 | 44 | // SystemClock implements the Clock interface by using the real system clock. 45 | type SystemClock struct{} 46 | 47 | // NewSystemClock creates a new instance of SystemClock. 48 | func NewSystemClock() *SystemClock { 49 | return &SystemClock{} 50 | } 51 | 52 | // Now returns the current system time. 53 | func (c *SystemClock) Now() time.Time { 54 | return time.Now() 55 | } 56 | 57 | // Sleep blocks (sleeps) for the given duration. 58 | func (c *SystemClock) Sleep(d time.Duration) { 59 | time.Sleep(d) 60 | } 61 | -------------------------------------------------------------------------------- /limiters_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "hash/fnv" 8 | "os" 9 | "strings" 10 | "sync" 11 | "testing" 12 | "time" 13 | 14 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 15 | "github.com/aws/aws-sdk-go-v2/aws" 16 | "github.com/aws/aws-sdk-go-v2/config" 17 | "github.com/aws/aws-sdk-go-v2/credentials" 18 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 19 | "github.com/bradfitz/gomemcache/memcache" 20 | "github.com/go-redsync/redsync/v4/redis/goredis/v9" 21 | "github.com/google/uuid" 22 | "github.com/hashicorp/consul/api" 23 | l "github.com/mennanov/limiters" 24 | "github.com/redis/go-redis/v9" 25 | "github.com/samuel/go-zookeeper/zk" 26 | "github.com/stretchr/testify/suite" 27 | clientv3 "go.etcd.io/etcd/client/v3" 28 | ) 29 | 30 | type fakeClock struct { 31 | mu sync.Mutex 32 | initial time.Time 33 | t time.Time 34 | } 35 | 36 | func newFakeClock() *fakeClock { 37 | now := time.Now() 38 | 39 | return &fakeClock{t: now, initial: now} 40 | } 41 | 42 | func newFakeClockWithTime(t time.Time) *fakeClock { 43 | return &fakeClock{t: t, initial: t} 44 | } 45 | 46 | func (c *fakeClock) Now() time.Time { 47 | c.mu.Lock() 48 | defer c.mu.Unlock() 49 | 50 | return c.t 51 | } 52 | 53 | func (c *fakeClock) Sleep(d time.Duration) { 54 | if d == 0 { 55 | return 56 | } 57 | c.mu.Lock() 58 | defer c.mu.Unlock() 59 | c.t = c.t.Add(d) 60 | } 61 | 62 | func (c *fakeClock) reset() { 63 | c.mu.Lock() 64 | defer c.mu.Unlock() 65 | c.t = c.initial 66 | } 67 | 68 | type LimitersTestSuite struct { 69 | suite.Suite 70 | etcdClient *clientv3.Client 71 | redisClient *redis.Client 72 | redisClusterClient *redis.ClusterClient 73 | consulClient *api.Client 74 | zkConn *zk.Conn 75 | logger *l.StdLogger 76 | dynamodbClient *dynamodb.Client 77 | dynamoDBTableProps l.DynamoDBTableProperties 78 | memcacheClient *memcache.Client 79 | pgDb *sql.DB 80 | cosmosClient *azcosmos.Client 81 | cosmosContainerClient *azcosmos.ContainerClient 82 | } 83 | 84 | func (s *LimitersTestSuite) SetupSuite() { 85 | var err error 86 | s.etcdClient, err = clientv3.New(clientv3.Config{ 87 | Endpoints: strings.Split(os.Getenv("ETCD_ENDPOINTS"), ","), 88 | DialTimeout: time.Second, 89 | }) 90 | s.Require().NoError(err) 91 | s.redisClient = redis.NewClient(&redis.Options{ 92 | Addr: os.Getenv("REDIS_ADDR"), 93 | }) 94 | s.redisClusterClient = redis.NewClusterClient(&redis.ClusterOptions{ 95 | Addrs: strings.Split(os.Getenv("REDIS_NODES"), ","), 96 | }) 97 | s.consulClient, err = api.NewClient(&api.Config{Address: os.Getenv("CONSUL_ADDR")}) 98 | s.Require().NoError(err) 99 | s.zkConn, _, err = zk.Connect(strings.Split(os.Getenv("ZOOKEEPER_ENDPOINTS"), ","), time.Second) 100 | s.Require().NoError(err) 101 | s.logger = l.NewStdLogger() 102 | 103 | awsCfg, err := config.LoadDefaultConfig(context.Background(), 104 | config.WithRegion("us-east-1"), 105 | config.WithCredentialsProvider(credentials.StaticCredentialsProvider{ 106 | Value: aws.Credentials{ 107 | AccessKeyID: "dummy", SecretAccessKey: "dummy", SessionToken: "dummy", 108 | Source: "Hard-coded credentials; values are irrelevant for local DynamoDB", 109 | }, 110 | }), 111 | ) 112 | s.Require().NoError(err) 113 | 114 | endpoint := fmt.Sprintf("http://%s", os.Getenv("AWS_ADDR")) 115 | s.dynamodbClient = dynamodb.NewFromConfig(awsCfg, func(options *dynamodb.Options) { 116 | options.BaseEndpoint = &endpoint 117 | }) 118 | 119 | _ = DeleteTestDynamoDBTable(context.Background(), s.dynamodbClient) 120 | s.Require().NoError(CreateTestDynamoDBTable(context.Background(), s.dynamodbClient)) 121 | s.dynamoDBTableProps, err = l.LoadDynamoDBTableProperties(context.Background(), s.dynamodbClient, testDynamoDBTableName) 122 | s.Require().NoError(err) 123 | 124 | s.memcacheClient = memcache.New(strings.Split(os.Getenv("MEMCACHED_ADDR"), ",")...) 125 | s.Require().NoError(s.memcacheClient.Ping()) 126 | 127 | s.pgDb, err = sql.Open("postgres", os.Getenv("POSTGRES_URL")) 128 | s.Require().NoError(err) 129 | s.Require().NoError(s.pgDb.Ping()) 130 | 131 | // https://learn.microsoft.com/en-us/azure/cosmos-db/emulator#authentication 132 | connString := fmt.Sprintf("AccountEndpoint=http://%s/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;", os.Getenv("COSMOS_ADDR")) 133 | s.cosmosClient, err = azcosmos.NewClientFromConnectionString(connString, &azcosmos.ClientOptions{}) 134 | s.Require().NoError(err) 135 | 136 | if err = CreateCosmosDBContainer(context.Background(), s.cosmosClient); err != nil { 137 | s.Require().ErrorContains(err, "Database 'limiters-db-test' already exists") 138 | } 139 | 140 | s.cosmosContainerClient, err = s.cosmosClient.NewContainer(testCosmosDBName, testCosmosContainerName) 141 | s.Require().NoError(err) 142 | } 143 | 144 | func (s *LimitersTestSuite) TearDownSuite() { 145 | s.Assert().NoError(s.etcdClient.Close()) 146 | s.Assert().NoError(s.redisClient.Close()) 147 | s.Assert().NoError(s.redisClusterClient.Close()) 148 | s.Assert().NoError(DeleteTestDynamoDBTable(context.Background(), s.dynamodbClient)) 149 | s.Assert().NoError(s.memcacheClient.Close()) 150 | s.Assert().NoError(s.pgDb.Close()) 151 | s.Assert().NoError(DeleteCosmosDBContainer(context.Background(), s.cosmosClient)) 152 | } 153 | 154 | func TestLimitersTestSuite(t *testing.T) { 155 | suite.Run(t, new(LimitersTestSuite)) 156 | } 157 | 158 | // lockers returns all possible lockers (including noop). 159 | func (s *LimitersTestSuite) lockers(generateKeys bool) map[string]l.DistLocker { 160 | lockers := s.distLockers(generateKeys) 161 | lockers["LockNoop"] = l.NewLockNoop() 162 | 163 | return lockers 164 | } 165 | 166 | func hash(s string) int64 { 167 | h := fnv.New32a() 168 | _, err := h.Write([]byte(s)) 169 | if err != nil { 170 | panic(err) 171 | } 172 | 173 | return int64(h.Sum32()) 174 | } 175 | 176 | // distLockers returns distributed lockers only. 177 | func (s *LimitersTestSuite) distLockers(generateKeys bool) map[string]l.DistLocker { 178 | randomKey := uuid.New().String() 179 | consulKey := randomKey 180 | etcdKey := randomKey 181 | zkKey := "/" + randomKey 182 | redisKey := randomKey 183 | memcacheKey := randomKey 184 | pgKey := randomKey 185 | if !generateKeys { 186 | consulKey = "dist_locker" 187 | etcdKey = "dist_locker" 188 | zkKey = "/dist_locker" 189 | redisKey = "dist_locker" 190 | memcacheKey = "dist_locker" 191 | pgKey = "dist_locker" 192 | } 193 | consulLock, err := s.consulClient.LockKey(consulKey) 194 | s.Require().NoError(err) 195 | 196 | return map[string]l.DistLocker{ 197 | "LockEtcd": l.NewLockEtcd(s.etcdClient, etcdKey, s.logger), 198 | "LockConsul": l.NewLockConsul(consulLock), 199 | "LockZookeeper": l.NewLockZookeeper(zk.NewLock(s.zkConn, zkKey, zk.WorldACL(zk.PermAll))), 200 | "LockRedis": l.NewLockRedis(goredis.NewPool(s.redisClient), redisKey), 201 | "LockRedisCluster": l.NewLockRedis(goredis.NewPool(s.redisClusterClient), redisKey), 202 | "LockMemcached": l.NewLockMemcached(s.memcacheClient, memcacheKey), 203 | "LockPostgreSQL": l.NewLockPostgreSQL(s.pgDb, hash(pgKey)), 204 | } 205 | } 206 | 207 | func (s *LimitersTestSuite) TestLimitContextCancelled() { 208 | clock := newFakeClock() 209 | capacity := int64(2) 210 | rate := time.Second 211 | limiters := make(map[string]interface{}) 212 | for n, b := range s.tokenBuckets(capacity, rate, clock) { 213 | limiters[n] = b 214 | } 215 | for n, b := range s.leakyBuckets(capacity, rate, clock) { 216 | limiters[n] = b 217 | } 218 | for n, w := range s.fixedWindows(capacity, rate, clock) { 219 | limiters[n] = w 220 | } 221 | for n, w := range s.slidingWindows(capacity, rate, clock, 1e-9) { 222 | limiters[n] = w 223 | } 224 | for n, b := range s.concurrentBuffers(capacity, rate, clock) { 225 | limiters[n] = b 226 | } 227 | type rateLimiter interface { 228 | Limit(context.Context) (time.Duration, error) 229 | } 230 | type concurrentLimiter interface { 231 | Limit(context.Context, string) error 232 | } 233 | 234 | for name, limiter := range limiters { 235 | s.Run(name, func() { 236 | done1 := make(chan struct{}) 237 | go func(limiter interface{}) { 238 | defer close(done1) 239 | // The context is expired shortly after it is created. 240 | ctx, cancel := context.WithCancel(context.Background()) 241 | cancel() 242 | switch lim := limiter.(type) { 243 | case rateLimiter: 244 | _, err := lim.Limit(ctx) 245 | s.Error(err, "%T", limiter) 246 | 247 | case concurrentLimiter: 248 | s.Error(lim.Limit(ctx, "key"), "%T", limiter) 249 | } 250 | }(limiter) 251 | done2 := make(chan struct{}) 252 | go func(limiter interface{}) { 253 | defer close(done2) 254 | <-done1 255 | ctx := context.Background() 256 | switch lim := limiter.(type) { 257 | case rateLimiter: 258 | _, err := lim.Limit(ctx) 259 | s.NoError(err, "%T", limiter) 260 | 261 | case concurrentLimiter: 262 | s.NoError(lim.Limit(ctx, "key"), "%T", limiter) 263 | } 264 | }(limiter) 265 | // Verify that the second go routine succeeded calling the Limit() method. 266 | <-done2 267 | }) 268 | } 269 | } 270 | -------------------------------------------------------------------------------- /locks.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "time" 7 | 8 | lock "github.com/alessandro-c/gomemcached-lock" 9 | "github.com/alessandro-c/gomemcached-lock/adapters/gomemcache" 10 | "github.com/bradfitz/gomemcache/memcache" 11 | "github.com/cenkalti/backoff/v3" 12 | "github.com/go-redsync/redsync/v4" 13 | redsyncredis "github.com/go-redsync/redsync/v4/redis" 14 | "github.com/hashicorp/consul/api" 15 | _ "github.com/lib/pq" 16 | "github.com/pkg/errors" 17 | "github.com/samuel/go-zookeeper/zk" 18 | clientv3 "go.etcd.io/etcd/client/v3" 19 | "go.etcd.io/etcd/client/v3/concurrency" 20 | ) 21 | 22 | // DistLocker is a context aware distributed locker (interface is similar to sync.Locker). 23 | type DistLocker interface { 24 | // Lock locks the locker. 25 | Lock(ctx context.Context) error 26 | // Unlock unlocks the previously successfully locked lock. 27 | Unlock(ctx context.Context) error 28 | } 29 | 30 | // LockNoop is a no-op implementation of the DistLocker interface. 31 | // It should only be used with the in-memory backends as they are already thread-safe and don't need distributed locks. 32 | type LockNoop struct{} 33 | 34 | // NewLockNoop creates a new LockNoop. 35 | func NewLockNoop() *LockNoop { 36 | return &LockNoop{} 37 | } 38 | 39 | // Lock imitates locking. 40 | func (n LockNoop) Lock(ctx context.Context) error { 41 | return ctx.Err() 42 | } 43 | 44 | // Unlock does nothing. 45 | func (n LockNoop) Unlock(_ context.Context) error { 46 | return nil 47 | } 48 | 49 | // LockEtcd implements the DistLocker interface using etcd. 50 | // 51 | // See https://github.com/etcd-io/etcd/blob/master/Documentation/learning/why.md#using-etcd-for-distributed-coordination 52 | type LockEtcd struct { 53 | cli *clientv3.Client 54 | prefix string 55 | logger Logger 56 | mu *concurrency.Mutex 57 | session *concurrency.Session 58 | } 59 | 60 | // NewLockEtcd creates a new instance of LockEtcd. 61 | func NewLockEtcd(cli *clientv3.Client, prefix string, logger Logger) *LockEtcd { 62 | return &LockEtcd{cli: cli, prefix: prefix, logger: logger} 63 | } 64 | 65 | // Lock creates a new session-based lock in etcd and locks it. 66 | func (l *LockEtcd) Lock(ctx context.Context) error { 67 | var err error 68 | l.session, err = concurrency.NewSession(l.cli, concurrency.WithTTL(1)) 69 | if err != nil { 70 | return errors.Wrap(err, "failed to create an etcd session") 71 | } 72 | l.mu = concurrency.NewMutex(l.session, l.prefix) 73 | 74 | return errors.Wrap(l.mu.Lock(ctx), "failed to lock a mutex in etcd") 75 | } 76 | 77 | // Unlock unlocks the previously locked lock. 78 | func (l *LockEtcd) Unlock(ctx context.Context) error { 79 | defer func() { 80 | if err := l.session.Close(); err != nil { 81 | l.logger.Log(err) 82 | } 83 | }() 84 | 85 | return errors.Wrap(l.mu.Unlock(ctx), "failed to unlock a mutex in etcd") 86 | } 87 | 88 | // LockConsul is a wrapper around github.com/hashicorp/consul/api.Lock that implements the DistLocker interface. 89 | type LockConsul struct { 90 | lock *api.Lock 91 | } 92 | 93 | // NewLockConsul creates a new LockConsul instance. 94 | func NewLockConsul(lock *api.Lock) *LockConsul { 95 | return &LockConsul{lock: lock} 96 | } 97 | 98 | // Lock locks the lock in Consul. 99 | func (l *LockConsul) Lock(ctx context.Context) error { 100 | _, err := l.lock.Lock(ctx.Done()) 101 | 102 | return errors.Wrap(err, "failed to lock a mutex in consul") 103 | } 104 | 105 | // Unlock unlocks the lock in Consul. 106 | func (l *LockConsul) Unlock(_ context.Context) error { 107 | return l.lock.Unlock() 108 | } 109 | 110 | // LockZookeeper is a wrapper around github.com/samuel/go-zookeeper/zk.Lock that implements the DistLocker interface. 111 | type LockZookeeper struct { 112 | lock *zk.Lock 113 | } 114 | 115 | // NewLockZookeeper creates a new instance of LockZookeeper. 116 | func NewLockZookeeper(lock *zk.Lock) *LockZookeeper { 117 | return &LockZookeeper{lock: lock} 118 | } 119 | 120 | // Lock locks the lock in Zookeeper. 121 | // TODO: add context aware support once https://github.com/samuel/go-zookeeper/pull/168 is merged. 122 | func (l *LockZookeeper) Lock(_ context.Context) error { 123 | return l.lock.Lock() 124 | } 125 | 126 | // Unlock unlocks the lock in Zookeeper. 127 | func (l *LockZookeeper) Unlock(_ context.Context) error { 128 | return l.lock.Unlock() 129 | } 130 | 131 | // LockRedis is a wrapper around github.com/go-redsync/redsync that implements the DistLocker interface. 132 | type LockRedis struct { 133 | mutex *redsync.Mutex 134 | } 135 | 136 | // NewLockRedis creates a new instance of LockRedis. 137 | func NewLockRedis(pool redsyncredis.Pool, mutexName string, options ...redsync.Option) *LockRedis { 138 | rs := redsync.New(pool) 139 | mutex := rs.NewMutex(mutexName, options...) 140 | 141 | return &LockRedis{mutex: mutex} 142 | } 143 | 144 | // Lock locks the lock in Redis. 145 | func (l *LockRedis) Lock(ctx context.Context) error { 146 | err := l.mutex.LockContext(ctx) 147 | 148 | return errors.Wrap(err, "failed to lock a mutex in redis") 149 | } 150 | 151 | // Unlock unlocks the lock in Redis. 152 | func (l *LockRedis) Unlock(ctx context.Context) error { 153 | if ok, err := l.mutex.UnlockContext(ctx); !ok || err != nil { 154 | return errors.Wrap(err, "failed to unlock a mutex in redis") 155 | } 156 | 157 | return nil 158 | } 159 | 160 | // LockMemcached is a wrapper around github.com/alessandro-c/gomemcached-lock that implements the DistLocker interface. 161 | // It is caller's responsibility to make sure the uniqueness of mutexName, and not to use the same key in multiple 162 | // Memcached-based implementations. 163 | type LockMemcached struct { 164 | locker *lock.Locker 165 | mutexName string 166 | backoff backoff.BackOff 167 | } 168 | 169 | // NewLockMemcached creates a new instance of LockMemcached. 170 | // Default backoff is to retry every 100ms for 100 times (10 seconds). 171 | func NewLockMemcached(client *memcache.Client, mutexName string) *LockMemcached { 172 | adapter := gomemcache.New(client) 173 | locker := lock.New(adapter, mutexName, "") 174 | b := backoff.WithMaxRetries(backoff.NewConstantBackOff(100*time.Millisecond), 100) 175 | 176 | return &LockMemcached{ 177 | locker: locker, 178 | mutexName: mutexName, 179 | backoff: b, 180 | } 181 | } 182 | 183 | // WithLockAcquireBackoff sets the backoff policy for retrying an operation. 184 | func (l *LockMemcached) WithLockAcquireBackoff(b backoff.BackOff) *LockMemcached { 185 | l.backoff = b 186 | 187 | return l 188 | } 189 | 190 | // Lock locks the lock in Memcached. 191 | func (l *LockMemcached) Lock(ctx context.Context) error { 192 | o := func() error { return l.locker.Lock(time.Minute) } 193 | 194 | return backoff.Retry(o, l.backoff) 195 | } 196 | 197 | // Unlock unlocks the lock in Memcached. 198 | func (l *LockMemcached) Unlock(ctx context.Context) error { 199 | return l.locker.Release() 200 | } 201 | 202 | // LockPostgreSQL is an implementation of the DistLocker interface using PostgreSQL's advisory lock. 203 | type LockPostgreSQL struct { 204 | db *sql.DB 205 | id int64 206 | tx *sql.Tx 207 | } 208 | 209 | // NewLockPostgreSQL creates a new LockPostgreSQL. 210 | func NewLockPostgreSQL(db *sql.DB, id int64) *LockPostgreSQL { 211 | return &LockPostgreSQL{db, id, nil} 212 | } 213 | 214 | // Make sure LockPostgreSQL implements DistLocker interface. 215 | var _ DistLocker = (*LockPostgreSQL)(nil) 216 | 217 | // Lock acquire an advisory lock in PostgreSQL. 218 | func (l *LockPostgreSQL) Lock(ctx context.Context) error { 219 | var err error 220 | l.tx, err = l.db.BeginTx(ctx, &sql.TxOptions{}) 221 | if err != nil { 222 | return err 223 | } 224 | _, err = l.tx.ExecContext(ctx, "SELECT pg_advisory_xact_lock($1)", l.id) 225 | 226 | return err 227 | } 228 | 229 | // Unlock releases an advisory lock in PostgreSQL. 230 | func (l *LockPostgreSQL) Unlock(ctx context.Context) error { 231 | return l.tx.Rollback() 232 | } 233 | -------------------------------------------------------------------------------- /locks_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | "github.com/mennanov/limiters" 10 | ) 11 | 12 | func (s *LimitersTestSuite) useLock(lock limiters.DistLocker, shared *int, sleep time.Duration) { 13 | s.NoError(lock.Lock(context.TODO())) 14 | sh := *shared 15 | // Imitate heavy work... 16 | time.Sleep(sleep) 17 | // Check for the race condition. 18 | s.Equal(sh, *shared) 19 | *shared++ 20 | s.NoError(lock.Unlock(context.Background())) 21 | } 22 | 23 | func (s *LimitersTestSuite) TestDistLockers() { 24 | locks1 := s.distLockers(false) 25 | locks2 := s.distLockers(false) 26 | for name := range locks1 { 27 | s.Run(name, func() { 28 | var shared int 29 | rounds := 6 30 | sleep := time.Millisecond * 50 31 | for i := 0; i < rounds; i++ { 32 | wg := sync.WaitGroup{} 33 | wg.Add(2) 34 | go func(k string) { 35 | defer wg.Done() 36 | s.useLock(locks1[k], &shared, sleep) 37 | }(name) 38 | go func(k string) { 39 | defer wg.Done() 40 | s.useLock(locks2[k], &shared, sleep) 41 | }(name) 42 | wg.Wait() 43 | } 44 | s.Equal(rounds*2, shared) 45 | }) 46 | } 47 | } 48 | 49 | func BenchmarkDistLockers(b *testing.B) { 50 | s := new(LimitersTestSuite) 51 | s.SetT(&testing.T{}) 52 | s.SetupSuite() 53 | lockers := s.distLockers(false) 54 | for name, locker := range lockers { 55 | b.Run(name, func(b *testing.B) { 56 | for i := 0; i < b.N; i++ { 57 | s.Require().NoError(locker.Lock(context.Background())) 58 | s.Require().NoError(locker.Unlock(context.Background())) 59 | } 60 | }) 61 | } 62 | s.TearDownSuite() 63 | } 64 | -------------------------------------------------------------------------------- /registry.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "container/heap" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // pqItem is an item in the priority queue. 10 | type pqItem struct { 11 | value interface{} 12 | exp time.Time 13 | index int 14 | key string 15 | } 16 | 17 | // gcPq is a priority queue. 18 | type gcPq []*pqItem 19 | 20 | func (pq gcPq) Len() int { return len(pq) } 21 | 22 | func (pq gcPq) Less(i, j int) bool { 23 | return pq[i].exp.Before(pq[j].exp) 24 | } 25 | 26 | func (pq gcPq) Swap(i, j int) { 27 | pq[i], pq[j] = pq[j], pq[i] 28 | pq[i].index = i 29 | pq[j].index = j 30 | } 31 | 32 | func (pq *gcPq) Push(x interface{}) { 33 | n := len(*pq) 34 | item := x.(*pqItem) 35 | item.index = n 36 | *pq = append(*pq, item) 37 | } 38 | 39 | func (pq *gcPq) Pop() interface{} { 40 | old := *pq 41 | n := len(old) 42 | item := old[n-1] 43 | item.index = -1 // for safety 44 | *pq = old[0 : n-1] 45 | 46 | return item 47 | } 48 | 49 | // Registry is a thread-safe garbage-collectable registry of values. 50 | type Registry struct { 51 | // Guards all the fields below it. 52 | mx sync.Mutex 53 | pq *gcPq 54 | m map[string]*pqItem 55 | } 56 | 57 | // NewRegistry creates a new instance of Registry. 58 | func NewRegistry() *Registry { 59 | pq := make(gcPq, 0) 60 | 61 | return &Registry{pq: &pq, m: make(map[string]*pqItem)} 62 | } 63 | 64 | // GetOrCreate gets an existing value by key and updates its expiration time. 65 | // If the key lookup fails it creates a new value by calling the provided value closure and puts it on the queue. 66 | func (r *Registry) GetOrCreate(key string, value func() interface{}, ttl time.Duration, now time.Time) interface{} { 67 | r.mx.Lock() 68 | defer r.mx.Unlock() 69 | item, ok := r.m[key] 70 | if ok { 71 | // Update the expiration time. 72 | item.exp = now.Add(ttl) 73 | heap.Fix(r.pq, item.index) 74 | } else { 75 | item = &pqItem{ 76 | value: value(), 77 | exp: now.Add(ttl), 78 | key: key, 79 | } 80 | heap.Push(r.pq, item) 81 | r.m[key] = item 82 | } 83 | 84 | return item.value 85 | } 86 | 87 | // DeleteExpired deletes expired items from the registry and returns the number of deleted items. 88 | func (r *Registry) DeleteExpired(now time.Time) int { 89 | r.mx.Lock() 90 | defer r.mx.Unlock() 91 | c := 0 92 | for len(*r.pq) != 0 { 93 | item := (*r.pq)[0] 94 | if now.Before(item.exp) { 95 | break 96 | } 97 | delete(r.m, item.key) 98 | heap.Pop(r.pq) 99 | c++ 100 | } 101 | 102 | return c 103 | } 104 | 105 | // Delete deletes an item from the registry. 106 | func (r *Registry) Delete(key string) { 107 | r.mx.Lock() 108 | defer r.mx.Unlock() 109 | item, ok := r.m[key] 110 | if !ok { 111 | return 112 | } 113 | delete(r.m, key) 114 | heap.Remove(r.pq, item.index) 115 | } 116 | 117 | // Exists returns true if an item with the given key exists in the registry. 118 | func (r *Registry) Exists(key string) bool { 119 | r.mx.Lock() 120 | defer r.mx.Unlock() 121 | _, ok := r.m[key] 122 | 123 | return ok 124 | } 125 | 126 | // Len returns the number of items in the registry. 127 | func (r *Registry) Len() int { 128 | r.mx.Lock() 129 | defer r.mx.Unlock() 130 | 131 | return len(*r.pq) 132 | } 133 | -------------------------------------------------------------------------------- /registry_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strconv" 7 | "testing" 8 | "time" 9 | 10 | "github.com/mennanov/limiters" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | type testingLimiter struct{} 16 | 17 | func newTestingLimiter() *testingLimiter { 18 | return &testingLimiter{} 19 | } 20 | 21 | func (l *testingLimiter) Limit(context.Context) (time.Duration, error) { 22 | return 0, nil 23 | } 24 | 25 | func TestRegistry_GetOrCreate(t *testing.T) { 26 | registry := limiters.NewRegistry() 27 | called := false 28 | clock := newFakeClock() 29 | limiter := newTestingLimiter() 30 | l := registry.GetOrCreate("key", func() interface{} { 31 | called = true 32 | 33 | return limiter 34 | }, time.Second, clock.Now()) 35 | assert.Equal(t, limiter, l) 36 | // Verify that the closure was called to create a value. 37 | assert.True(t, called) 38 | called = false 39 | l = registry.GetOrCreate("key", func() interface{} { 40 | called = true 41 | 42 | return newTestingLimiter() 43 | }, time.Second, clock.Now()) 44 | assert.Equal(t, limiter, l) 45 | // Verify that the closure was NOT called to create a value as it already exists. 46 | assert.False(t, called) 47 | } 48 | 49 | func TestRegistry_DeleteExpired(t *testing.T) { 50 | registry := limiters.NewRegistry() 51 | clock := newFakeClock() 52 | // Add limiters to the registry. 53 | for i := 1; i <= 10; i++ { 54 | registry.GetOrCreate(fmt.Sprintf("key%d", i), func() interface{} { 55 | return newTestingLimiter() 56 | }, time.Second*time.Duration(i), clock.Now()) 57 | } 58 | clock.Sleep(time.Second * 3) 59 | // "touch" the "key3" value that is about to be expired so that its expiration time is extended for 1s. 60 | registry.GetOrCreate("key3", func() interface{} { 61 | return newTestingLimiter() 62 | }, time.Second, clock.Now()) 63 | 64 | assert.Equal(t, 2, registry.DeleteExpired(clock.Now())) 65 | for i := 1; i <= 10; i++ { 66 | if i <= 2 { 67 | assert.False(t, registry.Exists(fmt.Sprintf("key%d", i))) 68 | } else { 69 | assert.True(t, registry.Exists(fmt.Sprintf("key%d", i))) 70 | } 71 | } 72 | } 73 | 74 | func TestRegistry_Delete(t *testing.T) { 75 | registry := limiters.NewRegistry() 76 | clock := newFakeClock() 77 | item := &struct{}{} 78 | require.Equal(t, item, registry.GetOrCreate("key", func() interface{} { 79 | return item 80 | }, time.Second, clock.Now())) 81 | require.Equal(t, item, registry.GetOrCreate("key", func() interface{} { 82 | return &struct{}{} 83 | }, time.Second, clock.Now())) 84 | registry.Delete("key") 85 | assert.False(t, registry.Exists("key")) 86 | } 87 | 88 | // This test is expected to fail when run with the --race flag. 89 | func TestRegistry_ConcurrentUsage(t *testing.T) { 90 | registry := limiters.NewRegistry() 91 | clock := newFakeClock() 92 | for i := 0; i < 10; i++ { 93 | go func(i int) { 94 | registry.GetOrCreate(strconv.Itoa(i), func() interface{} { return &struct{}{} }, 0, clock.Now()) 95 | }(i) 96 | } 97 | for i := 0; i < 10; i++ { 98 | go func(i int) { 99 | registry.DeleteExpired(clock.Now()) 100 | }(i) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /revive.toml: -------------------------------------------------------------------------------- 1 | ignoreGeneratedHeader = false 2 | severity = "warning" 3 | confidence = 0.8 4 | errorCode = 0 5 | warningCode = 0 6 | 7 | [rule.blank-imports] 8 | [rule.context-as-argument] 9 | [rule.context-keys-type] 10 | [rule.dot-imports] 11 | [rule.error-return] 12 | [rule.error-strings] 13 | [rule.error-naming] 14 | [rule.exported] 15 | [rule.if-return] 16 | [rule.increment-decrement] 17 | [rule.var-naming] 18 | [rule.var-declaration] 19 | [rule.package-comments] 20 | [rule.range] 21 | [rule.receiver-naming] 22 | [rule.time-naming] 23 | [rule.unexported-return] 24 | [rule.indent-error-flow] 25 | [rule.errorf] 26 | [rule.empty-block] 27 | [rule.superfluous-else] 28 | [rule.unused-parameter] 29 | [rule.unreachable-code] 30 | [rule.redefines-builtin-id] 31 | -------------------------------------------------------------------------------- /slidingwindow.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "math" 8 | "net/http" 9 | "strconv" 10 | "sync" 11 | "time" 12 | 13 | "github.com/Azure/azure-sdk-for-go/sdk/azcore" 14 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 15 | "github.com/aws/aws-sdk-go-v2/aws" 16 | "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" 17 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 18 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" 19 | "github.com/bradfitz/gomemcache/memcache" 20 | "github.com/pkg/errors" 21 | "github.com/redis/go-redis/v9" 22 | ) 23 | 24 | // SlidingWindowIncrementer wraps the Increment method. 25 | type SlidingWindowIncrementer interface { 26 | // Increment increments the request counter for the current window and returns the counter values for the previous 27 | // window and the current one. 28 | // TTL is the time duration before the next window. 29 | Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (prevCount, currCount int64, err error) 30 | } 31 | 32 | // SlidingWindow implements a Sliding Window rate limiting algorithm. 33 | // 34 | // It does not require a distributed lock and uses a minimum amount of memory, however it will disallow all the requests 35 | // in case when a client is flooding the service with requests. 36 | // It's the client's responsibility to handle the disallowed request and wait before making a new request again. 37 | type SlidingWindow struct { 38 | backend SlidingWindowIncrementer 39 | clock Clock 40 | rate time.Duration 41 | capacity int64 42 | epsilon float64 43 | } 44 | 45 | // NewSlidingWindow creates a new instance of SlidingWindow. 46 | // Capacity is the maximum amount of requests allowed per window. 47 | // Rate is the window size. 48 | // Epsilon is the max-allowed range of difference when comparing the current weighted number of requests with capacity. 49 | func NewSlidingWindow(capacity int64, rate time.Duration, slidingWindowIncrementer SlidingWindowIncrementer, clock Clock, epsilon float64) *SlidingWindow { 50 | return &SlidingWindow{backend: slidingWindowIncrementer, clock: clock, rate: rate, capacity: capacity, epsilon: epsilon} 51 | } 52 | 53 | // Limit returns the time duration to wait before the request can be processed. 54 | // It returns ErrLimitExhausted if the request overflows the capacity. 55 | func (s *SlidingWindow) Limit(ctx context.Context) (time.Duration, error) { 56 | now := s.clock.Now() 57 | currWindow := now.Truncate(s.rate) 58 | prevWindow := currWindow.Add(-s.rate) 59 | ttl := s.rate - now.Sub(currWindow) 60 | prev, curr, err := s.backend.Increment(ctx, prevWindow, currWindow, ttl+s.rate) 61 | if err != nil { 62 | return 0, err 63 | } 64 | 65 | // "prev" and "curr" are capped at "s.capacity + s.epsilon" using math.Ceil to round up any fractional values, 66 | // ensuring that in the worst case, "total" can be slightly greater than "s.capacity". 67 | prev = int64(math.Min(float64(prev), math.Ceil(float64(s.capacity)+s.epsilon))) 68 | curr = int64(math.Min(float64(curr), math.Ceil(float64(s.capacity)+s.epsilon))) 69 | 70 | total := float64(prev*int64(ttl))/float64(s.rate) + float64(curr) 71 | if total-float64(s.capacity) >= s.epsilon { 72 | var wait time.Duration 73 | if curr <= s.capacity-1 && prev > 0 { 74 | wait = ttl - time.Duration(float64(s.capacity-1-curr)/float64(prev)*float64(s.rate)) 75 | } else { 76 | // If prev == 0. 77 | wait = ttl + time.Duration((1-float64(s.capacity-1)/float64(curr))*float64(s.rate)) 78 | } 79 | 80 | return wait, ErrLimitExhausted 81 | } 82 | 83 | return 0, nil 84 | } 85 | 86 | // SlidingWindowInMemory is an in-memory implementation of SlidingWindowIncrementer. 87 | type SlidingWindowInMemory struct { 88 | mu sync.Mutex 89 | prevC, currC int64 90 | prevW, currW time.Time 91 | } 92 | 93 | // NewSlidingWindowInMemory creates a new instance of SlidingWindowInMemory. 94 | func NewSlidingWindowInMemory() *SlidingWindowInMemory { 95 | return &SlidingWindowInMemory{} 96 | } 97 | 98 | // Increment increments the current window's counter and returns the number of requests in the previous window and the 99 | // current one. 100 | func (s *SlidingWindowInMemory) Increment(ctx context.Context, prev, curr time.Time, _ time.Duration) (int64, int64, error) { 101 | s.mu.Lock() 102 | defer s.mu.Unlock() 103 | if curr != s.currW { 104 | if prev.Equal(s.currW) { 105 | s.prevW = s.currW 106 | s.prevC = s.currC 107 | } else { 108 | s.prevW = time.Time{} 109 | s.prevC = 0 110 | } 111 | s.currW = curr 112 | s.currC = 0 113 | } 114 | s.currC++ 115 | 116 | return s.prevC, s.currC, ctx.Err() 117 | } 118 | 119 | // SlidingWindowRedis implements SlidingWindow in Redis. 120 | type SlidingWindowRedis struct { 121 | cli redis.UniversalClient 122 | prefix string 123 | } 124 | 125 | // NewSlidingWindowRedis creates a new instance of SlidingWindowRedis. 126 | func NewSlidingWindowRedis(cli redis.UniversalClient, prefix string) *SlidingWindowRedis { 127 | return &SlidingWindowRedis{cli: cli, prefix: prefix} 128 | } 129 | 130 | // Increment increments the current window's counter in Redis and returns the number of requests in the previous window 131 | // and the current one. 132 | func (s *SlidingWindowRedis) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) { 133 | var incr *redis.IntCmd 134 | var prevCountCmd *redis.StringCmd 135 | var err error 136 | done := make(chan struct{}) 137 | go func() { 138 | defer close(done) 139 | _, err = s.cli.Pipelined(ctx, func(pipeliner redis.Pipeliner) error { 140 | currKey := fmt.Sprintf("%d", curr.UnixNano()) 141 | incr = pipeliner.Incr(ctx, redisKey(s.prefix, currKey)) 142 | pipeliner.PExpire(ctx, redisKey(s.prefix, currKey), ttl) 143 | prevCountCmd = pipeliner.Get(ctx, redisKey(s.prefix, fmt.Sprintf("%d", prev.UnixNano()))) 144 | 145 | return nil 146 | }) 147 | }() 148 | 149 | var prevCount int64 150 | select { 151 | case <-done: 152 | if errors.Is(err, redis.TxFailedErr) { 153 | return 0, 0, errors.Wrap(err, "redis transaction failed") 154 | } else if errors.Is(err, redis.Nil) { 155 | prevCount = 0 156 | } else if err != nil { 157 | return 0, 0, errors.Wrap(err, "unexpected error from redis") 158 | } else { 159 | prevCount, err = strconv.ParseInt(prevCountCmd.Val(), 10, 64) 160 | if err != nil { 161 | return 0, 0, errors.Wrap(err, "failed to parse response from redis") 162 | } 163 | } 164 | 165 | return prevCount, incr.Val(), nil 166 | case <-ctx.Done(): 167 | return 0, 0, ctx.Err() 168 | } 169 | } 170 | 171 | // SlidingWindowMemcached implements SlidingWindow in Memcached. 172 | type SlidingWindowMemcached struct { 173 | cli *memcache.Client 174 | prefix string 175 | } 176 | 177 | // NewSlidingWindowMemcached creates a new instance of SlidingWindowMemcached. 178 | func NewSlidingWindowMemcached(cli *memcache.Client, prefix string) *SlidingWindowMemcached { 179 | return &SlidingWindowMemcached{cli: cli, prefix: prefix} 180 | } 181 | 182 | // Increment increments the current window's counter in Memcached and returns the number of requests in the previous window 183 | // and the current one. 184 | func (s *SlidingWindowMemcached) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) { 185 | var prevCount uint64 186 | var currCount uint64 187 | var err error 188 | done := make(chan struct{}) 189 | go func() { 190 | defer close(done) 191 | 192 | var item *memcache.Item 193 | prevKey := fmt.Sprintf("%s:%d", s.prefix, prev.UnixNano()) 194 | item, err = s.cli.Get(prevKey) 195 | if err != nil { 196 | if errors.Is(err, memcache.ErrCacheMiss) { 197 | err = nil 198 | prevCount = 0 199 | } else { 200 | return 201 | } 202 | } else { 203 | prevCount, err = strconv.ParseUint(string(item.Value), 10, 64) 204 | if err != nil { 205 | return 206 | } 207 | } 208 | 209 | currKey := fmt.Sprintf("%s:%d", s.prefix, curr.UnixNano()) 210 | currCount, err = s.cli.Increment(currKey, 1) 211 | if err != nil && errors.Is(err, memcache.ErrCacheMiss) { 212 | currCount = 1 213 | item = &memcache.Item{ 214 | Key: currKey, 215 | Value: []byte(strconv.FormatUint(currCount, 10)), 216 | } 217 | err = s.cli.Add(item) 218 | } 219 | }() 220 | 221 | select { 222 | case <-done: 223 | if err != nil { 224 | if errors.Is(err, memcache.ErrNotStored) { 225 | return s.Increment(ctx, prev, curr, ttl) 226 | } 227 | 228 | return 0, 0, err 229 | } 230 | 231 | return int64(prevCount), int64(currCount), nil 232 | case <-ctx.Done(): 233 | return 0, 0, ctx.Err() 234 | } 235 | } 236 | 237 | // SlidingWindowDynamoDB implements SlidingWindow in DynamoDB. 238 | type SlidingWindowDynamoDB struct { 239 | client *dynamodb.Client 240 | partitionKey string 241 | tableProps DynamoDBTableProperties 242 | } 243 | 244 | // NewSlidingWindowDynamoDB creates a new instance of SlidingWindowDynamoDB. 245 | // PartitionKey is the key used to store all the this implementation in DynamoDB. 246 | // 247 | // TableProps describe the table that this backend should work with. This backend requires the following on the table: 248 | // * SortKey 249 | // * TTL. 250 | func NewSlidingWindowDynamoDB(client *dynamodb.Client, partitionKey string, props DynamoDBTableProperties) *SlidingWindowDynamoDB { 251 | return &SlidingWindowDynamoDB{ 252 | client: client, 253 | partitionKey: partitionKey, 254 | tableProps: props, 255 | } 256 | } 257 | 258 | // Increment increments the current window's counter in DynamoDB and returns the number of requests in the previous window 259 | // and the current one. 260 | func (s *SlidingWindowDynamoDB) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) { 261 | wg := &sync.WaitGroup{} 262 | wg.Add(2) 263 | 264 | done := make(chan struct{}) 265 | go func() { 266 | defer close(done) 267 | wg.Wait() 268 | }() 269 | 270 | var currentCount int64 271 | var currentErr error 272 | go func() { 273 | defer wg.Done() 274 | resp, err := s.client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ 275 | Key: map[string]types.AttributeValue{ 276 | s.tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: s.partitionKey}, 277 | s.tableProps.SortKeyName: &types.AttributeValueMemberS{Value: strconv.FormatInt(curr.UnixNano(), 10)}, 278 | }, 279 | UpdateExpression: aws.String(fixedWindowDynamoDBUpdateExpression), 280 | ExpressionAttributeNames: map[string]string{ 281 | "#TTL": s.tableProps.TTLFieldName, 282 | "#C": dynamodbWindowCountKey, 283 | }, 284 | ExpressionAttributeValues: map[string]types.AttributeValue{ 285 | ":ttl": &types.AttributeValueMemberN{Value: strconv.FormatInt(time.Now().Add(ttl).Unix(), 10)}, 286 | ":def": &types.AttributeValueMemberN{Value: "0"}, 287 | ":inc": &types.AttributeValueMemberN{Value: "1"}, 288 | }, 289 | TableName: &s.tableProps.TableName, 290 | ReturnValues: types.ReturnValueAllNew, 291 | }) 292 | if err != nil { 293 | currentErr = errors.Wrap(err, "dynamodb get item failed") 294 | 295 | return 296 | } 297 | 298 | var tmp float64 299 | err = attributevalue.Unmarshal(resp.Attributes[dynamodbWindowCountKey], &tmp) 300 | if err != nil { 301 | currentErr = errors.Wrap(err, "unmarshal of dynamodb attribute value failed") 302 | 303 | return 304 | } 305 | 306 | currentCount = int64(tmp) 307 | }() 308 | 309 | var priorCount int64 310 | var priorErr error 311 | go func() { 312 | defer wg.Done() 313 | resp, err := s.client.GetItem(ctx, &dynamodb.GetItemInput{ 314 | TableName: aws.String(s.tableProps.TableName), 315 | Key: map[string]types.AttributeValue{ 316 | s.tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: s.partitionKey}, 317 | s.tableProps.SortKeyName: &types.AttributeValueMemberS{Value: strconv.FormatInt(prev.UnixNano(), 10)}, 318 | }, 319 | ConsistentRead: aws.Bool(true), 320 | }) 321 | if err != nil { 322 | priorCount, priorErr = 0, errors.Wrap(err, "dynamodb get item failed") 323 | 324 | return 325 | } 326 | 327 | if len(resp.Item) == 0 { 328 | priorCount = 0 329 | 330 | return 331 | } 332 | 333 | var count float64 334 | err = attributevalue.Unmarshal(resp.Item[dynamodbWindowCountKey], &count) 335 | if err != nil { 336 | priorCount, priorErr = 0, errors.Wrap(err, "unmarshal of dynamodb attribute value failed") 337 | 338 | return 339 | } 340 | 341 | priorCount = int64(count) 342 | }() 343 | 344 | select { 345 | case <-done: 346 | case <-ctx.Done(): 347 | return 0, 0, ctx.Err() 348 | } 349 | 350 | if currentErr != nil { 351 | return 0, 0, errors.Wrap(currentErr, "failed to update current count") 352 | } else if priorErr != nil { 353 | return 0, 0, errors.Wrap(priorErr, "failed to get previous count") 354 | } 355 | 356 | return priorCount, currentCount, nil 357 | } 358 | 359 | // SlidingWindowCosmosDB implements SlidingWindow in Azure Cosmos DB. 360 | type SlidingWindowCosmosDB struct { 361 | client *azcosmos.ContainerClient 362 | partitionKey string 363 | } 364 | 365 | // NewSlidingWindowCosmosDB creates a new instance of SlidingWindowCosmosDB. 366 | // PartitionKey is the key used to store all the this implementation in Cosmos. 367 | func NewSlidingWindowCosmosDB(client *azcosmos.ContainerClient, partitionKey string) *SlidingWindowCosmosDB { 368 | return &SlidingWindowCosmosDB{ 369 | client: client, 370 | partitionKey: partitionKey, 371 | } 372 | } 373 | 374 | // Increment increments the current window's counter in Cosmos and returns the number of requests in the previous window 375 | // and the current one. 376 | func (s *SlidingWindowCosmosDB) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) { 377 | wg := &sync.WaitGroup{} 378 | wg.Add(2) 379 | 380 | done := make(chan struct{}) 381 | go func() { 382 | defer close(done) 383 | wg.Wait() 384 | }() 385 | 386 | var currentCount int64 387 | var currentErr error 388 | go func() { 389 | defer wg.Done() 390 | 391 | id := strconv.FormatInt(curr.UnixNano(), 10) 392 | tmp := cosmosItem{ 393 | ID: id, 394 | PartitionKey: s.partitionKey, 395 | Count: 1, 396 | TTL: int32(ttl), 397 | } 398 | 399 | ops := azcosmos.PatchOperations{} 400 | ops.AppendIncrement(`/Count`, 1) 401 | 402 | patchResp, err := s.client.PatchItem(ctx, azcosmos.NewPartitionKey().AppendString(s.partitionKey), id, ops, &azcosmos.ItemOptions{ 403 | EnableContentResponseOnWrite: true, 404 | }) 405 | if err == nil { 406 | // value exists and was updated 407 | err = json.Unmarshal(patchResp.Value, &tmp) 408 | if err != nil { 409 | currentErr = errors.Wrap(err, "unmarshal of cosmos value current failed") 410 | 411 | return 412 | } 413 | currentCount = tmp.Count 414 | 415 | return 416 | } 417 | 418 | var respErr *azcore.ResponseError 419 | if !errors.As(err, &respErr) || respErr.StatusCode != http.StatusNotFound { 420 | currentErr = errors.Wrap(err, `patch of cosmos value current failed`) 421 | 422 | return 423 | } 424 | 425 | newValue, err := json.Marshal(tmp) 426 | if err != nil { 427 | currentErr = errors.Wrap(err, "marshal of cosmos value current failed") 428 | 429 | return 430 | } 431 | 432 | _, err = s.client.CreateItem(ctx, azcosmos.NewPartitionKey().AppendString(s.partitionKey), newValue, &azcosmos.ItemOptions{ 433 | SessionToken: patchResp.SessionToken, 434 | IfMatchEtag: &patchResp.ETag, 435 | }) 436 | if err != nil { 437 | currentErr = errors.Wrap(err, "upsert of cosmos value current failed") 438 | 439 | return 440 | } 441 | 442 | currentCount = tmp.Count 443 | }() 444 | 445 | var priorCount int64 446 | var priorErr error 447 | go func() { 448 | defer wg.Done() 449 | 450 | id := strconv.FormatInt(prev.UnixNano(), 10) 451 | resp, err := s.client.ReadItem(ctx, azcosmos.NewPartitionKey().AppendString(s.partitionKey), id, &azcosmos.ItemOptions{}) 452 | if err != nil { 453 | var azerr *azcore.ResponseError 454 | if errors.As(err, &azerr) && azerr.StatusCode == http.StatusNotFound { 455 | priorCount, priorErr = 0, nil 456 | 457 | return 458 | } 459 | priorErr = errors.Wrap(err, "cosmos get item prior failed") 460 | 461 | return 462 | } 463 | 464 | var tmp cosmosItem 465 | err = json.Unmarshal(resp.Value, &tmp) 466 | if err != nil { 467 | priorErr = errors.Wrap(err, "unmarshal of cosmos value prior failed") 468 | 469 | return 470 | } 471 | 472 | priorCount = tmp.Count 473 | }() 474 | 475 | select { 476 | case <-done: 477 | case <-ctx.Done(): 478 | return 0, 0, ctx.Err() 479 | } 480 | 481 | if currentErr != nil { 482 | return 0, 0, errors.Wrap(currentErr, "failed to update current count") 483 | } else if priorErr != nil { 484 | return 0, 0, errors.Wrap(priorErr, "failed to get previous count") 485 | } 486 | 487 | return priorCount, currentCount, nil 488 | } 489 | -------------------------------------------------------------------------------- /slidingwindow_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | l "github.com/mennanov/limiters" 10 | ) 11 | 12 | // slidingWindows returns all the possible SlidingWindow combinations. 13 | func (s *LimitersTestSuite) slidingWindows(capacity int64, rate time.Duration, clock l.Clock, epsilon float64) map[string]*l.SlidingWindow { 14 | windows := make(map[string]*l.SlidingWindow) 15 | for name, inc := range s.slidingWindowIncrementers() { 16 | windows[name] = l.NewSlidingWindow(capacity, rate, inc, clock, epsilon) 17 | } 18 | 19 | return windows 20 | } 21 | 22 | func (s *LimitersTestSuite) slidingWindowIncrementers() map[string]l.SlidingWindowIncrementer { 23 | return map[string]l.SlidingWindowIncrementer{ 24 | "SlidingWindowInMemory": l.NewSlidingWindowInMemory(), 25 | "SlidingWindowRedis": l.NewSlidingWindowRedis(s.redisClient, uuid.New().String()), 26 | "SlidingWindowRedisCluster": l.NewSlidingWindowRedis(s.redisClusterClient, uuid.New().String()), 27 | "SlidingWindowMemcached": l.NewSlidingWindowMemcached(s.memcacheClient, uuid.New().String()), 28 | "SlidingWindowDynamoDB": l.NewSlidingWindowDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps), 29 | "SlidingWindowCosmosDB": l.NewSlidingWindowCosmosDB(s.cosmosContainerClient, uuid.New().String()), 30 | } 31 | } 32 | 33 | var slidingWindowTestCases = []struct { 34 | capacity int64 35 | rate time.Duration 36 | epsilon float64 37 | results []struct { 38 | w time.Duration 39 | e error 40 | } 41 | requests int 42 | delta float64 43 | }{ 44 | { 45 | capacity: 1, 46 | rate: time.Second, 47 | epsilon: 1e-9, 48 | requests: 6, 49 | results: []struct { 50 | w time.Duration 51 | e error 52 | }{ 53 | { 54 | 0, nil, 55 | }, 56 | { 57 | time.Second * 2, l.ErrLimitExhausted, 58 | }, 59 | { 60 | 0, nil, 61 | }, 62 | { 63 | time.Second * 2, l.ErrLimitExhausted, 64 | }, 65 | { 66 | 0, nil, 67 | }, 68 | { 69 | time.Second * 2, l.ErrLimitExhausted, 70 | }, 71 | }, 72 | }, 73 | { 74 | capacity: 2, 75 | rate: time.Second, 76 | epsilon: 3e-9, 77 | requests: 10, 78 | delta: 1, 79 | results: []struct { 80 | w time.Duration 81 | e error 82 | }{ 83 | { 84 | 0, nil, 85 | }, 86 | { 87 | 0, nil, 88 | }, 89 | { 90 | time.Second + time.Second*2/3, l.ErrLimitExhausted, 91 | }, 92 | { 93 | 0, nil, 94 | }, 95 | { 96 | time.Second/3 + time.Second/2, l.ErrLimitExhausted, 97 | }, 98 | { 99 | 0, nil, 100 | }, 101 | { 102 | time.Second, l.ErrLimitExhausted, 103 | }, 104 | { 105 | 0, nil, 106 | }, 107 | { 108 | time.Second, l.ErrLimitExhausted, 109 | }, 110 | { 111 | 0, nil, 112 | }, 113 | }, 114 | }, 115 | { 116 | capacity: 3, 117 | rate: time.Second, 118 | epsilon: 1e-9, 119 | requests: 11, 120 | delta: 0, 121 | results: []struct { 122 | w time.Duration 123 | e error 124 | }{ 125 | { 126 | 0, nil, 127 | }, 128 | { 129 | 0, nil, 130 | }, 131 | { 132 | 0, nil, 133 | }, 134 | { 135 | time.Second + time.Second/2, l.ErrLimitExhausted, 136 | }, 137 | { 138 | 0, nil, 139 | }, 140 | { 141 | time.Second / 2, l.ErrLimitExhausted, 142 | }, 143 | { 144 | 0, nil, 145 | }, 146 | { 147 | time.Second, l.ErrLimitExhausted, 148 | }, 149 | { 150 | 0, nil, 151 | }, 152 | { 153 | time.Second, l.ErrLimitExhausted, 154 | }, 155 | { 156 | 0, nil, 157 | }, 158 | }, 159 | }, 160 | { 161 | capacity: 4, 162 | rate: time.Second, 163 | epsilon: 1e-9, 164 | requests: 17, 165 | delta: 0, 166 | results: []struct { 167 | w time.Duration 168 | e error 169 | }{ 170 | { 171 | 0, nil, 172 | }, 173 | { 174 | 0, nil, 175 | }, 176 | { 177 | 0, nil, 178 | }, 179 | { 180 | 0, nil, 181 | }, 182 | { 183 | time.Second + time.Second*2/5, l.ErrLimitExhausted, 184 | }, 185 | { 186 | 0, nil, 187 | }, 188 | { 189 | time.Second * 2 / 5, l.ErrLimitExhausted, 190 | }, 191 | { 192 | 0, nil, 193 | }, 194 | { 195 | time.Second/5 + time.Second/4, l.ErrLimitExhausted, 196 | }, 197 | { 198 | 0, nil, 199 | }, 200 | { 201 | time.Second / 2, l.ErrLimitExhausted, 202 | }, 203 | { 204 | 0, nil, 205 | }, 206 | { 207 | time.Second / 2, l.ErrLimitExhausted, 208 | }, 209 | { 210 | 0, nil, 211 | }, 212 | { 213 | time.Second / 2, l.ErrLimitExhausted, 214 | }, 215 | { 216 | 0, nil, 217 | }, 218 | { 219 | time.Second / 2, l.ErrLimitExhausted, 220 | }, 221 | }, 222 | }, 223 | { 224 | capacity: 5, 225 | rate: time.Second, 226 | epsilon: 3e-9, 227 | requests: 18, 228 | delta: 1, 229 | results: []struct { 230 | w time.Duration 231 | e error 232 | }{ 233 | { 234 | 0, nil, 235 | }, 236 | { 237 | 0, nil, 238 | }, 239 | { 240 | 0, nil, 241 | }, 242 | { 243 | 0, nil, 244 | }, 245 | { 246 | 0, nil, 247 | }, 248 | { 249 | time.Second + time.Second/3, l.ErrLimitExhausted, 250 | }, 251 | { 252 | 0, nil, 253 | }, 254 | { 255 | time.Second * 2 / 6, l.ErrLimitExhausted, 256 | }, 257 | { 258 | 0, nil, 259 | }, 260 | { 261 | time.Second * 2 / 6, l.ErrLimitExhausted, 262 | }, 263 | { 264 | 0, nil, 265 | }, 266 | { 267 | time.Second / 2, l.ErrLimitExhausted, 268 | }, 269 | { 270 | 0, nil, 271 | }, 272 | { 273 | time.Second / 2, l.ErrLimitExhausted, 274 | }, 275 | { 276 | 0, nil, 277 | }, 278 | { 279 | time.Second / 2, l.ErrLimitExhausted, 280 | }, 281 | { 282 | 0, nil, 283 | }, 284 | { 285 | time.Second / 2, l.ErrLimitExhausted, 286 | }, 287 | { 288 | 0, nil, 289 | }, 290 | }, 291 | }, 292 | } 293 | 294 | func (s *LimitersTestSuite) TestSlidingWindowOverflowAndWait() { 295 | clock := newFakeClockWithTime(time.Date(2019, 9, 3, 0, 0, 0, 0, time.UTC)) 296 | for _, testCase := range slidingWindowTestCases { 297 | for name, bucket := range s.slidingWindows(testCase.capacity, testCase.rate, clock, testCase.epsilon) { 298 | s.Run(name, func() { 299 | clock.reset() 300 | for i := 0; i < testCase.requests; i++ { 301 | w, err := bucket.Limit(context.TODO()) 302 | s.Require().LessOrEqual(i, len(testCase.results)-1) 303 | s.InDelta(testCase.results[i].w, w, testCase.delta, i) 304 | s.Equal(testCase.results[i].e, err, i) 305 | clock.Sleep(w) 306 | } 307 | }) 308 | } 309 | } 310 | } 311 | 312 | func (s *LimitersTestSuite) TestSlidingWindowOverflowAndNoWait() { 313 | capacity := int64(3) 314 | clock := newFakeClock() 315 | for name, bucket := range s.slidingWindows(capacity, time.Second, clock, 1e-9) { 316 | s.Run(name, func() { 317 | clock.reset() 318 | 319 | // Keep sending requests until it reaches the capacity. 320 | for i := int64(0); i < capacity; i++ { 321 | w, err := bucket.Limit(context.TODO()) 322 | s.Require().NoError(err) 323 | s.Require().Equal(time.Duration(0), w) 324 | clock.Sleep(time.Millisecond) 325 | } 326 | 327 | // The next request will be the first one to be rejected. 328 | w, err := bucket.Limit(context.TODO()) 329 | s.Require().Equal(l.ErrLimitExhausted, err) 330 | expected := clock.Now().Add(w) 331 | 332 | // Send a few more requests, all of them should be told to come back at the same time. 333 | for i := int64(0); i < capacity; i++ { 334 | w, err = bucket.Limit(context.TODO()) 335 | s.Require().Equal(l.ErrLimitExhausted, err) 336 | actual := clock.Now().Add(w) 337 | s.Require().Equal(expected, actual, i) 338 | clock.Sleep(time.Millisecond) 339 | } 340 | 341 | // Wait until it is ready. 342 | clock.Sleep(expected.Sub(clock.Now())) 343 | w, err = bucket.Limit(context.TODO()) 344 | s.Require().NoError(err) 345 | s.Require().Equal(time.Duration(0), w) 346 | }) 347 | } 348 | } 349 | 350 | func BenchmarkSlidingWindows(b *testing.B) { 351 | s := new(LimitersTestSuite) 352 | s.SetT(&testing.T{}) 353 | s.SetupSuite() 354 | capacity := int64(1) 355 | rate := time.Second 356 | clock := newFakeClock() 357 | epsilon := 1e-9 358 | windows := s.slidingWindows(capacity, rate, clock, epsilon) 359 | for name, window := range windows { 360 | b.Run(name, func(b *testing.B) { 361 | for i := 0; i < b.N; i++ { 362 | _, err := window.Limit(context.TODO()) 363 | s.Require().NoError(err) 364 | } 365 | }) 366 | } 367 | s.TearDownSuite() 368 | } 369 | -------------------------------------------------------------------------------- /tokenbucket.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/gob" 7 | "encoding/json" 8 | "fmt" 9 | "net/http" 10 | "strconv" 11 | "sync" 12 | "time" 13 | 14 | "github.com/Azure/azure-sdk-for-go/sdk/azcore" 15 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 16 | "github.com/aws/aws-sdk-go-v2/aws" 17 | "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" 18 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 19 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" 20 | "github.com/bradfitz/gomemcache/memcache" 21 | "github.com/pkg/errors" 22 | "github.com/redis/go-redis/v9" 23 | "go.etcd.io/etcd/api/v3/mvccpb" 24 | "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" 25 | clientv3 "go.etcd.io/etcd/client/v3" 26 | ) 27 | 28 | // TokenBucketState represents a state of a token bucket. 29 | type TokenBucketState struct { 30 | // Last is the last time the state was updated (Unix timestamp in nanoseconds). 31 | Last int64 32 | // Available is the number of available tokens in the bucket. 33 | Available int64 34 | } 35 | 36 | // isZero returns true if the bucket state is zero valued. 37 | func (s TokenBucketState) isZero() bool { 38 | return s.Last == 0 && s.Available == 0 39 | } 40 | 41 | // TokenBucketStateBackend interface encapsulates the logic of retrieving and persisting the state of a TokenBucket. 42 | type TokenBucketStateBackend interface { 43 | // State gets the current state of the TokenBucket. 44 | State(ctx context.Context) (TokenBucketState, error) 45 | // SetState sets (persists) the current state of the TokenBucket. 46 | SetState(ctx context.Context, state TokenBucketState) error 47 | // Reset resets (persists) the current state of the TokenBucket. 48 | Reset(ctx context.Context) error 49 | } 50 | 51 | // TokenBucket implements the https://en.wikipedia.org/wiki/Token_bucket algorithm. 52 | type TokenBucket struct { 53 | locker DistLocker 54 | backend TokenBucketStateBackend 55 | clock Clock 56 | logger Logger 57 | // refillRate is the tokens refill rate (1 token per duration). 58 | refillRate time.Duration 59 | // capacity is the bucket's capacity. 60 | capacity int64 61 | mu sync.Mutex 62 | } 63 | 64 | // NewTokenBucket creates a new instance of TokenBucket. 65 | func NewTokenBucket(capacity int64, refillRate time.Duration, locker DistLocker, tokenBucketStateBackend TokenBucketStateBackend, clock Clock, logger Logger) *TokenBucket { 66 | return &TokenBucket{ 67 | locker: locker, 68 | backend: tokenBucketStateBackend, 69 | clock: clock, 70 | logger: logger, 71 | refillRate: refillRate, 72 | capacity: capacity, 73 | } 74 | } 75 | 76 | // Take takes tokens from the bucket. 77 | // 78 | // It returns a zero duration and a nil error if the bucket has sufficient amount of tokens. 79 | // 80 | // It returns ErrLimitExhausted if the amount of available tokens is less than requested. In this case the returned 81 | // duration is the amount of time to wait to retry the request. 82 | func (t *TokenBucket) Take(ctx context.Context, tokens int64) (time.Duration, error) { 83 | t.mu.Lock() 84 | defer t.mu.Unlock() 85 | if err := t.locker.Lock(ctx); err != nil { 86 | return 0, err 87 | } 88 | defer func() { 89 | if err := t.locker.Unlock(ctx); err != nil { 90 | t.logger.Log(err) 91 | } 92 | }() 93 | state, err := t.backend.State(ctx) 94 | if err != nil { 95 | return 0, err 96 | } 97 | if state.isZero() { 98 | // Initially the bucket is full. 99 | state.Available = t.capacity 100 | } 101 | now := t.clock.Now().UnixNano() 102 | // Refill the bucket. 103 | tokensToAdd := (now - state.Last) / int64(t.refillRate) 104 | partialTime := (now - state.Last) % int64(t.refillRate) 105 | if tokensToAdd > 0 { 106 | if tokensToAdd+state.Available < t.capacity { 107 | state.Available += tokensToAdd 108 | state.Last = now - partialTime 109 | } else { 110 | state.Available = t.capacity 111 | state.Last = now 112 | } 113 | } 114 | 115 | if tokens > state.Available { 116 | return t.refillRate * time.Duration(tokens-state.Available), ErrLimitExhausted 117 | } 118 | // Take the tokens from the bucket. 119 | state.Available -= tokens 120 | if err = t.backend.SetState(ctx, state); err != nil { 121 | return 0, err 122 | } 123 | 124 | return 0, nil 125 | } 126 | 127 | // Limit takes 1 token from the bucket. 128 | func (t *TokenBucket) Limit(ctx context.Context) (time.Duration, error) { 129 | return t.Take(ctx, 1) 130 | } 131 | 132 | // Reset resets the bucket. 133 | func (t *TokenBucket) Reset(ctx context.Context) error { 134 | return t.backend.Reset(ctx) 135 | } 136 | 137 | // TokenBucketInMemory is an in-memory implementation of TokenBucketStateBackend. 138 | // 139 | // The state is not shared nor persisted so it won't survive restarts or failures. 140 | // Due to the local nature of the state the rate at which some endpoints are accessed can't be reliably predicted or 141 | // limited. 142 | // 143 | // Although it can be used as a global rate limiter with a round-robin load-balancer. 144 | type TokenBucketInMemory struct { 145 | state TokenBucketState 146 | } 147 | 148 | // NewTokenBucketInMemory creates a new instance of TokenBucketInMemory. 149 | func NewTokenBucketInMemory() *TokenBucketInMemory { 150 | return &TokenBucketInMemory{} 151 | } 152 | 153 | // State returns the current bucket's state. 154 | func (t *TokenBucketInMemory) State(ctx context.Context) (TokenBucketState, error) { 155 | return t.state, ctx.Err() 156 | } 157 | 158 | // SetState sets the current bucket's state. 159 | func (t *TokenBucketInMemory) SetState(ctx context.Context, state TokenBucketState) error { 160 | t.state = state 161 | 162 | return ctx.Err() 163 | } 164 | 165 | // Reset resets the current bucket's state. 166 | func (t *TokenBucketInMemory) Reset(ctx context.Context) error { 167 | state := TokenBucketState{ 168 | Last: 0, 169 | Available: 0, 170 | } 171 | 172 | return t.SetState(ctx, state) 173 | } 174 | 175 | const ( 176 | etcdKeyTBLease = "lease" 177 | etcdKeyTBAvailable = "available" 178 | etcdKeyTBLast = "last" 179 | ) 180 | 181 | // TokenBucketEtcd is an etcd implementation of a TokenBucketStateBackend. 182 | // 183 | // See https://github.com/etcd-io/etcd/blob/master/Documentation/learning/data_model.md 184 | // 185 | // etcd is designed to reliably store infrequently updated data, thus it should only be used for the API endpoints which 186 | // are accessed less frequently than it can be processed by the rate limiter. 187 | // 188 | // Aggressive compaction and defragmentation has to be enabled in etcd to prevent the size of the storage 189 | // to grow indefinitely: every change of the state of the bucket (every access) will create a new revision in etcd. 190 | // 191 | // It probably makes it impractical for the high load cases, but can be used to reliably and precisely rate limit an 192 | // access to the business critical endpoints where each access must be reliably logged. 193 | type TokenBucketEtcd struct { 194 | // prefix is the etcd key prefix. 195 | prefix string 196 | cli *clientv3.Client 197 | leaseID clientv3.LeaseID 198 | ttl time.Duration 199 | raceCheck bool 200 | lastVersion int64 201 | } 202 | 203 | // NewTokenBucketEtcd creates a new TokenBucketEtcd instance. 204 | // Prefix is used as an etcd key prefix for all keys stored in etcd by this algorithm. 205 | // TTL is a TTL of the etcd lease in seconds used to store all the keys: all the keys are automatically deleted after 206 | // the TTL expires. 207 | // 208 | // If raceCheck is true and the keys in etcd are modified in between State() and SetState() calls then 209 | // ErrRaceCondition is returned. 210 | // It does not add any significant overhead as it can be trivially checked on etcd side before updating the keys. 211 | func NewTokenBucketEtcd(cli *clientv3.Client, prefix string, ttl time.Duration, raceCheck bool) *TokenBucketEtcd { 212 | return &TokenBucketEtcd{ 213 | prefix: prefix, 214 | cli: cli, 215 | ttl: ttl, 216 | raceCheck: raceCheck, 217 | } 218 | } 219 | 220 | // etcdKey returns a full etcd key from the provided key and prefix. 221 | func etcdKey(prefix, key string) string { 222 | return fmt.Sprintf("%s/%s", prefix, key) 223 | } 224 | 225 | // parseEtcdInt64 parses the etcd value into int64. 226 | func parseEtcdInt64(kv *mvccpb.KeyValue) (int64, error) { 227 | v, err := strconv.ParseInt(string(kv.Value), 10, 64) 228 | if err != nil { 229 | return 0, errors.Wrapf(err, "failed to parse key '%s' as int64", string(kv.Key)) 230 | } 231 | 232 | return v, nil 233 | } 234 | 235 | func incPrefix(p string) string { 236 | b := []byte(p) 237 | b[len(b)-1]++ 238 | 239 | return string(b) 240 | } 241 | 242 | // State gets the bucket's current state from etcd. 243 | // If there is no state available in etcd then the initial bucket's state is returned. 244 | func (t *TokenBucketEtcd) State(ctx context.Context) (TokenBucketState, error) { 245 | // Get all the keys under the prefix in a single request. 246 | r, err := t.cli.Get(ctx, t.prefix, clientv3.WithRange(incPrefix(t.prefix))) 247 | if err != nil { 248 | return TokenBucketState{}, errors.Wrapf(err, "failed to get keys in range ['%s', '%s') from etcd", t.prefix, incPrefix(t.prefix)) 249 | } 250 | if len(r.Kvs) == 0 { 251 | // State not found, return zero valued state. 252 | return TokenBucketState{}, nil 253 | } 254 | state := TokenBucketState{} 255 | parsed := 0 256 | var v int64 257 | for _, kv := range r.Kvs { 258 | switch string(kv.Key) { 259 | case etcdKey(t.prefix, etcdKeyTBAvailable): 260 | v, err = parseEtcdInt64(kv) 261 | if err != nil { 262 | return TokenBucketState{}, err 263 | } 264 | state.Available = v 265 | parsed |= 1 266 | 267 | case etcdKey(t.prefix, etcdKeyTBLast): 268 | v, err = parseEtcdInt64(kv) 269 | if err != nil { 270 | return TokenBucketState{}, err 271 | } 272 | state.Last = v 273 | parsed |= 2 274 | t.lastVersion = kv.Version 275 | 276 | case etcdKey(t.prefix, etcdKeyTBLease): 277 | v, err = parseEtcdInt64(kv) 278 | if err != nil { 279 | return TokenBucketState{}, err 280 | } 281 | t.leaseID = clientv3.LeaseID(v) 282 | parsed |= 4 283 | } 284 | } 285 | if parsed != 7 { 286 | return TokenBucketState{}, errors.New("failed to get state from etcd: some keys are missing") 287 | } 288 | 289 | return state, nil 290 | } 291 | 292 | // createLease creates a new lease in etcd and updates the t.leaseID value. 293 | func (t *TokenBucketEtcd) createLease(ctx context.Context) error { 294 | lease, err := t.cli.Grant(ctx, int64(t.ttl/time.Nanosecond)) 295 | if err != nil { 296 | return errors.Wrap(err, "failed to create a new lease in etcd") 297 | } 298 | t.leaseID = lease.ID 299 | 300 | return nil 301 | } 302 | 303 | // save saves the state to etcd using the existing lease and the fencing token. 304 | func (t *TokenBucketEtcd) save(ctx context.Context, state TokenBucketState) error { 305 | if !t.raceCheck { 306 | if _, err := t.cli.Txn(ctx).Then( 307 | clientv3.OpPut(etcdKey(t.prefix, etcdKeyTBAvailable), fmt.Sprintf("%d", state.Available), clientv3.WithLease(t.leaseID)), 308 | clientv3.OpPut(etcdKey(t.prefix, etcdKeyTBLast), fmt.Sprintf("%d", state.Last), clientv3.WithLease(t.leaseID)), 309 | clientv3.OpPut(etcdKey(t.prefix, etcdKeyTBLease), fmt.Sprintf("%d", t.leaseID), clientv3.WithLease(t.leaseID)), 310 | ).Commit(); err != nil { 311 | return errors.Wrap(err, "failed to commit a transaction to etcd") 312 | } 313 | 314 | return nil 315 | } 316 | // Put the keys only if they have not been modified since the most recent read. 317 | r, err := t.cli.Txn(ctx).If( 318 | clientv3.Compare(clientv3.Version(etcdKey(t.prefix, etcdKeyTBLast)), ">", t.lastVersion), 319 | ).Else( 320 | clientv3.OpPut(etcdKey(t.prefix, etcdKeyTBAvailable), fmt.Sprintf("%d", state.Available), clientv3.WithLease(t.leaseID)), 321 | clientv3.OpPut(etcdKey(t.prefix, etcdKeyTBLast), fmt.Sprintf("%d", state.Last), clientv3.WithLease(t.leaseID)), 322 | clientv3.OpPut(etcdKey(t.prefix, etcdKeyTBLease), fmt.Sprintf("%d", t.leaseID), clientv3.WithLease(t.leaseID)), 323 | ).Commit() 324 | if err != nil { 325 | return errors.Wrap(err, "failed to commit a transaction to etcd") 326 | } 327 | 328 | if !r.Succeeded { 329 | return nil 330 | } 331 | 332 | return ErrRaceCondition 333 | } 334 | 335 | // SetState updates the state of the bucket. 336 | func (t *TokenBucketEtcd) SetState(ctx context.Context, state TokenBucketState) error { 337 | if t.leaseID == 0 { 338 | // Lease does not exist, create one. 339 | if err := t.createLease(ctx); err != nil { 340 | return err 341 | } 342 | // No need to send KeepAlive for the newly created lease: save the state immediately. 343 | return t.save(ctx, state) 344 | } 345 | // Send the KeepAlive request to extend the existing lease. 346 | if _, err := t.cli.KeepAliveOnce(ctx, t.leaseID); errors.Is(err, rpctypes.ErrLeaseNotFound) { 347 | // Create a new lease since the current one has expired. 348 | if err = t.createLease(ctx); err != nil { 349 | return err 350 | } 351 | } else if err != nil { 352 | return errors.Wrapf(err, "failed to extend the lease '%d'", t.leaseID) 353 | } 354 | 355 | return t.save(ctx, state) 356 | } 357 | 358 | // Reset resets the state of the bucket. 359 | func (t *TokenBucketEtcd) Reset(ctx context.Context) error { 360 | state := TokenBucketState{ 361 | Last: 0, 362 | Available: 0, 363 | } 364 | 365 | return t.SetState(ctx, state) 366 | } 367 | 368 | // Deprecated: These legacy keys will be removed in a future version. 369 | // The state is now stored in a single JSON document under the "state" key. 370 | const ( 371 | redisKeyTBAvailable = "available" 372 | redisKeyTBLast = "last" 373 | redisKeyTBVersion = "version" 374 | ) 375 | 376 | // If we do use cluster client and if the cluster is large enough, it is possible that when accessing multiple keys 377 | // in leaky bucket or token bucket, these keys might go different slots and it will fail with error message 378 | // `CROSSSLOT Keys in request don't hash to the same slot`. Adding hash tags in redisKey will force them into the 379 | // same slot for keys with the same prefix. 380 | // 381 | // https://redis.io/docs/latest/operate/oss_and_stack/reference/cluster-spec/#hash-tags 382 | func redisKey(prefix, key string) string { 383 | return fmt.Sprintf("{%s}%s", prefix, key) 384 | } 385 | 386 | // TokenBucketRedis is a Redis implementation of a TokenBucketStateBackend. 387 | // 388 | // Redis is an in-memory key-value data storage which also supports persistence. 389 | // It is a better choice for high load cases than etcd as it does not keep old values of the keys thus does not need 390 | // the compaction/defragmentation. 391 | // 392 | // Although depending on a persistence and a cluster configuration some data might be lost in case of a failure 393 | // resulting in an under-limiting the accesses to the service. 394 | type TokenBucketRedis struct { 395 | cli redis.UniversalClient 396 | prefix string 397 | ttl time.Duration 398 | raceCheck bool 399 | lastVersion int64 400 | } 401 | 402 | // NewTokenBucketRedis creates a new TokenBucketRedis instance. 403 | // Prefix is the key prefix used to store all the keys used in this implementation in Redis. 404 | // TTL is the TTL of the stored keys. 405 | // 406 | // If raceCheck is true and the keys in Redis are modified in between State() and SetState() calls then 407 | // ErrRaceCondition is returned. 408 | // This adds an extra overhead since a Lua script has to be executed on the Redis side which locks the entire database. 409 | func NewTokenBucketRedis(cli redis.UniversalClient, prefix string, ttl time.Duration, raceCheck bool) *TokenBucketRedis { 410 | return &TokenBucketRedis{cli: cli, prefix: prefix, ttl: ttl, raceCheck: raceCheck} 411 | } 412 | 413 | // Deprecated: Legacy format support will be removed in a future version. 414 | func (t *TokenBucketRedis) oldState(ctx context.Context) (TokenBucketState, error) { 415 | var values []interface{} 416 | var err error 417 | done := make(chan struct{}, 1) 418 | 419 | if t.raceCheck { 420 | // reset in a case of returning an empty TokenBucketState 421 | t.lastVersion = 0 422 | } 423 | 424 | go func() { 425 | defer close(done) 426 | keys := []string{ 427 | redisKey(t.prefix, redisKeyTBLast), 428 | redisKey(t.prefix, redisKeyTBAvailable), 429 | } 430 | if t.raceCheck { 431 | keys = append(keys, redisKey(t.prefix, redisKeyTBVersion)) 432 | } 433 | values, err = t.cli.MGet(ctx, keys...).Result() 434 | }() 435 | 436 | select { 437 | case <-done: 438 | 439 | case <-ctx.Done(): 440 | return TokenBucketState{}, ctx.Err() 441 | } 442 | 443 | if err != nil { 444 | return TokenBucketState{}, errors.Wrap(err, "failed to get keys from redis") 445 | } 446 | nilAny := false 447 | for _, v := range values { 448 | if v == nil { 449 | nilAny = true 450 | 451 | break 452 | } 453 | } 454 | if nilAny || errors.Is(err, redis.Nil) { 455 | // Keys don't exist, return the initial state. 456 | return TokenBucketState{}, nil 457 | } 458 | 459 | last, err := strconv.ParseInt(values[0].(string), 10, 64) 460 | if err != nil { 461 | return TokenBucketState{}, err 462 | } 463 | available, err := strconv.ParseInt(values[1].(string), 10, 64) 464 | if err != nil { 465 | return TokenBucketState{}, err 466 | } 467 | if t.raceCheck { 468 | t.lastVersion, err = strconv.ParseInt(values[2].(string), 10, 64) 469 | if err != nil { 470 | return TokenBucketState{}, err 471 | } 472 | } 473 | 474 | return TokenBucketState{ 475 | Last: last, 476 | Available: available, 477 | }, nil 478 | } 479 | 480 | // State gets the bucket's state from Redis. 481 | func (t *TokenBucketRedis) State(ctx context.Context) (TokenBucketState, error) { 482 | var err error 483 | done := make(chan struct{}, 1) 484 | errCh := make(chan error, 1) 485 | var state TokenBucketState 486 | 487 | if t.raceCheck { 488 | // reset in a case of returning an empty TokenBucketState 489 | t.lastVersion = 0 490 | } 491 | 492 | go func() { 493 | defer close(done) 494 | key := redisKey(t.prefix, "state") 495 | value, err := t.cli.Get(ctx, key).Result() 496 | if err != nil && !errors.Is(err, redis.Nil) { 497 | errCh <- err 498 | 499 | return 500 | } 501 | 502 | if errors.Is(err, redis.Nil) { 503 | state, err = t.oldState(ctx) 504 | errCh <- err 505 | 506 | return 507 | } 508 | 509 | // Try new format 510 | var item struct { 511 | State TokenBucketState `json:"state"` 512 | Version int64 `json:"version"` 513 | } 514 | if err = json.Unmarshal([]byte(value), &item); err != nil { 515 | errCh <- err 516 | 517 | return 518 | } 519 | 520 | state = item.State 521 | if t.raceCheck { 522 | t.lastVersion = item.Version 523 | } 524 | errCh <- nil 525 | }() 526 | 527 | select { 528 | case <-done: 529 | err = <-errCh 530 | case <-ctx.Done(): 531 | return TokenBucketState{}, ctx.Err() 532 | } 533 | 534 | if err != nil { 535 | return TokenBucketState{}, errors.Wrap(err, "failed to get state from redis") 536 | } 537 | 538 | return state, nil 539 | } 540 | 541 | // SetState updates the state in Redis. 542 | func (t *TokenBucketRedis) SetState(ctx context.Context, state TokenBucketState) error { 543 | var err error 544 | done := make(chan struct{}, 1) 545 | errCh := make(chan error, 1) 546 | 547 | go func() { 548 | defer close(done) 549 | key := redisKey(t.prefix, "state") 550 | item := struct { 551 | State TokenBucketState `json:"state"` 552 | Version int64 `json:"version"` 553 | }{ 554 | State: state, 555 | Version: t.lastVersion + 1, 556 | } 557 | 558 | value, err := json.Marshal(item) 559 | if err != nil { 560 | errCh <- err 561 | 562 | return 563 | } 564 | 565 | if !t.raceCheck { 566 | errCh <- t.cli.Set(ctx, key, value, t.ttl).Err() 567 | 568 | return 569 | } 570 | 571 | script := ` 572 | local current = redis.call('get', KEYS[1]) 573 | if current then 574 | local data = cjson.decode(current) 575 | if data.version > tonumber(ARGV[2]) then 576 | return 'RACE_CONDITION' 577 | end 578 | end 579 | redis.call('set', KEYS[1], ARGV[1], 'PX', ARGV[3]) 580 | return 'OK' 581 | ` 582 | result, err := t.cli.Eval(ctx, script, []string{key}, value, t.lastVersion, int64(t.ttl/time.Millisecond)).Result() 583 | if err != nil { 584 | errCh <- err 585 | 586 | return 587 | } 588 | if result == "RACE_CONDITION" { 589 | errCh <- ErrRaceCondition 590 | 591 | return 592 | } 593 | errCh <- nil 594 | }() 595 | 596 | select { 597 | case <-done: 598 | err = <-errCh 599 | case <-ctx.Done(): 600 | return ctx.Err() 601 | } 602 | 603 | if err != nil { 604 | return errors.Wrap(err, "failed to save state to redis") 605 | } 606 | 607 | return nil 608 | } 609 | 610 | // Reset resets the state in Redis. 611 | func (t *TokenBucketRedis) Reset(ctx context.Context) error { 612 | state := TokenBucketState{ 613 | Last: 0, 614 | Available: 0, 615 | } 616 | 617 | return t.SetState(ctx, state) 618 | } 619 | 620 | // TokenBucketMemcached is a Memcached implementation of a TokenBucketStateBackend. 621 | // 622 | // Memcached is a distributed memory object caching system. 623 | type TokenBucketMemcached struct { 624 | cli *memcache.Client 625 | key string 626 | ttl time.Duration 627 | raceCheck bool 628 | casId uint64 629 | } 630 | 631 | // NewTokenBucketMemcached creates a new TokenBucketMemcached instance. 632 | // Key is the key used to store all the keys used in this implementation in Memcached. 633 | // TTL is the TTL of the stored keys. 634 | // 635 | // If raceCheck is true and the keys in Memcached are modified in between State() and SetState() calls then 636 | // ErrRaceCondition is returned. 637 | // This adds an extra overhead since a Lua script has to be executed on the Memcached side which locks the entire database. 638 | func NewTokenBucketMemcached(cli *memcache.Client, key string, ttl time.Duration, raceCheck bool) *TokenBucketMemcached { 639 | return &TokenBucketMemcached{cli: cli, key: key, ttl: ttl, raceCheck: raceCheck} 640 | } 641 | 642 | // State gets the bucket's state from Memcached. 643 | func (t *TokenBucketMemcached) State(ctx context.Context) (TokenBucketState, error) { 644 | var item *memcache.Item 645 | var state TokenBucketState 646 | var err error 647 | 648 | done := make(chan struct{}, 1) 649 | t.casId = 0 650 | 651 | go func() { 652 | defer close(done) 653 | item, err = t.cli.Get(t.key) 654 | }() 655 | 656 | select { 657 | case <-done: 658 | 659 | case <-ctx.Done(): 660 | return state, ctx.Err() 661 | } 662 | 663 | if err != nil { 664 | if errors.Is(err, memcache.ErrCacheMiss) { 665 | // Keys don't exist, return the initial state. 666 | return state, nil 667 | } 668 | 669 | return state, errors.Wrap(err, "failed to get key from memcached") 670 | } 671 | b := bytes.NewBuffer(item.Value) 672 | err = gob.NewDecoder(b).Decode(&state) 673 | if err != nil { 674 | return state, errors.Wrap(err, "failed to Decode") 675 | } 676 | t.casId = item.CasID 677 | 678 | return state, nil 679 | } 680 | 681 | // SetState updates the state in Memcached. 682 | func (t *TokenBucketMemcached) SetState(ctx context.Context, state TokenBucketState) error { 683 | var err error 684 | done := make(chan struct{}, 1) 685 | var b bytes.Buffer 686 | err = gob.NewEncoder(&b).Encode(state) 687 | if err != nil { 688 | return errors.Wrap(err, "failed to Encode") 689 | } 690 | go func() { 691 | defer close(done) 692 | item := &memcache.Item{ 693 | Key: t.key, 694 | Value: b.Bytes(), 695 | CasID: t.casId, 696 | } 697 | if t.raceCheck && t.casId > 0 { 698 | err = t.cli.CompareAndSwap(item) 699 | } else { 700 | err = t.cli.Set(item) 701 | } 702 | }() 703 | 704 | select { 705 | case <-done: 706 | 707 | case <-ctx.Done(): 708 | return ctx.Err() 709 | } 710 | 711 | return errors.Wrap(err, "failed to save keys to memcached") 712 | } 713 | 714 | // Reset resets the state in Memcached. 715 | func (t *TokenBucketMemcached) Reset(ctx context.Context) error { 716 | state := TokenBucketState{ 717 | Last: 0, 718 | Available: 0, 719 | } 720 | // Override casId to 0 to Set instead of CompareAndSwap in SetState 721 | t.casId = 0 722 | 723 | return t.SetState(ctx, state) 724 | } 725 | 726 | // TokenBucketDynamoDB is a DynamoDB implementation of a TokenBucketStateBackend. 727 | type TokenBucketDynamoDB struct { 728 | client *dynamodb.Client 729 | tableProps DynamoDBTableProperties 730 | partitionKey string 731 | ttl time.Duration 732 | raceCheck bool 733 | latestVersion int64 734 | keys map[string]types.AttributeValue 735 | } 736 | 737 | // NewTokenBucketDynamoDB creates a new TokenBucketDynamoDB instance. 738 | // PartitionKey is the key used to store all the this implementation in DynamoDB. 739 | // 740 | // TableProps describe the table that this backend should work with. This backend requires the following on the table: 741 | // * TTL 742 | // 743 | // TTL is the TTL of the stored item. 744 | // 745 | // If raceCheck is true and the item in DynamoDB are modified in between State() and SetState() calls then 746 | // ErrRaceCondition is returned. 747 | func NewTokenBucketDynamoDB(client *dynamodb.Client, partitionKey string, tableProps DynamoDBTableProperties, ttl time.Duration, raceCheck bool) *TokenBucketDynamoDB { 748 | keys := map[string]types.AttributeValue{ 749 | tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: partitionKey}, 750 | } 751 | 752 | if tableProps.SortKeyUsed { 753 | keys[tableProps.SortKeyName] = &types.AttributeValueMemberS{Value: partitionKey} 754 | } 755 | 756 | return &TokenBucketDynamoDB{ 757 | client: client, 758 | partitionKey: partitionKey, 759 | tableProps: tableProps, 760 | ttl: ttl, 761 | raceCheck: raceCheck, 762 | keys: keys, 763 | } 764 | } 765 | 766 | // State gets the bucket's state from DynamoDB. 767 | func (t *TokenBucketDynamoDB) State(ctx context.Context) (TokenBucketState, error) { 768 | resp, err := dynamoDBGetItem(ctx, t.client, t.getGetItemInput()) 769 | if err != nil { 770 | return TokenBucketState{}, err 771 | } 772 | 773 | return t.loadStateFromDynamoDB(resp) 774 | } 775 | 776 | // SetState updates the state in DynamoDB. 777 | func (t *TokenBucketDynamoDB) SetState(ctx context.Context, state TokenBucketState) error { 778 | input := t.getPutItemInputFromState(state) 779 | 780 | var err error 781 | done := make(chan struct{}) 782 | go func() { 783 | defer close(done) 784 | _, err = dynamoDBputItem(ctx, t.client, input) 785 | }() 786 | 787 | select { 788 | case <-done: 789 | case <-ctx.Done(): 790 | return ctx.Err() 791 | } 792 | 793 | return err 794 | } 795 | 796 | // Reset resets the state in DynamoDB. 797 | func (t *TokenBucketDynamoDB) Reset(ctx context.Context) error { 798 | state := TokenBucketState{ 799 | Last: 0, 800 | Available: 0, 801 | } 802 | 803 | return t.SetState(ctx, state) 804 | } 805 | 806 | const dynamoDBBucketAvailableKey = "Available" 807 | 808 | func (t *TokenBucketDynamoDB) getGetItemInput() *dynamodb.GetItemInput { 809 | return &dynamodb.GetItemInput{ 810 | TableName: &t.tableProps.TableName, 811 | Key: t.keys, 812 | } 813 | } 814 | 815 | func (t *TokenBucketDynamoDB) getPutItemInputFromState(state TokenBucketState) *dynamodb.PutItemInput { 816 | item := map[string]types.AttributeValue{} 817 | for k, v := range t.keys { 818 | item[k] = v 819 | } 820 | 821 | item[dynamoDBBucketLastKey] = &types.AttributeValueMemberN{Value: strconv.FormatInt(state.Last, 10)} 822 | item[dynamoDBBucketVersionKey] = &types.AttributeValueMemberN{Value: strconv.FormatInt(t.latestVersion+1, 10)} 823 | item[t.tableProps.TTLFieldName] = &types.AttributeValueMemberN{Value: strconv.FormatInt(time.Now().Add(t.ttl).Unix(), 10)} 824 | item[dynamoDBBucketAvailableKey] = &types.AttributeValueMemberN{Value: strconv.FormatInt(state.Available, 10)} 825 | 826 | input := &dynamodb.PutItemInput{ 827 | TableName: &t.tableProps.TableName, 828 | Item: item, 829 | } 830 | 831 | if t.raceCheck && t.latestVersion > 0 { 832 | input.ConditionExpression = aws.String(dynamodbBucketRaceConditionExpression) 833 | input.ExpressionAttributeValues = map[string]types.AttributeValue{ 834 | ":version": &types.AttributeValueMemberN{Value: strconv.FormatInt(t.latestVersion, 10)}, 835 | } 836 | } 837 | 838 | return input 839 | } 840 | 841 | func (t *TokenBucketDynamoDB) loadStateFromDynamoDB(resp *dynamodb.GetItemOutput) (TokenBucketState, error) { 842 | state := TokenBucketState{} 843 | 844 | err := attributevalue.Unmarshal(resp.Item[dynamoDBBucketLastKey], &state.Last) 845 | if err != nil { 846 | return state, fmt.Errorf("unmarshal dynamodb Last attribute failed: %w", err) 847 | } 848 | 849 | err = attributevalue.Unmarshal(resp.Item[dynamoDBBucketAvailableKey], &state.Available) 850 | if err != nil { 851 | return state, errors.Wrap(err, "unmarshal of dynamodb item attribute failed") 852 | } 853 | 854 | if t.raceCheck { 855 | err = attributevalue.Unmarshal(resp.Item[dynamoDBBucketVersionKey], &t.latestVersion) 856 | if err != nil { 857 | return state, fmt.Errorf("unmarshal dynamodb Version attribute failed: %w", err) 858 | } 859 | } 860 | 861 | return state, nil 862 | } 863 | 864 | // CosmosDBTokenBucketItem represents a document in CosmosDB. 865 | type CosmosDBTokenBucketItem struct { 866 | ID string `json:"id"` 867 | PartitionKey string `json:"partitionKey"` 868 | State TokenBucketState `json:"state"` 869 | Version int64 `json:"version"` 870 | TTL int64 `json:"ttl"` 871 | } 872 | 873 | // TokenBucketCosmosDB is a CosmosDB implementation of a TokenBucketStateBackend. 874 | type TokenBucketCosmosDB struct { 875 | client *azcosmos.ContainerClient 876 | partitionKey string 877 | id string 878 | ttl time.Duration 879 | raceCheck bool 880 | latestVersion int64 881 | } 882 | 883 | // NewTokenBucketCosmosDB creates a new TokenBucketCosmosDB instance. 884 | // PartitionKey is the key used to store all the implementation in CosmosDB. 885 | // TTL is the TTL of the stored item. 886 | // 887 | // If raceCheck is true and the item in CosmosDB is modified in between State() and SetState() calls then 888 | // ErrRaceCondition is returned. 889 | func NewTokenBucketCosmosDB(client *azcosmos.ContainerClient, partitionKey string, ttl time.Duration, raceCheck bool) *TokenBucketCosmosDB { 890 | return &TokenBucketCosmosDB{ 891 | client: client, 892 | partitionKey: partitionKey, 893 | id: "token-bucket-" + partitionKey, 894 | ttl: ttl, 895 | raceCheck: raceCheck, 896 | } 897 | } 898 | 899 | func (t *TokenBucketCosmosDB) State(ctx context.Context) (TokenBucketState, error) { 900 | var item CosmosDBTokenBucketItem 901 | resp, err := t.client.ReadItem(ctx, azcosmos.NewPartitionKey().AppendString(t.partitionKey), t.id, &azcosmos.ItemOptions{}) 902 | if err != nil { 903 | var respErr *azcore.ResponseError 904 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { 905 | return TokenBucketState{}, nil 906 | } 907 | 908 | return TokenBucketState{}, err 909 | } 910 | 911 | err = json.Unmarshal(resp.Value, &item) 912 | if err != nil { 913 | return TokenBucketState{}, errors.Wrap(err, "failed to decode state from Cosmos DB") 914 | } 915 | 916 | if time.Now().Unix() > item.TTL { 917 | return TokenBucketState{}, nil 918 | } 919 | 920 | if t.raceCheck { 921 | t.latestVersion = item.Version 922 | } 923 | 924 | return item.State, nil 925 | } 926 | 927 | func (t *TokenBucketCosmosDB) SetState(ctx context.Context, state TokenBucketState) error { 928 | var err error 929 | done := make(chan struct{}, 1) 930 | 931 | item := CosmosDBTokenBucketItem{ 932 | ID: t.id, 933 | PartitionKey: t.partitionKey, 934 | State: state, 935 | Version: t.latestVersion + 1, 936 | TTL: time.Now().Add(t.ttl).Unix(), 937 | } 938 | 939 | value, err := json.Marshal(item) 940 | if err != nil { 941 | return errors.Wrap(err, "failed to encode state to JSON") 942 | } 943 | 944 | go func() { 945 | defer close(done) 946 | _, err = t.client.UpsertItem(ctx, azcosmos.NewPartitionKey().AppendString(t.partitionKey), value, &azcosmos.ItemOptions{}) 947 | }() 948 | 949 | select { 950 | case <-done: 951 | case <-ctx.Done(): 952 | return ctx.Err() 953 | } 954 | 955 | if err != nil { 956 | var respErr *azcore.ResponseError 957 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict && t.raceCheck { 958 | return ErrRaceCondition 959 | } 960 | 961 | return errors.Wrap(err, "failed to save keys to Cosmos DB") 962 | } 963 | 964 | return nil 965 | } 966 | 967 | func (t *TokenBucketCosmosDB) Reset(ctx context.Context) error { 968 | state := TokenBucketState{ 969 | Last: 0, 970 | Available: 0, 971 | } 972 | 973 | return t.SetState(ctx, state) 974 | } 975 | -------------------------------------------------------------------------------- /tokenbucket_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/google/uuid" 11 | l "github.com/mennanov/limiters" 12 | "github.com/redis/go-redis/v9" 13 | ) 14 | 15 | var tokenBucketUniformTestCases = []struct { 16 | capacity int64 17 | refillRate time.Duration 18 | requestCount int64 19 | requestRate time.Duration 20 | missExpected int 21 | delta float64 22 | }{ 23 | { 24 | 5, 25 | time.Millisecond * 100, 26 | 10, 27 | time.Millisecond * 25, 28 | 3, 29 | 1, 30 | }, 31 | { 32 | 10, 33 | time.Millisecond * 100, 34 | 13, 35 | time.Millisecond * 25, 36 | 0, 37 | 1, 38 | }, 39 | { 40 | 10, 41 | time.Millisecond * 100, 42 | 15, 43 | time.Millisecond * 33, 44 | 2, 45 | 1, 46 | }, 47 | { 48 | 10, 49 | time.Millisecond * 100, 50 | 16, 51 | time.Millisecond * 33, 52 | 2, 53 | 1, 54 | }, 55 | { 56 | 10, 57 | time.Millisecond * 10, 58 | 20, 59 | time.Millisecond * 10, 60 | 0, 61 | 1, 62 | }, 63 | { 64 | 1, 65 | time.Millisecond * 200, 66 | 10, 67 | time.Millisecond * 100, 68 | 5, 69 | 2, 70 | }, 71 | } 72 | 73 | // tokenBuckets returns all the possible TokenBucket combinations. 74 | func (s *LimitersTestSuite) tokenBuckets(capacity int64, refillRate time.Duration, clock l.Clock) map[string]*l.TokenBucket { 75 | buckets := make(map[string]*l.TokenBucket) 76 | for lockerName, locker := range s.lockers(true) { 77 | for backendName, backend := range s.tokenBucketBackends() { 78 | buckets[lockerName+":"+backendName] = l.NewTokenBucket(capacity, refillRate, locker, backend, clock, s.logger) 79 | } 80 | } 81 | 82 | return buckets 83 | } 84 | 85 | func (s *LimitersTestSuite) tokenBucketBackends() map[string]l.TokenBucketStateBackend { 86 | return map[string]l.TokenBucketStateBackend{ 87 | "TokenBucketInMemory": l.NewTokenBucketInMemory(), 88 | "TokenBucketEtcdNoRaceCheck": l.NewTokenBucketEtcd(s.etcdClient, uuid.New().String(), time.Second, false), 89 | "TokenBucketEtcdWithRaceCheck": l.NewTokenBucketEtcd(s.etcdClient, uuid.New().String(), time.Second, true), 90 | "TokenBucketRedisNoRaceCheck": l.NewTokenBucketRedis(s.redisClient, uuid.New().String(), time.Second, false), 91 | "TokenBucketRedisWithRaceCheck": l.NewTokenBucketRedis(s.redisClient, uuid.New().String(), time.Second, true), 92 | "TokenBucketRedisClusterNoRaceCheck": l.NewTokenBucketRedis(s.redisClusterClient, uuid.New().String(), time.Second, false), 93 | "TokenBucketRedisClusterWithRaceCheck": l.NewTokenBucketRedis(s.redisClusterClient, uuid.New().String(), time.Second, true), 94 | "TokenBucketMemcachedNoRaceCheck": l.NewTokenBucketMemcached(s.memcacheClient, uuid.New().String(), time.Second, false), 95 | "TokenBucketMemcachedWithRaceCheck": l.NewTokenBucketMemcached(s.memcacheClient, uuid.New().String(), time.Second, true), 96 | "TokenBucketDynamoDBNoRaceCheck": l.NewTokenBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, false), 97 | "TokenBucketDynamoDBWithRaceCheck": l.NewTokenBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, time.Second, true), 98 | "TokenBucketCosmosDBNoRaceCheck": l.NewTokenBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), time.Second, false), 99 | "TokenBucketCosmosDBWithRaceCheck": l.NewTokenBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), time.Second, true), 100 | } 101 | } 102 | 103 | func (s *LimitersTestSuite) TestTokenBucketRealClock() { 104 | clock := l.NewSystemClock() 105 | for _, testCase := range tokenBucketUniformTestCases { 106 | for name, bucket := range s.tokenBuckets(testCase.capacity, testCase.refillRate, clock) { 107 | s.Run(name, func() { 108 | wg := sync.WaitGroup{} 109 | // mu guards the miss variable below. 110 | var mu sync.Mutex 111 | miss := 0 112 | for i := int64(0); i < testCase.requestCount; i++ { 113 | // No pause for the first request. 114 | if i > 0 { 115 | clock.Sleep(testCase.requestRate) 116 | } 117 | wg.Add(1) 118 | go func(bucket *l.TokenBucket) { 119 | defer wg.Done() 120 | if _, err := bucket.Limit(context.TODO()); err != nil { 121 | s.Equal(l.ErrLimitExhausted, err, "%T %v", bucket, bucket) 122 | mu.Lock() 123 | miss++ 124 | mu.Unlock() 125 | } 126 | }(bucket) 127 | } 128 | wg.Wait() 129 | s.InDelta(testCase.missExpected, miss, testCase.delta, testCase) 130 | }) 131 | } 132 | } 133 | } 134 | 135 | func (s *LimitersTestSuite) TestTokenBucketFakeClock() { 136 | for _, testCase := range tokenBucketUniformTestCases { 137 | clock := newFakeClock() 138 | for name, bucket := range s.tokenBuckets(testCase.capacity, testCase.refillRate, clock) { 139 | s.Run(name, func() { 140 | clock.reset() 141 | miss := 0 142 | for i := int64(0); i < testCase.requestCount; i++ { 143 | // No pause for the first request. 144 | if i > 0 { 145 | clock.Sleep(testCase.requestRate) 146 | } 147 | if _, err := bucket.Limit(context.TODO()); err != nil { 148 | s.Equal(l.ErrLimitExhausted, err) 149 | miss++ 150 | } 151 | } 152 | s.InDelta(testCase.missExpected, miss, testCase.delta, testCase) 153 | }) 154 | } 155 | } 156 | } 157 | 158 | func (s *LimitersTestSuite) TestTokenBucketOverflow() { 159 | clock := newFakeClock() 160 | rate := time.Second 161 | for name, bucket := range s.tokenBuckets(2, rate, clock) { 162 | s.Run(name, func() { 163 | clock.reset() 164 | wait, err := bucket.Limit(context.TODO()) 165 | s.Require().NoError(err) 166 | s.Equal(time.Duration(0), wait) 167 | wait, err = bucket.Limit(context.TODO()) 168 | s.Require().NoError(err) 169 | s.Equal(time.Duration(0), wait) 170 | // The third call should fail. 171 | wait, err = bucket.Limit(context.TODO()) 172 | s.Require().Equal(l.ErrLimitExhausted, err) 173 | s.Equal(rate, wait) 174 | clock.Sleep(wait) 175 | // Retry the last call. 176 | wait, err = bucket.Limit(context.TODO()) 177 | s.Require().NoError(err) 178 | s.Equal(time.Duration(0), wait) 179 | }) 180 | } 181 | } 182 | 183 | func (s *LimitersTestSuite) TestTokenBucketReset() { 184 | clock := newFakeClock() 185 | rate := time.Second 186 | for name, bucket := range s.tokenBuckets(2, rate, clock) { 187 | s.Run(name, func() { 188 | clock.reset() 189 | wait, err := bucket.Limit(context.TODO()) 190 | s.Require().NoError(err) 191 | s.Equal(time.Duration(0), wait) 192 | wait, err = bucket.Limit(context.TODO()) 193 | s.Require().NoError(err) 194 | s.Equal(time.Duration(0), wait) 195 | // The third call should fail. 196 | wait, err = bucket.Limit(context.TODO()) 197 | s.Require().Equal(l.ErrLimitExhausted, err) 198 | s.Equal(rate, wait) 199 | err = bucket.Reset(context.TODO()) 200 | s.Require().NoError(err) 201 | // Retry the last call. 202 | wait, err = bucket.Limit(context.TODO()) 203 | s.Require().NoError(err) 204 | s.Equal(time.Duration(0), wait) 205 | }) 206 | } 207 | } 208 | 209 | func (s *LimitersTestSuite) TestTokenBucketRefill() { 210 | for name, backend := range s.tokenBucketBackends() { 211 | s.Run(name, func() { 212 | clock := newFakeClock() 213 | 214 | bucket := l.NewTokenBucket(4, time.Millisecond*100, l.NewLockNoop(), backend, clock, s.logger) 215 | sleepDurations := []int{150, 90, 50, 70} 216 | desiredAvailable := []int64{3, 2, 2, 2} 217 | 218 | _, err := bucket.Limit(context.Background()) 219 | s.Require().NoError(err) 220 | 221 | _, err = backend.State(context.Background()) 222 | s.Require().NoError(err, "unable to retrieve backend state") 223 | 224 | for i := range sleepDurations { 225 | clock.Sleep(time.Millisecond * time.Duration(sleepDurations[i])) 226 | 227 | _, err := bucket.Limit(context.Background()) 228 | s.Require().NoError(err) 229 | 230 | state, err := backend.State(context.Background()) 231 | s.Require().NoError(err, "unable to retrieve backend state") 232 | 233 | s.Require().Equal(desiredAvailable[i], state.Available) 234 | } 235 | }) 236 | } 237 | } 238 | 239 | // setTokenBucketStateInOldFormat is a test utility method for writing state in the old format to Redis. 240 | func setTokenBucketStateInOldFormat(ctx context.Context, cli *redis.Client, prefix string, state l.TokenBucketState, ttl time.Duration) error { 241 | _, err := cli.TxPipelined(ctx, func(pipeliner redis.Pipeliner) error { 242 | if err := pipeliner.Set(ctx, fmt.Sprintf("{%s}last", prefix), state.Last, ttl).Err(); err != nil { 243 | return err 244 | } 245 | if err := pipeliner.Set(ctx, fmt.Sprintf("{%s}available", prefix), state.Available, ttl).Err(); err != nil { 246 | return err 247 | } 248 | 249 | return nil 250 | }) 251 | 252 | return err 253 | } 254 | 255 | // TestTokenBucketRedisBackwardCompatibility tests that the new State method can read data written in the old format. 256 | func (s *LimitersTestSuite) TestTokenBucketRedisBackwardCompatibility() { 257 | // Create a new TokenBucketRedis instance 258 | prefix := uuid.New().String() 259 | backend := l.NewTokenBucketRedis(s.redisClient, prefix, time.Second, false) 260 | 261 | // Write state using old format 262 | ctx := context.Background() 263 | expectedState := l.TokenBucketState{ 264 | Last: 12345, 265 | Available: 67890, 266 | } 267 | 268 | // Write directly to Redis using old format 269 | err := setTokenBucketStateInOldFormat(ctx, s.redisClient, prefix, expectedState, time.Second) 270 | s.Require().NoError(err, "Failed to set state using old format") 271 | 272 | // Read state using new format (State) 273 | actualState, err := backend.State(ctx) 274 | s.Require().NoError(err, "Failed to get state using new format") 275 | 276 | // Verify the state is correctly read 277 | s.Equal(expectedState.Last, actualState.Last, "Last values should match") 278 | s.Equal(expectedState.Available, actualState.Available, "Available values should match") 279 | } 280 | 281 | func BenchmarkTokenBuckets(b *testing.B) { 282 | s := new(LimitersTestSuite) 283 | s.SetT(&testing.T{}) 284 | s.SetupSuite() 285 | capacity := int64(1) 286 | rate := time.Second 287 | clock := newFakeClock() 288 | buckets := s.tokenBuckets(capacity, rate, clock) 289 | for name, bucket := range buckets { 290 | b.Run(name, func(b *testing.B) { 291 | for i := 0; i < b.N; i++ { 292 | _, err := bucket.Limit(context.TODO()) 293 | s.Require().NoError(err) 294 | } 295 | }) 296 | } 297 | s.TearDownSuite() 298 | } 299 | --------------------------------------------------------------------------------