├── .gitignore ├── codecov.yml ├── .github ├── dependabot.yaml └── workflows │ ├── staticcheck.yml │ ├── blockwatch.yml │ ├── vet.yml │ ├── golangci-lint.yml │ └── tests.yml ├── .pre-commit-config.yaml ├── examples ├── gprc_service.go ├── helloworld │ ├── helloworld.proto │ ├── helloworld_grpc.pb.go │ └── helloworld.pb.go ├── example_grpc_simple_limiter_test.go └── example_grpc_ip_limiter_test.go ├── revive.toml ├── .golangci.yml ├── LICENSE ├── cosmos_test.go ├── Makefile ├── docker-compose.yml ├── locks_test.go ├── limiters.go ├── cosmosdb.go ├── registry_test.go ├── registry.go ├── dynamodb_test.go ├── go.mod ├── dynamodb.go ├── concurrent_buffer_test.go ├── fixedwindow_test.go ├── README.md ├── locks.go ├── slidingwindow_test.go ├── limiters_test.go ├── concurrent_buffer.go ├── fixedwindow.go ├── tokenbucket_test.go ├── leakybucket_test.go ├── slidingwindow.go └── leakybucket.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | coverage.txt 3 | .vscode -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | patch: off 4 | -------------------------------------------------------------------------------- /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: gomod 4 | directory: / 5 | schedule: 6 | interval: monthly 7 | groups: 8 | all-dependencies: 9 | patterns: 10 | - "*" 11 | -------------------------------------------------------------------------------- /.github/workflows/staticcheck.yml: -------------------------------------------------------------------------------- 1 | name: Static Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | staticcheck: 13 | name: Linter 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: dominikh/staticcheck-action@v1.3.1 18 | with: 19 | version: "latest" -------------------------------------------------------------------------------- /.github/workflows/blockwatch.yml: -------------------------------------------------------------------------------- 1 | name: Blockwatch linter 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | staticcheck: 13 | name: Blockwatch 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v5 17 | with: 18 | fetch-depth: 2 # Required to diff against the base branch 19 | - uses: mennanov/blockwatch-action@v1 20 | -------------------------------------------------------------------------------- /.github/workflows/vet.yml: -------------------------------------------------------------------------------- 1 | name: Linter 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | vet: 13 | name: Go vet 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Setup Go 1.25 18 | uses: actions/setup-go@v5 19 | with: 20 | go-version: '1.25' 21 | - name: Go vet 22 | run: go vet ./... 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: blockwatch 5 | name: blockwatch 6 | entry: bash -c 'git diff --patch --cached --unified=0 | blockwatch' 7 | language: system 8 | stages: [ pre-commit ] 9 | pass_filenames: false 10 | - id: golangci-lint 11 | name: golangci-lint 12 | entry: golangci-lint run 13 | language: system 14 | types: [go] 15 | pass_filenames: false 16 | -------------------------------------------------------------------------------- /examples/gprc_service.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "context" 5 | 6 | pb "github.com/mennanov/limiters/examples/helloworld" 7 | ) 8 | 9 | const ( 10 | Port = ":50051" 11 | ) 12 | 13 | // Server is used to implement helloworld.GreeterServer. 14 | type Server struct { 15 | pb.UnimplementedGreeterServer 16 | } 17 | 18 | // SayHello implements helloworld.GreeterServer. 19 | func (s *Server) SayHello(_ context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { 20 | return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil 21 | } 22 | 23 | var _ pb.GreeterServer = new(Server) 24 | -------------------------------------------------------------------------------- /examples/helloworld/helloworld.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package helloworld; 4 | option go_package = "github.com/mennanov/limiters/examples/helloworld"; 5 | 6 | // The greeting service definition. 7 | service Greeter { 8 | // Sends a greeting 9 | rpc SayHello (HelloRequest) returns (HelloReply) { 10 | } 11 | } 12 | 13 | // The request message containing the user's name. 14 | message HelloRequest { 15 | string name = 1; 16 | } 17 | 18 | // The response message containing the greetings 19 | message HelloReply { 20 | string message = 1; 21 | } 22 | 23 | message GoodbyeRequest { 24 | string name = 1; 25 | } 26 | 27 | message GoodbyeReply { 28 | string message = 1; 29 | } -------------------------------------------------------------------------------- /revive.toml: -------------------------------------------------------------------------------- 1 | ignoreGeneratedHeader = false 2 | severity = "warning" 3 | confidence = 0.8 4 | errorCode = 0 5 | warningCode = 0 6 | 7 | [rule.blank-imports] 8 | [rule.context-as-argument] 9 | [rule.context-keys-type] 10 | [rule.dot-imports] 11 | [rule.error-return] 12 | [rule.error-strings] 13 | [rule.error-naming] 14 | [rule.exported] 15 | [rule.if-return] 16 | [rule.increment-decrement] 17 | [rule.var-naming] 18 | [rule.var-declaration] 19 | [rule.package-comments] 20 | [rule.range] 21 | [rule.receiver-naming] 22 | [rule.time-naming] 23 | [rule.unexported-return] 24 | [rule.indent-error-flow] 25 | [rule.errorf] 26 | [rule.empty-block] 27 | [rule.superfluous-else] 28 | [rule.unused-parameter] 29 | [rule.unreachable-code] 30 | [rule.redefines-builtin-id] 31 | -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | on: 3 | push: 4 | branches: 5 | - main 6 | - master 7 | pull_request: 8 | 9 | permissions: 10 | contents: read 11 | # Optional: allow read access to pull request. Use with `only-new-issues` option. 12 | # pull-requests: read 13 | 14 | jobs: 15 | golangci: 16 | name: lint 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v4 20 | - uses: actions/setup-go@v5 21 | with: 22 | go-version: stable 23 | - name: golangci-lint 24 | uses: golangci/golangci-lint-action@v8 25 | with: 26 | # 27 | version: "v2.5.0" 28 | # 29 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | linters: 3 | default: all 4 | disable: 5 | - cyclop 6 | - depguard 7 | - dupl 8 | - dupword 9 | - exhaustruct 10 | - forcetypeassert 11 | - funcorder 12 | - funlen 13 | - gochecknoglobals 14 | - goconst 15 | - gocritic 16 | - godox 17 | - gomoddirectives 18 | - gosec 19 | - inamedparam 20 | - intrange 21 | - lll 22 | - mnd 23 | - musttag 24 | - paralleltest 25 | - perfsprint 26 | - recvcheck 27 | - revive 28 | - tagliatelle 29 | - testifylint 30 | - unparam 31 | - varnamelen 32 | - wrapcheck 33 | - wsl 34 | exclusions: 35 | generated: lax 36 | presets: 37 | - comments 38 | - common-false-positives 39 | - legacy 40 | - std-error-handling 41 | paths: 42 | - third_party$ 43 | - builtin$ 44 | - examples$ 45 | formatters: 46 | enable: 47 | - gci 48 | - gofmt 49 | - gofumpt 50 | - goimports 51 | exclusions: 52 | generated: lax 53 | paths: 54 | - third_party$ 55 | - builtin$ 56 | - examples$ 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Renat Mennanov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cosmos_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 7 | ) 8 | 9 | const ( 10 | testCosmosDBName = "limiters-db-test" 11 | testCosmosContainerName = "limiters-container-test" 12 | ) 13 | 14 | var defaultTTL int32 = 86400 15 | 16 | func CreateCosmosDBContainer(ctx context.Context, client *azcosmos.Client) error { 17 | resp, err := client.CreateDatabase(ctx, azcosmos.DatabaseProperties{ 18 | ID: testCosmosDBName, 19 | }, &azcosmos.CreateDatabaseOptions{}) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | dbClient, err := client.NewDatabase(resp.DatabaseProperties.ID) 25 | if err != nil { 26 | return err 27 | } 28 | 29 | _, err = dbClient.CreateContainer(ctx, azcosmos.ContainerProperties{ 30 | ID: testCosmosContainerName, 31 | DefaultTimeToLive: &defaultTTL, 32 | PartitionKeyDefinition: azcosmos.PartitionKeyDefinition{ 33 | Paths: []string{`/partitionKey`}, 34 | }, 35 | }, &azcosmos.CreateContainerOptions{}) 36 | 37 | return err 38 | } 39 | 40 | func DeleteCosmosDBContainer(ctx context.Context, client *azcosmos.Client) error { 41 | dbClient, err := client.NewDatabase(testCosmosDBName) 42 | if err != nil { 43 | return err 44 | } 45 | 46 | _, err = dbClient.Delete(ctx, &azcosmos.DeleteDatabaseOptions{}) 47 | 48 | return err 49 | } 50 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # 2 | SERVICES = ETCD_ENDPOINTS="127.0.0.1:2379" REDIS_ADDR="127.0.0.1:6379" REDIS_NODES="127.0.0.1:11000,127.0.0.1:11001,127.0.0.1:11002,127.0.0.1:11003,127.0.0.1:11004,127.0.0.1:11005" ZOOKEEPER_ENDPOINTS="127.0.0.1" CONSUL_ADDR="127.0.0.1:8500" AWS_ADDR="127.0.0.1:8000" MEMCACHED_ADDR="127.0.0.1:11211" POSTGRES_URL="postgres://postgres:postgres@localhost:5432/?sslmode=disable" COSMOS_ADDR="127.0.0.1:8081" 3 | # 4 | 5 | all: gofumpt goimports lint test benchmark 6 | 7 | docker-compose-up: 8 | docker compose up -d 9 | 10 | test: docker-compose-up 11 | @$(SERVICES) go test -race -v -failfast 12 | 13 | benchmark: docker-compose-up 14 | @$(SERVICES) go test -race -run=nonexistent -bench=. 15 | 16 | lint: 17 | # 18 | @(which golangci-lint && golangci-lint --version | grep 2.5.0) || (curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/v2.5.0/install.sh | sh -s -- -b $(shell go env GOPATH)/bin v2.5.0) 19 | # 20 | golangci-lint run --fix ./... 21 | 22 | goimports: 23 | @which goimports 2>&1 > /dev/null || go install golang.org/x/tools/cmd/goimports@latest 24 | goimports -w . 25 | 26 | gofumpt: 27 | @which gofumpt 2>&1 > /dev/null || go install mvdan.cc/gofumpt@latest 28 | gofumpt -l -w . 29 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | # 3 | etcd: 4 | image: gcr.io/etcd-development/etcd:v3.6.5 5 | environment: 6 | ETCD_LISTEN_CLIENT_URLS: "http://0.0.0.0:2379" 7 | ETCD_ADVERTISE_CLIENT_URLS: "http://0.0.0.0:2379" 8 | ports: 9 | - "2379:2379" 10 | 11 | redis: 12 | image: redis:8.2 13 | ports: 14 | - "6379:6379" 15 | 16 | redis-cluster: 17 | image: grokzen/redis-cluster:7.0.10 18 | environment: 19 | IP: "0.0.0.0" 20 | INITIAL_PORT: 11000 21 | ports: 22 | - "11000-11005:11000-11005" 23 | 24 | memcached: 25 | image: memcached:1.6 26 | ports: 27 | - "11211:11211" 28 | 29 | consul: 30 | image: hashicorp/consul:1.21 31 | ports: 32 | - "8500:8500" 33 | 34 | zookeeper: 35 | image: zookeeper:3.9 36 | ports: 37 | - "2181:2181" 38 | 39 | dynamodb-local: 40 | image: "amazon/dynamodb-local:latest" 41 | command: "-jar DynamoDBLocal.jar -inMemory" 42 | ports: 43 | - "8000:8000" 44 | 45 | postgresql: 46 | image: postgres:18 47 | environment: 48 | POSTGRES_PASSWORD: postgres 49 | ports: 50 | - "5432:5432" 51 | 52 | cosmos: 53 | image: mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:vnext-preview 54 | command: [ "--protocol", "http" ] 55 | environment: 56 | PROTOCOL: http 57 | COSMOS_HTTP_CONNECTION_WITHOUT_TLS_ALLOWED: "true" 58 | ports: 59 | - "8081:8081" 60 | # 61 | -------------------------------------------------------------------------------- /locks_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | "github.com/mennanov/limiters" 10 | ) 11 | 12 | func (s *LimitersTestSuite) useLock(lock limiters.DistLocker, shared *int, sleep time.Duration) { 13 | s.NoError(lock.Lock(context.TODO())) 14 | 15 | sh := *shared 16 | // Imitate heavy work... 17 | time.Sleep(sleep) 18 | // Check for the race condition. 19 | s.Equal(sh, *shared) 20 | *shared++ 21 | 22 | s.NoError(lock.Unlock(context.Background())) 23 | } 24 | 25 | func (s *LimitersTestSuite) TestDistLockers() { 26 | locks1 := s.distLockers(false) 27 | 28 | locks2 := s.distLockers(false) 29 | for name := range locks1 { 30 | s.Run(name, func() { 31 | var shared int 32 | 33 | rounds := 6 34 | sleep := time.Millisecond * 50 35 | 36 | for i := 0; i < rounds; i++ { 37 | wg := sync.WaitGroup{} 38 | wg.Add(2) 39 | 40 | go func(k string) { 41 | defer wg.Done() 42 | 43 | s.useLock(locks1[k], &shared, sleep) 44 | }(name) 45 | go func(k string) { 46 | defer wg.Done() 47 | 48 | s.useLock(locks2[k], &shared, sleep) 49 | }(name) 50 | 51 | wg.Wait() 52 | } 53 | 54 | s.Equal(rounds*2, shared) 55 | }) 56 | } 57 | } 58 | 59 | func BenchmarkDistLockers(b *testing.B) { 60 | s := new(LimitersTestSuite) 61 | s.SetT(&testing.T{}) 62 | s.SetupSuite() 63 | 64 | lockers := s.distLockers(false) 65 | for name, locker := range lockers { 66 | b.Run(name, func(b *testing.B) { 67 | for i := 0; i < b.N; i++ { 68 | s.Require().NoError(locker.Lock(context.Background())) 69 | s.Require().NoError(locker.Unlock(context.Background())) 70 | } 71 | }) 72 | } 73 | 74 | s.TearDownSuite() 75 | } 76 | -------------------------------------------------------------------------------- /limiters.go: -------------------------------------------------------------------------------- 1 | // Package limiters provides general purpose rate limiter implementations. 2 | package limiters 3 | 4 | import ( 5 | "errors" 6 | "log" 7 | "time" 8 | ) 9 | 10 | var ( 11 | // ErrLimitExhausted is returned by the Limiter in case the number of requests overflows the capacity of a Limiter. 12 | ErrLimitExhausted = errors.New("requests limit exhausted") 13 | 14 | // ErrRaceCondition is returned when there is a race condition while saving a state of a rate limiter. 15 | ErrRaceCondition = errors.New("race condition detected") 16 | ) 17 | 18 | // Logger wraps the Log method for logging. 19 | type Logger interface { 20 | // Log logs the given arguments. 21 | Log(v ...interface{}) 22 | } 23 | 24 | // StdLogger implements the Logger interface. 25 | type StdLogger struct{} 26 | 27 | // NewStdLogger creates a new instance of StdLogger. 28 | func NewStdLogger() *StdLogger { 29 | return &StdLogger{} 30 | } 31 | 32 | // Log delegates the logging to the std logger. 33 | func (l *StdLogger) Log(v ...interface{}) { 34 | log.Println(v...) 35 | } 36 | 37 | // Clock encapsulates a system Clock. 38 | // Used. 39 | type Clock interface { 40 | // Now returns the current system time. 41 | Now() time.Time 42 | } 43 | 44 | // SystemClock implements the Clock interface by using the real system clock. 45 | type SystemClock struct{} 46 | 47 | // NewSystemClock creates a new instance of SystemClock. 48 | func NewSystemClock() *SystemClock { 49 | return &SystemClock{} 50 | } 51 | 52 | // Now returns the current system time. 53 | func (c *SystemClock) Now() time.Time { 54 | return time.Now() 55 | } 56 | 57 | // Sleep blocks (sleeps) for the given duration. 58 | func (c *SystemClock) Sleep(d time.Duration) { 59 | time.Sleep(d) 60 | } 61 | -------------------------------------------------------------------------------- /cosmosdb.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/Azure/azure-sdk-for-go/sdk/azcore" 10 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 11 | "github.com/pkg/errors" 12 | ) 13 | 14 | type cosmosItem struct { 15 | Count int64 `json:"Count"` 16 | PartitionKey string `json:"partitionKey"` 17 | ID string `json:"id"` 18 | TTL int32 `json:"ttl"` 19 | } 20 | 21 | // incrementCosmosItemRMW increments the count of a cosmosItem in a read-modify-write loop. 22 | // It uses optimistic concurrency control (ETags) to ensure consistency. 23 | func incrementCosmosItemRMW(ctx context.Context, client *azcosmos.ContainerClient, partitionKey, id string, ttl time.Duration) (int64, error) { 24 | pk := azcosmos.NewPartitionKey().AppendString(partitionKey) 25 | 26 | for { 27 | err := ctx.Err() 28 | if err != nil { 29 | return 0, err 30 | } 31 | 32 | resp, err := client.ReadItem(ctx, pk, id, nil) 33 | if err != nil { 34 | return 0, errors.Wrap(err, "read of cosmos value failed") 35 | } 36 | 37 | var existing cosmosItem 38 | 39 | err = json.Unmarshal(resp.Value, &existing) 40 | if err != nil { 41 | return 0, errors.Wrap(err, "unmarshal of cosmos value failed") 42 | } 43 | 44 | existing.Count++ 45 | existing.TTL = int32(ttl) 46 | 47 | updated, err := json.Marshal(existing) 48 | if err != nil { 49 | return 0, err 50 | } 51 | 52 | _, err = client.ReplaceItem(ctx, pk, id, updated, &azcosmos.ItemOptions{ 53 | IfMatchEtag: &resp.ETag, 54 | }) 55 | if err == nil { 56 | return existing.Count, nil 57 | } 58 | 59 | var respErr *azcore.ResponseError 60 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusPreconditionFailed { 61 | continue 62 | } 63 | 64 | return 0, errors.Wrap(err, "replace of cosmos value failed") 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /registry_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strconv" 7 | "testing" 8 | "time" 9 | 10 | "github.com/mennanov/limiters" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | type testingLimiter struct{} 16 | 17 | func newTestingLimiter() *testingLimiter { 18 | return &testingLimiter{} 19 | } 20 | 21 | func (l *testingLimiter) Limit(context.Context) (time.Duration, error) { 22 | return 0, nil 23 | } 24 | 25 | func TestRegistry_GetOrCreate(t *testing.T) { 26 | registry := limiters.NewRegistry() 27 | called := false 28 | clock := newFakeClock() 29 | limiter := newTestingLimiter() 30 | l := registry.GetOrCreate("key", func() interface{} { 31 | called = true 32 | 33 | return limiter 34 | }, time.Second, clock.Now()) 35 | assert.Equal(t, limiter, l) 36 | // Verify that the closure was called to create a value. 37 | assert.True(t, called) 38 | called = false 39 | l = registry.GetOrCreate("key", func() interface{} { 40 | called = true 41 | 42 | return newTestingLimiter() 43 | }, time.Second, clock.Now()) 44 | assert.Equal(t, limiter, l) 45 | // Verify that the closure was NOT called to create a value as it already exists. 46 | assert.False(t, called) 47 | } 48 | 49 | func TestRegistry_DeleteExpired(t *testing.T) { 50 | registry := limiters.NewRegistry() 51 | clock := newFakeClock() 52 | // Add limiters to the registry. 53 | for i := 1; i <= 10; i++ { 54 | registry.GetOrCreate(fmt.Sprintf("key%d", i), func() interface{} { 55 | return newTestingLimiter() 56 | }, time.Second*time.Duration(i), clock.Now()) 57 | } 58 | 59 | clock.Sleep(time.Second * 3) 60 | // "touch" the "key3" value that is about to be expired so that its expiration time is extended for 1s. 61 | registry.GetOrCreate("key3", func() interface{} { 62 | return newTestingLimiter() 63 | }, time.Second, clock.Now()) 64 | 65 | assert.Equal(t, 2, registry.DeleteExpired(clock.Now())) 66 | 67 | for i := 1; i <= 10; i++ { 68 | if i <= 2 { 69 | assert.False(t, registry.Exists(fmt.Sprintf("key%d", i))) 70 | } else { 71 | assert.True(t, registry.Exists(fmt.Sprintf("key%d", i))) 72 | } 73 | } 74 | } 75 | 76 | func TestRegistry_Delete(t *testing.T) { 77 | registry := limiters.NewRegistry() 78 | clock := newFakeClock() 79 | item := &struct{}{} 80 | require.Equal(t, item, registry.GetOrCreate("key", func() interface{} { 81 | return item 82 | }, time.Second, clock.Now())) 83 | require.Equal(t, item, registry.GetOrCreate("key", func() interface{} { 84 | return &struct{}{} 85 | }, time.Second, clock.Now())) 86 | registry.Delete("key") 87 | assert.False(t, registry.Exists("key")) 88 | } 89 | 90 | // This test is expected to fail when run with the --race flag. 91 | func TestRegistry_ConcurrentUsage(t *testing.T) { 92 | registry := limiters.NewRegistry() 93 | clock := newFakeClock() 94 | 95 | for i := 0; i < 10; i++ { 96 | go func(i int) { 97 | registry.GetOrCreate(strconv.Itoa(i), func() interface{} { return &struct{}{} }, 0, clock.Now()) 98 | }(i) 99 | } 100 | 101 | for i := 0; i < 10; i++ { 102 | go func(i int) { 103 | registry.DeleteExpired(clock.Now()) 104 | }(i) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Go tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | go-test: 13 | 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | # 18 | go-version: [ '1.23', '1.24', '1.25' ] 19 | # 20 | 21 | services: 22 | # 23 | etcd: 24 | image: gcr.io/etcd-development/etcd:v3.6.5 25 | env: 26 | ETCD_LISTEN_CLIENT_URLS: "http://0.0.0.0:2379" 27 | ETCD_ADVERTISE_CLIENT_URLS: "http://0.0.0.0:2379" 28 | ports: 29 | - 2379:2379 30 | redis: 31 | image: redis:8.2 32 | ports: 33 | - 6379:6379 34 | redis-cluster: 35 | image: grokzen/redis-cluster:7.0.10 36 | env: 37 | IP: 0.0.0.0 38 | INITIAL_PORT: 11000 39 | ports: 40 | - 11000-11005:11000-11005 41 | memcached: 42 | image: memcached:1.6 43 | ports: 44 | - 11211:11211 45 | consul: 46 | image: hashicorp/consul:1.21 47 | ports: 48 | - 8500:8500 49 | zookeeper: 50 | image: zookeeper:3.9 51 | ports: 52 | - 2181:2181 53 | postgresql: 54 | image: postgres:18 55 | env: 56 | POSTGRES_PASSWORD: postgres 57 | ports: 58 | - 5432:5432 59 | cosmos: 60 | image: mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:vnext-preview 61 | env: 62 | PROTOCOL: http 63 | COSMOS_HTTP_CONNECTION_WITHOUT_TLS_ALLOWED: "true" 64 | ports: 65 | - "8081:8081" 66 | # 67 | 68 | steps: 69 | - uses: actions/checkout@v4 70 | - name: Setup Go ${{ matrix.go-version }} 71 | uses: actions/setup-go@v5 72 | with: 73 | go-version: ${{ matrix.go-version }} 74 | - name: Setup DynamoDB Local 75 | uses: rrainn/dynamodb-action@v4.0.0 76 | with: 77 | port: 8000 78 | - name: Run tests 79 | # 80 | env: 81 | ETCD_ENDPOINTS: 'localhost:2379' 82 | REDIS_ADDR: 'localhost:6379' 83 | REDIS_NODES: 'localhost:11000,localhost:11001,localhost:11002,localhost:11003,localhost:11004,localhost:11005' 84 | CONSUL_ADDR: 'localhost:8500' 85 | ZOOKEEPER_ENDPOINTS: 'localhost:2181' 86 | AWS_ADDR: 'localhost:8000' 87 | MEMCACHED_ADDR: '127.0.0.1:11211' 88 | POSTGRES_URL: postgres://postgres:postgres@localhost:5432/?sslmode=disable 89 | COSMOS_ADDR: '127.0.0.1:8081' 90 | # 91 | run: | 92 | if [ "${{ matrix.go-version }}" = "1.25" ]; then 93 | go test -race -coverprofile=coverage.txt -covermode=atomic ./... 94 | else 95 | go test -race ./... 96 | fi 97 | - uses: codecov/codecov-action@v5 98 | if: matrix.go-version == '1.25' 99 | with: 100 | verbose: true 101 | 102 | -------------------------------------------------------------------------------- /registry.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "container/heap" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // pqItem is an item in the priority queue. 10 | type pqItem struct { 11 | value interface{} 12 | exp time.Time 13 | index int 14 | key string 15 | } 16 | 17 | // gcPq is a priority queue. 18 | type gcPq []*pqItem 19 | 20 | func (pq gcPq) Len() int { return len(pq) } 21 | 22 | func (pq gcPq) Less(i, j int) bool { 23 | return pq[i].exp.Before(pq[j].exp) 24 | } 25 | 26 | func (pq gcPq) Swap(i, j int) { 27 | pq[i], pq[j] = pq[j], pq[i] 28 | pq[i].index = i 29 | pq[j].index = j 30 | } 31 | 32 | func (pq *gcPq) Push(x interface{}) { 33 | n := len(*pq) 34 | item := x.(*pqItem) 35 | item.index = n 36 | *pq = append(*pq, item) 37 | } 38 | 39 | func (pq *gcPq) Pop() interface{} { 40 | old := *pq 41 | n := len(old) 42 | item := old[n-1] 43 | item.index = -1 // for safety 44 | *pq = old[0 : n-1] 45 | 46 | return item 47 | } 48 | 49 | // Registry is a thread-safe garbage-collectable registry of values. 50 | type Registry struct { 51 | // Guards all the fields below it. 52 | mx sync.Mutex 53 | pq *gcPq 54 | m map[string]*pqItem 55 | } 56 | 57 | // NewRegistry creates a new instance of Registry. 58 | func NewRegistry() *Registry { 59 | pq := make(gcPq, 0) 60 | 61 | return &Registry{pq: &pq, m: make(map[string]*pqItem)} 62 | } 63 | 64 | // GetOrCreate gets an existing value by key and updates its expiration time. 65 | // If the key lookup fails it creates a new value by calling the provided value closure and puts it on the queue. 66 | func (r *Registry) GetOrCreate(key string, value func() interface{}, ttl time.Duration, now time.Time) interface{} { 67 | r.mx.Lock() 68 | defer r.mx.Unlock() 69 | 70 | item, ok := r.m[key] 71 | if ok { 72 | // Update the expiration time. 73 | item.exp = now.Add(ttl) 74 | heap.Fix(r.pq, item.index) 75 | } else { 76 | item = &pqItem{ 77 | value: value(), 78 | exp: now.Add(ttl), 79 | key: key, 80 | } 81 | heap.Push(r.pq, item) 82 | r.m[key] = item 83 | } 84 | 85 | return item.value 86 | } 87 | 88 | // DeleteExpired deletes expired items from the registry and returns the number of deleted items. 89 | func (r *Registry) DeleteExpired(now time.Time) int { 90 | r.mx.Lock() 91 | defer r.mx.Unlock() 92 | 93 | c := 0 94 | 95 | for len(*r.pq) != 0 { 96 | item := (*r.pq)[0] 97 | if now.Before(item.exp) { 98 | break 99 | } 100 | 101 | delete(r.m, item.key) 102 | heap.Pop(r.pq) 103 | 104 | c++ 105 | } 106 | 107 | return c 108 | } 109 | 110 | // Delete deletes an item from the registry. 111 | func (r *Registry) Delete(key string) { 112 | r.mx.Lock() 113 | defer r.mx.Unlock() 114 | 115 | item, ok := r.m[key] 116 | if !ok { 117 | return 118 | } 119 | 120 | delete(r.m, key) 121 | heap.Remove(r.pq, item.index) 122 | } 123 | 124 | // Exists returns true if an item with the given key exists in the registry. 125 | func (r *Registry) Exists(key string) bool { 126 | r.mx.Lock() 127 | defer r.mx.Unlock() 128 | 129 | _, ok := r.m[key] 130 | 131 | return ok 132 | } 133 | 134 | // Len returns the number of items in the registry. 135 | func (r *Registry) Len() int { 136 | r.mx.Lock() 137 | defer r.mx.Unlock() 138 | 139 | return len(*r.pq) 140 | } 141 | -------------------------------------------------------------------------------- /dynamodb_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/aws/aws-sdk-go-v2/aws" 8 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 9 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" 10 | "github.com/mennanov/limiters" 11 | "github.com/pkg/errors" 12 | ) 13 | 14 | const testDynamoDBTableName = "limiters-test" 15 | 16 | func CreateTestDynamoDBTable(ctx context.Context, client *dynamodb.Client) error { 17 | _, err := client.CreateTable(ctx, &dynamodb.CreateTableInput{ 18 | TableName: aws.String(testDynamoDBTableName), 19 | BillingMode: types.BillingModePayPerRequest, 20 | KeySchema: []types.KeySchemaElement{ 21 | { 22 | AttributeName: aws.String("PK"), 23 | KeyType: types.KeyTypeHash, 24 | }, 25 | { 26 | AttributeName: aws.String("SK"), 27 | KeyType: types.KeyTypeRange, 28 | }, 29 | }, 30 | AttributeDefinitions: []types.AttributeDefinition{ 31 | { 32 | AttributeName: aws.String("PK"), 33 | AttributeType: types.ScalarAttributeTypeS, 34 | }, 35 | { 36 | AttributeName: aws.String("SK"), 37 | AttributeType: types.ScalarAttributeTypeS, 38 | }, 39 | }, 40 | }) 41 | if err != nil { 42 | return errors.Wrap(err, "create test dynamodb table failed") 43 | } 44 | 45 | _, err = client.UpdateTimeToLive(ctx, &dynamodb.UpdateTimeToLiveInput{ 46 | TableName: aws.String(testDynamoDBTableName), 47 | TimeToLiveSpecification: &types.TimeToLiveSpecification{ 48 | AttributeName: aws.String("TTL"), 49 | Enabled: aws.Bool(true), 50 | }, 51 | }) 52 | if err != nil { 53 | return errors.Wrap(err, "set dynamodb ttl failed") 54 | } 55 | 56 | ctx, cancel := context.WithTimeout(ctx, time.Second*10) 57 | defer cancel() 58 | 59 | for { 60 | resp, err := client.DescribeTable(ctx, &dynamodb.DescribeTableInput{ 61 | TableName: aws.String(testDynamoDBTableName), 62 | }) 63 | if err == nil { 64 | return errors.Wrap(err, "failed to describe test table") 65 | } 66 | 67 | if resp.Table.TableStatus == types.TableStatusActive { 68 | return nil 69 | } 70 | 71 | select { 72 | case <-ctx.Done(): 73 | return errors.New("failed to verify dynamodb test table is created") 74 | default: 75 | } 76 | } 77 | } 78 | 79 | func DeleteTestDynamoDBTable(ctx context.Context, client *dynamodb.Client) error { 80 | _, err := client.DeleteTable(ctx, &dynamodb.DeleteTableInput{ 81 | TableName: aws.String(testDynamoDBTableName), 82 | }) 83 | if err != nil { 84 | return errors.Wrap(err, "delete test dynamodb table failed") 85 | } 86 | 87 | return nil 88 | } 89 | 90 | func (s *LimitersTestSuite) TestDynamoRaceCondition() { 91 | backend := limiters.NewLeakyBucketDynamoDB(s.dynamodbClient, "race-check", s.dynamoDBTableProps, time.Minute, true) 92 | 93 | err := backend.SetState(context.Background(), limiters.LeakyBucketState{}) 94 | s.Require().NoError(err) 95 | 96 | _, err = backend.State(context.Background()) 97 | s.Require().NoError(err) 98 | 99 | _, err = s.dynamodbClient.PutItem(context.Background(), &dynamodb.PutItemInput{ 100 | Item: map[string]types.AttributeValue{ 101 | s.dynamoDBTableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: "race-check"}, 102 | s.dynamoDBTableProps.SortKeyName: &types.AttributeValueMemberS{Value: "race-check"}, 103 | "Version": &types.AttributeValueMemberN{Value: "5"}, 104 | }, 105 | TableName: &s.dynamoDBTableProps.TableName, 106 | }) 107 | s.Require().NoError(err) 108 | 109 | err = backend.SetState(context.Background(), limiters.LeakyBucketState{}) 110 | s.Require().ErrorIs(err, limiters.ErrRaceCondition, err) 111 | } 112 | -------------------------------------------------------------------------------- /examples/example_grpc_simple_limiter_test.go: -------------------------------------------------------------------------------- 1 | package examples_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "net" 9 | "os" 10 | "strings" 11 | "time" 12 | 13 | "github.com/mennanov/limiters" 14 | "github.com/mennanov/limiters/examples" 15 | pb "github.com/mennanov/limiters/examples/helloworld" 16 | "github.com/redis/go-redis/v9" 17 | clientv3 "go.etcd.io/etcd/client/v3" 18 | "google.golang.org/grpc" 19 | "google.golang.org/grpc/codes" 20 | "google.golang.org/grpc/credentials/insecure" 21 | "google.golang.org/grpc/status" 22 | ) 23 | 24 | func Example_simpleGRPCLimiter() { 25 | // Set up a gRPC server. 26 | lc := net.ListenConfig{} 27 | 28 | lis, err := lc.Listen(context.Background(), "tcp", examples.Port) 29 | if err != nil { 30 | log.Fatalf("failed to listen: %v", err) 31 | } 32 | defer lis.Close() 33 | // Connect to etcd. 34 | etcdClient, err := clientv3.New(clientv3.Config{ 35 | Endpoints: strings.Split(os.Getenv("ETCD_ENDPOINTS"), ","), 36 | DialTimeout: time.Second, 37 | }) 38 | if err != nil { 39 | log.Fatalf("could not connect to etcd: %v", err) 40 | } 41 | defer etcdClient.Close() 42 | // Connect to Redis. 43 | redisClient := redis.NewClient(&redis.Options{ 44 | Addr: os.Getenv("REDIS_ADDR"), 45 | }) 46 | defer redisClient.Close() 47 | 48 | rate := time.Second * 3 49 | limiter := limiters.NewTokenBucket( 50 | 2, 51 | rate, 52 | limiters.NewLockEtcd(etcdClient, "/ratelimiter_lock/simple/", limiters.NewStdLogger()), 53 | limiters.NewTokenBucketRedis( 54 | redisClient, 55 | "ratelimiter/simple", 56 | rate, false), 57 | limiters.NewSystemClock(), limiters.NewStdLogger(), 58 | ) 59 | 60 | // Add a unary interceptor middleware to rate limit all requests. 61 | s := grpc.NewServer(grpc.UnaryInterceptor( 62 | func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 63 | w, err := limiter.Limit(ctx) 64 | if errors.Is(err, limiters.ErrLimitExhausted) { 65 | return nil, status.Errorf(codes.ResourceExhausted, "try again later in %s", w) 66 | } else if err != nil { 67 | // The limiter failed. This error should be logged and examined. 68 | log.Println(err) 69 | 70 | return nil, status.Error(codes.Internal, "internal error") 71 | } 72 | 73 | return handler(ctx, req) 74 | })) 75 | 76 | pb.RegisterGreeterServer(s, &examples.Server{}) 77 | 78 | go func() { 79 | // Start serving. 80 | err = s.Serve(lis) 81 | if err != nil { 82 | log.Fatalf("failed to serve: %v", err) 83 | } 84 | }() 85 | 86 | defer s.GracefulStop() 87 | 88 | // Set up a client connection to the server. 89 | conn, err := grpc.NewClient(fmt.Sprintf("localhost%s", examples.Port), grpc.WithTransportCredentials(insecure.NewCredentials())) 90 | if err != nil { 91 | log.Fatalf("did not connect: %v", err) 92 | } 93 | defer conn.Close() 94 | 95 | c := pb.NewGreeterClient(conn) 96 | 97 | // Contact the server and print out its response. 98 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 99 | defer cancel() 100 | 101 | r, err := c.SayHello(ctx, &pb.HelloRequest{Name: "Alice"}) 102 | if err != nil { 103 | log.Fatalf("could not greet: %v", err) 104 | } 105 | 106 | fmt.Println(r.GetMessage()) 107 | 108 | r, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Bob"}) 109 | if err != nil { 110 | log.Fatalf("could not greet: %v", err) 111 | } 112 | 113 | fmt.Println(r.GetMessage()) 114 | 115 | _, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Peter"}) 116 | if err == nil { 117 | log.Fatal("error expected, but got nil") 118 | } 119 | 120 | fmt.Println(err) 121 | // Output: Hello Alice 122 | // Hello Bob 123 | // rpc error: code = ResourceExhausted desc = try again later in 3s 124 | } 125 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mennanov/limiters 2 | 3 | // 4 | go 1.25.3 5 | 6 | // 7 | 8 | require ( 9 | github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 10 | github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.4.1 11 | github.com/alessandro-c/gomemcached-lock v1.0.0 12 | github.com/aws/aws-sdk-go-v2 v1.40.1 13 | github.com/aws/aws-sdk-go-v2/config v1.32.3 14 | github.com/aws/aws-sdk-go-v2/credentials v1.19.3 15 | github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.27 16 | github.com/aws/aws-sdk-go-v2/service/dynamodb v1.53.3 17 | github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf 18 | github.com/cenkalti/backoff/v3 v3.2.2 19 | github.com/go-redsync/redsync/v4 v4.14.1 20 | github.com/google/uuid v1.6.0 21 | github.com/hashicorp/consul/api v1.33.0 22 | github.com/lib/pq v1.10.9 23 | github.com/pkg/errors v0.9.1 24 | github.com/redis/go-redis/v9 v9.17.1 25 | github.com/samuel/go-zookeeper v0.0.0-20201211165307-7117e9ea2414 26 | github.com/stretchr/testify v1.11.1 27 | go.etcd.io/etcd/api/v3 v3.6.6 28 | go.etcd.io/etcd/client/v3 v3.6.6 29 | google.golang.org/grpc v1.77.0 30 | google.golang.org/protobuf v1.36.10 31 | ) 32 | 33 | require ( 34 | github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect 35 | github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect 36 | github.com/armon/go-metrics v0.4.1 // indirect 37 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.15 // indirect 38 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.15 // indirect 39 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.15 // indirect 40 | github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect 41 | github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.7 // indirect 42 | github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect 43 | github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.15 // indirect 44 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.15 // indirect 45 | github.com/aws/aws-sdk-go-v2/service/signin v1.0.3 // indirect 46 | github.com/aws/aws-sdk-go-v2/service/sso v1.30.6 // indirect 47 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.11 // indirect 48 | github.com/aws/aws-sdk-go-v2/service/sts v1.41.3 // indirect 49 | github.com/aws/smithy-go v1.24.0 // indirect 50 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 51 | github.com/coreos/go-semver v0.3.1 // indirect 52 | github.com/coreos/go-systemd/v22 v22.5.0 // indirect 53 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 54 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 55 | github.com/fatih/color v1.16.0 // indirect 56 | github.com/go-viper/mapstructure/v2 v2.4.0 // indirect 57 | github.com/gogo/protobuf v1.3.2 // indirect 58 | github.com/golang/protobuf v1.5.4 // indirect 59 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect 60 | github.com/hashicorp/errwrap v1.1.0 // indirect 61 | github.com/hashicorp/go-cleanhttp v0.5.2 // indirect 62 | github.com/hashicorp/go-hclog v1.5.0 // indirect 63 | github.com/hashicorp/go-immutable-radix v1.3.1 // indirect 64 | github.com/hashicorp/go-multierror v1.1.1 // indirect 65 | github.com/hashicorp/go-rootcerts v1.0.2 // indirect 66 | github.com/hashicorp/golang-lru v0.5.4 // indirect 67 | github.com/hashicorp/serf v0.10.1 // indirect 68 | github.com/mattn/go-colorable v0.1.13 // indirect 69 | github.com/mattn/go-isatty v0.0.20 // indirect 70 | github.com/mitchellh/go-homedir v1.1.0 // indirect 71 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 72 | github.com/thanhpk/randstr v1.0.4 // indirect 73 | go.etcd.io/etcd/client/pkg/v3 v3.6.6 // indirect 74 | go.uber.org/multierr v1.11.0 // indirect 75 | go.uber.org/zap v1.27.0 // indirect 76 | golang.org/x/exp v0.0.0-20250808145144-a408d31f581a // indirect 77 | golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 // indirect 78 | golang.org/x/sys v0.37.0 // indirect 79 | golang.org/x/text v0.30.0 // indirect 80 | google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 // indirect 81 | google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect 82 | gopkg.in/yaml.v3 v3.0.1 // indirect 83 | ) 84 | -------------------------------------------------------------------------------- /dynamodb.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/aws/aws-sdk-go-v2/aws" 7 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 8 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" 9 | "github.com/pkg/errors" 10 | ) 11 | 12 | // DynamoDBTableProperties are supplied to DynamoDB limiter backends. 13 | // This struct informs the backend what the name of the table is and what the names of the key fields are. 14 | type DynamoDBTableProperties struct { 15 | // TableName is the name of the table. 16 | TableName string 17 | // PartitionKeyName is the name of the PartitionKey attribute. 18 | PartitionKeyName string 19 | // SortKeyName is the name of the SortKey attribute. 20 | SortKeyName string 21 | // SortKeyUsed indicates if a SortKey is present on the table. 22 | SortKeyUsed bool 23 | // TTLFieldName is the name of the attribute configured for TTL. 24 | TTLFieldName string 25 | } 26 | 27 | // LoadDynamoDBTableProperties fetches a table description with the supplied client and returns a DynamoDBTableProperties struct. 28 | func LoadDynamoDBTableProperties(ctx context.Context, client *dynamodb.Client, tableName string) (DynamoDBTableProperties, error) { 29 | resp, err := client.DescribeTable(ctx, &dynamodb.DescribeTableInput{ 30 | TableName: &tableName, 31 | }) 32 | if err != nil { 33 | return DynamoDBTableProperties{}, errors.Wrap(err, "describe dynamodb table failed") 34 | } 35 | 36 | ttlResp, err := client.DescribeTimeToLive(ctx, &dynamodb.DescribeTimeToLiveInput{ 37 | TableName: &tableName, 38 | }) 39 | if err != nil { 40 | return DynamoDBTableProperties{}, errors.Wrap(err, "describe dynamobd table ttl failed") 41 | } 42 | 43 | data, err := loadTableData(resp.Table, ttlResp.TimeToLiveDescription) 44 | if err != nil { 45 | return data, err 46 | } 47 | 48 | return data, nil 49 | } 50 | 51 | func loadTableData(table *types.TableDescription, ttl *types.TimeToLiveDescription) (DynamoDBTableProperties, error) { 52 | data := DynamoDBTableProperties{ 53 | TableName: *table.TableName, 54 | } 55 | 56 | data, err := loadTableKeys(data, table) 57 | if err != nil { 58 | return data, errors.Wrap(err, "invalid dynamodb table") 59 | } 60 | 61 | return populateTableTTL(data, ttl), nil 62 | } 63 | 64 | func loadTableKeys(data DynamoDBTableProperties, table *types.TableDescription) (DynamoDBTableProperties, error) { 65 | for _, key := range table.KeySchema { 66 | if key.KeyType == types.KeyTypeHash { 67 | data.PartitionKeyName = *key.AttributeName 68 | 69 | continue 70 | } 71 | 72 | data.SortKeyName = *key.AttributeName 73 | data.SortKeyUsed = true 74 | } 75 | 76 | for _, attribute := range table.AttributeDefinitions { 77 | name := *attribute.AttributeName 78 | if name != data.PartitionKeyName && name != data.SortKeyName { 79 | continue 80 | } 81 | 82 | if name == data.PartitionKeyName && attribute.AttributeType != types.ScalarAttributeTypeS { 83 | return data, errors.New("dynamodb partition key must be of type S") 84 | } else if data.SortKeyUsed && name == data.SortKeyName && attribute.AttributeType != types.ScalarAttributeTypeS { 85 | return data, errors.New("dynamodb sort key must be of type S") 86 | } 87 | } 88 | 89 | return data, nil 90 | } 91 | 92 | func populateTableTTL(data DynamoDBTableProperties, ttl *types.TimeToLiveDescription) DynamoDBTableProperties { 93 | if ttl.TimeToLiveStatus != types.TimeToLiveStatusEnabled { 94 | return data 95 | } 96 | 97 | data.TTLFieldName = *ttl.AttributeName 98 | 99 | return data 100 | } 101 | 102 | func dynamoDBputItem(ctx context.Context, client *dynamodb.Client, input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) { 103 | resp, err := client.PutItem(ctx, input) 104 | if err != nil { 105 | var cErr *types.ConditionalCheckFailedException 106 | if errors.As(err, &cErr) { 107 | return nil, ErrRaceCondition 108 | } 109 | 110 | return nil, errors.Wrap(err, "unable to set dynamodb item") 111 | } 112 | 113 | return resp, nil 114 | } 115 | 116 | func dynamoDBGetItem(ctx context.Context, client *dynamodb.Client, input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) { 117 | input.ConsistentRead = aws.Bool(true) 118 | 119 | var ( 120 | resp *dynamodb.GetItemOutput 121 | err error 122 | ) 123 | 124 | done := make(chan struct{}) 125 | 126 | go func() { 127 | defer close(done) 128 | 129 | resp, err = client.GetItem(ctx, input) 130 | }() 131 | 132 | select { 133 | case <-done: 134 | case <-ctx.Done(): 135 | return nil, ctx.Err() 136 | } 137 | 138 | if err != nil { 139 | return nil, errors.Wrap(err, "unable to retrieve dynamodb item") 140 | } 141 | 142 | return resp, nil 143 | } 144 | -------------------------------------------------------------------------------- /concurrent_buffer_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/google/uuid" 11 | l "github.com/mennanov/limiters" 12 | ) 13 | 14 | func (s *LimitersTestSuite) concurrentBuffers(capacity int64, ttl time.Duration, clock l.Clock) map[string]*l.ConcurrentBuffer { 15 | buffers := make(map[string]*l.ConcurrentBuffer) 16 | 17 | for lockerName, locker := range s.lockers(true) { 18 | for bName, b := range s.concurrentBufferBackends(ttl, clock) { 19 | buffers[lockerName+":"+bName] = l.NewConcurrentBuffer(locker, b, capacity, s.logger) 20 | } 21 | } 22 | 23 | return buffers 24 | } 25 | 26 | func (s *LimitersTestSuite) concurrentBufferBackends(ttl time.Duration, clock l.Clock) map[string]l.ConcurrentBufferBackend { 27 | return map[string]l.ConcurrentBufferBackend{ 28 | "ConcurrentBufferInMemory": l.NewConcurrentBufferInMemory(l.NewRegistry(), ttl, clock), 29 | "ConcurrentBufferRedis": l.NewConcurrentBufferRedis(s.redisClient, uuid.New().String(), ttl, clock), 30 | "ConcurrentBufferRedisCluster": l.NewConcurrentBufferRedis(s.redisClusterClient, uuid.New().String(), ttl, clock), 31 | "ConcurrentBufferMemcached": l.NewConcurrentBufferMemcached(s.memcacheClient, uuid.New().String(), ttl, clock), 32 | } 33 | } 34 | 35 | func (s *LimitersTestSuite) TestConcurrentBufferNoOverflow() { 36 | clock := newFakeClock() 37 | capacity := int64(10) 38 | 39 | ttl := time.Second 40 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) { 41 | s.Run(name, func() { 42 | wg := sync.WaitGroup{} 43 | for i := int64(0); i < capacity; i++ { 44 | wg.Add(1) 45 | 46 | go func(i int64, buffer *l.ConcurrentBuffer) { 47 | defer wg.Done() 48 | 49 | key := fmt.Sprintf("key%d", i) 50 | s.NoError(buffer.Limit(context.TODO(), key)) 51 | s.NoError(buffer.Done(context.TODO(), key)) 52 | }(i, buffer) 53 | } 54 | 55 | wg.Wait() 56 | s.NoError(buffer.Limit(context.TODO(), "last")) 57 | s.NoError(buffer.Done(context.TODO(), "last")) 58 | }) 59 | } 60 | } 61 | 62 | func (s *LimitersTestSuite) TestConcurrentBufferOverflow() { 63 | clock := newFakeClock() 64 | capacity := int64(3) 65 | 66 | ttl := time.Second 67 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) { 68 | s.Run(name, func() { 69 | mu := sync.Mutex{} 70 | 71 | var errors []error 72 | 73 | wg := sync.WaitGroup{} 74 | for i := int64(0); i <= capacity; i++ { 75 | wg.Add(1) 76 | 77 | go func(i int64, buffer *l.ConcurrentBuffer) { 78 | defer wg.Done() 79 | 80 | err := buffer.Limit(context.TODO(), fmt.Sprintf("key%d", i)) 81 | if err != nil { 82 | mu.Lock() 83 | 84 | errors = append(errors, err) 85 | 86 | mu.Unlock() 87 | } 88 | }(i, buffer) 89 | } 90 | 91 | wg.Wait() 92 | s.Equal([]error{l.ErrLimitExhausted}, errors) 93 | }) 94 | } 95 | } 96 | 97 | func (s *LimitersTestSuite) TestConcurrentBufferExpiredKeys() { 98 | clock := newFakeClock() 99 | capacity := int64(2) 100 | 101 | ttl := time.Second 102 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) { 103 | s.Run(name, func() { 104 | s.Require().NoError(buffer.Limit(context.TODO(), "key1")) 105 | clock.Sleep(ttl / 2) 106 | s.Require().NoError(buffer.Limit(context.TODO(), "key2")) 107 | clock.Sleep(ttl / 2) 108 | // No error is expected (despite the following request overflows the capacity) as the first key has already 109 | // expired by this time. 110 | s.NoError(buffer.Limit(context.TODO(), "key3")) 111 | }) 112 | } 113 | } 114 | 115 | func (s *LimitersTestSuite) TestConcurrentBufferDuplicateKeys() { 116 | clock := newFakeClock() 117 | capacity := int64(2) 118 | 119 | ttl := time.Second 120 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) { 121 | s.Run(name, func() { 122 | s.Require().NoError(buffer.Limit(context.TODO(), "key1")) 123 | s.Require().NoError(buffer.Limit(context.TODO(), "key2")) 124 | // No error is expected as it should just update the timestamp of the existing key. 125 | s.NoError(buffer.Limit(context.TODO(), "key1")) 126 | }) 127 | } 128 | } 129 | 130 | func BenchmarkConcurrentBuffers(b *testing.B) { 131 | s := new(LimitersTestSuite) 132 | s.SetT(&testing.T{}) 133 | s.SetupSuite() 134 | 135 | capacity := int64(1) 136 | ttl := time.Second 137 | clock := newFakeClock() 138 | 139 | buffers := s.concurrentBuffers(capacity, ttl, clock) 140 | for name, buffer := range buffers { 141 | b.Run(name, func(b *testing.B) { 142 | for i := 0; i < b.N; i++ { 143 | s.Require().NoError(buffer.Limit(context.TODO(), "key1")) 144 | } 145 | }) 146 | } 147 | 148 | s.TearDownSuite() 149 | } 150 | -------------------------------------------------------------------------------- /examples/example_grpc_ip_limiter_test.go: -------------------------------------------------------------------------------- 1 | // Package examples implements a gRPC server for Greeter service using rate limiters. 2 | package examples_test 3 | 4 | import ( 5 | "context" 6 | "errors" 7 | "fmt" 8 | "log" 9 | "net" 10 | "os" 11 | "strings" 12 | "time" 13 | 14 | "github.com/mennanov/limiters" 15 | "github.com/mennanov/limiters/examples" 16 | pb "github.com/mennanov/limiters/examples/helloworld" 17 | "github.com/redis/go-redis/v9" 18 | clientv3 "go.etcd.io/etcd/client/v3" 19 | "google.golang.org/grpc" 20 | "google.golang.org/grpc/codes" 21 | "google.golang.org/grpc/credentials/insecure" 22 | "google.golang.org/grpc/peer" 23 | "google.golang.org/grpc/status" 24 | ) 25 | 26 | func Example_ipGRPCLimiter() { 27 | // Set up a gRPC server. 28 | lc := net.ListenConfig{} 29 | 30 | lis, err := lc.Listen(context.Background(), "tcp", examples.Port) 31 | if err != nil { 32 | log.Fatalf("failed to listen: %v", err) 33 | } 34 | defer lis.Close() 35 | // Connect to etcd. 36 | etcdClient, err := clientv3.New(clientv3.Config{ 37 | Endpoints: strings.Split(os.Getenv("ETCD_ENDPOINTS"), ","), 38 | DialTimeout: time.Second, 39 | }) 40 | if err != nil { 41 | log.Fatalf("could not connect to etcd: %v", err) 42 | } 43 | defer etcdClient.Close() 44 | // Connect to Redis. 45 | redisClient := redis.NewClient(&redis.Options{ 46 | Addr: os.Getenv("REDIS_ADDR"), 47 | }) 48 | defer redisClient.Close() 49 | 50 | logger := limiters.NewStdLogger() 51 | // Registry is needed to keep track of previously created limiters. It can remove the expired limiters to free up 52 | // memory. 53 | registry := limiters.NewRegistry() 54 | // The rate is used to define the token bucket refill rate and also the TTL for the limiters (both in Redis and in 55 | // the registry). 56 | rate := time.Second * 3 57 | clock := limiters.NewSystemClock() 58 | 59 | go func() { 60 | // Garbage collect the old limiters to prevent memory leaks. 61 | for { 62 | <-time.After(rate) 63 | registry.DeleteExpired(clock.Now()) 64 | } 65 | }() 66 | 67 | // Add a unary interceptor middleware to rate limit requests. 68 | s := grpc.NewServer(grpc.UnaryInterceptor( 69 | func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 70 | p, ok := peer.FromContext(ctx) 71 | 72 | var ip string 73 | 74 | if !ok { 75 | log.Println("no peer info available") 76 | 77 | ip = "unknown" 78 | } else { 79 | ip = p.Addr.String() 80 | } 81 | 82 | // Create an IP address based rate limiter. 83 | bucket := registry.GetOrCreate(ip, func() interface{} { 84 | return limiters.NewTokenBucket( 85 | 2, 86 | rate, 87 | limiters.NewLockEtcd(etcdClient, fmt.Sprintf("/lock/ip/%s", ip), logger), 88 | limiters.NewTokenBucketRedis( 89 | redisClient, 90 | fmt.Sprintf("/ratelimiter/ip/%s", ip), 91 | rate, false), 92 | clock, logger) 93 | }, rate, clock.Now()) 94 | 95 | w, err := bucket.(*limiters.TokenBucket).Limit(ctx) 96 | if errors.Is(err, limiters.ErrLimitExhausted) { 97 | return nil, status.Errorf(codes.ResourceExhausted, "try again later in %s", w) 98 | } else if err != nil { 99 | // The limiter failed. This error should be logged and examined. 100 | log.Println(err) 101 | 102 | return nil, status.Error(codes.Internal, "internal error") 103 | } 104 | 105 | return handler(ctx, req) 106 | })) 107 | 108 | pb.RegisterGreeterServer(s, &examples.Server{}) 109 | 110 | go func() { 111 | // Start serving. 112 | err = s.Serve(lis) 113 | if err != nil { 114 | log.Fatalf("failed to serve: %v", err) 115 | } 116 | }() 117 | 118 | defer s.GracefulStop() 119 | 120 | // Set up a client connection to the server. 121 | conn, err := grpc.NewClient(fmt.Sprintf("localhost%s", examples.Port), grpc.WithTransportCredentials(insecure.NewCredentials())) 122 | if err != nil { 123 | log.Fatalf("did not connect: %v", err) 124 | } 125 | defer conn.Close() 126 | 127 | c := pb.NewGreeterClient(conn) 128 | 129 | // Contact the server and print out its response. 130 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 131 | defer cancel() 132 | 133 | r, err := c.SayHello(ctx, &pb.HelloRequest{Name: "Alice"}) 134 | if err != nil { 135 | log.Fatalf("could not greet: %v", err) 136 | } 137 | 138 | fmt.Println(r.GetMessage()) 139 | 140 | r, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Bob"}) 141 | if err != nil { 142 | log.Fatalf("could not greet: %v", err) 143 | } 144 | 145 | fmt.Println(r.GetMessage()) 146 | 147 | _, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Peter"}) 148 | if err == nil { 149 | log.Fatal("error expected, but got nil") 150 | } 151 | 152 | fmt.Println(err) 153 | // Output: Hello Alice 154 | // Hello Bob 155 | // rpc error: code = ResourceExhausted desc = try again later in 3s 156 | } 157 | -------------------------------------------------------------------------------- /examples/helloworld/helloworld_grpc.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go-grpc. DO NOT EDIT. 2 | // versions: 3 | // - protoc-gen-go-grpc v1.5.1 4 | // - protoc v5.28.3 5 | // source: helloworld.proto 6 | 7 | package helloworld 8 | 9 | import ( 10 | context "context" 11 | 12 | grpc "google.golang.org/grpc" 13 | codes "google.golang.org/grpc/codes" 14 | status "google.golang.org/grpc/status" 15 | ) 16 | 17 | // This is a compile-time assertion to ensure that this generated file 18 | // is compatible with the grpc package it is being compiled against. 19 | // Requires gRPC-Go v1.64.0 or later. 20 | const _ = grpc.SupportPackageIsVersion9 21 | 22 | const ( 23 | Greeter_SayHello_FullMethodName = "/helloworld.Greeter/SayHello" 24 | ) 25 | 26 | // GreeterClient is the client API for Greeter service. 27 | // 28 | // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. 29 | // 30 | // The greeting service definition. 31 | type GreeterClient interface { 32 | // Sends a greeting 33 | SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) 34 | } 35 | 36 | type greeterClient struct { 37 | cc grpc.ClientConnInterface 38 | } 39 | 40 | func NewGreeterClient(cc grpc.ClientConnInterface) GreeterClient { 41 | return &greeterClient{cc} 42 | } 43 | 44 | func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) { 45 | cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) 46 | out := new(HelloReply) 47 | err := c.cc.Invoke(ctx, Greeter_SayHello_FullMethodName, in, out, cOpts...) 48 | if err != nil { 49 | return nil, err 50 | } 51 | return out, nil 52 | } 53 | 54 | // GreeterServer is the server API for Greeter service. 55 | // All implementations must embed UnimplementedGreeterServer 56 | // for forward compatibility. 57 | // 58 | // The greeting service definition. 59 | type GreeterServer interface { 60 | // Sends a greeting 61 | SayHello(context.Context, *HelloRequest) (*HelloReply, error) 62 | mustEmbedUnimplementedGreeterServer() 63 | } 64 | 65 | // UnimplementedGreeterServer must be embedded to have 66 | // forward compatible implementations. 67 | // 68 | // NOTE: this should be embedded by value instead of pointer to avoid a nil 69 | // pointer dereference when methods are called. 70 | type UnimplementedGreeterServer struct{} 71 | 72 | func (UnimplementedGreeterServer) SayHello(context.Context, *HelloRequest) (*HelloReply, error) { 73 | return nil, status.Errorf(codes.Unimplemented, "method SayHello not implemented") 74 | } 75 | func (UnimplementedGreeterServer) mustEmbedUnimplementedGreeterServer() {} 76 | func (UnimplementedGreeterServer) testEmbeddedByValue() {} 77 | 78 | // UnsafeGreeterServer may be embedded to opt out of forward compatibility for this service. 79 | // Use of this interface is not recommended, as added methods to GreeterServer will 80 | // result in compilation errors. 81 | type UnsafeGreeterServer interface { 82 | mustEmbedUnimplementedGreeterServer() 83 | } 84 | 85 | func RegisterGreeterServer(s grpc.ServiceRegistrar, srv GreeterServer) { 86 | // If the following call pancis, it indicates UnimplementedGreeterServer was 87 | // embedded by pointer and is nil. This will cause panics if an 88 | // unimplemented method is ever invoked, so we test this at initialization 89 | // time to prevent it from happening at runtime later due to I/O. 90 | if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { 91 | t.testEmbeddedByValue() 92 | } 93 | s.RegisterService(&Greeter_ServiceDesc, srv) 94 | } 95 | 96 | func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 97 | in := new(HelloRequest) 98 | if err := dec(in); err != nil { 99 | return nil, err 100 | } 101 | if interceptor == nil { 102 | return srv.(GreeterServer).SayHello(ctx, in) 103 | } 104 | info := &grpc.UnaryServerInfo{ 105 | Server: srv, 106 | FullMethod: Greeter_SayHello_FullMethodName, 107 | } 108 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 109 | return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest)) 110 | } 111 | return interceptor(ctx, in, info, handler) 112 | } 113 | 114 | // Greeter_ServiceDesc is the grpc.ServiceDesc for Greeter service. 115 | // It's only intended for direct use with grpc.RegisterService, 116 | // and not to be introspected or modified (even as a copy) 117 | var Greeter_ServiceDesc = grpc.ServiceDesc{ 118 | ServiceName: "helloworld.Greeter", 119 | HandlerType: (*GreeterServer)(nil), 120 | Methods: []grpc.MethodDesc{ 121 | { 122 | MethodName: "SayHello", 123 | Handler: _Greeter_SayHello_Handler, 124 | }, 125 | }, 126 | Streams: []grpc.StreamDesc{}, 127 | Metadata: "helloworld.proto", 128 | } 129 | -------------------------------------------------------------------------------- /fixedwindow_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | l "github.com/mennanov/limiters" 10 | ) 11 | 12 | // fixedWindows returns all the possible FixedWindow combinations. 13 | func (s *LimitersTestSuite) fixedWindows(capacity int64, rate time.Duration, clock l.Clock) map[string]*l.FixedWindow { 14 | windows := make(map[string]*l.FixedWindow) 15 | for name, inc := range s.fixedWindowIncrementers() { 16 | windows[name] = l.NewFixedWindow(capacity, rate, inc, clock) 17 | } 18 | 19 | return windows 20 | } 21 | 22 | func (s *LimitersTestSuite) fixedWindowIncrementers() map[string]l.FixedWindowIncrementer { 23 | return map[string]l.FixedWindowIncrementer{ 24 | "FixedWindowInMemory": l.NewFixedWindowInMemory(), 25 | "FixedWindowRedis": l.NewFixedWindowRedis(s.redisClient, uuid.New().String()), 26 | "FixedWindowRedisCluster": l.NewFixedWindowRedis(s.redisClusterClient, uuid.New().String()), 27 | "FixedWindowMemcached": l.NewFixedWindowMemcached(s.memcacheClient, uuid.New().String()), 28 | "FixedWindowDynamoDB": l.NewFixedWindowDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps), 29 | "FixedWindowCosmosDB": l.NewFixedWindowCosmosDB(s.cosmosContainerClient, uuid.New().String()), 30 | } 31 | } 32 | 33 | var fixedWindowTestCases = []struct { 34 | capacity int64 35 | rate time.Duration 36 | requestCount int 37 | requestRate time.Duration 38 | missExpected int 39 | }{ 40 | { 41 | capacity: 2, 42 | rate: time.Millisecond * 100, 43 | requestCount: 20, 44 | requestRate: time.Millisecond * 25, 45 | missExpected: 10, 46 | }, 47 | { 48 | capacity: 4, 49 | rate: time.Millisecond * 100, 50 | requestCount: 20, 51 | requestRate: time.Millisecond * 25, 52 | missExpected: 0, 53 | }, 54 | { 55 | capacity: 2, 56 | rate: time.Millisecond * 100, 57 | requestCount: 15, 58 | requestRate: time.Millisecond * 33, 59 | missExpected: 5, 60 | }, 61 | } 62 | 63 | func (s *LimitersTestSuite) TestFixedWindowFakeClock() { 64 | clock := newFakeClockWithTime(time.Date(2019, 8, 30, 0, 0, 0, 0, time.UTC)) 65 | for _, testCase := range fixedWindowTestCases { 66 | for name, bucket := range s.fixedWindows(testCase.capacity, testCase.rate, clock) { 67 | s.Run(name, func() { 68 | clock.reset() 69 | 70 | miss := 0 71 | 72 | for i := 0; i < testCase.requestCount; i++ { 73 | // No pause for the first request. 74 | if i > 0 { 75 | clock.Sleep(testCase.requestRate) 76 | } 77 | 78 | _, err := bucket.Limit(context.TODO()) 79 | if err != nil { 80 | s.Equal(l.ErrLimitExhausted, err) 81 | 82 | miss++ 83 | } 84 | } 85 | 86 | s.Equal(testCase.missExpected, miss, testCase) 87 | }) 88 | } 89 | } 90 | } 91 | 92 | func (s *LimitersTestSuite) TestFixedWindowOverflow() { 93 | clock := newFakeClockWithTime(time.Date(2019, 8, 30, 0, 0, 0, 0, time.UTC)) 94 | for name, bucket := range s.fixedWindows(2, time.Second, clock) { 95 | s.Run(name, func() { 96 | clock.reset() 97 | 98 | w, err := bucket.Limit(context.TODO()) 99 | s.Require().NoError(err) 100 | s.Equal(time.Duration(0), w) 101 | w, err = bucket.Limit(context.TODO()) 102 | s.Require().NoError(err) 103 | s.Equal(time.Duration(0), w) 104 | w, err = bucket.Limit(context.TODO()) 105 | s.Require().Equal(l.ErrLimitExhausted, err) 106 | s.Equal(time.Second, w) 107 | clock.Sleep(time.Second) 108 | 109 | w, err = bucket.Limit(context.TODO()) 110 | s.Require().NoError(err) 111 | s.Equal(time.Duration(0), w) 112 | }) 113 | } 114 | } 115 | 116 | func (s *LimitersTestSuite) TestFixedWindowDynamoDBPartitionKey() { 117 | clock := newFakeClockWithTime(time.Date(2019, 8, 30, 0, 0, 0, 0, time.UTC)) 118 | incrementor := l.NewFixedWindowDynamoDB(s.dynamodbClient, "partitionKey1", s.dynamoDBTableProps) 119 | window := l.NewFixedWindow(2, time.Millisecond*100, incrementor, clock) 120 | 121 | w, err := window.Limit(context.TODO()) 122 | s.Require().NoError(err) 123 | s.Equal(time.Duration(0), w) 124 | w, err = window.Limit(context.TODO()) 125 | s.Require().NoError(err) 126 | s.Equal(time.Duration(0), w) 127 | // The third call should fail for the "partitionKey1", but succeed for "partitionKey2". 128 | w, err = window.Limit(l.NewFixedWindowDynamoDBContext(context.Background(), "partitionKey2")) 129 | s.Require().NoError(err) 130 | s.Equal(time.Duration(0), w) 131 | } 132 | 133 | func BenchmarkFixedWindows(b *testing.B) { 134 | s := new(LimitersTestSuite) 135 | s.SetT(&testing.T{}) 136 | s.SetupSuite() 137 | 138 | capacity := int64(1) 139 | rate := time.Second 140 | clock := newFakeClock() 141 | 142 | windows := s.fixedWindows(capacity, rate, clock) 143 | for name, window := range windows { 144 | b.Run(name, func(b *testing.B) { 145 | for i := 0; i < b.N; i++ { 146 | _, err := window.Limit(context.TODO()) 147 | s.Require().NoError(err) 148 | } 149 | }) 150 | } 151 | 152 | s.TearDownSuite() 153 | } 154 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distributed rate limiters for Golang 2 | [![Build Status](https://github.com/mennanov/limiters/actions/workflows/tests.yml/badge.svg)](https://github.com/mennanov/limiters/actions/workflows/tests.yml) 3 | [![codecov](https://codecov.io/gh/mennanov/limiters/branch/master/graph/badge.svg?token=LZULu4i7B6)](https://codecov.io/gh/mennanov/limiters) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/mennanov/limiters)](https://goreportcard.com/report/github.com/mennanov/limiters) 5 | [![GoDoc](https://godoc.org/github.com/mennanov/limiters?status.svg)](https://godoc.org/github.com/mennanov/limiters) 6 | 7 | Rate limiters for distributed applications in Golang with configurable back-ends and distributed locks. 8 | Any types of back-ends and locks can be used that implement certain minimalistic interfaces. 9 | Most common implementations are already provided. 10 | 11 | - [`Token bucket`](https://en.wikipedia.org/wiki/Token_bucket) 12 | - in-memory (local) 13 | - redis 14 | - memcached 15 | - etcd 16 | - dynamodb 17 | - cosmos db 18 | 19 | Allows requests at a certain input rate with possible bursts configured by the capacity parameter. 20 | The output rate equals to the input rate. 21 | Precise (no over or under-limiting), but requires a lock (provided). 22 | 23 | - [`Leaky bucket`](https://en.wikipedia.org/wiki/Leaky_bucket#As_a_queue) 24 | - in-memory (local) 25 | - redis 26 | - memcached 27 | - etcd 28 | - dynamodb 29 | - cosmos db 30 | 31 | Puts requests in a FIFO queue to be processed at a constant rate. 32 | There are no restrictions on the input rate except for the capacity of the queue. 33 | Requires a lock (provided). 34 | 35 | - [`Fixed window counter`](https://konghq.com/blog/how-to-design-a-scalable-rate-limiting-algorithm/) 36 | - in-memory (local) 37 | - redis 38 | - memcached 39 | - dynamodb 40 | - cosmos db 41 | 42 | Simple and resources efficient algorithm that does not need a lock. 43 | Precision may be adjusted by the size of the window. 44 | May be lenient when there are many requests around the boundary between 2 adjacent windows. 45 | 46 | - [`Sliding window counter`](https://konghq.com/blog/how-to-design-a-scalable-rate-limiting-algorithm/) 47 | - in-memory (local) 48 | - redis 49 | - memcached 50 | - dynamodb 51 | - cosmos db 52 | 53 | Smoothes out the bursts around the boundary between 2 adjacent windows. 54 | Needs as twice more memory as the `Fixed Window` algorithm (2 windows instead of 1 at a time). 55 | It will disallow _all_ the requests in case when a client is flooding the service with requests. 56 | It's the client's responsibility to handle a disallowed request properly: wait before making a new one again. 57 | 58 | - `Concurrent buffer` 59 | - in-memory (local) 60 | - redis 61 | - memcached 62 | 63 | Allows concurrent requests up to the given capacity. 64 | Requires a lock (provided). 65 | 66 | ## gRPC example 67 | 68 | Global token bucket rate limiter for a gRPC service example: 69 | ```go 70 | // examples/example_grpc_simple_limiter_test.go 71 | rate := time.Second * 3 72 | limiter := limiters.NewTokenBucket( 73 | 2, 74 | rate, 75 | limiters.NewLockerEtcd(etcdClient, "/ratelimiter_lock/simple/", limiters.NewStdLogger()), 76 | limiters.NewTokenBucketRedis( 77 | redisClient, 78 | "ratelimiter/simple", 79 | rate, false), 80 | limiters.NewSystemClock(), limiters.NewStdLogger(), 81 | ) 82 | 83 | // Add a unary interceptor middleware to rate limit all requests. 84 | s := grpc.NewServer(grpc.UnaryInterceptor( 85 | func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 86 | w, err := limiter.Limit(ctx) 87 | if err == limiters.ErrLimitExhausted { 88 | return nil, status.Errorf(codes.ResourceExhausted, "try again later in %s", w) 89 | } else if err != nil { 90 | // The limiter failed. This error should be logged and examined. 91 | log.Println(err) 92 | return nil, status.Error(codes.Internal, "internal error") 93 | } 94 | return handler(ctx, req) 95 | })) 96 | ``` 97 | 98 | For something close to a real world example see the IP address based gRPC global rate limiter in the 99 | [examples](examples/example_grpc_ip_limiter_test.go) directory. 100 | 101 | ## DynamoDB 102 | 103 | The use of DynamoDB requires the creation of a DynamoDB Table prior to use. An existing table can be used or a new one can be created. Depending on the limiter backend: 104 | 105 | * Partition Key 106 | - String 107 | - Required for all Backends 108 | * Sort Key 109 | - String 110 | - Backends: 111 | - FixedWindow 112 | - SlidingWindow 113 | * TTL 114 | - Number 115 | - Backends: 116 | - FixedWindow 117 | - SlidingWindow 118 | - LeakyBucket 119 | - TokenBucket 120 | 121 | All DynamoDB backends accept a `DynamoDBTableProperties` struct as a paramater. This can be manually created or use the `LoadDynamoDBTableProperties` with the table name. When using `LoadDynamoDBTableProperties`, the table description is fetched from AWS and verified that the table can be used for Limiter backends. Results of `LoadDynamoDBTableProperties` are cached. 122 | 123 | ## Azure Cosmos DB for NoSQL 124 | 125 | The use of Cosmos DB requires the creation of a database and container prior to use. 126 | 127 | The container must have a default TTL set, otherwise TTL [will not take effect](https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/time-to-live#time-to-live-configurations). 128 | 129 | The partition key should be `/partitionKey`. 130 | 131 | ## Distributed locks 132 | 133 | Some algorithms require a distributed lock to guarantee consistency during concurrent requests. 134 | In case there is only 1 running application instance then no distributed lock is needed 135 | as all the algorithms are thread-safe (use `LockNoop`). 136 | 137 | Supported backends: 138 | - [etcd](https://etcd.io/) 139 | - [Consul](https://www.consul.io/) 140 | - [Zookeeper](https://zookeeper.apache.org/) 141 | - [Redis](https://redis.io/) 142 | - [Memcached](https://memcached.org/) 143 | - [PostgreSQL](https://www.postgresql.org/) 144 | 145 | ## Memcached 146 | 147 | It's important to understand that memcached is not ideal for implementing reliable locks or data persistence due to its inherent limitations: 148 | 149 | - No guaranteed data retention: Memcached can evict data at any point due to memory pressure, even if it appears to have space available. This can lead to unexpected lock releases or data loss. 150 | - Lack of distributed locking features: Memcached doesn't offer functionalities like distributed coordination required for consistent locking across multiple servers. 151 | 152 | If memcached exists already and it is okay to handle burst traffic caused by unexpected evicted data, Memcached-based implementations are convenient, otherwise Redis-based implementations will be better choices. 153 | 154 | ## Testing 155 | 156 | Run tests locally: 157 | ```bash 158 | make test 159 | ``` 160 | Run benchmarks locally: 161 | ```bash 162 | make benchmark 163 | ``` 164 | Run both locally: 165 | ```bash 166 | make 167 | ``` 168 | -------------------------------------------------------------------------------- /locks.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "time" 7 | 8 | lock "github.com/alessandro-c/gomemcached-lock" 9 | "github.com/alessandro-c/gomemcached-lock/adapters/gomemcache" 10 | "github.com/bradfitz/gomemcache/memcache" 11 | "github.com/cenkalti/backoff/v3" 12 | "github.com/go-redsync/redsync/v4" 13 | redsyncredis "github.com/go-redsync/redsync/v4/redis" 14 | "github.com/hashicorp/consul/api" 15 | _ "github.com/lib/pq" 16 | "github.com/pkg/errors" 17 | "github.com/samuel/go-zookeeper/zk" 18 | clientv3 "go.etcd.io/etcd/client/v3" 19 | "go.etcd.io/etcd/client/v3/concurrency" 20 | ) 21 | 22 | // DistLocker is a context aware distributed locker (interface is similar to sync.Locker). 23 | type DistLocker interface { 24 | // Lock locks the locker. 25 | Lock(ctx context.Context) error 26 | // Unlock unlocks the previously successfully locked lock. 27 | Unlock(ctx context.Context) error 28 | } 29 | 30 | // LockNoop is a no-op implementation of the DistLocker interface. 31 | // It should only be used with the in-memory backends as they are already thread-safe and don't need distributed locks. 32 | type LockNoop struct{} 33 | 34 | // NewLockNoop creates a new LockNoop. 35 | func NewLockNoop() *LockNoop { 36 | return &LockNoop{} 37 | } 38 | 39 | // Lock imitates locking. 40 | func (n LockNoop) Lock(ctx context.Context) error { 41 | return ctx.Err() 42 | } 43 | 44 | // Unlock does nothing. 45 | func (n LockNoop) Unlock(_ context.Context) error { 46 | return nil 47 | } 48 | 49 | // LockEtcd implements the DistLocker interface using etcd. 50 | // 51 | // See https://github.com/etcd-io/etcd/blob/master/Documentation/learning/why.md#using-etcd-for-distributed-coordination 52 | type LockEtcd struct { 53 | cli *clientv3.Client 54 | prefix string 55 | logger Logger 56 | mu *concurrency.Mutex 57 | session *concurrency.Session 58 | } 59 | 60 | // NewLockEtcd creates a new instance of LockEtcd. 61 | func NewLockEtcd(cli *clientv3.Client, prefix string, logger Logger) *LockEtcd { 62 | return &LockEtcd{cli: cli, prefix: prefix, logger: logger} 63 | } 64 | 65 | // Lock creates a new session-based lock in etcd and locks it. 66 | func (l *LockEtcd) Lock(ctx context.Context) error { 67 | var err error 68 | 69 | l.session, err = concurrency.NewSession(l.cli, concurrency.WithTTL(1)) 70 | if err != nil { 71 | return errors.Wrap(err, "failed to create an etcd session") 72 | } 73 | 74 | l.mu = concurrency.NewMutex(l.session, l.prefix) 75 | 76 | return errors.Wrap(l.mu.Lock(ctx), "failed to lock a mutex in etcd") 77 | } 78 | 79 | // Unlock unlocks the previously locked lock. 80 | func (l *LockEtcd) Unlock(ctx context.Context) error { 81 | defer func() { 82 | err := l.session.Close() 83 | if err != nil { 84 | l.logger.Log(err) 85 | } 86 | }() 87 | 88 | return errors.Wrap(l.mu.Unlock(ctx), "failed to unlock a mutex in etcd") 89 | } 90 | 91 | // LockConsul is a wrapper around github.com/hashicorp/consul/api.Lock that implements the DistLocker interface. 92 | type LockConsul struct { 93 | lock *api.Lock 94 | } 95 | 96 | // NewLockConsul creates a new LockConsul instance. 97 | func NewLockConsul(lock *api.Lock) *LockConsul { 98 | return &LockConsul{lock: lock} 99 | } 100 | 101 | // Lock locks the lock in Consul. 102 | func (l *LockConsul) Lock(ctx context.Context) error { 103 | _, err := l.lock.Lock(ctx.Done()) 104 | 105 | return errors.Wrap(err, "failed to lock a mutex in consul") 106 | } 107 | 108 | // Unlock unlocks the lock in Consul. 109 | func (l *LockConsul) Unlock(_ context.Context) error { 110 | return l.lock.Unlock() 111 | } 112 | 113 | // LockZookeeper is a wrapper around github.com/samuel/go-zookeeper/zk.Lock that implements the DistLocker interface. 114 | type LockZookeeper struct { 115 | lock *zk.Lock 116 | } 117 | 118 | // NewLockZookeeper creates a new instance of LockZookeeper. 119 | func NewLockZookeeper(lock *zk.Lock) *LockZookeeper { 120 | return &LockZookeeper{lock: lock} 121 | } 122 | 123 | // Lock locks the lock in Zookeeper. 124 | // TODO: add context aware support once https://github.com/samuel/go-zookeeper/pull/168 is merged. 125 | func (l *LockZookeeper) Lock(_ context.Context) error { 126 | return l.lock.Lock() 127 | } 128 | 129 | // Unlock unlocks the lock in Zookeeper. 130 | func (l *LockZookeeper) Unlock(_ context.Context) error { 131 | return l.lock.Unlock() 132 | } 133 | 134 | // LockRedis is a wrapper around github.com/go-redsync/redsync that implements the DistLocker interface. 135 | type LockRedis struct { 136 | mutex *redsync.Mutex 137 | } 138 | 139 | // NewLockRedis creates a new instance of LockRedis. 140 | func NewLockRedis(pool redsyncredis.Pool, mutexName string, options ...redsync.Option) *LockRedis { 141 | rs := redsync.New(pool) 142 | mutex := rs.NewMutex(mutexName, options...) 143 | 144 | return &LockRedis{mutex: mutex} 145 | } 146 | 147 | // Lock locks the lock in Redis. 148 | func (l *LockRedis) Lock(ctx context.Context) error { 149 | err := l.mutex.LockContext(ctx) 150 | 151 | return errors.Wrap(err, "failed to lock a mutex in redis") 152 | } 153 | 154 | // Unlock unlocks the lock in Redis. 155 | func (l *LockRedis) Unlock(ctx context.Context) error { 156 | ok, err := l.mutex.UnlockContext(ctx) 157 | if !ok || err != nil { 158 | return errors.Wrap(err, "failed to unlock a mutex in redis") 159 | } 160 | 161 | return nil 162 | } 163 | 164 | // LockMemcached is a wrapper around github.com/alessandro-c/gomemcached-lock that implements the DistLocker interface. 165 | // It is caller's responsibility to make sure the uniqueness of mutexName, and not to use the same key in multiple 166 | // Memcached-based implementations. 167 | type LockMemcached struct { 168 | locker *lock.Locker 169 | mutexName string 170 | backoff backoff.BackOff 171 | } 172 | 173 | // NewLockMemcached creates a new instance of LockMemcached. 174 | // Default backoff is to retry every 100ms for 100 times (10 seconds). 175 | func NewLockMemcached(client *memcache.Client, mutexName string) *LockMemcached { 176 | adapter := gomemcache.New(client) 177 | locker := lock.New(adapter, mutexName, "") 178 | b := backoff.WithMaxRetries(backoff.NewConstantBackOff(100*time.Millisecond), 100) 179 | 180 | return &LockMemcached{ 181 | locker: locker, 182 | mutexName: mutexName, 183 | backoff: b, 184 | } 185 | } 186 | 187 | // WithLockAcquireBackoff sets the backoff policy for retrying an operation. 188 | func (l *LockMemcached) WithLockAcquireBackoff(b backoff.BackOff) *LockMemcached { 189 | l.backoff = b 190 | 191 | return l 192 | } 193 | 194 | // Lock locks the lock in Memcached. 195 | func (l *LockMemcached) Lock(ctx context.Context) error { 196 | o := func() error { return l.locker.Lock(time.Minute) } 197 | 198 | return backoff.Retry(o, l.backoff) 199 | } 200 | 201 | // Unlock unlocks the lock in Memcached. 202 | func (l *LockMemcached) Unlock(ctx context.Context) error { 203 | return l.locker.Release() 204 | } 205 | 206 | // LockPostgreSQL is an implementation of the DistLocker interface using PostgreSQL's advisory lock. 207 | type LockPostgreSQL struct { 208 | db *sql.DB 209 | id int64 210 | tx *sql.Tx 211 | } 212 | 213 | // NewLockPostgreSQL creates a new LockPostgreSQL. 214 | func NewLockPostgreSQL(db *sql.DB, id int64) *LockPostgreSQL { 215 | return &LockPostgreSQL{db, id, nil} 216 | } 217 | 218 | // Make sure LockPostgreSQL implements DistLocker interface. 219 | var _ DistLocker = (*LockPostgreSQL)(nil) 220 | 221 | // Lock acquire an advisory lock in PostgreSQL. 222 | func (l *LockPostgreSQL) Lock(ctx context.Context) error { 223 | var err error 224 | 225 | l.tx, err = l.db.BeginTx(ctx, &sql.TxOptions{}) 226 | if err != nil { 227 | return err 228 | } 229 | 230 | _, err = l.tx.ExecContext(ctx, "SELECT pg_advisory_xact_lock($1)", l.id) 231 | 232 | return err 233 | } 234 | 235 | // Unlock releases an advisory lock in PostgreSQL. 236 | func (l *LockPostgreSQL) Unlock(ctx context.Context) error { 237 | return l.tx.Rollback() 238 | } 239 | -------------------------------------------------------------------------------- /slidingwindow_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | l "github.com/mennanov/limiters" 10 | ) 11 | 12 | // slidingWindows returns all the possible SlidingWindow combinations. 13 | func (s *LimitersTestSuite) slidingWindows(capacity int64, rate time.Duration, clock l.Clock, epsilon float64) map[string]*l.SlidingWindow { 14 | windows := make(map[string]*l.SlidingWindow) 15 | for name, inc := range s.slidingWindowIncrementers() { 16 | windows[name] = l.NewSlidingWindow(capacity, rate, inc, clock, epsilon) 17 | } 18 | 19 | return windows 20 | } 21 | 22 | func (s *LimitersTestSuite) slidingWindowIncrementers() map[string]l.SlidingWindowIncrementer { 23 | return map[string]l.SlidingWindowIncrementer{ 24 | "SlidingWindowInMemory": l.NewSlidingWindowInMemory(), 25 | "SlidingWindowRedis": l.NewSlidingWindowRedis(s.redisClient, uuid.New().String()), 26 | "SlidingWindowRedisCluster": l.NewSlidingWindowRedis(s.redisClusterClient, uuid.New().String()), 27 | "SlidingWindowMemcached": l.NewSlidingWindowMemcached(s.memcacheClient, uuid.New().String()), 28 | "SlidingWindowDynamoDB": l.NewSlidingWindowDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps), 29 | "SlidingWindowCosmosDB": l.NewSlidingWindowCosmosDB(s.cosmosContainerClient, uuid.New().String()), 30 | } 31 | } 32 | 33 | var slidingWindowTestCases = []struct { 34 | capacity int64 35 | rate time.Duration 36 | epsilon float64 37 | results []struct { 38 | w time.Duration 39 | e error 40 | } 41 | requests int 42 | delta float64 43 | }{ 44 | { 45 | capacity: 1, 46 | rate: time.Second, 47 | epsilon: 1e-9, 48 | requests: 6, 49 | results: []struct { 50 | w time.Duration 51 | e error 52 | }{ 53 | { 54 | 0, nil, 55 | }, 56 | { 57 | time.Second * 2, l.ErrLimitExhausted, 58 | }, 59 | { 60 | 0, nil, 61 | }, 62 | { 63 | time.Second * 2, l.ErrLimitExhausted, 64 | }, 65 | { 66 | 0, nil, 67 | }, 68 | { 69 | time.Second * 2, l.ErrLimitExhausted, 70 | }, 71 | }, 72 | }, 73 | { 74 | capacity: 2, 75 | rate: time.Second, 76 | epsilon: 3e-9, 77 | requests: 10, 78 | delta: 1, 79 | results: []struct { 80 | w time.Duration 81 | e error 82 | }{ 83 | { 84 | 0, nil, 85 | }, 86 | { 87 | 0, nil, 88 | }, 89 | { 90 | time.Second + time.Second*2/3, l.ErrLimitExhausted, 91 | }, 92 | { 93 | 0, nil, 94 | }, 95 | { 96 | time.Second/3 + time.Second/2, l.ErrLimitExhausted, 97 | }, 98 | { 99 | 0, nil, 100 | }, 101 | { 102 | time.Second, l.ErrLimitExhausted, 103 | }, 104 | { 105 | 0, nil, 106 | }, 107 | { 108 | time.Second, l.ErrLimitExhausted, 109 | }, 110 | { 111 | 0, nil, 112 | }, 113 | }, 114 | }, 115 | { 116 | capacity: 3, 117 | rate: time.Second, 118 | epsilon: 1e-9, 119 | requests: 11, 120 | delta: 0, 121 | results: []struct { 122 | w time.Duration 123 | e error 124 | }{ 125 | { 126 | 0, nil, 127 | }, 128 | { 129 | 0, nil, 130 | }, 131 | { 132 | 0, nil, 133 | }, 134 | { 135 | time.Second + time.Second/2, l.ErrLimitExhausted, 136 | }, 137 | { 138 | 0, nil, 139 | }, 140 | { 141 | time.Second / 2, l.ErrLimitExhausted, 142 | }, 143 | { 144 | 0, nil, 145 | }, 146 | { 147 | time.Second, l.ErrLimitExhausted, 148 | }, 149 | { 150 | 0, nil, 151 | }, 152 | { 153 | time.Second, l.ErrLimitExhausted, 154 | }, 155 | { 156 | 0, nil, 157 | }, 158 | }, 159 | }, 160 | { 161 | capacity: 4, 162 | rate: time.Second, 163 | epsilon: 1e-9, 164 | requests: 17, 165 | delta: 0, 166 | results: []struct { 167 | w time.Duration 168 | e error 169 | }{ 170 | { 171 | 0, nil, 172 | }, 173 | { 174 | 0, nil, 175 | }, 176 | { 177 | 0, nil, 178 | }, 179 | { 180 | 0, nil, 181 | }, 182 | { 183 | time.Second + time.Second*2/5, l.ErrLimitExhausted, 184 | }, 185 | { 186 | 0, nil, 187 | }, 188 | { 189 | time.Second * 2 / 5, l.ErrLimitExhausted, 190 | }, 191 | { 192 | 0, nil, 193 | }, 194 | { 195 | time.Second/5 + time.Second/4, l.ErrLimitExhausted, 196 | }, 197 | { 198 | 0, nil, 199 | }, 200 | { 201 | time.Second / 2, l.ErrLimitExhausted, 202 | }, 203 | { 204 | 0, nil, 205 | }, 206 | { 207 | time.Second / 2, l.ErrLimitExhausted, 208 | }, 209 | { 210 | 0, nil, 211 | }, 212 | { 213 | time.Second / 2, l.ErrLimitExhausted, 214 | }, 215 | { 216 | 0, nil, 217 | }, 218 | { 219 | time.Second / 2, l.ErrLimitExhausted, 220 | }, 221 | }, 222 | }, 223 | { 224 | capacity: 5, 225 | rate: time.Second, 226 | epsilon: 3e-9, 227 | requests: 18, 228 | delta: 1, 229 | results: []struct { 230 | w time.Duration 231 | e error 232 | }{ 233 | { 234 | 0, nil, 235 | }, 236 | { 237 | 0, nil, 238 | }, 239 | { 240 | 0, nil, 241 | }, 242 | { 243 | 0, nil, 244 | }, 245 | { 246 | 0, nil, 247 | }, 248 | { 249 | time.Second + time.Second/3, l.ErrLimitExhausted, 250 | }, 251 | { 252 | 0, nil, 253 | }, 254 | { 255 | time.Second * 2 / 6, l.ErrLimitExhausted, 256 | }, 257 | { 258 | 0, nil, 259 | }, 260 | { 261 | time.Second * 2 / 6, l.ErrLimitExhausted, 262 | }, 263 | { 264 | 0, nil, 265 | }, 266 | { 267 | time.Second / 2, l.ErrLimitExhausted, 268 | }, 269 | { 270 | 0, nil, 271 | }, 272 | { 273 | time.Second / 2, l.ErrLimitExhausted, 274 | }, 275 | { 276 | 0, nil, 277 | }, 278 | { 279 | time.Second / 2, l.ErrLimitExhausted, 280 | }, 281 | { 282 | 0, nil, 283 | }, 284 | { 285 | time.Second / 2, l.ErrLimitExhausted, 286 | }, 287 | { 288 | 0, nil, 289 | }, 290 | }, 291 | }, 292 | } 293 | 294 | func (s *LimitersTestSuite) TestSlidingWindowOverflowAndWait() { 295 | clock := newFakeClockWithTime(time.Date(2019, 9, 3, 0, 0, 0, 0, time.UTC)) 296 | for _, testCase := range slidingWindowTestCases { 297 | for name, bucket := range s.slidingWindows(testCase.capacity, testCase.rate, clock, testCase.epsilon) { 298 | s.Run(name, func() { 299 | clock.reset() 300 | 301 | for i := 0; i < testCase.requests; i++ { 302 | w, err := bucket.Limit(context.TODO()) 303 | 304 | s.Require().LessOrEqual(i, len(testCase.results)-1) 305 | s.InDelta(testCase.results[i].w, w, testCase.delta, i) 306 | s.Equal(testCase.results[i].e, err, i) 307 | clock.Sleep(w) 308 | } 309 | }) 310 | } 311 | } 312 | } 313 | 314 | func (s *LimitersTestSuite) TestSlidingWindowOverflowAndNoWait() { 315 | capacity := int64(3) 316 | 317 | clock := newFakeClock() 318 | for name, bucket := range s.slidingWindows(capacity, time.Second, clock, 1e-9) { 319 | s.Run(name, func() { 320 | clock.reset() 321 | 322 | // Keep sending requests until it reaches the capacity. 323 | for i := int64(0); i < capacity; i++ { 324 | w, err := bucket.Limit(context.TODO()) 325 | s.Require().NoError(err) 326 | s.Require().Equal(time.Duration(0), w) 327 | clock.Sleep(time.Millisecond) 328 | } 329 | 330 | // The next request will be the first one to be rejected. 331 | w, err := bucket.Limit(context.TODO()) 332 | s.Require().Equal(l.ErrLimitExhausted, err) 333 | 334 | expected := clock.Now().Add(w) 335 | 336 | // Send a few more requests, all of them should be told to come back at the same time. 337 | for i := int64(0); i < capacity; i++ { 338 | w, err = bucket.Limit(context.TODO()) 339 | s.Require().Equal(l.ErrLimitExhausted, err) 340 | 341 | actual := clock.Now().Add(w) 342 | s.Require().Equal(expected, actual, i) 343 | clock.Sleep(time.Millisecond) 344 | } 345 | 346 | // Wait until it is ready. 347 | clock.Sleep(expected.Sub(clock.Now())) 348 | 349 | w, err = bucket.Limit(context.TODO()) 350 | s.Require().NoError(err) 351 | s.Require().Equal(time.Duration(0), w) 352 | }) 353 | } 354 | } 355 | 356 | func BenchmarkSlidingWindows(b *testing.B) { 357 | s := new(LimitersTestSuite) 358 | s.SetT(&testing.T{}) 359 | s.SetupSuite() 360 | 361 | capacity := int64(1) 362 | rate := time.Second 363 | clock := newFakeClock() 364 | epsilon := 1e-9 365 | 366 | windows := s.slidingWindows(capacity, rate, clock, epsilon) 367 | for name, window := range windows { 368 | b.Run(name, func(b *testing.B) { 369 | for i := 0; i < b.N; i++ { 370 | _, err := window.Limit(context.TODO()) 371 | s.Require().NoError(err) 372 | } 373 | }) 374 | } 375 | 376 | s.TearDownSuite() 377 | } 378 | -------------------------------------------------------------------------------- /limiters_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "hash/fnv" 8 | "os" 9 | "strings" 10 | "sync" 11 | "testing" 12 | "time" 13 | 14 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 15 | "github.com/aws/aws-sdk-go-v2/aws" 16 | "github.com/aws/aws-sdk-go-v2/config" 17 | "github.com/aws/aws-sdk-go-v2/credentials" 18 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 19 | "github.com/bradfitz/gomemcache/memcache" 20 | "github.com/go-redsync/redsync/v4/redis/goredis/v9" 21 | "github.com/google/uuid" 22 | "github.com/hashicorp/consul/api" 23 | l "github.com/mennanov/limiters" 24 | "github.com/redis/go-redis/v9" 25 | "github.com/samuel/go-zookeeper/zk" 26 | "github.com/stretchr/testify/suite" 27 | clientv3 "go.etcd.io/etcd/client/v3" 28 | ) 29 | 30 | type fakeClock struct { 31 | mu sync.Mutex 32 | initial time.Time 33 | t time.Time 34 | } 35 | 36 | func newFakeClock() *fakeClock { 37 | now := time.Now() 38 | 39 | return &fakeClock{t: now, initial: now} 40 | } 41 | 42 | func newFakeClockWithTime(t time.Time) *fakeClock { 43 | return &fakeClock{t: t, initial: t} 44 | } 45 | 46 | func (c *fakeClock) Now() time.Time { 47 | c.mu.Lock() 48 | defer c.mu.Unlock() 49 | 50 | return c.t 51 | } 52 | 53 | func (c *fakeClock) Sleep(d time.Duration) { 54 | if d == 0 { 55 | return 56 | } 57 | 58 | c.mu.Lock() 59 | defer c.mu.Unlock() 60 | 61 | c.t = c.t.Add(d) 62 | } 63 | 64 | func (c *fakeClock) reset() { 65 | c.mu.Lock() 66 | defer c.mu.Unlock() 67 | 68 | c.t = c.initial 69 | } 70 | 71 | type LimitersTestSuite struct { 72 | suite.Suite 73 | 74 | etcdClient *clientv3.Client 75 | redisClient *redis.Client 76 | redisClusterClient *redis.ClusterClient 77 | consulClient *api.Client 78 | zkConn *zk.Conn 79 | logger *l.StdLogger 80 | dynamodbClient *dynamodb.Client 81 | dynamoDBTableProps l.DynamoDBTableProperties 82 | memcacheClient *memcache.Client 83 | pgDb *sql.DB 84 | cosmosClient *azcosmos.Client 85 | cosmosContainerClient *azcosmos.ContainerClient 86 | } 87 | 88 | func (s *LimitersTestSuite) SetupSuite() { 89 | var err error 90 | 91 | s.etcdClient, err = clientv3.New(clientv3.Config{ 92 | Endpoints: strings.Split(os.Getenv("ETCD_ENDPOINTS"), ","), 93 | DialTimeout: time.Second, 94 | }) 95 | s.Require().NoError(err) 96 | _, err = s.etcdClient.Status(context.Background(), s.etcdClient.Endpoints()[0]) 97 | s.Require().NoError(err) 98 | 99 | s.redisClient = redis.NewClient(&redis.Options{ 100 | Addr: os.Getenv("REDIS_ADDR"), 101 | }) 102 | s.Require().NoError(s.redisClient.Ping(context.Background()).Err()) 103 | 104 | s.redisClusterClient = redis.NewClusterClient(&redis.ClusterOptions{ 105 | Addrs: strings.Split(os.Getenv("REDIS_NODES"), ","), 106 | }) 107 | s.Require().NoError(s.redisClusterClient.Ping(context.Background()).Err()) 108 | 109 | s.consulClient, err = api.NewClient(&api.Config{Address: os.Getenv("CONSUL_ADDR")}) 110 | s.Require().NoError(err) 111 | _, err = s.consulClient.Status().Leader() 112 | s.Require().NoError(err) 113 | 114 | s.zkConn, _, err = zk.Connect(strings.Split(os.Getenv("ZOOKEEPER_ENDPOINTS"), ","), time.Second) 115 | s.Require().NoError(err) 116 | _, _, err = s.zkConn.Get("/") 117 | s.Require().NoError(err) 118 | s.logger = l.NewStdLogger() 119 | 120 | awsCfg, err := config.LoadDefaultConfig(context.Background(), 121 | config.WithRegion("us-east-1"), 122 | config.WithCredentialsProvider(credentials.StaticCredentialsProvider{ 123 | Value: aws.Credentials{ 124 | AccessKeyID: "dummy", SecretAccessKey: "dummy", SessionToken: "dummy", 125 | Source: "Hard-coded credentials; values are irrelevant for local DynamoDB", 126 | }, 127 | }), 128 | ) 129 | s.Require().NoError(err) 130 | 131 | endpoint := fmt.Sprintf("http://%s", os.Getenv("AWS_ADDR")) 132 | s.dynamodbClient = dynamodb.NewFromConfig(awsCfg, func(options *dynamodb.Options) { 133 | options.BaseEndpoint = &endpoint 134 | }) 135 | 136 | _ = DeleteTestDynamoDBTable(context.Background(), s.dynamodbClient) 137 | s.Require().NoError(CreateTestDynamoDBTable(context.Background(), s.dynamodbClient)) 138 | s.dynamoDBTableProps, err = l.LoadDynamoDBTableProperties(context.Background(), s.dynamodbClient, testDynamoDBTableName) 139 | s.Require().NoError(err) 140 | 141 | s.memcacheClient = memcache.New(strings.Split(os.Getenv("MEMCACHED_ADDR"), ",")...) 142 | s.Require().NoError(s.memcacheClient.Ping()) 143 | 144 | s.pgDb, err = sql.Open("postgres", os.Getenv("POSTGRES_URL")) 145 | s.Require().NoError(err) 146 | s.Require().NoError(s.pgDb.PingContext(context.Background())) 147 | 148 | // https://learn.microsoft.com/en-us/azure/cosmos-db/emulator#authentication 149 | connString := fmt.Sprintf("AccountEndpoint=http://%s/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;", os.Getenv("COSMOS_ADDR")) 150 | s.cosmosClient, err = azcosmos.NewClientFromConnectionString(connString, &azcosmos.ClientOptions{}) 151 | s.Require().NoError(err) 152 | 153 | err = CreateCosmosDBContainer(context.Background(), s.cosmosClient) 154 | if err != nil { 155 | s.Require().ErrorContains(err, "Database 'limiters-db-test' already exists") 156 | } 157 | 158 | s.cosmosContainerClient, err = s.cosmosClient.NewContainer(testCosmosDBName, testCosmosContainerName) 159 | s.Require().NoError(err) 160 | } 161 | 162 | func (s *LimitersTestSuite) TearDownSuite() { 163 | s.Assert().NoError(s.etcdClient.Close()) 164 | s.Assert().NoError(s.redisClient.Close()) 165 | s.Assert().NoError(s.redisClusterClient.Close()) 166 | s.Assert().NoError(DeleteTestDynamoDBTable(context.Background(), s.dynamodbClient)) 167 | s.Assert().NoError(s.memcacheClient.Close()) 168 | s.Assert().NoError(s.pgDb.Close()) 169 | s.Assert().NoError(DeleteCosmosDBContainer(context.Background(), s.cosmosClient)) 170 | } 171 | 172 | func TestLimitersTestSuite(t *testing.T) { 173 | suite.Run(t, new(LimitersTestSuite)) 174 | } 175 | 176 | // lockers returns all possible lockers (including noop). 177 | func (s *LimitersTestSuite) lockers(generateKeys bool) map[string]l.DistLocker { 178 | lockers := s.distLockers(generateKeys) 179 | lockers["LockNoop"] = l.NewLockNoop() 180 | 181 | return lockers 182 | } 183 | 184 | func hash(s string) int64 { 185 | h := fnv.New32a() 186 | 187 | _, err := h.Write([]byte(s)) 188 | if err != nil { 189 | panic(err) 190 | } 191 | 192 | return int64(h.Sum32()) 193 | } 194 | 195 | // distLockers returns distributed lockers only. 196 | func (s *LimitersTestSuite) distLockers(generateKeys bool) map[string]l.DistLocker { 197 | randomKey := uuid.New().String() 198 | consulKey := randomKey 199 | etcdKey := randomKey 200 | zkKey := "/" + randomKey 201 | redisKey := randomKey 202 | memcacheKey := randomKey 203 | pgKey := randomKey 204 | 205 | if !generateKeys { 206 | consulKey = "dist_locker" 207 | etcdKey = "dist_locker" 208 | zkKey = "/dist_locker" 209 | redisKey = "dist_locker" 210 | memcacheKey = "dist_locker" 211 | pgKey = "dist_locker" 212 | } 213 | 214 | consulLock, err := s.consulClient.LockKey(consulKey) 215 | s.Require().NoError(err) 216 | 217 | return map[string]l.DistLocker{ 218 | "LockEtcd": l.NewLockEtcd(s.etcdClient, etcdKey, s.logger), 219 | "LockConsul": l.NewLockConsul(consulLock), 220 | "LockZookeeper": l.NewLockZookeeper(zk.NewLock(s.zkConn, zkKey, zk.WorldACL(zk.PermAll))), 221 | "LockRedis": l.NewLockRedis(goredis.NewPool(s.redisClient), redisKey), 222 | "LockRedisCluster": l.NewLockRedis(goredis.NewPool(s.redisClusterClient), redisKey), 223 | "LockMemcached": l.NewLockMemcached(s.memcacheClient, memcacheKey), 224 | "LockPostgreSQL": l.NewLockPostgreSQL(s.pgDb, hash(pgKey)), 225 | } 226 | } 227 | 228 | func (s *LimitersTestSuite) TestLimitContextCancelled() { 229 | clock := newFakeClock() 230 | capacity := int64(2) 231 | rate := time.Second 232 | 233 | limiters := make(map[string]interface{}) 234 | for n, b := range s.tokenBuckets(capacity, rate, time.Second, clock) { 235 | limiters[n] = b 236 | } 237 | 238 | for n, b := range s.leakyBuckets(capacity, rate, time.Second, clock) { 239 | limiters[n] = b 240 | } 241 | 242 | for n, w := range s.fixedWindows(capacity, rate, clock) { 243 | limiters[n] = w 244 | } 245 | 246 | for n, w := range s.slidingWindows(capacity, rate, clock, 1e-9) { 247 | limiters[n] = w 248 | } 249 | 250 | for n, b := range s.concurrentBuffers(capacity, rate, clock) { 251 | limiters[n] = b 252 | } 253 | 254 | type rateLimiter interface { 255 | Limit(context.Context) (time.Duration, error) 256 | } 257 | 258 | type concurrentLimiter interface { 259 | Limit(context.Context, string) error 260 | } 261 | 262 | for name, limiter := range limiters { 263 | s.Run(name, func() { 264 | done1 := make(chan struct{}) 265 | 266 | go func(limiter interface{}) { 267 | defer close(done1) 268 | // The context is expired shortly after it is created. 269 | ctx, cancel := context.WithCancel(context.Background()) 270 | cancel() 271 | 272 | switch lim := limiter.(type) { 273 | case rateLimiter: 274 | _, err := lim.Limit(ctx) 275 | s.Error(err, "%T", limiter) 276 | 277 | case concurrentLimiter: 278 | s.Error(lim.Limit(ctx, "key"), "%T", limiter) 279 | } 280 | }(limiter) 281 | 282 | done2 := make(chan struct{}) 283 | 284 | go func(limiter interface{}) { 285 | defer close(done2) 286 | 287 | <-done1 288 | 289 | ctx := context.Background() 290 | 291 | switch lim := limiter.(type) { 292 | case rateLimiter: 293 | _, err := lim.Limit(ctx) 294 | s.NoError(err, "%T", limiter) 295 | 296 | case concurrentLimiter: 297 | s.NoError(lim.Limit(ctx, "key"), "%T", limiter) 298 | } 299 | }(limiter) 300 | // Verify that the second go routine succeeded calling the Limit() method. 301 | <-done2 302 | }) 303 | } 304 | } 305 | -------------------------------------------------------------------------------- /concurrent_buffer.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/gob" 7 | "fmt" 8 | "sync" 9 | "time" 10 | 11 | "github.com/bradfitz/gomemcache/memcache" 12 | "github.com/pkg/errors" 13 | "github.com/redis/go-redis/v9" 14 | ) 15 | 16 | // ConcurrentBufferBackend wraps the Add and Remove methods. 17 | type ConcurrentBufferBackend interface { 18 | // Add adds the request with the given key to the buffer and returns the total number of requests in it. 19 | Add(ctx context.Context, key string) (int64, error) 20 | // Remove removes the request from the buffer. 21 | Remove(ctx context.Context, key string) error 22 | } 23 | 24 | // ConcurrentBuffer implements a limiter that allows concurrent requests up to the given capacity. 25 | type ConcurrentBuffer struct { 26 | locker DistLocker 27 | backend ConcurrentBufferBackend 28 | logger Logger 29 | capacity int64 30 | mu sync.Mutex 31 | } 32 | 33 | // NewConcurrentBuffer creates a new ConcurrentBuffer instance. 34 | func NewConcurrentBuffer(locker DistLocker, concurrentStateBackend ConcurrentBufferBackend, capacity int64, logger Logger) *ConcurrentBuffer { 35 | return &ConcurrentBuffer{locker: locker, backend: concurrentStateBackend, capacity: capacity, logger: logger} 36 | } 37 | 38 | // Limit puts the request identified by the key in a buffer. 39 | func (c *ConcurrentBuffer) Limit(ctx context.Context, key string) error { 40 | c.mu.Lock() 41 | defer c.mu.Unlock() 42 | 43 | err := c.locker.Lock(ctx) 44 | if err != nil { 45 | return err 46 | } 47 | 48 | defer func() { 49 | err := c.locker.Unlock(ctx) 50 | if err != nil { 51 | c.logger.Log(err) 52 | } 53 | }() 54 | // Optimistically add the new request. 55 | counter, err := c.backend.Add(ctx, key) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | if counter > c.capacity { 61 | // Rollback the Add() operation. 62 | err = c.backend.Remove(ctx, key) 63 | if err != nil { 64 | c.logger.Log(err) 65 | } 66 | 67 | return ErrLimitExhausted 68 | } 69 | 70 | return nil 71 | } 72 | 73 | // Done removes the request identified by the key from the buffer. 74 | func (c *ConcurrentBuffer) Done(ctx context.Context, key string) error { 75 | return c.backend.Remove(ctx, key) 76 | } 77 | 78 | // ConcurrentBufferInMemory is an in-memory implementation of ConcurrentBufferBackend. 79 | type ConcurrentBufferInMemory struct { 80 | clock Clock 81 | ttl time.Duration 82 | mu sync.Mutex 83 | registry *Registry 84 | } 85 | 86 | // NewConcurrentBufferInMemory creates a new instance of ConcurrentBufferInMemory. 87 | // When the TTL of a key exceeds the key is removed from the buffer. This is needed in case if the process that added 88 | // that key to the buffer did not call Done() for some reason. 89 | func NewConcurrentBufferInMemory(registry *Registry, ttl time.Duration, clock Clock) *ConcurrentBufferInMemory { 90 | return &ConcurrentBufferInMemory{clock: clock, ttl: ttl, registry: registry} 91 | } 92 | 93 | // Add adds the request with the given key to the buffer and returns the total number of requests in it. 94 | // It also removes the keys with expired TTL. 95 | func (c *ConcurrentBufferInMemory) Add(ctx context.Context, key string) (int64, error) { 96 | c.mu.Lock() 97 | defer c.mu.Unlock() 98 | 99 | now := c.clock.Now() 100 | c.registry.DeleteExpired(now) 101 | c.registry.GetOrCreate(key, func() interface{} { 102 | return struct{}{} 103 | }, c.ttl, now) 104 | 105 | return int64(c.registry.Len()), ctx.Err() 106 | } 107 | 108 | // Remove removes the request from the buffer. 109 | func (c *ConcurrentBufferInMemory) Remove(_ context.Context, key string) error { 110 | c.mu.Lock() 111 | defer c.mu.Unlock() 112 | 113 | c.registry.Delete(key) 114 | 115 | return nil 116 | } 117 | 118 | // ConcurrentBufferRedis implements ConcurrentBufferBackend in Redis. 119 | type ConcurrentBufferRedis struct { 120 | clock Clock 121 | cli redis.UniversalClient 122 | key string 123 | ttl time.Duration 124 | } 125 | 126 | // NewConcurrentBufferRedis creates a new instance of ConcurrentBufferRedis. 127 | // When the TTL of a key exceeds the key is removed from the buffer. This is needed in case if the process that added 128 | // that key to the buffer did not call Done() for some reason. 129 | func NewConcurrentBufferRedis(cli redis.UniversalClient, key string, ttl time.Duration, clock Clock) *ConcurrentBufferRedis { 130 | return &ConcurrentBufferRedis{clock: clock, cli: cli, key: key, ttl: ttl} 131 | } 132 | 133 | // Add adds the request with the given key to the sorted set in Redis and returns the total number of requests in it. 134 | // It also removes the keys with expired TTL. 135 | func (c *ConcurrentBufferRedis) Add(ctx context.Context, key string) (int64, error) { 136 | var ( 137 | countCmd *redis.IntCmd 138 | err error 139 | ) 140 | 141 | done := make(chan struct{}) 142 | 143 | go func() { 144 | defer close(done) 145 | 146 | _, err = c.cli.Pipelined(ctx, func(pipeliner redis.Pipeliner) error { 147 | // Remove expired items. 148 | now := c.clock.Now() 149 | pipeliner.ZRemRangeByScore(ctx, c.key, "-inf", fmt.Sprintf("%d", now.Add(-c.ttl).UnixNano())) 150 | pipeliner.ZAdd(ctx, c.key, redis.Z{ 151 | Score: float64(now.UnixNano()), 152 | Member: key, 153 | }) 154 | pipeliner.Expire(ctx, c.key, c.ttl) 155 | countCmd = pipeliner.ZCard(ctx, c.key) 156 | 157 | return nil 158 | }) 159 | }() 160 | 161 | select { 162 | case <-ctx.Done(): 163 | return 0, ctx.Err() 164 | 165 | case <-done: 166 | if err != nil { 167 | return 0, errors.Wrap(err, "failed to add an item to redis set") 168 | } 169 | 170 | return countCmd.Val(), nil 171 | } 172 | } 173 | 174 | // Remove removes the request identified by the key from the sorted set in Redis. 175 | func (c *ConcurrentBufferRedis) Remove(ctx context.Context, key string) error { 176 | return errors.Wrap(c.cli.ZRem(ctx, c.key, key).Err(), "failed to remove an item from redis set") 177 | } 178 | 179 | // ConcurrentBufferMemcached implements ConcurrentBufferBackend in Memcached. 180 | type ConcurrentBufferMemcached struct { 181 | clock Clock 182 | cli *memcache.Client 183 | key string 184 | ttl time.Duration 185 | } 186 | 187 | // NewConcurrentBufferMemcached creates a new instance of ConcurrentBufferMemcached. 188 | // When the TTL of a key exceeds the key is removed from the buffer. This is needed in case if the process that added 189 | // that key to the buffer did not call Done() for some reason. 190 | func NewConcurrentBufferMemcached(cli *memcache.Client, key string, ttl time.Duration, clock Clock) *ConcurrentBufferMemcached { 191 | return &ConcurrentBufferMemcached{clock: clock, cli: cli, key: key, ttl: ttl} 192 | } 193 | 194 | type SortedSetNode struct { 195 | CreatedAt int64 196 | Value string 197 | } 198 | 199 | // Add adds the request with the given key to the slice in Memcached and returns the total number of requests in it. 200 | // It also removes the keys with expired TTL. 201 | func (c *ConcurrentBufferMemcached) Add(ctx context.Context, element string) (int64, error) { 202 | var err error 203 | 204 | done := make(chan struct{}) 205 | now := c.clock.Now() 206 | 207 | var ( 208 | newNodes []SortedSetNode 209 | casId uint64 = 0 210 | ) 211 | 212 | go func() { 213 | defer close(done) 214 | 215 | var item *memcache.Item 216 | 217 | item, err = c.cli.Get(c.key) 218 | if err != nil { 219 | if !errors.Is(err, memcache.ErrCacheMiss) { 220 | return 221 | } 222 | } else { 223 | casId = item.CasID 224 | b := bytes.NewBuffer(item.Value) 225 | 226 | var oldNodes []SortedSetNode 227 | 228 | _ = gob.NewDecoder(b).Decode(&oldNodes) 229 | for _, node := range oldNodes { 230 | if node.CreatedAt > now.UnixNano() && node.Value != element { 231 | newNodes = append(newNodes, node) 232 | } 233 | } 234 | } 235 | 236 | newNodes = append(newNodes, SortedSetNode{CreatedAt: now.Add(c.ttl).UnixNano(), Value: element}) 237 | 238 | var b bytes.Buffer 239 | 240 | _ = gob.NewEncoder(&b).Encode(newNodes) 241 | 242 | item = &memcache.Item{ 243 | Key: c.key, 244 | Value: b.Bytes(), 245 | CasID: casId, 246 | } 247 | if casId > 0 { 248 | err = c.cli.CompareAndSwap(item) 249 | } else { 250 | err = c.cli.Add(item) 251 | } 252 | }() 253 | 254 | select { 255 | case <-ctx.Done(): 256 | return 0, ctx.Err() 257 | 258 | case <-done: 259 | if err != nil { 260 | if errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss) { 261 | return c.Add(ctx, element) 262 | } 263 | 264 | return 0, errors.Wrap(err, "failed to add in memcached") 265 | } 266 | 267 | return int64(len(newNodes)), nil 268 | } 269 | } 270 | 271 | // Remove removes the request identified by the key from the slice in Memcached. 272 | func (c *ConcurrentBufferMemcached) Remove(ctx context.Context, key string) error { 273 | var err error 274 | 275 | now := c.clock.Now() 276 | 277 | var ( 278 | newNodes []SortedSetNode 279 | casID uint64 280 | ) 281 | 282 | item, err := c.cli.Get(c.key) 283 | if err != nil { 284 | if errors.Is(err, memcache.ErrCacheMiss) { 285 | return nil 286 | } 287 | 288 | return errors.Wrap(err, "failed to Get") 289 | } 290 | 291 | casID = item.CasID 292 | 293 | var oldNodes []SortedSetNode 294 | 295 | _ = gob.NewDecoder(bytes.NewBuffer(item.Value)).Decode(&oldNodes) 296 | for _, node := range oldNodes { 297 | if node.CreatedAt > now.UnixNano() && node.Value != key { 298 | newNodes = append(newNodes, node) 299 | } 300 | } 301 | 302 | var b bytes.Buffer 303 | 304 | _ = gob.NewEncoder(&b).Encode(newNodes) 305 | item = &memcache.Item{ 306 | Key: c.key, 307 | Value: b.Bytes(), 308 | CasID: casID, 309 | } 310 | 311 | err = c.cli.CompareAndSwap(item) 312 | if err != nil && (errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss)) { 313 | return c.Remove(ctx, key) 314 | } 315 | 316 | return errors.Wrap(err, "failed to CompareAndSwap") 317 | } 318 | -------------------------------------------------------------------------------- /examples/helloworld/helloworld.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.35.1 4 | // protoc v5.28.3 5 | // source: helloworld.proto 6 | 7 | package helloworld 8 | 9 | import ( 10 | reflect "reflect" 11 | sync "sync" 12 | 13 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 14 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 15 | ) 16 | 17 | const ( 18 | // Verify that this generated code is sufficiently up-to-date. 19 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 20 | // Verify that runtime/protoimpl is sufficiently up-to-date. 21 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 22 | ) 23 | 24 | // The request message containing the user's name. 25 | type HelloRequest struct { 26 | state protoimpl.MessageState 27 | sizeCache protoimpl.SizeCache 28 | unknownFields protoimpl.UnknownFields 29 | 30 | Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` 31 | } 32 | 33 | func (x *HelloRequest) Reset() { 34 | *x = HelloRequest{} 35 | mi := &file_helloworld_proto_msgTypes[0] 36 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 37 | ms.StoreMessageInfo(mi) 38 | } 39 | 40 | func (x *HelloRequest) String() string { 41 | return protoimpl.X.MessageStringOf(x) 42 | } 43 | 44 | func (*HelloRequest) ProtoMessage() {} 45 | 46 | func (x *HelloRequest) ProtoReflect() protoreflect.Message { 47 | mi := &file_helloworld_proto_msgTypes[0] 48 | if x != nil { 49 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 50 | if ms.LoadMessageInfo() == nil { 51 | ms.StoreMessageInfo(mi) 52 | } 53 | return ms 54 | } 55 | return mi.MessageOf(x) 56 | } 57 | 58 | // Deprecated: Use HelloRequest.ProtoReflect.Descriptor instead. 59 | func (*HelloRequest) Descriptor() ([]byte, []int) { 60 | return file_helloworld_proto_rawDescGZIP(), []int{0} 61 | } 62 | 63 | func (x *HelloRequest) GetName() string { 64 | if x != nil { 65 | return x.Name 66 | } 67 | return "" 68 | } 69 | 70 | // The response message containing the greetings 71 | type HelloReply struct { 72 | state protoimpl.MessageState 73 | sizeCache protoimpl.SizeCache 74 | unknownFields protoimpl.UnknownFields 75 | 76 | Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` 77 | } 78 | 79 | func (x *HelloReply) Reset() { 80 | *x = HelloReply{} 81 | mi := &file_helloworld_proto_msgTypes[1] 82 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 83 | ms.StoreMessageInfo(mi) 84 | } 85 | 86 | func (x *HelloReply) String() string { 87 | return protoimpl.X.MessageStringOf(x) 88 | } 89 | 90 | func (*HelloReply) ProtoMessage() {} 91 | 92 | func (x *HelloReply) ProtoReflect() protoreflect.Message { 93 | mi := &file_helloworld_proto_msgTypes[1] 94 | if x != nil { 95 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 96 | if ms.LoadMessageInfo() == nil { 97 | ms.StoreMessageInfo(mi) 98 | } 99 | return ms 100 | } 101 | return mi.MessageOf(x) 102 | } 103 | 104 | // Deprecated: Use HelloReply.ProtoReflect.Descriptor instead. 105 | func (*HelloReply) Descriptor() ([]byte, []int) { 106 | return file_helloworld_proto_rawDescGZIP(), []int{1} 107 | } 108 | 109 | func (x *HelloReply) GetMessage() string { 110 | if x != nil { 111 | return x.Message 112 | } 113 | return "" 114 | } 115 | 116 | type GoodbyeRequest struct { 117 | state protoimpl.MessageState 118 | sizeCache protoimpl.SizeCache 119 | unknownFields protoimpl.UnknownFields 120 | 121 | Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` 122 | } 123 | 124 | func (x *GoodbyeRequest) Reset() { 125 | *x = GoodbyeRequest{} 126 | mi := &file_helloworld_proto_msgTypes[2] 127 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 128 | ms.StoreMessageInfo(mi) 129 | } 130 | 131 | func (x *GoodbyeRequest) String() string { 132 | return protoimpl.X.MessageStringOf(x) 133 | } 134 | 135 | func (*GoodbyeRequest) ProtoMessage() {} 136 | 137 | func (x *GoodbyeRequest) ProtoReflect() protoreflect.Message { 138 | mi := &file_helloworld_proto_msgTypes[2] 139 | if x != nil { 140 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 141 | if ms.LoadMessageInfo() == nil { 142 | ms.StoreMessageInfo(mi) 143 | } 144 | return ms 145 | } 146 | return mi.MessageOf(x) 147 | } 148 | 149 | // Deprecated: Use GoodbyeRequest.ProtoReflect.Descriptor instead. 150 | func (*GoodbyeRequest) Descriptor() ([]byte, []int) { 151 | return file_helloworld_proto_rawDescGZIP(), []int{2} 152 | } 153 | 154 | func (x *GoodbyeRequest) GetName() string { 155 | if x != nil { 156 | return x.Name 157 | } 158 | return "" 159 | } 160 | 161 | type GoodbyeReply struct { 162 | state protoimpl.MessageState 163 | sizeCache protoimpl.SizeCache 164 | unknownFields protoimpl.UnknownFields 165 | 166 | Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` 167 | } 168 | 169 | func (x *GoodbyeReply) Reset() { 170 | *x = GoodbyeReply{} 171 | mi := &file_helloworld_proto_msgTypes[3] 172 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 173 | ms.StoreMessageInfo(mi) 174 | } 175 | 176 | func (x *GoodbyeReply) String() string { 177 | return protoimpl.X.MessageStringOf(x) 178 | } 179 | 180 | func (*GoodbyeReply) ProtoMessage() {} 181 | 182 | func (x *GoodbyeReply) ProtoReflect() protoreflect.Message { 183 | mi := &file_helloworld_proto_msgTypes[3] 184 | if x != nil { 185 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 186 | if ms.LoadMessageInfo() == nil { 187 | ms.StoreMessageInfo(mi) 188 | } 189 | return ms 190 | } 191 | return mi.MessageOf(x) 192 | } 193 | 194 | // Deprecated: Use GoodbyeReply.ProtoReflect.Descriptor instead. 195 | func (*GoodbyeReply) Descriptor() ([]byte, []int) { 196 | return file_helloworld_proto_rawDescGZIP(), []int{3} 197 | } 198 | 199 | func (x *GoodbyeReply) GetMessage() string { 200 | if x != nil { 201 | return x.Message 202 | } 203 | return "" 204 | } 205 | 206 | var File_helloworld_proto protoreflect.FileDescriptor 207 | 208 | var file_helloworld_proto_rawDesc = []byte{ 209 | 0x0a, 0x10, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x70, 0x72, 0x6f, 210 | 0x74, 0x6f, 0x12, 0x0a, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x22, 0x22, 211 | 0x0a, 0x0c, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 212 | 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 213 | 0x6d, 0x65, 0x22, 0x26, 0x0a, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x70, 0x6c, 0x79, 214 | 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 215 | 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x24, 0x0a, 0x0e, 0x47, 0x6f, 216 | 0x6f, 0x64, 0x62, 0x79, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 217 | 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 218 | 0x22, 0x28, 0x0a, 0x0c, 0x47, 0x6f, 0x6f, 0x64, 0x62, 0x79, 0x65, 0x52, 0x65, 0x70, 0x6c, 0x79, 219 | 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 220 | 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0x49, 0x0a, 0x07, 0x47, 0x72, 221 | 0x65, 0x65, 0x74, 0x65, 0x72, 0x12, 0x3e, 0x0a, 0x08, 0x53, 0x61, 0x79, 0x48, 0x65, 0x6c, 0x6c, 222 | 0x6f, 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 223 | 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x68, 0x65, 224 | 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 225 | 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 226 | 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x65, 0x6e, 0x6e, 0x61, 0x6e, 0x6f, 0x76, 0x2f, 0x6c, 0x69, 0x6d, 227 | 0x69, 0x74, 0x65, 0x72, 0x73, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x68, 228 | 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 229 | 0x33, 230 | } 231 | 232 | var ( 233 | file_helloworld_proto_rawDescOnce sync.Once 234 | file_helloworld_proto_rawDescData = file_helloworld_proto_rawDesc 235 | ) 236 | 237 | func file_helloworld_proto_rawDescGZIP() []byte { 238 | file_helloworld_proto_rawDescOnce.Do(func() { 239 | file_helloworld_proto_rawDescData = protoimpl.X.CompressGZIP(file_helloworld_proto_rawDescData) 240 | }) 241 | return file_helloworld_proto_rawDescData 242 | } 243 | 244 | var file_helloworld_proto_msgTypes = make([]protoimpl.MessageInfo, 4) 245 | var file_helloworld_proto_goTypes = []any{ 246 | (*HelloRequest)(nil), // 0: helloworld.HelloRequest 247 | (*HelloReply)(nil), // 1: helloworld.HelloReply 248 | (*GoodbyeRequest)(nil), // 2: helloworld.GoodbyeRequest 249 | (*GoodbyeReply)(nil), // 3: helloworld.GoodbyeReply 250 | } 251 | var file_helloworld_proto_depIdxs = []int32{ 252 | 0, // 0: helloworld.Greeter.SayHello:input_type -> helloworld.HelloRequest 253 | 1, // 1: helloworld.Greeter.SayHello:output_type -> helloworld.HelloReply 254 | 1, // [1:2] is the sub-list for method output_type 255 | 0, // [0:1] is the sub-list for method input_type 256 | 0, // [0:0] is the sub-list for extension type_name 257 | 0, // [0:0] is the sub-list for extension extendee 258 | 0, // [0:0] is the sub-list for field type_name 259 | } 260 | 261 | func init() { file_helloworld_proto_init() } 262 | func file_helloworld_proto_init() { 263 | if File_helloworld_proto != nil { 264 | return 265 | } 266 | type x struct{} 267 | out := protoimpl.TypeBuilder{ 268 | File: protoimpl.DescBuilder{ 269 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 270 | RawDescriptor: file_helloworld_proto_rawDesc, 271 | NumEnums: 0, 272 | NumMessages: 4, 273 | NumExtensions: 0, 274 | NumServices: 1, 275 | }, 276 | GoTypes: file_helloworld_proto_goTypes, 277 | DependencyIndexes: file_helloworld_proto_depIdxs, 278 | MessageInfos: file_helloworld_proto_msgTypes, 279 | }.Build() 280 | File_helloworld_proto = out.File 281 | file_helloworld_proto_rawDesc = nil 282 | file_helloworld_proto_goTypes = nil 283 | file_helloworld_proto_depIdxs = nil 284 | } 285 | -------------------------------------------------------------------------------- /fixedwindow.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "strconv" 10 | "sync" 11 | "time" 12 | 13 | "github.com/Azure/azure-sdk-for-go/sdk/azcore" 14 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 15 | "github.com/aws/aws-sdk-go-v2/aws" 16 | "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" 17 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 18 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" 19 | "github.com/bradfitz/gomemcache/memcache" 20 | "github.com/pkg/errors" 21 | "github.com/redis/go-redis/v9" 22 | ) 23 | 24 | // FixedWindowIncrementer wraps the Increment method. 25 | type FixedWindowIncrementer interface { 26 | // Increment increments the request counter for the window and returns the counter value. 27 | // TTL is the time duration before the next window. 28 | Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) 29 | } 30 | 31 | // FixedWindow implements a Fixed Window rate limiting algorithm. 32 | // 33 | // Simple and memory efficient algorithm that does not need a distributed lock. 34 | // However it may be lenient when there are many requests around the boundary between 2 adjacent windows. 35 | type FixedWindow struct { 36 | backend FixedWindowIncrementer 37 | clock Clock 38 | rate time.Duration 39 | capacity int64 40 | mu sync.Mutex 41 | window time.Time 42 | overflow bool 43 | } 44 | 45 | // NewFixedWindow creates a new instance of FixedWindow. 46 | // Capacity is the maximum amount of requests allowed per window. 47 | // Rate is the window size. 48 | func NewFixedWindow(capacity int64, rate time.Duration, fixedWindowIncrementer FixedWindowIncrementer, clock Clock) *FixedWindow { 49 | return &FixedWindow{backend: fixedWindowIncrementer, clock: clock, rate: rate, capacity: capacity} 50 | } 51 | 52 | // Limit returns the time duration to wait before the request can be processed. 53 | // It returns ErrLimitExhausted if the request overflows the window's capacity. 54 | func (f *FixedWindow) Limit(ctx context.Context) (time.Duration, error) { 55 | f.mu.Lock() 56 | defer f.mu.Unlock() 57 | 58 | now := f.clock.Now() 59 | 60 | window := now.Truncate(f.rate) 61 | if f.window != window { 62 | f.window = window 63 | f.overflow = false 64 | } 65 | 66 | ttl := f.rate - now.Sub(window) 67 | if f.overflow { 68 | // If the window is already overflowed don't increment the counter. 69 | return ttl, ErrLimitExhausted 70 | } 71 | 72 | c, err := f.backend.Increment(ctx, window, ttl) 73 | if err != nil { 74 | return 0, err 75 | } 76 | 77 | if c > f.capacity { 78 | f.overflow = true 79 | 80 | return ttl, ErrLimitExhausted 81 | } 82 | 83 | return 0, nil 84 | } 85 | 86 | // FixedWindowInMemory is an in-memory implementation of FixedWindowIncrementer. 87 | type FixedWindowInMemory struct { 88 | mu sync.Mutex 89 | c int64 90 | window time.Time 91 | } 92 | 93 | // NewFixedWindowInMemory creates a new instance of FixedWindowInMemory. 94 | func NewFixedWindowInMemory() *FixedWindowInMemory { 95 | return &FixedWindowInMemory{} 96 | } 97 | 98 | // Increment increments the window's counter. 99 | func (f *FixedWindowInMemory) Increment(ctx context.Context, window time.Time, _ time.Duration) (int64, error) { 100 | f.mu.Lock() 101 | defer f.mu.Unlock() 102 | 103 | if window != f.window { 104 | f.c = 0 105 | f.window = window 106 | } 107 | 108 | f.c++ 109 | 110 | return f.c, ctx.Err() 111 | } 112 | 113 | // FixedWindowRedis implements FixedWindow in Redis. 114 | type FixedWindowRedis struct { 115 | cli redis.UniversalClient 116 | prefix string 117 | } 118 | 119 | // NewFixedWindowRedis returns a new instance of FixedWindowRedis. 120 | // Prefix is the key prefix used to store all the keys used in this implementation in Redis. 121 | func NewFixedWindowRedis(cli redis.UniversalClient, prefix string) *FixedWindowRedis { 122 | return &FixedWindowRedis{cli: cli, prefix: prefix} 123 | } 124 | 125 | // Increment increments the window's counter in Redis. 126 | func (f *FixedWindowRedis) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) { 127 | var ( 128 | incr *redis.IntCmd 129 | err error 130 | ) 131 | 132 | done := make(chan struct{}) 133 | 134 | go func() { 135 | defer close(done) 136 | 137 | _, err = f.cli.Pipelined(ctx, func(pipeliner redis.Pipeliner) error { 138 | key := fmt.Sprintf("%d", window.UnixNano()) 139 | incr = pipeliner.Incr(ctx, redisKey(f.prefix, key)) 140 | pipeliner.PExpire(ctx, redisKey(f.prefix, key), ttl) 141 | 142 | return nil 143 | }) 144 | }() 145 | 146 | select { 147 | case <-done: 148 | if err != nil { 149 | return 0, errors.Wrap(err, "redis transaction failed") 150 | } 151 | 152 | return incr.Val(), incr.Err() 153 | case <-ctx.Done(): 154 | return 0, ctx.Err() 155 | } 156 | } 157 | 158 | // FixedWindowMemcached implements FixedWindow in Memcached. 159 | type FixedWindowMemcached struct { 160 | cli *memcache.Client 161 | prefix string 162 | } 163 | 164 | // NewFixedWindowMemcached returns a new instance of FixedWindowMemcached. 165 | // Prefix is the key prefix used to store all the keys used in this implementation in Memcached. 166 | func NewFixedWindowMemcached(cli *memcache.Client, prefix string) *FixedWindowMemcached { 167 | return &FixedWindowMemcached{cli: cli, prefix: prefix} 168 | } 169 | 170 | // Increment increments the window's counter in Memcached. 171 | func (f *FixedWindowMemcached) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) { 172 | var ( 173 | newValue uint64 174 | err error 175 | ) 176 | 177 | done := make(chan struct{}) 178 | 179 | go func() { 180 | defer close(done) 181 | 182 | key := fmt.Sprintf("%s:%d", f.prefix, window.UnixNano()) 183 | 184 | newValue, err = f.cli.Increment(key, 1) 185 | if err != nil && errors.Is(err, memcache.ErrCacheMiss) { 186 | newValue = 1 187 | item := &memcache.Item{ 188 | Key: key, 189 | Value: []byte(strconv.FormatUint(newValue, 10)), 190 | } 191 | err = f.cli.Add(item) 192 | } 193 | }() 194 | 195 | select { 196 | case <-done: 197 | if err != nil { 198 | if errors.Is(err, memcache.ErrNotStored) { 199 | return f.Increment(ctx, window, ttl) 200 | } 201 | 202 | return 0, errors.Wrap(err, "failed to Increment or Add") 203 | } 204 | 205 | return int64(newValue), err 206 | case <-ctx.Done(): 207 | return 0, ctx.Err() 208 | } 209 | } 210 | 211 | // FixedWindowDynamoDB implements FixedWindow in DynamoDB. 212 | type FixedWindowDynamoDB struct { 213 | client *dynamodb.Client 214 | partitionKey string 215 | tableProps DynamoDBTableProperties 216 | } 217 | 218 | // NewFixedWindowDynamoDB creates a new instance of FixedWindowDynamoDB. 219 | // PartitionKey is the key used to store all the this implementation in DynamoDB. 220 | // 221 | // TableProps describe the table that this backend should work with. This backend requires the following on the table: 222 | // * SortKey 223 | // * TTL. 224 | func NewFixedWindowDynamoDB(client *dynamodb.Client, partitionKey string, props DynamoDBTableProperties) *FixedWindowDynamoDB { 225 | return &FixedWindowDynamoDB{ 226 | client: client, 227 | partitionKey: partitionKey, 228 | tableProps: props, 229 | } 230 | } 231 | 232 | type contextKey int 233 | 234 | var fixedWindowDynamoDBPartitionKey contextKey 235 | 236 | // NewFixedWindowDynamoDBContext creates a context for FixedWindowDynamoDB with a partition key. 237 | // 238 | // This context can be used to control the partition key per-request. 239 | // 240 | // Deprecated: NewFixedWindowDynamoDBContext is deprecated and will be removed in future versions. 241 | // Separate FixedWindow rate limiters should be used for different partition keys instead. 242 | // Consider using the `Registry` to manage multiple FixedWindow instances with different partition keys. 243 | func NewFixedWindowDynamoDBContext(ctx context.Context, partitionKey string) context.Context { 244 | log.Printf("DEPRECATED: NewFixedWindowDynamoDBContext is deprecated and will be removed in future versions.") 245 | 246 | return context.WithValue(ctx, fixedWindowDynamoDBPartitionKey, partitionKey) 247 | } 248 | 249 | const ( 250 | fixedWindowDynamoDBUpdateExpression = "SET #C = if_not_exists(#C, :def) + :inc, #TTL = :ttl" 251 | dynamodbWindowCountKey = "Count" 252 | ) 253 | 254 | // Increment increments the window's counter in DynamoDB. 255 | func (f *FixedWindowDynamoDB) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) { 256 | var ( 257 | resp *dynamodb.UpdateItemOutput 258 | err error 259 | ) 260 | 261 | done := make(chan struct{}) 262 | 263 | go func() { 264 | defer close(done) 265 | 266 | partitionKey := f.partitionKey 267 | if key, ok := ctx.Value(fixedWindowDynamoDBPartitionKey).(string); ok { 268 | partitionKey = key 269 | } 270 | 271 | resp, err = f.client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ 272 | Key: map[string]types.AttributeValue{ 273 | f.tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: partitionKey}, 274 | f.tableProps.SortKeyName: &types.AttributeValueMemberS{Value: strconv.FormatInt(window.UnixNano(), 10)}, 275 | }, 276 | UpdateExpression: aws.String(fixedWindowDynamoDBUpdateExpression), 277 | ExpressionAttributeNames: map[string]string{ 278 | "#TTL": f.tableProps.TTLFieldName, 279 | "#C": dynamodbWindowCountKey, 280 | }, 281 | ExpressionAttributeValues: map[string]types.AttributeValue{ 282 | ":ttl": &types.AttributeValueMemberN{Value: strconv.FormatInt(time.Now().Add(ttl).Unix(), 10)}, 283 | ":def": &types.AttributeValueMemberN{Value: "0"}, 284 | ":inc": &types.AttributeValueMemberN{Value: "1"}, 285 | }, 286 | TableName: &f.tableProps.TableName, 287 | ReturnValues: types.ReturnValueAllNew, 288 | }) 289 | }() 290 | 291 | select { 292 | case <-done: 293 | case <-ctx.Done(): 294 | return 0, ctx.Err() 295 | } 296 | 297 | if err != nil { 298 | return 0, errors.Wrap(err, "dynamodb update item failed") 299 | } 300 | 301 | var count float64 302 | 303 | err = attributevalue.Unmarshal(resp.Attributes[dynamodbWindowCountKey], &count) 304 | if err != nil { 305 | return 0, errors.Wrap(err, "unmarshal of dynamodb attribute value failed") 306 | } 307 | 308 | return int64(count), nil 309 | } 310 | 311 | // FixedWindowCosmosDB implements FixedWindow in CosmosDB. 312 | type FixedWindowCosmosDB struct { 313 | client *azcosmos.ContainerClient 314 | partitionKey string 315 | } 316 | 317 | // NewFixedWindowCosmosDB creates a new instance of FixedWindowCosmosDB. 318 | // PartitionKey is the key used for partitioning data into multiple partitions. 319 | func NewFixedWindowCosmosDB(client *azcosmos.ContainerClient, partitionKey string) *FixedWindowCosmosDB { 320 | return &FixedWindowCosmosDB{ 321 | client: client, 322 | partitionKey: partitionKey, 323 | } 324 | } 325 | 326 | func (f *FixedWindowCosmosDB) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) { 327 | id := strconv.FormatInt(window.UnixNano(), 10) 328 | tmp := cosmosItem{ 329 | ID: id, 330 | PartitionKey: f.partitionKey, 331 | Count: 1, 332 | TTL: int32(ttl), 333 | } 334 | 335 | ops := azcosmos.PatchOperations{} 336 | ops.AppendIncrement(`/Count`, 1) 337 | 338 | patchResp, err := f.client.PatchItem(ctx, azcosmos.NewPartitionKey().AppendString(f.partitionKey), id, ops, &azcosmos.ItemOptions{ 339 | EnableContentResponseOnWrite: true, 340 | }) 341 | if err == nil { 342 | // value exists and was updated 343 | err = json.Unmarshal(patchResp.Value, &tmp) 344 | if err != nil { 345 | return 0, errors.Wrap(err, "unmarshal of cosmos value failed") 346 | } 347 | 348 | return tmp.Count, nil 349 | } 350 | 351 | var respErr *azcore.ResponseError 352 | if !errors.As(err, &respErr) || (respErr.StatusCode != http.StatusNotFound && respErr.StatusCode != http.StatusBadRequest) { 353 | return 0, errors.Wrap(err, `patch of cosmos value failed`) 354 | } 355 | 356 | newValue, err := json.Marshal(tmp) 357 | if err != nil { 358 | return 0, errors.Wrap(err, "marshal of cosmos value failed") 359 | } 360 | 361 | // Try to CreateItem. If it fails with Conflict, it means the item exists (and Patch failed with 400), so we fallback to RMW. 362 | _, err = f.client.CreateItem(ctx, azcosmos.NewPartitionKey().AppendString(f.partitionKey), newValue, &azcosmos.ItemOptions{ 363 | SessionToken: patchResp.SessionToken, 364 | }) 365 | if err != nil { 366 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { 367 | // Fallback to Read-Modify-Write 368 | return incrementCosmosItemRMW(ctx, f.client, f.partitionKey, id, ttl) 369 | } 370 | 371 | return 0, errors.Wrap(err, "upsert of cosmos value failed") 372 | } 373 | 374 | return tmp.Count, nil 375 | } 376 | -------------------------------------------------------------------------------- /tokenbucket_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/google/uuid" 12 | l "github.com/mennanov/limiters" 13 | "github.com/redis/go-redis/v9" 14 | ) 15 | 16 | var tokenBucketUniformTestCases = []struct { 17 | capacity int64 18 | refillRate time.Duration 19 | requestCount int64 20 | requestRate time.Duration 21 | missExpected int 22 | delta float64 23 | }{ 24 | { 25 | 5, 26 | time.Millisecond * 100, 27 | 10, 28 | time.Millisecond * 25, 29 | 3, 30 | 1, 31 | }, 32 | { 33 | 10, 34 | time.Millisecond * 100, 35 | 13, 36 | time.Millisecond * 25, 37 | 0, 38 | 1, 39 | }, 40 | { 41 | 10, 42 | time.Millisecond * 100, 43 | 15, 44 | time.Millisecond * 33, 45 | 2, 46 | 1, 47 | }, 48 | { 49 | 10, 50 | time.Millisecond * 100, 51 | 16, 52 | time.Millisecond * 33, 53 | 2, 54 | 1, 55 | }, 56 | { 57 | 10, 58 | time.Millisecond * 10, 59 | 20, 60 | time.Millisecond * 10, 61 | 0, 62 | 1, 63 | }, 64 | { 65 | 1, 66 | time.Millisecond * 200, 67 | 10, 68 | time.Millisecond * 100, 69 | 5, 70 | 2, 71 | }, 72 | } 73 | 74 | // tokenBuckets returns all the possible TokenBucket combinations. 75 | func (s *LimitersTestSuite) tokenBuckets(capacity int64, refillRate, ttl time.Duration, clock l.Clock) map[string]*l.TokenBucket { 76 | buckets := make(map[string]*l.TokenBucket) 77 | 78 | for lockerName, locker := range s.lockers(true) { 79 | for backendName, backend := range s.tokenBucketBackends(ttl) { 80 | buckets[lockerName+":"+backendName] = l.NewTokenBucket(capacity, refillRate, locker, backend, clock, s.logger) 81 | } 82 | } 83 | 84 | return buckets 85 | } 86 | 87 | func (s *LimitersTestSuite) tokenBucketBackends(ttl time.Duration) map[string]l.TokenBucketStateBackend { 88 | return map[string]l.TokenBucketStateBackend{ 89 | "TokenBucketInMemory": l.NewTokenBucketInMemory(), 90 | "TokenBucketEtcdNoRaceCheck": l.NewTokenBucketEtcd(s.etcdClient, uuid.New().String(), ttl, false), 91 | "TokenBucketEtcdWithRaceCheck": l.NewTokenBucketEtcd(s.etcdClient, uuid.New().String(), ttl, true), 92 | "TokenBucketRedisNoRaceCheck": l.NewTokenBucketRedis(s.redisClient, uuid.New().String(), ttl, false), 93 | "TokenBucketRedisWithRaceCheck": l.NewTokenBucketRedis(s.redisClient, uuid.New().String(), ttl, true), 94 | "TokenBucketRedisClusterNoRaceCheck": l.NewTokenBucketRedis(s.redisClusterClient, uuid.New().String(), ttl, false), 95 | "TokenBucketRedisClusterWithRaceCheck": l.NewTokenBucketRedis(s.redisClusterClient, uuid.New().String(), ttl, true), 96 | "TokenBucketMemcachedNoRaceCheck": l.NewTokenBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, false), 97 | "TokenBucketMemcachedWithRaceCheck": l.NewTokenBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, true), 98 | "TokenBucketDynamoDBNoRaceCheck": l.NewTokenBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, ttl, false), 99 | "TokenBucketDynamoDBWithRaceCheck": l.NewTokenBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, ttl, true), 100 | "TokenBucketCosmosDBNoRaceCheck": l.NewTokenBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), ttl, false), 101 | "TokenBucketCosmosDBWithRaceCheck": l.NewTokenBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), ttl, true), 102 | } 103 | } 104 | 105 | func (s *LimitersTestSuite) TestTokenBucketRealClock() { 106 | clock := l.NewSystemClock() 107 | for _, testCase := range tokenBucketUniformTestCases { 108 | for name, bucket := range s.tokenBuckets(testCase.capacity, testCase.refillRate, 0, clock) { 109 | s.Run(name, func() { 110 | wg := sync.WaitGroup{} 111 | // mu guards the miss variable below. 112 | var mu sync.Mutex 113 | 114 | miss := 0 115 | 116 | for i := int64(0); i < testCase.requestCount; i++ { 117 | // No pause for the first request. 118 | if i > 0 { 119 | clock.Sleep(testCase.requestRate) 120 | } 121 | 122 | wg.Add(1) 123 | 124 | go func(bucket *l.TokenBucket) { 125 | defer wg.Done() 126 | 127 | _, err := bucket.Limit(context.TODO()) 128 | if err != nil { 129 | s.Equal(l.ErrLimitExhausted, err, "%T %v", bucket, bucket) 130 | mu.Lock() 131 | 132 | miss++ 133 | 134 | mu.Unlock() 135 | } 136 | }(bucket) 137 | } 138 | 139 | wg.Wait() 140 | s.InDelta(testCase.missExpected, miss, testCase.delta, testCase) 141 | }) 142 | } 143 | } 144 | } 145 | 146 | func (s *LimitersTestSuite) TestTokenBucketFakeClock() { 147 | for _, testCase := range tokenBucketUniformTestCases { 148 | clock := newFakeClock() 149 | for name, bucket := range s.tokenBuckets(testCase.capacity, testCase.refillRate, 0, clock) { 150 | s.Run(name, func() { 151 | clock.reset() 152 | 153 | miss := 0 154 | 155 | for i := int64(0); i < testCase.requestCount; i++ { 156 | // No pause for the first request. 157 | if i > 0 { 158 | clock.Sleep(testCase.requestRate) 159 | } 160 | 161 | _, err := bucket.Limit(context.TODO()) 162 | if err != nil { 163 | s.Equal(l.ErrLimitExhausted, err) 164 | 165 | miss++ 166 | } 167 | } 168 | 169 | s.InDelta(testCase.missExpected, miss, testCase.delta, testCase) 170 | }) 171 | } 172 | } 173 | } 174 | 175 | func (s *LimitersTestSuite) TestTokenBucketOverflow() { 176 | clock := newFakeClock() 177 | 178 | rate := 100 * time.Millisecond 179 | for name, bucket := range s.tokenBuckets(2, rate, 0, clock) { 180 | s.Run(name, func() { 181 | clock.reset() 182 | 183 | wait, err := bucket.Limit(context.TODO()) 184 | s.Require().NoError(err) 185 | s.Equal(time.Duration(0), wait) 186 | wait, err = bucket.Limit(context.TODO()) 187 | s.Require().NoError(err) 188 | s.Equal(time.Duration(0), wait) 189 | // The third call should fail. 190 | wait, err = bucket.Limit(context.TODO()) 191 | s.Require().Equal(l.ErrLimitExhausted, err) 192 | s.Equal(rate, wait) 193 | clock.Sleep(wait) 194 | // Retry the last call. 195 | wait, err = bucket.Limit(context.TODO()) 196 | s.Require().NoError(err) 197 | s.Equal(time.Duration(0), wait) 198 | }) 199 | } 200 | } 201 | 202 | func (s *LimitersTestSuite) TestTokenTakeMaxOverflow() { 203 | clock := newFakeClock() 204 | 205 | rate := 100 * time.Millisecond 206 | for name, bucket := range s.tokenBuckets(3, rate, 0, clock) { 207 | s.Run(name, func() { 208 | clock.reset() 209 | // Take max, as it's below capacity. 210 | taken, err := bucket.TakeMax(context.TODO(), 2) 211 | s.Require().NoError(err) 212 | s.Equal(int64(2), taken) 213 | // Take as much as it's available 214 | taken, err = bucket.TakeMax(context.TODO(), 2) 215 | s.Require().NoError(err) 216 | s.Equal(int64(1), taken) 217 | // Now it is not able to take any. 218 | taken, err = bucket.TakeMax(context.TODO(), 2) 219 | s.Require().NoError(err) 220 | s.Equal(int64(0), taken) 221 | clock.Sleep(rate) 222 | // Retry the last call. 223 | taken, err = bucket.TakeMax(context.TODO(), 2) 224 | s.Require().NoError(err) 225 | s.Equal(int64(1), taken) 226 | }) 227 | } 228 | } 229 | 230 | func (s *LimitersTestSuite) TestTokenBucketReset() { 231 | clock := newFakeClock() 232 | 233 | rate := 100 * time.Millisecond 234 | for name, bucket := range s.tokenBuckets(2, rate, 0, clock) { 235 | s.Run(name, func() { 236 | clock.reset() 237 | 238 | wait, err := bucket.Limit(context.TODO()) 239 | s.Require().NoError(err) 240 | s.Equal(time.Duration(0), wait) 241 | wait, err = bucket.Limit(context.TODO()) 242 | s.Require().NoError(err) 243 | s.Equal(time.Duration(0), wait) 244 | // The third call should fail. 245 | wait, err = bucket.Limit(context.TODO()) 246 | s.Require().Equal(l.ErrLimitExhausted, err) 247 | s.Equal(rate, wait) 248 | 249 | err = bucket.Reset(context.TODO()) 250 | s.Require().NoError(err) 251 | // Retry the last call. 252 | wait, err = bucket.Limit(context.TODO()) 253 | s.Require().NoError(err) 254 | s.Equal(time.Duration(0), wait) 255 | }) 256 | } 257 | } 258 | 259 | func (s *LimitersTestSuite) TestTokenBucketRefill() { 260 | for name, backend := range s.tokenBucketBackends(0) { 261 | s.Run(name, func() { 262 | clock := newFakeClock() 263 | 264 | bucket := l.NewTokenBucket(4, time.Millisecond*100, l.NewLockNoop(), backend, clock, s.logger) 265 | sleepDurations := []int{150, 90, 50, 70} 266 | desiredAvailable := []int64{3, 2, 2, 2} 267 | 268 | _, err := bucket.Limit(context.Background()) 269 | s.Require().NoError(err) 270 | 271 | _, err = backend.State(context.Background()) 272 | s.Require().NoError(err, "unable to retrieve backend state") 273 | 274 | for i := range sleepDurations { 275 | clock.Sleep(time.Millisecond * time.Duration(sleepDurations[i])) 276 | 277 | _, err := bucket.Limit(context.Background()) 278 | s.Require().NoError(err) 279 | 280 | state, err := backend.State(context.Background()) 281 | s.Require().NoError(err, "unable to retrieve backend state") 282 | 283 | s.Require().Equal(desiredAvailable[i], state.Available) 284 | } 285 | }) 286 | } 287 | } 288 | 289 | // setTokenBucketStateInOldFormat is a test utility method for writing state in the old format to Redis. 290 | func setTokenBucketStateInOldFormat(ctx context.Context, cli *redis.Client, prefix string, state l.TokenBucketState, ttl time.Duration) error { 291 | _, err := cli.TxPipelined(ctx, func(pipeliner redis.Pipeliner) error { 292 | err := pipeliner.Set(ctx, fmt.Sprintf("{%s}last", prefix), state.Last, ttl).Err() 293 | if err != nil { 294 | return err 295 | } 296 | 297 | err = pipeliner.Set(ctx, fmt.Sprintf("{%s}available", prefix), state.Available, ttl).Err() 298 | if err != nil { 299 | return err 300 | } 301 | 302 | return nil 303 | }) 304 | 305 | return err 306 | } 307 | 308 | // TestTokenBucketRedisBackwardCompatibility tests that the new State method can read data written in the old format. 309 | func (s *LimitersTestSuite) TestTokenBucketRedisBackwardCompatibility() { 310 | // Create a new TokenBucketRedis instance 311 | prefix := uuid.New().String() 312 | backend := l.NewTokenBucketRedis(s.redisClient, prefix, time.Minute, false) 313 | 314 | // Write state using old format 315 | ctx := context.Background() 316 | expectedState := l.TokenBucketState{ 317 | Last: 12345, 318 | Available: 67890, 319 | } 320 | 321 | // Write directly to Redis using old format 322 | err := setTokenBucketStateInOldFormat(ctx, s.redisClient, prefix, expectedState, time.Minute) 323 | s.Require().NoError(err, "Failed to set state using old format") 324 | 325 | // Read state using new format (State) 326 | actualState, err := backend.State(ctx) 327 | s.Require().NoError(err, "Failed to get state using new format") 328 | 329 | // Verify the state is correctly read 330 | s.Equal(expectedState.Last, actualState.Last, "Last values should match") 331 | s.Equal(expectedState.Available, actualState.Available, "Available values should match") 332 | } 333 | 334 | func (s *LimitersTestSuite) TestTokenBucketNoExpiration() { 335 | clock := l.NewSystemClock() 336 | buckets := s.tokenBuckets(1, time.Minute, 0, clock) 337 | 338 | // Take all capacity from all buckets 339 | for _, bucket := range buckets { 340 | _, _ = bucket.Limit(context.TODO()) 341 | } 342 | 343 | // Wait for 2 seconds to check if it treats 0 TTL as infinite 344 | clock.Sleep(2 * time.Second) 345 | 346 | // Expect all buckets to be still filled 347 | for name, bucket := range buckets { 348 | s.Run(name, func() { 349 | _, err := bucket.Limit(context.TODO()) 350 | s.Require().Equal(l.ErrLimitExhausted, err) 351 | }) 352 | } 353 | } 354 | 355 | func (s *LimitersTestSuite) TestTokenBucketTTLExpiration() { 356 | clock := l.NewSystemClock() 357 | buckets := s.tokenBuckets(1, time.Minute, time.Second, clock) 358 | 359 | // Ignore in-memory bucket, as it has no expiration, 360 | // ignore DynamoDB, as amazon/dynamodb-local doesn't support TTLs. 361 | for k := range buckets { 362 | if strings.Contains(k, "BucketInMemory") || strings.Contains(k, "BucketDynamoDB") { 363 | delete(buckets, k) 364 | } 365 | } 366 | 367 | // Take all capacity from all buckets 368 | for _, bucket := range buckets { 369 | _, _ = bucket.Limit(context.TODO()) 370 | } 371 | 372 | // Wait for 3 seconds to check if the items have been deleted successfully 373 | clock.Sleep(3 * time.Second) 374 | 375 | // Expect all buckets to be still empty (as the data expired) 376 | for name, bucket := range buckets { 377 | s.Run(name, func() { 378 | wait, err := bucket.Limit(context.TODO()) 379 | s.Require().Equal(time.Duration(0), wait) 380 | s.Require().NoError(err) 381 | }) 382 | } 383 | } 384 | 385 | func BenchmarkTokenBuckets(b *testing.B) { 386 | s := new(LimitersTestSuite) 387 | s.SetT(&testing.T{}) 388 | s.SetupSuite() 389 | 390 | capacity := int64(1) 391 | rate := 100 * time.Millisecond 392 | clock := newFakeClock() 393 | 394 | buckets := s.tokenBuckets(capacity, rate, 0, clock) 395 | for name, bucket := range buckets { 396 | b.Run(name, func(b *testing.B) { 397 | for i := 0; i < b.N; i++ { 398 | _, err := bucket.Limit(context.TODO()) 399 | s.Require().NoError(err) 400 | } 401 | }) 402 | } 403 | 404 | s.TearDownSuite() 405 | } 406 | -------------------------------------------------------------------------------- /leakybucket_test.go: -------------------------------------------------------------------------------- 1 | package limiters_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/google/uuid" 12 | l "github.com/mennanov/limiters" 13 | "github.com/redis/go-redis/v9" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | // leakyBuckets returns all the possible leakyBuckets combinations. 18 | func (s *LimitersTestSuite) leakyBuckets(capacity int64, rate, ttl time.Duration, clock l.Clock) map[string]*l.LeakyBucket { 19 | buckets := make(map[string]*l.LeakyBucket) 20 | 21 | for lockerName, locker := range s.lockers(true) { 22 | for backendName, backend := range s.leakyBucketBackends(ttl) { 23 | buckets[lockerName+":"+backendName] = l.NewLeakyBucket(capacity, rate, locker, backend, clock, s.logger) 24 | } 25 | } 26 | 27 | return buckets 28 | } 29 | 30 | func (s *LimitersTestSuite) leakyBucketBackends(ttl time.Duration) map[string]l.LeakyBucketStateBackend { 31 | return map[string]l.LeakyBucketStateBackend{ 32 | "LeakyBucketInMemory": l.NewLeakyBucketInMemory(), 33 | "LeakyBucketEtcdNoRaceCheck": l.NewLeakyBucketEtcd(s.etcdClient, uuid.New().String(), ttl, false), 34 | "LeakyBucketEtcdWithRaceCheck": l.NewLeakyBucketEtcd(s.etcdClient, uuid.New().String(), ttl, true), 35 | "LeakyBucketRedisNoRaceCheck": l.NewLeakyBucketRedis(s.redisClient, uuid.New().String(), ttl, false), 36 | "LeakyBucketRedisWithRaceCheck": l.NewLeakyBucketRedis(s.redisClient, uuid.New().String(), ttl, true), 37 | "LeakyBucketRedisClusterNoRaceCheck": l.NewLeakyBucketRedis(s.redisClusterClient, uuid.New().String(), ttl, false), 38 | "LeakyBucketRedisClusterWithRaceCheck": l.NewLeakyBucketRedis(s.redisClusterClient, uuid.New().String(), ttl, true), 39 | "LeakyBucketMemcachedNoRaceCheck": l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, false), 40 | "LeakyBucketMemcachedWithRaceCheck": l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, true), 41 | "LeakyBucketDynamoDBNoRaceCheck": l.NewLeakyBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, ttl, false), 42 | "LeakyBucketDynamoDBWithRaceCheck": l.NewLeakyBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, ttl, true), 43 | "LeakyBucketCosmosDBNoRaceCheck": l.NewLeakyBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), ttl, false), 44 | "LeakyBucketCosmosDBWithRaceCheck": l.NewLeakyBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), ttl, true), 45 | } 46 | } 47 | 48 | func (s *LimitersTestSuite) TestLeakyBucketRealClock() { 49 | capacity := int64(10) 50 | rate := time.Millisecond * 10 51 | 52 | clock := l.NewSystemClock() 53 | for _, requestRate := range []time.Duration{rate / 2} { 54 | for name, bucket := range s.leakyBuckets(capacity, rate, 0, clock) { 55 | s.Run(name, func() { 56 | wg := sync.WaitGroup{} 57 | mu := sync.Mutex{} 58 | 59 | var totalWait time.Duration 60 | 61 | for i := int64(0); i < capacity; i++ { 62 | // No pause for the first request. 63 | if i > 0 { 64 | clock.Sleep(requestRate) 65 | } 66 | 67 | wg.Add(1) 68 | 69 | go func(bucket *l.LeakyBucket) { 70 | defer wg.Done() 71 | 72 | wait, err := bucket.Limit(context.TODO()) 73 | s.Require().NoError(err) 74 | 75 | if wait > 0 { 76 | mu.Lock() 77 | 78 | totalWait += wait 79 | 80 | mu.Unlock() 81 | clock.Sleep(wait) 82 | } 83 | }(bucket) 84 | } 85 | 86 | wg.Wait() 87 | 88 | expectedWait := time.Duration(0) 89 | if rate > requestRate { 90 | expectedWait = time.Duration(float64(rate-requestRate) * float64(capacity-1) / 2 * float64(capacity)) 91 | } 92 | 93 | // Allow 5ms lag for each request. 94 | // TODO: figure out if this is enough for slow PCs and possibly avoid hard-coding it. 95 | delta := float64(time.Duration(capacity) * time.Millisecond * 25) 96 | s.InDelta(expectedWait, totalWait, delta, "request rate: %d, bucket: %v", requestRate, bucket) 97 | }) 98 | } 99 | } 100 | } 101 | 102 | func (s *LimitersTestSuite) TestLeakyBucketFakeClock() { 103 | capacity := int64(10) 104 | rate := time.Millisecond * 100 105 | 106 | clock := newFakeClock() 107 | for _, requestRate := range []time.Duration{rate * 2, rate, rate / 2, rate / 3, rate / 4, 0} { 108 | for name, bucket := range s.leakyBuckets(capacity, rate, 0, clock) { 109 | s.Run(name, func() { 110 | clock.reset() 111 | start := clock.Now() 112 | 113 | for i := int64(0); i < capacity; i++ { 114 | // No pause for the first request. 115 | if i > 0 { 116 | clock.Sleep(requestRate) 117 | } 118 | 119 | wait, err := bucket.Limit(context.TODO()) 120 | s.Require().NoError(err) 121 | clock.Sleep(wait) 122 | } 123 | 124 | interval := rate 125 | if requestRate > rate { 126 | interval = requestRate 127 | } 128 | 129 | s.Equal(interval*time.Duration(capacity-1), clock.Now().Sub(start), "request rate: %d, bucket: %v", requestRate, bucket) 130 | }) 131 | } 132 | } 133 | } 134 | 135 | func (s *LimitersTestSuite) TestLeakyBucketOverflow() { 136 | rate := 100 * time.Millisecond 137 | capacity := int64(2) 138 | 139 | clock := newFakeClock() 140 | for name, bucket := range s.leakyBuckets(capacity, rate, 0, clock) { 141 | s.Run(name, func() { 142 | clock.reset() 143 | // The first call has no wait since there were no calls before. 144 | wait, err := bucket.Limit(context.TODO()) 145 | s.Require().NoError(err) 146 | s.Equal(time.Duration(0), wait) 147 | // The second call increments the queue size by 1. 148 | wait, err = bucket.Limit(context.TODO()) 149 | s.Require().NoError(err) 150 | s.Equal(rate, wait) 151 | // The third call overflows the bucket capacity. 152 | wait, err = bucket.Limit(context.TODO()) 153 | s.Require().Equal(l.ErrLimitExhausted, err) 154 | s.Equal(rate*2, wait) 155 | // Move the Clock 1 position forward. 156 | clock.Sleep(rate) 157 | // Retry the last call. This time it should succeed. 158 | wait, err = bucket.Limit(context.TODO()) 159 | s.Require().NoError(err) 160 | s.Equal(rate, wait) 161 | }) 162 | } 163 | } 164 | 165 | func (s *LimitersTestSuite) TestLeakyBucketNoExpiration() { 166 | clock := l.NewSystemClock() 167 | buckets := s.leakyBuckets(1, time.Minute, 0, clock) 168 | 169 | // Take all capacity from all buckets 170 | for _, bucket := range buckets { 171 | _, _ = bucket.Limit(context.TODO()) 172 | _, _ = bucket.Limit(context.TODO()) 173 | } 174 | 175 | // Wait for 2 seconds to check if it treats 0 TTL as infinite 176 | clock.Sleep(2 * time.Second) 177 | 178 | // Expect all buckets to be still filled 179 | for name, bucket := range buckets { 180 | s.Run(name, func() { 181 | _, err := bucket.Limit(context.TODO()) 182 | s.Require().Equal(l.ErrLimitExhausted, err) 183 | }) 184 | } 185 | } 186 | 187 | func (s *LimitersTestSuite) TestLeakyBucketTTLExpiration() { 188 | clock := l.NewSystemClock() 189 | buckets := s.leakyBuckets(1, time.Minute, time.Second, clock) 190 | 191 | // Ignore in-memory bucket, as it has no expiration, 192 | // ignore DynamoDB, as amazon/dynamodb-local doesn't support TTLs. 193 | for k := range buckets { 194 | if strings.Contains(k, "BucketInMemory") || strings.Contains(k, "BucketDynamoDB") { 195 | delete(buckets, k) 196 | } 197 | } 198 | 199 | // Take all capacity from all buckets 200 | for _, bucket := range buckets { 201 | _, _ = bucket.Limit(context.TODO()) 202 | _, _ = bucket.Limit(context.TODO()) 203 | } 204 | 205 | // Wait for 3 seconds to check if the items have been deleted successfully 206 | clock.Sleep(3 * time.Second) 207 | 208 | // Expect all buckets to be empty (as the data expired) 209 | for name, bucket := range buckets { 210 | s.Run(name, func() { 211 | wait, err := bucket.Limit(context.TODO()) 212 | s.Require().Equal(time.Duration(0), wait) 213 | s.Require().NoError(err) 214 | }) 215 | } 216 | } 217 | 218 | func (s *LimitersTestSuite) TestLeakyBucketReset() { 219 | rate := 100 * time.Millisecond 220 | capacity := int64(2) 221 | 222 | clock := newFakeClock() 223 | for name, bucket := range s.leakyBuckets(capacity, rate, 0, clock) { 224 | s.Run(name, func() { 225 | clock.reset() 226 | // The first call has no wait since there were no calls before. 227 | wait, err := bucket.Limit(context.TODO()) 228 | s.Require().NoError(err) 229 | s.Equal(time.Duration(0), wait) 230 | // The second call increments the queue size by 1. 231 | wait, err = bucket.Limit(context.TODO()) 232 | s.Require().NoError(err) 233 | s.Equal(rate, wait) 234 | // The third call overflows the bucket capacity. 235 | wait, err = bucket.Limit(context.TODO()) 236 | s.Require().Equal(l.ErrLimitExhausted, err) 237 | s.Equal(rate*2, wait) 238 | // Reset the bucket 239 | err = bucket.Reset(context.TODO()) 240 | s.Require().NoError(err) 241 | // Retry the last call. This time it should succeed. 242 | wait, err = bucket.Limit(context.TODO()) 243 | s.Require().NoError(err) 244 | s.Equal(time.Duration(0), wait) 245 | }) 246 | } 247 | } 248 | 249 | // TestLeakyBucketMemcachedExpiry verifies Memcached TTL expiry for LeakyBucketMemcached (no race check). 250 | func (s *LimitersTestSuite) TestLeakyBucketMemcachedExpiry() { 251 | ttl := time.Second 252 | backend := l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, false) 253 | bucket := l.NewLeakyBucket(2, time.Second, l.NewLockNoop(), backend, l.NewSystemClock(), s.logger) 254 | ctx := context.Background() 255 | _, err := bucket.Limit(ctx) 256 | s.Require().NoError(err) 257 | _, err = bucket.Limit(ctx) 258 | s.Require().NoError(err) 259 | state, err := backend.State(ctx) 260 | s.Require().NoError(err) 261 | s.NotEqual(int64(0), state.Last, "Last should be set after token takes") 262 | time.Sleep(ttl + 1500*time.Millisecond) 263 | 264 | state, err = backend.State(ctx) 265 | s.Require().NoError(err) 266 | s.Equal(int64(0), state.Last, "State should be zero after expiry, got: %+v", state) 267 | } 268 | 269 | // TestLeakyBucketMemcachedExpiryWithRaceCheck verifies Memcached TTL expiry for LeakyBucketMemcached (race check enabled). 270 | func (s *LimitersTestSuite) TestLeakyBucketMemcachedExpiryWithRaceCheck() { 271 | ttl := time.Second 272 | backend := l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, true) 273 | bucket := l.NewLeakyBucket(2, time.Second, l.NewLockNoop(), backend, l.NewSystemClock(), s.logger) 274 | ctx := context.Background() 275 | _, err := bucket.Limit(ctx) 276 | s.Require().NoError(err) 277 | _, err = bucket.Limit(ctx) 278 | s.Require().NoError(err) 279 | state, err := backend.State(ctx) 280 | s.Require().NoError(err) 281 | s.NotEqual(int64(0), state.Last, "Last should be set after token takes") 282 | time.Sleep(ttl + 1500*time.Millisecond) 283 | 284 | state, err = backend.State(ctx) 285 | s.Require().NoError(err) 286 | s.Equal(int64(0), state.Last, "State should be zero after expiry, got: %+v", state) 287 | } 288 | 289 | func TestLeakyBucket_ZeroCapacity_ReturnsError(t *testing.T) { 290 | capacity := int64(0) 291 | rate := time.Hour 292 | logger := l.NewStdLogger() 293 | bucket := l.NewLeakyBucket(capacity, rate, l.NewLockNoop(), l.NewLeakyBucketInMemory(), newFakeClock(), logger) 294 | wait, err := bucket.Limit(context.TODO()) 295 | require.Equal(t, l.ErrLimitExhausted, err) 296 | require.Equal(t, time.Duration(0), wait) 297 | } 298 | 299 | func BenchmarkLeakyBuckets(b *testing.B) { 300 | s := new(LimitersTestSuite) 301 | s.SetT(&testing.T{}) 302 | s.SetupSuite() 303 | 304 | capacity := int64(1) 305 | rate := 100 * time.Millisecond 306 | clock := newFakeClock() 307 | 308 | buckets := s.leakyBuckets(capacity, rate, 0, clock) 309 | for name, bucket := range buckets { 310 | b.Run(name, func(b *testing.B) { 311 | for i := 0; i < b.N; i++ { 312 | _, err := bucket.Limit(context.TODO()) 313 | s.Require().NoError(err) 314 | } 315 | }) 316 | } 317 | 318 | s.TearDownSuite() 319 | } 320 | 321 | // setStateInOldFormat is a test utility method for writing state in the old format to Redis. 322 | func setStateInOldFormat(ctx context.Context, cli *redis.Client, prefix string, state l.LeakyBucketState, ttl time.Duration) error { 323 | _, err := cli.TxPipelined(ctx, func(pipeliner redis.Pipeliner) error { 324 | err := pipeliner.Set(ctx, fmt.Sprintf("{%s}last", prefix), state.Last, ttl).Err() 325 | if err != nil { 326 | return err 327 | } 328 | 329 | return nil 330 | }) 331 | 332 | return err 333 | } 334 | 335 | // TestLeakyBucketRedisBackwardCompatibility tests that the new State method can read data written in the old format. 336 | func (s *LimitersTestSuite) TestLeakyBucketRedisBackwardCompatibility() { 337 | // Create a new LeakyBucketRedis instance 338 | prefix := uuid.New().String() 339 | backend := l.NewLeakyBucketRedis(s.redisClient, prefix, time.Minute, false) 340 | 341 | // Write state using old format 342 | ctx := context.Background() 343 | expectedState := l.LeakyBucketState{ 344 | Last: 12345, 345 | } 346 | 347 | // Write directly to Redis using old format 348 | err := setStateInOldFormat(ctx, s.redisClient, prefix, expectedState, time.Minute) 349 | s.Require().NoError(err, "Failed to set state using old format") 350 | 351 | // Read state using new format (State) 352 | actualState, err := backend.State(ctx) 353 | s.Require().NoError(err, "Failed to get state using new format") 354 | 355 | // Verify the state is correctly read 356 | s.Equal(expectedState.Last, actualState.Last, "Last values should match") 357 | } 358 | -------------------------------------------------------------------------------- /slidingwindow.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "math" 8 | "net/http" 9 | "strconv" 10 | "sync" 11 | "time" 12 | 13 | "github.com/Azure/azure-sdk-for-go/sdk/azcore" 14 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 15 | "github.com/aws/aws-sdk-go-v2/aws" 16 | "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" 17 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 18 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" 19 | "github.com/bradfitz/gomemcache/memcache" 20 | "github.com/pkg/errors" 21 | "github.com/redis/go-redis/v9" 22 | ) 23 | 24 | // SlidingWindowIncrementer wraps the Increment method. 25 | type SlidingWindowIncrementer interface { 26 | // Increment increments the request counter for the current window and returns the counter values for the previous 27 | // window and the current one. 28 | // TTL is the time duration before the next window. 29 | Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (prevCount, currCount int64, err error) 30 | } 31 | 32 | // SlidingWindow implements a Sliding Window rate limiting algorithm. 33 | // 34 | // It does not require a distributed lock and uses a minimum amount of memory, however it will disallow all the requests 35 | // in case when a client is flooding the service with requests. 36 | // It's the client's responsibility to handle the disallowed request and wait before making a new request again. 37 | type SlidingWindow struct { 38 | backend SlidingWindowIncrementer 39 | clock Clock 40 | rate time.Duration 41 | capacity int64 42 | epsilon float64 43 | } 44 | 45 | // NewSlidingWindow creates a new instance of SlidingWindow. 46 | // Capacity is the maximum amount of requests allowed per window. 47 | // Rate is the window size. 48 | // Epsilon is the max-allowed range of difference when comparing the current weighted number of requests with capacity. 49 | func NewSlidingWindow(capacity int64, rate time.Duration, slidingWindowIncrementer SlidingWindowIncrementer, clock Clock, epsilon float64) *SlidingWindow { 50 | return &SlidingWindow{backend: slidingWindowIncrementer, clock: clock, rate: rate, capacity: capacity, epsilon: epsilon} 51 | } 52 | 53 | // Limit returns the time duration to wait before the request can be processed. 54 | // It returns ErrLimitExhausted if the request overflows the capacity. 55 | func (s *SlidingWindow) Limit(ctx context.Context) (time.Duration, error) { 56 | now := s.clock.Now() 57 | currWindow := now.Truncate(s.rate) 58 | prevWindow := currWindow.Add(-s.rate) 59 | ttl := s.rate - now.Sub(currWindow) 60 | 61 | prev, curr, err := s.backend.Increment(ctx, prevWindow, currWindow, ttl+s.rate) 62 | if err != nil { 63 | return 0, err 64 | } 65 | 66 | // "prev" and "curr" are capped at "s.capacity + s.epsilon" using math.Ceil to round up any fractional values, 67 | // ensuring that in the worst case, "total" can be slightly greater than "s.capacity". 68 | prev = int64(math.Min(float64(prev), math.Ceil(float64(s.capacity)+s.epsilon))) 69 | curr = int64(math.Min(float64(curr), math.Ceil(float64(s.capacity)+s.epsilon))) 70 | 71 | total := float64(prev*int64(ttl))/float64(s.rate) + float64(curr) 72 | if total-float64(s.capacity) >= s.epsilon { 73 | var wait time.Duration 74 | if curr <= s.capacity-1 && prev > 0 { 75 | wait = ttl - time.Duration(float64(s.capacity-1-curr)/float64(prev)*float64(s.rate)) 76 | } else { 77 | // If prev == 0. 78 | wait = ttl + time.Duration((1-float64(s.capacity-1)/float64(curr))*float64(s.rate)) 79 | } 80 | 81 | return wait, ErrLimitExhausted 82 | } 83 | 84 | return 0, nil 85 | } 86 | 87 | // SlidingWindowInMemory is an in-memory implementation of SlidingWindowIncrementer. 88 | type SlidingWindowInMemory struct { 89 | mu sync.Mutex 90 | prevC, currC int64 91 | prevW, currW time.Time 92 | } 93 | 94 | // NewSlidingWindowInMemory creates a new instance of SlidingWindowInMemory. 95 | func NewSlidingWindowInMemory() *SlidingWindowInMemory { 96 | return &SlidingWindowInMemory{} 97 | } 98 | 99 | // Increment increments the current window's counter and returns the number of requests in the previous window and the 100 | // current one. 101 | func (s *SlidingWindowInMemory) Increment(ctx context.Context, prev, curr time.Time, _ time.Duration) (int64, int64, error) { 102 | s.mu.Lock() 103 | defer s.mu.Unlock() 104 | 105 | if curr != s.currW { 106 | if prev.Equal(s.currW) { 107 | s.prevW = s.currW 108 | s.prevC = s.currC 109 | } else { 110 | s.prevW = time.Time{} 111 | s.prevC = 0 112 | } 113 | 114 | s.currW = curr 115 | s.currC = 0 116 | } 117 | 118 | s.currC++ 119 | 120 | return s.prevC, s.currC, ctx.Err() 121 | } 122 | 123 | // SlidingWindowRedis implements SlidingWindow in Redis. 124 | type SlidingWindowRedis struct { 125 | cli redis.UniversalClient 126 | prefix string 127 | } 128 | 129 | // NewSlidingWindowRedis creates a new instance of SlidingWindowRedis. 130 | func NewSlidingWindowRedis(cli redis.UniversalClient, prefix string) *SlidingWindowRedis { 131 | return &SlidingWindowRedis{cli: cli, prefix: prefix} 132 | } 133 | 134 | // Increment increments the current window's counter in Redis and returns the number of requests in the previous window 135 | // and the current one. 136 | func (s *SlidingWindowRedis) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) { 137 | var ( 138 | incr *redis.IntCmd 139 | prevCountCmd *redis.StringCmd 140 | err error 141 | ) 142 | 143 | done := make(chan struct{}) 144 | 145 | go func() { 146 | defer close(done) 147 | 148 | _, err = s.cli.Pipelined(ctx, func(pipeliner redis.Pipeliner) error { 149 | currKey := fmt.Sprintf("%d", curr.UnixNano()) 150 | incr = pipeliner.Incr(ctx, redisKey(s.prefix, currKey)) 151 | pipeliner.PExpire(ctx, redisKey(s.prefix, currKey), ttl) 152 | prevCountCmd = pipeliner.Get(ctx, redisKey(s.prefix, fmt.Sprintf("%d", prev.UnixNano()))) 153 | 154 | return nil 155 | }) 156 | }() 157 | 158 | var prevCount int64 159 | 160 | select { 161 | case <-done: 162 | if errors.Is(err, redis.TxFailedErr) { 163 | return 0, 0, errors.Wrap(err, "redis transaction failed") 164 | } else if errors.Is(err, redis.Nil) { 165 | prevCount = 0 166 | } else if err != nil { 167 | return 0, 0, errors.Wrap(err, "unexpected error from redis") 168 | } else { 169 | prevCount, err = strconv.ParseInt(prevCountCmd.Val(), 10, 64) 170 | if err != nil { 171 | return 0, 0, errors.Wrap(err, "failed to parse response from redis") 172 | } 173 | } 174 | 175 | return prevCount, incr.Val(), nil 176 | case <-ctx.Done(): 177 | return 0, 0, ctx.Err() 178 | } 179 | } 180 | 181 | // SlidingWindowMemcached implements SlidingWindow in Memcached. 182 | type SlidingWindowMemcached struct { 183 | cli *memcache.Client 184 | prefix string 185 | } 186 | 187 | // NewSlidingWindowMemcached creates a new instance of SlidingWindowMemcached. 188 | func NewSlidingWindowMemcached(cli *memcache.Client, prefix string) *SlidingWindowMemcached { 189 | return &SlidingWindowMemcached{cli: cli, prefix: prefix} 190 | } 191 | 192 | // Increment increments the current window's counter in Memcached and returns the number of requests in the previous window 193 | // and the current one. 194 | func (s *SlidingWindowMemcached) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) { 195 | var ( 196 | prevCount uint64 197 | currCount uint64 198 | err error 199 | ) 200 | 201 | done := make(chan struct{}) 202 | 203 | go func() { 204 | defer close(done) 205 | 206 | var item *memcache.Item 207 | 208 | prevKey := fmt.Sprintf("%s:%d", s.prefix, prev.UnixNano()) 209 | 210 | item, err = s.cli.Get(prevKey) 211 | if err != nil { 212 | if errors.Is(err, memcache.ErrCacheMiss) { 213 | err = nil 214 | prevCount = 0 215 | } else { 216 | return 217 | } 218 | } else { 219 | prevCount, err = strconv.ParseUint(string(item.Value), 10, 64) 220 | if err != nil { 221 | return 222 | } 223 | } 224 | 225 | currKey := fmt.Sprintf("%s:%d", s.prefix, curr.UnixNano()) 226 | 227 | currCount, err = s.cli.Increment(currKey, 1) 228 | if err != nil && errors.Is(err, memcache.ErrCacheMiss) { 229 | currCount = 1 230 | item = &memcache.Item{ 231 | Key: currKey, 232 | Value: []byte(strconv.FormatUint(currCount, 10)), 233 | } 234 | err = s.cli.Add(item) 235 | } 236 | }() 237 | 238 | select { 239 | case <-done: 240 | if err != nil { 241 | if errors.Is(err, memcache.ErrNotStored) { 242 | return s.Increment(ctx, prev, curr, ttl) 243 | } 244 | 245 | return 0, 0, err 246 | } 247 | 248 | return int64(prevCount), int64(currCount), nil 249 | case <-ctx.Done(): 250 | return 0, 0, ctx.Err() 251 | } 252 | } 253 | 254 | // SlidingWindowDynamoDB implements SlidingWindow in DynamoDB. 255 | type SlidingWindowDynamoDB struct { 256 | client *dynamodb.Client 257 | partitionKey string 258 | tableProps DynamoDBTableProperties 259 | } 260 | 261 | // NewSlidingWindowDynamoDB creates a new instance of SlidingWindowDynamoDB. 262 | // PartitionKey is the key used to store all the this implementation in DynamoDB. 263 | // 264 | // TableProps describe the table that this backend should work with. This backend requires the following on the table: 265 | // * SortKey 266 | // * TTL. 267 | func NewSlidingWindowDynamoDB(client *dynamodb.Client, partitionKey string, props DynamoDBTableProperties) *SlidingWindowDynamoDB { 268 | return &SlidingWindowDynamoDB{ 269 | client: client, 270 | partitionKey: partitionKey, 271 | tableProps: props, 272 | } 273 | } 274 | 275 | // Increment increments the current window's counter in DynamoDB and returns the number of requests in the previous window 276 | // and the current one. 277 | func (s *SlidingWindowDynamoDB) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) { 278 | wg := &sync.WaitGroup{} 279 | wg.Add(2) 280 | 281 | done := make(chan struct{}) 282 | 283 | go func() { 284 | defer close(done) 285 | 286 | wg.Wait() 287 | }() 288 | 289 | var ( 290 | currentCount int64 291 | currentErr error 292 | ) 293 | 294 | go func() { 295 | defer wg.Done() 296 | 297 | resp, err := s.client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ 298 | Key: map[string]types.AttributeValue{ 299 | s.tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: s.partitionKey}, 300 | s.tableProps.SortKeyName: &types.AttributeValueMemberS{Value: strconv.FormatInt(curr.UnixNano(), 10)}, 301 | }, 302 | UpdateExpression: aws.String(fixedWindowDynamoDBUpdateExpression), 303 | ExpressionAttributeNames: map[string]string{ 304 | "#TTL": s.tableProps.TTLFieldName, 305 | "#C": dynamodbWindowCountKey, 306 | }, 307 | ExpressionAttributeValues: map[string]types.AttributeValue{ 308 | ":ttl": &types.AttributeValueMemberN{Value: strconv.FormatInt(time.Now().Add(ttl).Unix(), 10)}, 309 | ":def": &types.AttributeValueMemberN{Value: "0"}, 310 | ":inc": &types.AttributeValueMemberN{Value: "1"}, 311 | }, 312 | TableName: &s.tableProps.TableName, 313 | ReturnValues: types.ReturnValueAllNew, 314 | }) 315 | if err != nil { 316 | currentErr = errors.Wrap(err, "dynamodb get item failed") 317 | 318 | return 319 | } 320 | 321 | var tmp float64 322 | 323 | err = attributevalue.Unmarshal(resp.Attributes[dynamodbWindowCountKey], &tmp) 324 | if err != nil { 325 | currentErr = errors.Wrap(err, "unmarshal of dynamodb attribute value failed") 326 | 327 | return 328 | } 329 | 330 | currentCount = int64(tmp) 331 | }() 332 | 333 | var ( 334 | priorCount int64 335 | priorErr error 336 | ) 337 | 338 | go func() { 339 | defer wg.Done() 340 | 341 | resp, err := s.client.GetItem(ctx, &dynamodb.GetItemInput{ 342 | TableName: aws.String(s.tableProps.TableName), 343 | Key: map[string]types.AttributeValue{ 344 | s.tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: s.partitionKey}, 345 | s.tableProps.SortKeyName: &types.AttributeValueMemberS{Value: strconv.FormatInt(prev.UnixNano(), 10)}, 346 | }, 347 | ConsistentRead: aws.Bool(true), 348 | }) 349 | if err != nil { 350 | priorCount, priorErr = 0, errors.Wrap(err, "dynamodb get item failed") 351 | 352 | return 353 | } 354 | 355 | if len(resp.Item) == 0 { 356 | priorCount = 0 357 | 358 | return 359 | } 360 | 361 | var count float64 362 | 363 | err = attributevalue.Unmarshal(resp.Item[dynamodbWindowCountKey], &count) 364 | if err != nil { 365 | priorCount, priorErr = 0, errors.Wrap(err, "unmarshal of dynamodb attribute value failed") 366 | 367 | return 368 | } 369 | 370 | priorCount = int64(count) 371 | }() 372 | 373 | select { 374 | case <-done: 375 | case <-ctx.Done(): 376 | return 0, 0, ctx.Err() 377 | } 378 | 379 | if currentErr != nil { 380 | return 0, 0, errors.Wrap(currentErr, "failed to update current count") 381 | } else if priorErr != nil { 382 | return 0, 0, errors.Wrap(priorErr, "failed to get previous count") 383 | } 384 | 385 | return priorCount, currentCount, nil 386 | } 387 | 388 | // SlidingWindowCosmosDB implements SlidingWindow in Azure Cosmos DB. 389 | type SlidingWindowCosmosDB struct { 390 | client *azcosmos.ContainerClient 391 | partitionKey string 392 | } 393 | 394 | // NewSlidingWindowCosmosDB creates a new instance of SlidingWindowCosmosDB. 395 | // PartitionKey is the key used to store all the this implementation in Cosmos. 396 | func NewSlidingWindowCosmosDB(client *azcosmos.ContainerClient, partitionKey string) *SlidingWindowCosmosDB { 397 | return &SlidingWindowCosmosDB{ 398 | client: client, 399 | partitionKey: partitionKey, 400 | } 401 | } 402 | 403 | // Increment increments the current window's counter in Cosmos and returns the number of requests in the previous window 404 | // and the current one. 405 | func (s *SlidingWindowCosmosDB) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) { 406 | wg := &sync.WaitGroup{} 407 | wg.Add(2) 408 | 409 | done := make(chan struct{}) 410 | 411 | go func() { 412 | defer close(done) 413 | 414 | wg.Wait() 415 | }() 416 | 417 | var ( 418 | currentCount int64 419 | currentErr error 420 | ) 421 | 422 | go func() { 423 | defer wg.Done() 424 | 425 | id := strconv.FormatInt(curr.UnixNano(), 10) 426 | tmp := cosmosItem{ 427 | ID: id, 428 | PartitionKey: s.partitionKey, 429 | Count: 1, 430 | TTL: int32(ttl), 431 | } 432 | 433 | ops := azcosmos.PatchOperations{} 434 | ops.AppendIncrement(`/Count`, 1) 435 | 436 | patchResp, err := s.client.PatchItem(ctx, azcosmos.NewPartitionKey().AppendString(s.partitionKey), id, ops, &azcosmos.ItemOptions{ 437 | EnableContentResponseOnWrite: true, 438 | }) 439 | if err == nil { 440 | // value exists and was updated 441 | err = json.Unmarshal(patchResp.Value, &tmp) 442 | if err != nil { 443 | currentErr = errors.Wrap(err, "unmarshal of cosmos value current failed") 444 | 445 | return 446 | } 447 | 448 | currentCount = tmp.Count 449 | 450 | return 451 | } 452 | 453 | var respErr *azcore.ResponseError 454 | if !errors.As(err, &respErr) || (respErr.StatusCode != http.StatusNotFound && respErr.StatusCode != http.StatusBadRequest) { 455 | currentErr = errors.Wrap(err, `patch of cosmos value current failed`) 456 | 457 | return 458 | } 459 | 460 | newValue, err := json.Marshal(tmp) 461 | if err != nil { 462 | currentErr = errors.Wrap(err, "marshal of cosmos value current failed") 463 | 464 | return 465 | } 466 | 467 | _, err = s.client.CreateItem(ctx, azcosmos.NewPartitionKey().AppendString(s.partitionKey), newValue, &azcosmos.ItemOptions{ 468 | SessionToken: patchResp.SessionToken, 469 | }) 470 | if err != nil { 471 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { 472 | // Fallback to Read-Modify-Write 473 | currentCount, currentErr = incrementCosmosItemRMW(ctx, s.client, s.partitionKey, id, ttl) 474 | 475 | return 476 | } 477 | 478 | currentErr = errors.Wrap(err, "upsert of cosmos value current failed") 479 | 480 | return 481 | } 482 | 483 | currentCount = tmp.Count 484 | }() 485 | 486 | var ( 487 | priorCount int64 488 | priorErr error 489 | ) 490 | 491 | go func() { 492 | defer wg.Done() 493 | 494 | id := strconv.FormatInt(prev.UnixNano(), 10) 495 | 496 | resp, err := s.client.ReadItem(ctx, azcosmos.NewPartitionKey().AppendString(s.partitionKey), id, &azcosmos.ItemOptions{}) 497 | if err != nil { 498 | var azerr *azcore.ResponseError 499 | if errors.As(err, &azerr) && azerr.StatusCode == http.StatusNotFound { 500 | priorCount, priorErr = 0, nil 501 | 502 | return 503 | } 504 | 505 | priorErr = errors.Wrap(err, "cosmos get item prior failed") 506 | 507 | return 508 | } 509 | 510 | var tmp cosmosItem 511 | 512 | err = json.Unmarshal(resp.Value, &tmp) 513 | if err != nil { 514 | priorErr = errors.Wrap(err, "unmarshal of cosmos value prior failed") 515 | 516 | return 517 | } 518 | 519 | priorCount = tmp.Count 520 | }() 521 | 522 | select { 523 | case <-done: 524 | case <-ctx.Done(): 525 | return 0, 0, ctx.Err() 526 | } 527 | 528 | if currentErr != nil { 529 | return 0, 0, errors.Wrap(currentErr, "failed to update current count") 530 | } else if priorErr != nil { 531 | return 0, 0, errors.Wrap(priorErr, "failed to get previous count") 532 | } 533 | 534 | return priorCount, currentCount, nil 535 | } 536 | -------------------------------------------------------------------------------- /leakybucket.go: -------------------------------------------------------------------------------- 1 | package limiters 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/gob" 7 | "encoding/json" 8 | "fmt" 9 | "math" 10 | "net/http" 11 | "strconv" 12 | "sync" 13 | "time" 14 | 15 | "github.com/Azure/azure-sdk-for-go/sdk/azcore" 16 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" 17 | "github.com/aws/aws-sdk-go-v2/aws" 18 | "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" 19 | "github.com/aws/aws-sdk-go-v2/service/dynamodb" 20 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" 21 | "github.com/bradfitz/gomemcache/memcache" 22 | "github.com/pkg/errors" 23 | "github.com/redis/go-redis/v9" 24 | "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" 25 | clientv3 "go.etcd.io/etcd/client/v3" 26 | ) 27 | 28 | // LeakyBucketState represents the state of a LeakyBucket. 29 | type LeakyBucketState struct { 30 | // Last is the Unix timestamp in nanoseconds of the most recent request. 31 | Last int64 32 | } 33 | 34 | // IzZero returns true if the bucket state is zero valued. 35 | func (s LeakyBucketState) IzZero() bool { 36 | return s.Last == 0 37 | } 38 | 39 | // LeakyBucketStateBackend interface encapsulates the logic of retrieving and persisting the state of a LeakyBucket. 40 | type LeakyBucketStateBackend interface { 41 | // State gets the current state of the LeakyBucket. 42 | State(ctx context.Context) (LeakyBucketState, error) 43 | // SetState sets (persists) the current state of the LeakyBucket. 44 | SetState(ctx context.Context, state LeakyBucketState) error 45 | // Reset resets (persists) the current state of the LeakyBucket. 46 | Reset(ctx context.Context) error 47 | } 48 | 49 | // LeakyBucket implements the https://en.wikipedia.org/wiki/Leaky_bucket#As_a_queue algorithm. 50 | type LeakyBucket struct { 51 | locker DistLocker 52 | backend LeakyBucketStateBackend 53 | clock Clock 54 | logger Logger 55 | // Capacity is the maximum allowed number of tokens in the bucket. 56 | capacity int64 57 | // Rate is the output rate: 1 request per the rate duration (in nanoseconds). 58 | rate int64 59 | mu sync.Mutex 60 | } 61 | 62 | // NewLeakyBucket creates a new instance of LeakyBucket. 63 | func NewLeakyBucket(capacity int64, rate time.Duration, locker DistLocker, leakyBucketStateBackend LeakyBucketStateBackend, clock Clock, logger Logger) *LeakyBucket { 64 | return &LeakyBucket{ 65 | locker: locker, 66 | backend: leakyBucketStateBackend, 67 | clock: clock, 68 | logger: logger, 69 | capacity: capacity, 70 | rate: int64(rate), 71 | } 72 | } 73 | 74 | // Limit returns the time duration to wait before the request can be processed. 75 | // It returns ErrLimitExhausted if the request overflows the bucket's capacity. In this case the returned duration 76 | // means how long it would have taken to wait for the request to be processed if the bucket was not overflowed. 77 | func (t *LeakyBucket) Limit(ctx context.Context) (time.Duration, error) { 78 | t.mu.Lock() 79 | defer t.mu.Unlock() 80 | 81 | err := t.locker.Lock(ctx) 82 | if err != nil { 83 | return 0, err 84 | } 85 | 86 | defer func() { 87 | err := t.locker.Unlock(ctx) 88 | if err != nil { 89 | t.logger.Log(err) 90 | } 91 | }() 92 | 93 | state, err := t.backend.State(ctx) 94 | if err != nil { 95 | return 0, err 96 | } 97 | 98 | now := t.clock.Now().UnixNano() 99 | if now < state.Last { 100 | // The queue has requests in it: move the current request to the last position + 1. 101 | state.Last += t.rate 102 | } else { 103 | // The queue is empty. 104 | // The offset is the duration to wait in case the last request happened less than rate duration ago. 105 | var offset int64 106 | 107 | delta := now - state.Last 108 | if delta < t.rate { 109 | offset = t.rate - delta 110 | } 111 | 112 | state.Last = now + offset 113 | } 114 | 115 | wait := state.Last - now 116 | if wait/t.rate >= t.capacity { 117 | return time.Duration(wait), ErrLimitExhausted 118 | } 119 | 120 | err = t.backend.SetState(ctx, state) 121 | if err != nil { 122 | return 0, err 123 | } 124 | 125 | return time.Duration(wait), nil 126 | } 127 | 128 | // Reset resets the bucket. 129 | func (t *LeakyBucket) Reset(ctx context.Context) error { 130 | return t.backend.Reset(ctx) 131 | } 132 | 133 | // LeakyBucketInMemory is an in-memory implementation of LeakyBucketStateBackend. 134 | type LeakyBucketInMemory struct { 135 | state LeakyBucketState 136 | } 137 | 138 | // NewLeakyBucketInMemory creates a new instance of LeakyBucketInMemory. 139 | func NewLeakyBucketInMemory() *LeakyBucketInMemory { 140 | return &LeakyBucketInMemory{} 141 | } 142 | 143 | // State gets the current state of the bucket. 144 | func (l *LeakyBucketInMemory) State(ctx context.Context) (LeakyBucketState, error) { 145 | return l.state, ctx.Err() 146 | } 147 | 148 | // SetState sets the current state of the bucket. 149 | func (l *LeakyBucketInMemory) SetState(ctx context.Context, state LeakyBucketState) error { 150 | l.state = state 151 | 152 | return ctx.Err() 153 | } 154 | 155 | // Reset resets the current state of the bucket. 156 | func (l *LeakyBucketInMemory) Reset(ctx context.Context) error { 157 | state := LeakyBucketState{ 158 | Last: 0, 159 | } 160 | 161 | return l.SetState(ctx, state) 162 | } 163 | 164 | const ( 165 | etcdKeyLBLease = "lease" 166 | etcdKeyLBLast = "last" 167 | ) 168 | 169 | // LeakyBucketEtcd is an etcd implementation of a LeakyBucketStateBackend. 170 | // See the TokenBucketEtcd description for the details on etcd usage. 171 | type LeakyBucketEtcd struct { 172 | // prefix is the etcd key prefix. 173 | prefix string 174 | cli *clientv3.Client 175 | leaseID clientv3.LeaseID 176 | ttl time.Duration 177 | raceCheck bool 178 | lastVersion int64 179 | } 180 | 181 | // NewLeakyBucketEtcd creates a new LeakyBucketEtcd instance. 182 | // Prefix is used as an etcd key prefix for all keys stored in etcd by this algorithm. 183 | // TTL is a TTL of the etcd lease used to store all the keys. 184 | // 185 | // If raceCheck is true and the keys in etcd are modified in between State() and SetState() calls then 186 | // ErrRaceCondition is returned. 187 | func NewLeakyBucketEtcd(cli *clientv3.Client, prefix string, ttl time.Duration, raceCheck bool) *LeakyBucketEtcd { 188 | return &LeakyBucketEtcd{ 189 | prefix: prefix, 190 | cli: cli, 191 | ttl: ttl, 192 | raceCheck: raceCheck, 193 | } 194 | } 195 | 196 | // State gets the bucket's current state from etcd. 197 | // If there is no state available in etcd then the initial bucket's state is returned. 198 | func (l *LeakyBucketEtcd) State(ctx context.Context) (LeakyBucketState, error) { 199 | // Reset the lease ID as it will be updated by the successful Get operation below. 200 | l.leaseID = 0 201 | // Get all the keys under the prefix in a single request. 202 | r, err := l.cli.Get(ctx, l.prefix, clientv3.WithRange(incPrefix(l.prefix))) 203 | if err != nil { 204 | return LeakyBucketState{}, errors.Wrapf(err, "failed to get keys in range ['%s', '%s') from etcd", l.prefix, incPrefix(l.prefix)) 205 | } 206 | 207 | if len(r.Kvs) == 0 { 208 | return LeakyBucketState{}, nil 209 | } 210 | 211 | state := LeakyBucketState{} 212 | 213 | parsed := 0 214 | if l.ttl == 0 { 215 | // Ignore lease when there is no expiration 216 | parsed |= 2 217 | } 218 | 219 | var v int64 220 | 221 | for _, kv := range r.Kvs { 222 | switch string(kv.Key) { 223 | case etcdKey(l.prefix, etcdKeyLBLast): 224 | v, err = parseEtcdInt64(kv) 225 | if err != nil { 226 | return LeakyBucketState{}, err 227 | } 228 | 229 | state.Last = v 230 | parsed |= 1 231 | l.lastVersion = kv.Version 232 | 233 | case etcdKey(l.prefix, etcdKeyLBLease): 234 | v, err = parseEtcdInt64(kv) 235 | if err != nil { 236 | return LeakyBucketState{}, err 237 | } 238 | 239 | l.leaseID = clientv3.LeaseID(v) 240 | parsed |= 2 241 | } 242 | } 243 | 244 | if parsed != 3 { 245 | return LeakyBucketState{}, errors.New("failed to get state from etcd: some keys are missing") 246 | } 247 | 248 | return state, nil 249 | } 250 | 251 | // createLease creates a new lease in etcd and updates the t.leaseID value. 252 | func (l *LeakyBucketEtcd) createLease(ctx context.Context) error { 253 | lease, err := l.cli.Grant(ctx, int64(math.Ceil(l.ttl.Seconds()))) 254 | if err != nil { 255 | return errors.Wrap(err, "failed to create a new lease in etcd") 256 | } 257 | 258 | l.leaseID = lease.ID 259 | 260 | return nil 261 | } 262 | 263 | // save saves the state to etcd using the existing lease. 264 | func (l *LeakyBucketEtcd) save(ctx context.Context, state LeakyBucketState) error { 265 | var opts []clientv3.OpOption 266 | if l.ttl > 0 { 267 | opts = append(opts, clientv3.WithLease(l.leaseID)) 268 | } 269 | 270 | ops := []clientv3.Op{ 271 | clientv3.OpPut(etcdKey(l.prefix, etcdKeyLBLast), fmt.Sprintf("%d", state.Last), opts...), 272 | } 273 | if l.ttl > 0 { 274 | ops = append(ops, clientv3.OpPut(etcdKey(l.prefix, etcdKeyLBLease), fmt.Sprintf("%d", l.leaseID), opts...)) 275 | } 276 | 277 | if !l.raceCheck { 278 | _, err := l.cli.Txn(ctx).Then(ops...).Commit() 279 | if err != nil { 280 | return errors.Wrap(err, "failed to commit a transaction to etcd") 281 | } 282 | 283 | return nil 284 | } 285 | // Put the keys only if they have not been modified since the most recent read. 286 | r, err := l.cli.Txn(ctx).If( 287 | clientv3.Compare(clientv3.Version(etcdKey(l.prefix, etcdKeyLBLast)), ">", l.lastVersion), 288 | ).Else(ops...).Commit() 289 | if err != nil { 290 | return errors.Wrap(err, "failed to commit a transaction to etcd") 291 | } 292 | 293 | if !r.Succeeded { 294 | return nil 295 | } 296 | 297 | return ErrRaceCondition 298 | } 299 | 300 | // SetState updates the state of the bucket in etcd. 301 | func (l *LeakyBucketEtcd) SetState(ctx context.Context, state LeakyBucketState) error { 302 | if l.ttl == 0 { 303 | // Avoid maintaining the lease when it has no TTL 304 | return l.save(ctx, state) 305 | } 306 | 307 | if l.leaseID == 0 { 308 | // Lease does not exist, create one. 309 | err := l.createLease(ctx) 310 | if err != nil { 311 | return err 312 | } 313 | // No need to send KeepAlive for the newly creates lease: save the state immediately. 314 | return l.save(ctx, state) 315 | } 316 | // Send the KeepAlive request to extend the existing lease. 317 | _, err := l.cli.KeepAliveOnce(ctx, l.leaseID) 318 | if errors.Is(err, rpctypes.ErrLeaseNotFound) { 319 | // Create a new lease since the current one has expired. 320 | err = l.createLease(ctx) 321 | if err != nil { 322 | return err 323 | } 324 | } else if err != nil { 325 | return errors.Wrapf(err, "failed to extend the lease '%d'", l.leaseID) 326 | } 327 | 328 | return l.save(ctx, state) 329 | } 330 | 331 | // Reset resets the state of the bucket in etcd. 332 | func (l *LeakyBucketEtcd) Reset(ctx context.Context) error { 333 | state := LeakyBucketState{ 334 | Last: 0, 335 | } 336 | 337 | return l.SetState(ctx, state) 338 | } 339 | 340 | // Deprecated: These legacy keys will be removed in a future version. 341 | // The state is now stored in a single JSON document under the "state" key. 342 | const ( 343 | redisKeyLBLast = "last" 344 | redisKeyLBVersion = "version" 345 | ) 346 | 347 | // LeakyBucketRedis is a Redis implementation of a LeakyBucketStateBackend. 348 | type LeakyBucketRedis struct { 349 | cli redis.UniversalClient 350 | prefix string 351 | ttl time.Duration 352 | raceCheck bool 353 | lastVersion int64 354 | } 355 | 356 | // NewLeakyBucketRedis creates a new LeakyBucketRedis instance. 357 | // Prefix is the key prefix used to store all the keys used in this implementation in Redis. 358 | // TTL is the TTL of the stored keys. 359 | // 360 | // If raceCheck is true and the keys in Redis are modified in between State() and SetState() calls then 361 | // ErrRaceCondition is returned. 362 | func NewLeakyBucketRedis(cli redis.UniversalClient, prefix string, ttl time.Duration, raceCheck bool) *LeakyBucketRedis { 363 | return &LeakyBucketRedis{cli: cli, prefix: prefix, ttl: ttl, raceCheck: raceCheck} 364 | } 365 | 366 | // Deprecated: Legacy format support will be removed in a future version. 367 | func (t *LeakyBucketRedis) oldState(ctx context.Context) (LeakyBucketState, error) { 368 | var ( 369 | values []interface{} 370 | err error 371 | ) 372 | 373 | done := make(chan struct{}, 1) 374 | 375 | go func() { 376 | defer close(done) 377 | 378 | keys := []string{ 379 | redisKey(t.prefix, redisKeyLBLast), 380 | } 381 | if t.raceCheck { 382 | keys = append(keys, redisKey(t.prefix, redisKeyLBVersion)) 383 | } 384 | 385 | values, err = t.cli.MGet(ctx, keys...).Result() 386 | }() 387 | 388 | select { 389 | case <-done: 390 | 391 | case <-ctx.Done(): 392 | return LeakyBucketState{}, ctx.Err() 393 | } 394 | 395 | if err != nil { 396 | return LeakyBucketState{}, errors.Wrap(err, "failed to get keys from redis") 397 | } 398 | 399 | nilAny := false 400 | 401 | for _, v := range values { 402 | if v == nil { 403 | nilAny = true 404 | 405 | break 406 | } 407 | } 408 | 409 | if nilAny || errors.Is(err, redis.Nil) { 410 | // Keys don't exist, return an empty state. 411 | return LeakyBucketState{}, nil 412 | } 413 | 414 | last, err := strconv.ParseInt(values[0].(string), 10, 64) 415 | if err != nil { 416 | return LeakyBucketState{}, err 417 | } 418 | 419 | if t.raceCheck { 420 | t.lastVersion, err = strconv.ParseInt(values[1].(string), 10, 64) 421 | if err != nil { 422 | return LeakyBucketState{}, err 423 | } 424 | } 425 | 426 | return LeakyBucketState{ 427 | Last: last, 428 | }, nil 429 | } 430 | 431 | // State gets the bucket's state from Redis. 432 | func (t *LeakyBucketRedis) State(ctx context.Context) (LeakyBucketState, error) { 433 | var err error 434 | 435 | done := make(chan struct{}, 1) 436 | errCh := make(chan error, 1) 437 | 438 | var state LeakyBucketState 439 | 440 | if t.raceCheck { 441 | // reset in a case of returning an empty LeakyBucketState 442 | t.lastVersion = 0 443 | } 444 | 445 | go func() { 446 | defer close(done) 447 | 448 | key := redisKey(t.prefix, "state") 449 | 450 | value, err := t.cli.Get(ctx, key).Result() 451 | if err != nil && !errors.Is(err, redis.Nil) { 452 | errCh <- err 453 | 454 | return 455 | } 456 | 457 | if errors.Is(err, redis.Nil) { 458 | state, err = t.oldState(ctx) 459 | errCh <- err 460 | 461 | return 462 | } 463 | 464 | // Try new format 465 | var item struct { 466 | State LeakyBucketState `json:"state"` 467 | Version int64 `json:"version"` 468 | } 469 | 470 | err = json.Unmarshal([]byte(value), &item) 471 | if err != nil { 472 | errCh <- err 473 | 474 | return 475 | } 476 | 477 | state = item.State 478 | if t.raceCheck { 479 | t.lastVersion = item.Version 480 | } 481 | 482 | errCh <- nil 483 | }() 484 | 485 | select { 486 | case <-done: 487 | err = <-errCh 488 | case <-ctx.Done(): 489 | return LeakyBucketState{}, ctx.Err() 490 | } 491 | 492 | if err != nil { 493 | return LeakyBucketState{}, errors.Wrap(err, "failed to get state from redis") 494 | } 495 | 496 | return state, nil 497 | } 498 | 499 | // SetState updates the state in Redis. 500 | func (t *LeakyBucketRedis) SetState(ctx context.Context, state LeakyBucketState) error { 501 | var err error 502 | 503 | done := make(chan struct{}, 1) 504 | errCh := make(chan error, 1) 505 | 506 | go func() { 507 | defer close(done) 508 | 509 | key := redisKey(t.prefix, "state") 510 | item := struct { 511 | State LeakyBucketState `json:"state"` 512 | Version int64 `json:"version"` 513 | }{ 514 | State: state, 515 | Version: t.lastVersion + 1, 516 | } 517 | 518 | value, err := json.Marshal(item) 519 | if err != nil { 520 | errCh <- err 521 | 522 | return 523 | } 524 | 525 | if !t.raceCheck { 526 | errCh <- t.cli.Set(ctx, key, value, t.ttl).Err() 527 | 528 | return 529 | } 530 | 531 | callScript := `redis.call('set', KEYS[1], ARGV[1], 'PX', ARGV[3])` 532 | if t.ttl == 0 { 533 | callScript = `redis.call('set', KEYS[1], ARGV[1])` 534 | } 535 | 536 | script := fmt.Sprintf(` 537 | local current = redis.call('get', KEYS[1]) 538 | if current then 539 | local data = cjson.decode(current) 540 | if data.version > tonumber(ARGV[2]) then 541 | return 'RACE_CONDITION' 542 | end 543 | end 544 | %s 545 | return 'OK' 546 | `, callScript) 547 | 548 | result, err := t.cli.Eval(ctx, script, []string{key}, value, t.lastVersion, int64(t.ttl/time.Millisecond)).Result() 549 | if err != nil { 550 | errCh <- err 551 | 552 | return 553 | } 554 | 555 | if result == "RACE_CONDITION" { 556 | errCh <- ErrRaceCondition 557 | 558 | return 559 | } 560 | 561 | errCh <- nil 562 | }() 563 | 564 | select { 565 | case <-done: 566 | err = <-errCh 567 | case <-ctx.Done(): 568 | return ctx.Err() 569 | } 570 | 571 | if err != nil { 572 | return errors.Wrap(err, "failed to save state to redis") 573 | } 574 | 575 | return nil 576 | } 577 | 578 | // Reset resets the state in Redis. 579 | func (t *LeakyBucketRedis) Reset(ctx context.Context) error { 580 | state := LeakyBucketState{ 581 | Last: 0, 582 | } 583 | 584 | return t.SetState(ctx, state) 585 | } 586 | 587 | // LeakyBucketMemcached is a Memcached implementation of a LeakyBucketStateBackend. 588 | type LeakyBucketMemcached struct { 589 | cli *memcache.Client 590 | key string 591 | ttl time.Duration 592 | raceCheck bool 593 | casId uint64 594 | } 595 | 596 | // NewLeakyBucketMemcached creates a new LeakyBucketMemcached instance. 597 | // Key is the key used to store all the keys used in this implementation in Memcached. 598 | // TTL is the TTL of the stored keys. 599 | // 600 | // If raceCheck is true and the keys in Memcached are modified in between State() and SetState() calls then 601 | // ErrRaceCondition is returned. 602 | func NewLeakyBucketMemcached(cli *memcache.Client, key string, ttl time.Duration, raceCheck bool) *LeakyBucketMemcached { 603 | return &LeakyBucketMemcached{cli: cli, key: key, ttl: ttl, raceCheck: raceCheck} 604 | } 605 | 606 | // State gets the bucket's state from Memcached. 607 | func (t *LeakyBucketMemcached) State(ctx context.Context) (LeakyBucketState, error) { 608 | var ( 609 | item *memcache.Item 610 | err error 611 | state LeakyBucketState 612 | ) 613 | 614 | done := make(chan struct{}, 1) 615 | t.casId = 0 616 | 617 | go func() { 618 | defer close(done) 619 | 620 | item, err = t.cli.Get(t.key) 621 | }() 622 | 623 | select { 624 | case <-done: 625 | 626 | case <-ctx.Done(): 627 | return state, ctx.Err() 628 | } 629 | 630 | if err != nil { 631 | if errors.Is(err, memcache.ErrCacheMiss) { 632 | // Keys don't exist, return an empty state. 633 | return state, nil 634 | } 635 | 636 | return state, errors.Wrap(err, "failed to get keys from memcached") 637 | } 638 | 639 | b := bytes.NewBuffer(item.Value) 640 | 641 | err = gob.NewDecoder(b).Decode(&state) 642 | if err != nil { 643 | return state, errors.Wrap(err, "failed to Decode") 644 | } 645 | 646 | t.casId = item.CasID 647 | 648 | return state, nil 649 | } 650 | 651 | // SetState updates the state in Memcached. 652 | // The provided fencing token is checked on the Memcached side before saving the keys. 653 | func (t *LeakyBucketMemcached) SetState(ctx context.Context, state LeakyBucketState) error { 654 | var err error 655 | 656 | done := make(chan struct{}, 1) 657 | 658 | var b bytes.Buffer 659 | 660 | err = gob.NewEncoder(&b).Encode(state) 661 | if err != nil { 662 | return errors.Wrap(err, "failed to Encode") 663 | } 664 | 665 | go func() { 666 | defer close(done) 667 | 668 | item := &memcache.Item{ 669 | Key: t.key, 670 | Value: b.Bytes(), 671 | CasID: t.casId, 672 | } 673 | if t.ttl > 30*24*time.Hour { 674 | // If the value is over 30 days, it treats it as UNIX timestamp. 675 | item.Expiration = int32(time.Now().Add(t.ttl).Unix()) 676 | } else if t.ttl > 0 { 677 | // Memcached supports expiration in seconds. It's more precise way. 678 | item.Expiration = int32(math.Ceil(t.ttl.Seconds())) 679 | } 680 | 681 | if t.raceCheck && t.casId > 0 { 682 | err = t.cli.CompareAndSwap(item) 683 | } else { 684 | err = t.cli.Set(item) 685 | } 686 | }() 687 | 688 | select { 689 | case <-done: 690 | 691 | case <-ctx.Done(): 692 | return ctx.Err() 693 | } 694 | 695 | if err != nil && (errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss)) { 696 | return ErrRaceCondition 697 | } 698 | 699 | return errors.Wrap(err, "failed to save keys to memcached") 700 | } 701 | 702 | // Reset resets the state in Memcached. 703 | func (t *LeakyBucketMemcached) Reset(ctx context.Context) error { 704 | state := LeakyBucketState{ 705 | Last: 0, 706 | } 707 | t.casId = 0 708 | 709 | return t.SetState(ctx, state) 710 | } 711 | 712 | // LeakyBucketDynamoDB is a DyanamoDB implementation of a LeakyBucketStateBackend. 713 | type LeakyBucketDynamoDB struct { 714 | client *dynamodb.Client 715 | tableProps DynamoDBTableProperties 716 | partitionKey string 717 | ttl time.Duration 718 | raceCheck bool 719 | latestVersion int64 720 | keys map[string]types.AttributeValue 721 | } 722 | 723 | // NewLeakyBucketDynamoDB creates a new LeakyBucketDynamoDB instance. 724 | // PartitionKey is the key used to store all the this implementation in DynamoDB. 725 | // 726 | // TableProps describe the table that this backend should work with. This backend requires the following on the table: 727 | // * TTL 728 | // 729 | // TTL is the TTL of the stored item. 730 | // 731 | // If raceCheck is true and the item in DynamoDB are modified in between State() and SetState() calls then 732 | // ErrRaceCondition is returned. 733 | func NewLeakyBucketDynamoDB(client *dynamodb.Client, partitionKey string, tableProps DynamoDBTableProperties, ttl time.Duration, raceCheck bool) *LeakyBucketDynamoDB { 734 | keys := map[string]types.AttributeValue{ 735 | tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: partitionKey}, 736 | } 737 | 738 | if tableProps.SortKeyUsed { 739 | keys[tableProps.SortKeyName] = &types.AttributeValueMemberS{Value: partitionKey} 740 | } 741 | 742 | return &LeakyBucketDynamoDB{ 743 | client: client, 744 | partitionKey: partitionKey, 745 | tableProps: tableProps, 746 | ttl: ttl, 747 | raceCheck: raceCheck, 748 | keys: keys, 749 | } 750 | } 751 | 752 | // State gets the bucket's state from DynamoDB. 753 | func (t *LeakyBucketDynamoDB) State(ctx context.Context) (LeakyBucketState, error) { 754 | resp, err := dynamoDBGetItem(ctx, t.client, t.getGetItemInput()) 755 | if err != nil { 756 | return LeakyBucketState{}, err 757 | } 758 | 759 | return t.loadStateFromDynamoDB(resp) 760 | } 761 | 762 | // SetState updates the state in DynamoDB. 763 | func (t *LeakyBucketDynamoDB) SetState(ctx context.Context, state LeakyBucketState) error { 764 | input := t.getPutItemInputFromState(state) 765 | 766 | var err error 767 | 768 | done := make(chan struct{}) 769 | 770 | go func() { 771 | defer close(done) 772 | 773 | _, err = dynamoDBputItem(ctx, t.client, input) 774 | }() 775 | 776 | select { 777 | case <-done: 778 | case <-ctx.Done(): 779 | return ctx.Err() 780 | } 781 | 782 | return err 783 | } 784 | 785 | // Reset resets the state in DynamoDB. 786 | func (t *LeakyBucketDynamoDB) Reset(ctx context.Context) error { 787 | state := LeakyBucketState{ 788 | Last: 0, 789 | } 790 | 791 | return t.SetState(ctx, state) 792 | } 793 | 794 | const ( 795 | dynamodbBucketRaceConditionExpression = "Version <= :version" 796 | dynamoDBBucketLastKey = "Last" 797 | dynamoDBBucketVersionKey = "Version" 798 | ) 799 | 800 | func (t *LeakyBucketDynamoDB) getPutItemInputFromState(state LeakyBucketState) *dynamodb.PutItemInput { 801 | item := map[string]types.AttributeValue{} 802 | for k, v := range t.keys { 803 | item[k] = v 804 | } 805 | 806 | item[dynamoDBBucketLastKey] = &types.AttributeValueMemberN{Value: strconv.FormatInt(state.Last, 10)} 807 | 808 | item[dynamoDBBucketVersionKey] = &types.AttributeValueMemberN{Value: strconv.FormatInt(t.latestVersion+1, 10)} 809 | if t.ttl > 0 { 810 | item[t.tableProps.TTLFieldName] = &types.AttributeValueMemberN{Value: strconv.FormatInt(time.Now().Add(t.ttl).Unix(), 10)} 811 | } 812 | 813 | input := &dynamodb.PutItemInput{ 814 | TableName: &t.tableProps.TableName, 815 | Item: item, 816 | } 817 | 818 | if t.raceCheck && t.latestVersion > 0 { 819 | input.ConditionExpression = aws.String(dynamodbBucketRaceConditionExpression) 820 | input.ExpressionAttributeValues = map[string]types.AttributeValue{ 821 | ":version": &types.AttributeValueMemberN{Value: strconv.FormatInt(t.latestVersion, 10)}, 822 | } 823 | } 824 | 825 | return input 826 | } 827 | 828 | func (t *LeakyBucketDynamoDB) getGetItemInput() *dynamodb.GetItemInput { 829 | return &dynamodb.GetItemInput{ 830 | TableName: &t.tableProps.TableName, 831 | Key: t.keys, 832 | } 833 | } 834 | 835 | func (t *LeakyBucketDynamoDB) loadStateFromDynamoDB(resp *dynamodb.GetItemOutput) (LeakyBucketState, error) { 836 | state := LeakyBucketState{} 837 | 838 | err := attributevalue.Unmarshal(resp.Item[dynamoDBBucketLastKey], &state.Last) 839 | if err != nil { 840 | return state, fmt.Errorf("unmarshal dynamodb Last attribute failed: %w", err) 841 | } 842 | 843 | if t.raceCheck { 844 | err = attributevalue.Unmarshal(resp.Item[dynamoDBBucketVersionKey], &t.latestVersion) 845 | if err != nil { 846 | return state, fmt.Errorf("unmarshal dynamodb Version attribute failed: %w", err) 847 | } 848 | } 849 | 850 | return state, nil 851 | } 852 | 853 | // CosmosDBLeakyBucketItem represents a document in CosmosDB for LeakyBucket. 854 | type CosmosDBLeakyBucketItem struct { 855 | ID string `json:"id"` 856 | PartitionKey string `json:"partitionKey"` 857 | State LeakyBucketState `json:"state"` 858 | Version int64 `json:"version"` 859 | TTL int64 `json:"ttl,omitempty"` 860 | } 861 | 862 | // LeakyBucketCosmosDB is a CosmosDB implementation of a LeakyBucketStateBackend. 863 | type LeakyBucketCosmosDB struct { 864 | client *azcosmos.ContainerClient 865 | partitionKey string 866 | id string 867 | ttl time.Duration 868 | raceCheck bool 869 | latestVersion int64 870 | } 871 | 872 | // NewLeakyBucketCosmosDB creates a new LeakyBucketCosmosDB instance. 873 | // PartitionKey is the key used to store all the implementation in CosmosDB. 874 | // TTL is the TTL of the stored item. 875 | // 876 | // If raceCheck is true and the item in CosmosDB is modified in between State() and SetState() calls then 877 | // ErrRaceCondition is returned. 878 | func NewLeakyBucketCosmosDB(client *azcosmos.ContainerClient, partitionKey string, ttl time.Duration, raceCheck bool) *LeakyBucketCosmosDB { 879 | return &LeakyBucketCosmosDB{ 880 | client: client, 881 | partitionKey: partitionKey, 882 | id: "leaky-bucket-" + partitionKey, 883 | ttl: ttl, 884 | raceCheck: raceCheck, 885 | } 886 | } 887 | 888 | func (t *LeakyBucketCosmosDB) State(ctx context.Context) (LeakyBucketState, error) { 889 | var item CosmosDBLeakyBucketItem 890 | 891 | resp, err := t.client.ReadItem(ctx, azcosmos.NewPartitionKey().AppendString(t.partitionKey), t.id, &azcosmos.ItemOptions{}) 892 | if err != nil { 893 | var respErr *azcore.ResponseError 894 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { 895 | return LeakyBucketState{}, nil 896 | } 897 | 898 | return LeakyBucketState{}, err 899 | } 900 | 901 | err = json.Unmarshal(resp.Value, &item) 902 | if err != nil { 903 | return LeakyBucketState{}, errors.Wrap(err, "failed to decode state from Cosmos DB") 904 | } 905 | 906 | if t.raceCheck { 907 | t.latestVersion = item.Version 908 | } 909 | 910 | return item.State, nil 911 | } 912 | 913 | func (t *LeakyBucketCosmosDB) SetState(ctx context.Context, state LeakyBucketState) error { 914 | var err error 915 | 916 | done := make(chan struct{}, 1) 917 | 918 | item := CosmosDBLeakyBucketItem{ 919 | ID: t.id, 920 | PartitionKey: t.partitionKey, 921 | State: state, 922 | Version: t.latestVersion + 1, 923 | } 924 | if t.ttl > 0 { 925 | item.TTL = int64(math.Ceil(t.ttl.Seconds())) 926 | } 927 | 928 | value, err := json.Marshal(item) 929 | if err != nil { 930 | return errors.Wrap(err, "failed to encode state to JSON") 931 | } 932 | 933 | go func() { 934 | defer close(done) 935 | 936 | _, err = t.client.UpsertItem(ctx, azcosmos.NewPartitionKey().AppendString(t.partitionKey), value, &azcosmos.ItemOptions{}) 937 | }() 938 | 939 | select { 940 | case <-done: 941 | case <-ctx.Done(): 942 | return ctx.Err() 943 | } 944 | 945 | if err != nil { 946 | var respErr *azcore.ResponseError 947 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict && t.raceCheck { 948 | return ErrRaceCondition 949 | } 950 | 951 | return errors.Wrap(err, "failed to save keys to Cosmos DB") 952 | } 953 | 954 | return nil 955 | } 956 | 957 | func (t *LeakyBucketCosmosDB) Reset(ctx context.Context) error { 958 | state := LeakyBucketState{ 959 | Last: 0, 960 | } 961 | 962 | return t.SetState(ctx, state) 963 | } 964 | --------------------------------------------------------------------------------