├── .gitignore
├── codecov.yml
├── .github
├── dependabot.yaml
└── workflows
│ ├── staticcheck.yml
│ ├── blockwatch.yml
│ ├── vet.yml
│ ├── golangci-lint.yml
│ └── tests.yml
├── .pre-commit-config.yaml
├── examples
├── gprc_service.go
├── helloworld
│ ├── helloworld.proto
│ ├── helloworld_grpc.pb.go
│ └── helloworld.pb.go
├── example_grpc_simple_limiter_test.go
└── example_grpc_ip_limiter_test.go
├── revive.toml
├── .golangci.yml
├── LICENSE
├── cosmos_test.go
├── Makefile
├── docker-compose.yml
├── locks_test.go
├── limiters.go
├── cosmosdb.go
├── registry_test.go
├── registry.go
├── dynamodb_test.go
├── go.mod
├── dynamodb.go
├── concurrent_buffer_test.go
├── fixedwindow_test.go
├── README.md
├── locks.go
├── slidingwindow_test.go
├── limiters_test.go
├── concurrent_buffer.go
├── fixedwindow.go
├── tokenbucket_test.go
├── leakybucket_test.go
├── slidingwindow.go
└── leakybucket.go
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | coverage.txt
3 | .vscode
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | status:
3 | patch: off
4 |
--------------------------------------------------------------------------------
/.github/dependabot.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: gomod
4 | directory: /
5 | schedule:
6 | interval: monthly
7 | groups:
8 | all-dependencies:
9 | patterns:
10 | - "*"
11 |
--------------------------------------------------------------------------------
/.github/workflows/staticcheck.yml:
--------------------------------------------------------------------------------
1 | name: Static Checks
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | branches:
9 | - master
10 |
11 | jobs:
12 | staticcheck:
13 | name: Linter
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v4
17 | - uses: dominikh/staticcheck-action@v1.3.1
18 | with:
19 | version: "latest"
--------------------------------------------------------------------------------
/.github/workflows/blockwatch.yml:
--------------------------------------------------------------------------------
1 | name: Blockwatch linter
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | branches:
9 | - master
10 |
11 | jobs:
12 | staticcheck:
13 | name: Blockwatch
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v5
17 | with:
18 | fetch-depth: 2 # Required to diff against the base branch
19 | - uses: mennanov/blockwatch-action@v1
20 |
--------------------------------------------------------------------------------
/.github/workflows/vet.yml:
--------------------------------------------------------------------------------
1 | name: Linter
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | branches:
9 | - master
10 |
11 | jobs:
12 | vet:
13 | name: Go vet
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v4
17 | - name: Setup Go 1.25
18 | uses: actions/setup-go@v5
19 | with:
20 | go-version: '1.25'
21 | - name: Go vet
22 | run: go vet ./...
23 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: blockwatch
5 | name: blockwatch
6 | entry: bash -c 'git diff --patch --cached --unified=0 | blockwatch'
7 | language: system
8 | stages: [ pre-commit ]
9 | pass_filenames: false
10 | - id: golangci-lint
11 | name: golangci-lint
12 | entry: golangci-lint run
13 | language: system
14 | types: [go]
15 | pass_filenames: false
16 |
--------------------------------------------------------------------------------
/examples/gprc_service.go:
--------------------------------------------------------------------------------
1 | package examples
2 |
3 | import (
4 | "context"
5 |
6 | pb "github.com/mennanov/limiters/examples/helloworld"
7 | )
8 |
9 | const (
10 | Port = ":50051"
11 | )
12 |
13 | // Server is used to implement helloworld.GreeterServer.
14 | type Server struct {
15 | pb.UnimplementedGreeterServer
16 | }
17 |
18 | // SayHello implements helloworld.GreeterServer.
19 | func (s *Server) SayHello(_ context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
20 | return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil
21 | }
22 |
23 | var _ pb.GreeterServer = new(Server)
24 |
--------------------------------------------------------------------------------
/examples/helloworld/helloworld.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package helloworld;
4 | option go_package = "github.com/mennanov/limiters/examples/helloworld";
5 |
6 | // The greeting service definition.
7 | service Greeter {
8 | // Sends a greeting
9 | rpc SayHello (HelloRequest) returns (HelloReply) {
10 | }
11 | }
12 |
13 | // The request message containing the user's name.
14 | message HelloRequest {
15 | string name = 1;
16 | }
17 |
18 | // The response message containing the greetings
19 | message HelloReply {
20 | string message = 1;
21 | }
22 |
23 | message GoodbyeRequest {
24 | string name = 1;
25 | }
26 |
27 | message GoodbyeReply {
28 | string message = 1;
29 | }
--------------------------------------------------------------------------------
/revive.toml:
--------------------------------------------------------------------------------
1 | ignoreGeneratedHeader = false
2 | severity = "warning"
3 | confidence = 0.8
4 | errorCode = 0
5 | warningCode = 0
6 |
7 | [rule.blank-imports]
8 | [rule.context-as-argument]
9 | [rule.context-keys-type]
10 | [rule.dot-imports]
11 | [rule.error-return]
12 | [rule.error-strings]
13 | [rule.error-naming]
14 | [rule.exported]
15 | [rule.if-return]
16 | [rule.increment-decrement]
17 | [rule.var-naming]
18 | [rule.var-declaration]
19 | [rule.package-comments]
20 | [rule.range]
21 | [rule.receiver-naming]
22 | [rule.time-naming]
23 | [rule.unexported-return]
24 | [rule.indent-error-flow]
25 | [rule.errorf]
26 | [rule.empty-block]
27 | [rule.superfluous-else]
28 | [rule.unused-parameter]
29 | [rule.unreachable-code]
30 | [rule.redefines-builtin-id]
31 |
--------------------------------------------------------------------------------
/.github/workflows/golangci-lint.yml:
--------------------------------------------------------------------------------
1 | name: golangci-lint
2 | on:
3 | push:
4 | branches:
5 | - main
6 | - master
7 | pull_request:
8 |
9 | permissions:
10 | contents: read
11 | # Optional: allow read access to pull request. Use with `only-new-issues` option.
12 | # pull-requests: read
13 |
14 | jobs:
15 | golangci:
16 | name: lint
17 | runs-on: ubuntu-latest
18 | steps:
19 | - uses: actions/checkout@v4
20 | - uses: actions/setup-go@v5
21 | with:
22 | go-version: stable
23 | - name: golangci-lint
24 | uses: golangci/golangci-lint-action@v8
25 | with:
26 | #
27 | version: "v2.5.0"
28 | #
29 |
--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
1 | version: "2"
2 | linters:
3 | default: all
4 | disable:
5 | - cyclop
6 | - depguard
7 | - dupl
8 | - dupword
9 | - exhaustruct
10 | - forcetypeassert
11 | - funcorder
12 | - funlen
13 | - gochecknoglobals
14 | - goconst
15 | - gocritic
16 | - godox
17 | - gomoddirectives
18 | - gosec
19 | - inamedparam
20 | - intrange
21 | - lll
22 | - mnd
23 | - musttag
24 | - paralleltest
25 | - perfsprint
26 | - recvcheck
27 | - revive
28 | - tagliatelle
29 | - testifylint
30 | - unparam
31 | - varnamelen
32 | - wrapcheck
33 | - wsl
34 | exclusions:
35 | generated: lax
36 | presets:
37 | - comments
38 | - common-false-positives
39 | - legacy
40 | - std-error-handling
41 | paths:
42 | - third_party$
43 | - builtin$
44 | - examples$
45 | formatters:
46 | enable:
47 | - gci
48 | - gofmt
49 | - gofumpt
50 | - goimports
51 | exclusions:
52 | generated: lax
53 | paths:
54 | - third_party$
55 | - builtin$
56 | - examples$
57 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Renat Mennanov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/cosmos_test.go:
--------------------------------------------------------------------------------
1 | package limiters_test
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
7 | )
8 |
9 | const (
10 | testCosmosDBName = "limiters-db-test"
11 | testCosmosContainerName = "limiters-container-test"
12 | )
13 |
14 | var defaultTTL int32 = 86400
15 |
16 | func CreateCosmosDBContainer(ctx context.Context, client *azcosmos.Client) error {
17 | resp, err := client.CreateDatabase(ctx, azcosmos.DatabaseProperties{
18 | ID: testCosmosDBName,
19 | }, &azcosmos.CreateDatabaseOptions{})
20 | if err != nil {
21 | return err
22 | }
23 |
24 | dbClient, err := client.NewDatabase(resp.DatabaseProperties.ID)
25 | if err != nil {
26 | return err
27 | }
28 |
29 | _, err = dbClient.CreateContainer(ctx, azcosmos.ContainerProperties{
30 | ID: testCosmosContainerName,
31 | DefaultTimeToLive: &defaultTTL,
32 | PartitionKeyDefinition: azcosmos.PartitionKeyDefinition{
33 | Paths: []string{`/partitionKey`},
34 | },
35 | }, &azcosmos.CreateContainerOptions{})
36 |
37 | return err
38 | }
39 |
40 | func DeleteCosmosDBContainer(ctx context.Context, client *azcosmos.Client) error {
41 | dbClient, err := client.NewDatabase(testCosmosDBName)
42 | if err != nil {
43 | return err
44 | }
45 |
46 | _, err = dbClient.Delete(ctx, &azcosmos.DeleteDatabaseOptions{})
47 |
48 | return err
49 | }
50 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | #
2 | SERVICES = ETCD_ENDPOINTS="127.0.0.1:2379" REDIS_ADDR="127.0.0.1:6379" REDIS_NODES="127.0.0.1:11000,127.0.0.1:11001,127.0.0.1:11002,127.0.0.1:11003,127.0.0.1:11004,127.0.0.1:11005" ZOOKEEPER_ENDPOINTS="127.0.0.1" CONSUL_ADDR="127.0.0.1:8500" AWS_ADDR="127.0.0.1:8000" MEMCACHED_ADDR="127.0.0.1:11211" POSTGRES_URL="postgres://postgres:postgres@localhost:5432/?sslmode=disable" COSMOS_ADDR="127.0.0.1:8081"
3 | #
4 |
5 | all: gofumpt goimports lint test benchmark
6 |
7 | docker-compose-up:
8 | docker compose up -d
9 |
10 | test: docker-compose-up
11 | @$(SERVICES) go test -race -v -failfast
12 |
13 | benchmark: docker-compose-up
14 | @$(SERVICES) go test -race -run=nonexistent -bench=.
15 |
16 | lint:
17 | #
18 | @(which golangci-lint && golangci-lint --version | grep 2.5.0) || (curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/v2.5.0/install.sh | sh -s -- -b $(shell go env GOPATH)/bin v2.5.0)
19 | #
20 | golangci-lint run --fix ./...
21 |
22 | goimports:
23 | @which goimports 2>&1 > /dev/null || go install golang.org/x/tools/cmd/goimports@latest
24 | goimports -w .
25 |
26 | gofumpt:
27 | @which gofumpt 2>&1 > /dev/null || go install mvdan.cc/gofumpt@latest
28 | gofumpt -l -w .
29 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | #
3 | etcd:
4 | image: gcr.io/etcd-development/etcd:v3.6.5
5 | environment:
6 | ETCD_LISTEN_CLIENT_URLS: "http://0.0.0.0:2379"
7 | ETCD_ADVERTISE_CLIENT_URLS: "http://0.0.0.0:2379"
8 | ports:
9 | - "2379:2379"
10 |
11 | redis:
12 | image: redis:8.2
13 | ports:
14 | - "6379:6379"
15 |
16 | redis-cluster:
17 | image: grokzen/redis-cluster:7.0.10
18 | environment:
19 | IP: "0.0.0.0"
20 | INITIAL_PORT: 11000
21 | ports:
22 | - "11000-11005:11000-11005"
23 |
24 | memcached:
25 | image: memcached:1.6
26 | ports:
27 | - "11211:11211"
28 |
29 | consul:
30 | image: hashicorp/consul:1.21
31 | ports:
32 | - "8500:8500"
33 |
34 | zookeeper:
35 | image: zookeeper:3.9
36 | ports:
37 | - "2181:2181"
38 |
39 | dynamodb-local:
40 | image: "amazon/dynamodb-local:latest"
41 | command: "-jar DynamoDBLocal.jar -inMemory"
42 | ports:
43 | - "8000:8000"
44 |
45 | postgresql:
46 | image: postgres:18
47 | environment:
48 | POSTGRES_PASSWORD: postgres
49 | ports:
50 | - "5432:5432"
51 |
52 | cosmos:
53 | image: mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:vnext-preview
54 | command: [ "--protocol", "http" ]
55 | environment:
56 | PROTOCOL: http
57 | COSMOS_HTTP_CONNECTION_WITHOUT_TLS_ALLOWED: "true"
58 | ports:
59 | - "8081:8081"
60 | #
61 |
--------------------------------------------------------------------------------
/locks_test.go:
--------------------------------------------------------------------------------
1 | package limiters_test
2 |
3 | import (
4 | "context"
5 | "sync"
6 | "testing"
7 | "time"
8 |
9 | "github.com/mennanov/limiters"
10 | )
11 |
12 | func (s *LimitersTestSuite) useLock(lock limiters.DistLocker, shared *int, sleep time.Duration) {
13 | s.NoError(lock.Lock(context.TODO()))
14 |
15 | sh := *shared
16 | // Imitate heavy work...
17 | time.Sleep(sleep)
18 | // Check for the race condition.
19 | s.Equal(sh, *shared)
20 | *shared++
21 |
22 | s.NoError(lock.Unlock(context.Background()))
23 | }
24 |
25 | func (s *LimitersTestSuite) TestDistLockers() {
26 | locks1 := s.distLockers(false)
27 |
28 | locks2 := s.distLockers(false)
29 | for name := range locks1 {
30 | s.Run(name, func() {
31 | var shared int
32 |
33 | rounds := 6
34 | sleep := time.Millisecond * 50
35 |
36 | for i := 0; i < rounds; i++ {
37 | wg := sync.WaitGroup{}
38 | wg.Add(2)
39 |
40 | go func(k string) {
41 | defer wg.Done()
42 |
43 | s.useLock(locks1[k], &shared, sleep)
44 | }(name)
45 | go func(k string) {
46 | defer wg.Done()
47 |
48 | s.useLock(locks2[k], &shared, sleep)
49 | }(name)
50 |
51 | wg.Wait()
52 | }
53 |
54 | s.Equal(rounds*2, shared)
55 | })
56 | }
57 | }
58 |
59 | func BenchmarkDistLockers(b *testing.B) {
60 | s := new(LimitersTestSuite)
61 | s.SetT(&testing.T{})
62 | s.SetupSuite()
63 |
64 | lockers := s.distLockers(false)
65 | for name, locker := range lockers {
66 | b.Run(name, func(b *testing.B) {
67 | for i := 0; i < b.N; i++ {
68 | s.Require().NoError(locker.Lock(context.Background()))
69 | s.Require().NoError(locker.Unlock(context.Background()))
70 | }
71 | })
72 | }
73 |
74 | s.TearDownSuite()
75 | }
76 |
--------------------------------------------------------------------------------
/limiters.go:
--------------------------------------------------------------------------------
1 | // Package limiters provides general purpose rate limiter implementations.
2 | package limiters
3 |
4 | import (
5 | "errors"
6 | "log"
7 | "time"
8 | )
9 |
10 | var (
11 | // ErrLimitExhausted is returned by the Limiter in case the number of requests overflows the capacity of a Limiter.
12 | ErrLimitExhausted = errors.New("requests limit exhausted")
13 |
14 | // ErrRaceCondition is returned when there is a race condition while saving a state of a rate limiter.
15 | ErrRaceCondition = errors.New("race condition detected")
16 | )
17 |
18 | // Logger wraps the Log method for logging.
19 | type Logger interface {
20 | // Log logs the given arguments.
21 | Log(v ...interface{})
22 | }
23 |
24 | // StdLogger implements the Logger interface.
25 | type StdLogger struct{}
26 |
27 | // NewStdLogger creates a new instance of StdLogger.
28 | func NewStdLogger() *StdLogger {
29 | return &StdLogger{}
30 | }
31 |
32 | // Log delegates the logging to the std logger.
33 | func (l *StdLogger) Log(v ...interface{}) {
34 | log.Println(v...)
35 | }
36 |
37 | // Clock encapsulates a system Clock.
38 | // Used.
39 | type Clock interface {
40 | // Now returns the current system time.
41 | Now() time.Time
42 | }
43 |
44 | // SystemClock implements the Clock interface by using the real system clock.
45 | type SystemClock struct{}
46 |
47 | // NewSystemClock creates a new instance of SystemClock.
48 | func NewSystemClock() *SystemClock {
49 | return &SystemClock{}
50 | }
51 |
52 | // Now returns the current system time.
53 | func (c *SystemClock) Now() time.Time {
54 | return time.Now()
55 | }
56 |
57 | // Sleep blocks (sleeps) for the given duration.
58 | func (c *SystemClock) Sleep(d time.Duration) {
59 | time.Sleep(d)
60 | }
61 |
--------------------------------------------------------------------------------
/cosmosdb.go:
--------------------------------------------------------------------------------
1 | package limiters
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "net/http"
7 | "time"
8 |
9 | "github.com/Azure/azure-sdk-for-go/sdk/azcore"
10 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
11 | "github.com/pkg/errors"
12 | )
13 |
14 | type cosmosItem struct {
15 | Count int64 `json:"Count"`
16 | PartitionKey string `json:"partitionKey"`
17 | ID string `json:"id"`
18 | TTL int32 `json:"ttl"`
19 | }
20 |
21 | // incrementCosmosItemRMW increments the count of a cosmosItem in a read-modify-write loop.
22 | // It uses optimistic concurrency control (ETags) to ensure consistency.
23 | func incrementCosmosItemRMW(ctx context.Context, client *azcosmos.ContainerClient, partitionKey, id string, ttl time.Duration) (int64, error) {
24 | pk := azcosmos.NewPartitionKey().AppendString(partitionKey)
25 |
26 | for {
27 | err := ctx.Err()
28 | if err != nil {
29 | return 0, err
30 | }
31 |
32 | resp, err := client.ReadItem(ctx, pk, id, nil)
33 | if err != nil {
34 | return 0, errors.Wrap(err, "read of cosmos value failed")
35 | }
36 |
37 | var existing cosmosItem
38 |
39 | err = json.Unmarshal(resp.Value, &existing)
40 | if err != nil {
41 | return 0, errors.Wrap(err, "unmarshal of cosmos value failed")
42 | }
43 |
44 | existing.Count++
45 | existing.TTL = int32(ttl)
46 |
47 | updated, err := json.Marshal(existing)
48 | if err != nil {
49 | return 0, err
50 | }
51 |
52 | _, err = client.ReplaceItem(ctx, pk, id, updated, &azcosmos.ItemOptions{
53 | IfMatchEtag: &resp.ETag,
54 | })
55 | if err == nil {
56 | return existing.Count, nil
57 | }
58 |
59 | var respErr *azcore.ResponseError
60 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusPreconditionFailed {
61 | continue
62 | }
63 |
64 | return 0, errors.Wrap(err, "replace of cosmos value failed")
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/registry_test.go:
--------------------------------------------------------------------------------
1 | package limiters_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strconv"
7 | "testing"
8 | "time"
9 |
10 | "github.com/mennanov/limiters"
11 | "github.com/stretchr/testify/assert"
12 | "github.com/stretchr/testify/require"
13 | )
14 |
15 | type testingLimiter struct{}
16 |
17 | func newTestingLimiter() *testingLimiter {
18 | return &testingLimiter{}
19 | }
20 |
21 | func (l *testingLimiter) Limit(context.Context) (time.Duration, error) {
22 | return 0, nil
23 | }
24 |
25 | func TestRegistry_GetOrCreate(t *testing.T) {
26 | registry := limiters.NewRegistry()
27 | called := false
28 | clock := newFakeClock()
29 | limiter := newTestingLimiter()
30 | l := registry.GetOrCreate("key", func() interface{} {
31 | called = true
32 |
33 | return limiter
34 | }, time.Second, clock.Now())
35 | assert.Equal(t, limiter, l)
36 | // Verify that the closure was called to create a value.
37 | assert.True(t, called)
38 | called = false
39 | l = registry.GetOrCreate("key", func() interface{} {
40 | called = true
41 |
42 | return newTestingLimiter()
43 | }, time.Second, clock.Now())
44 | assert.Equal(t, limiter, l)
45 | // Verify that the closure was NOT called to create a value as it already exists.
46 | assert.False(t, called)
47 | }
48 |
49 | func TestRegistry_DeleteExpired(t *testing.T) {
50 | registry := limiters.NewRegistry()
51 | clock := newFakeClock()
52 | // Add limiters to the registry.
53 | for i := 1; i <= 10; i++ {
54 | registry.GetOrCreate(fmt.Sprintf("key%d", i), func() interface{} {
55 | return newTestingLimiter()
56 | }, time.Second*time.Duration(i), clock.Now())
57 | }
58 |
59 | clock.Sleep(time.Second * 3)
60 | // "touch" the "key3" value that is about to be expired so that its expiration time is extended for 1s.
61 | registry.GetOrCreate("key3", func() interface{} {
62 | return newTestingLimiter()
63 | }, time.Second, clock.Now())
64 |
65 | assert.Equal(t, 2, registry.DeleteExpired(clock.Now()))
66 |
67 | for i := 1; i <= 10; i++ {
68 | if i <= 2 {
69 | assert.False(t, registry.Exists(fmt.Sprintf("key%d", i)))
70 | } else {
71 | assert.True(t, registry.Exists(fmt.Sprintf("key%d", i)))
72 | }
73 | }
74 | }
75 |
76 | func TestRegistry_Delete(t *testing.T) {
77 | registry := limiters.NewRegistry()
78 | clock := newFakeClock()
79 | item := &struct{}{}
80 | require.Equal(t, item, registry.GetOrCreate("key", func() interface{} {
81 | return item
82 | }, time.Second, clock.Now()))
83 | require.Equal(t, item, registry.GetOrCreate("key", func() interface{} {
84 | return &struct{}{}
85 | }, time.Second, clock.Now()))
86 | registry.Delete("key")
87 | assert.False(t, registry.Exists("key"))
88 | }
89 |
90 | // This test is expected to fail when run with the --race flag.
91 | func TestRegistry_ConcurrentUsage(t *testing.T) {
92 | registry := limiters.NewRegistry()
93 | clock := newFakeClock()
94 |
95 | for i := 0; i < 10; i++ {
96 | go func(i int) {
97 | registry.GetOrCreate(strconv.Itoa(i), func() interface{} { return &struct{}{} }, 0, clock.Now())
98 | }(i)
99 | }
100 |
101 | for i := 0; i < 10; i++ {
102 | go func(i int) {
103 | registry.DeleteExpired(clock.Now())
104 | }(i)
105 | }
106 | }
107 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Go tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | branches:
9 | - master
10 |
11 | jobs:
12 | go-test:
13 |
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | #
18 | go-version: [ '1.23', '1.24', '1.25' ]
19 | #
20 |
21 | services:
22 | #
23 | etcd:
24 | image: gcr.io/etcd-development/etcd:v3.6.5
25 | env:
26 | ETCD_LISTEN_CLIENT_URLS: "http://0.0.0.0:2379"
27 | ETCD_ADVERTISE_CLIENT_URLS: "http://0.0.0.0:2379"
28 | ports:
29 | - 2379:2379
30 | redis:
31 | image: redis:8.2
32 | ports:
33 | - 6379:6379
34 | redis-cluster:
35 | image: grokzen/redis-cluster:7.0.10
36 | env:
37 | IP: 0.0.0.0
38 | INITIAL_PORT: 11000
39 | ports:
40 | - 11000-11005:11000-11005
41 | memcached:
42 | image: memcached:1.6
43 | ports:
44 | - 11211:11211
45 | consul:
46 | image: hashicorp/consul:1.21
47 | ports:
48 | - 8500:8500
49 | zookeeper:
50 | image: zookeeper:3.9
51 | ports:
52 | - 2181:2181
53 | postgresql:
54 | image: postgres:18
55 | env:
56 | POSTGRES_PASSWORD: postgres
57 | ports:
58 | - 5432:5432
59 | cosmos:
60 | image: mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:vnext-preview
61 | env:
62 | PROTOCOL: http
63 | COSMOS_HTTP_CONNECTION_WITHOUT_TLS_ALLOWED: "true"
64 | ports:
65 | - "8081:8081"
66 | #
67 |
68 | steps:
69 | - uses: actions/checkout@v4
70 | - name: Setup Go ${{ matrix.go-version }}
71 | uses: actions/setup-go@v5
72 | with:
73 | go-version: ${{ matrix.go-version }}
74 | - name: Setup DynamoDB Local
75 | uses: rrainn/dynamodb-action@v4.0.0
76 | with:
77 | port: 8000
78 | - name: Run tests
79 | #
80 | env:
81 | ETCD_ENDPOINTS: 'localhost:2379'
82 | REDIS_ADDR: 'localhost:6379'
83 | REDIS_NODES: 'localhost:11000,localhost:11001,localhost:11002,localhost:11003,localhost:11004,localhost:11005'
84 | CONSUL_ADDR: 'localhost:8500'
85 | ZOOKEEPER_ENDPOINTS: 'localhost:2181'
86 | AWS_ADDR: 'localhost:8000'
87 | MEMCACHED_ADDR: '127.0.0.1:11211'
88 | POSTGRES_URL: postgres://postgres:postgres@localhost:5432/?sslmode=disable
89 | COSMOS_ADDR: '127.0.0.1:8081'
90 | #
91 | run: |
92 | if [ "${{ matrix.go-version }}" = "1.25" ]; then
93 | go test -race -coverprofile=coverage.txt -covermode=atomic ./...
94 | else
95 | go test -race ./...
96 | fi
97 | - uses: codecov/codecov-action@v5
98 | if: matrix.go-version == '1.25'
99 | with:
100 | verbose: true
101 |
102 |
--------------------------------------------------------------------------------
/registry.go:
--------------------------------------------------------------------------------
1 | package limiters
2 |
3 | import (
4 | "container/heap"
5 | "sync"
6 | "time"
7 | )
8 |
9 | // pqItem is an item in the priority queue.
10 | type pqItem struct {
11 | value interface{}
12 | exp time.Time
13 | index int
14 | key string
15 | }
16 |
17 | // gcPq is a priority queue.
18 | type gcPq []*pqItem
19 |
20 | func (pq gcPq) Len() int { return len(pq) }
21 |
22 | func (pq gcPq) Less(i, j int) bool {
23 | return pq[i].exp.Before(pq[j].exp)
24 | }
25 |
26 | func (pq gcPq) Swap(i, j int) {
27 | pq[i], pq[j] = pq[j], pq[i]
28 | pq[i].index = i
29 | pq[j].index = j
30 | }
31 |
32 | func (pq *gcPq) Push(x interface{}) {
33 | n := len(*pq)
34 | item := x.(*pqItem)
35 | item.index = n
36 | *pq = append(*pq, item)
37 | }
38 |
39 | func (pq *gcPq) Pop() interface{} {
40 | old := *pq
41 | n := len(old)
42 | item := old[n-1]
43 | item.index = -1 // for safety
44 | *pq = old[0 : n-1]
45 |
46 | return item
47 | }
48 |
49 | // Registry is a thread-safe garbage-collectable registry of values.
50 | type Registry struct {
51 | // Guards all the fields below it.
52 | mx sync.Mutex
53 | pq *gcPq
54 | m map[string]*pqItem
55 | }
56 |
57 | // NewRegistry creates a new instance of Registry.
58 | func NewRegistry() *Registry {
59 | pq := make(gcPq, 0)
60 |
61 | return &Registry{pq: &pq, m: make(map[string]*pqItem)}
62 | }
63 |
64 | // GetOrCreate gets an existing value by key and updates its expiration time.
65 | // If the key lookup fails it creates a new value by calling the provided value closure and puts it on the queue.
66 | func (r *Registry) GetOrCreate(key string, value func() interface{}, ttl time.Duration, now time.Time) interface{} {
67 | r.mx.Lock()
68 | defer r.mx.Unlock()
69 |
70 | item, ok := r.m[key]
71 | if ok {
72 | // Update the expiration time.
73 | item.exp = now.Add(ttl)
74 | heap.Fix(r.pq, item.index)
75 | } else {
76 | item = &pqItem{
77 | value: value(),
78 | exp: now.Add(ttl),
79 | key: key,
80 | }
81 | heap.Push(r.pq, item)
82 | r.m[key] = item
83 | }
84 |
85 | return item.value
86 | }
87 |
88 | // DeleteExpired deletes expired items from the registry and returns the number of deleted items.
89 | func (r *Registry) DeleteExpired(now time.Time) int {
90 | r.mx.Lock()
91 | defer r.mx.Unlock()
92 |
93 | c := 0
94 |
95 | for len(*r.pq) != 0 {
96 | item := (*r.pq)[0]
97 | if now.Before(item.exp) {
98 | break
99 | }
100 |
101 | delete(r.m, item.key)
102 | heap.Pop(r.pq)
103 |
104 | c++
105 | }
106 |
107 | return c
108 | }
109 |
110 | // Delete deletes an item from the registry.
111 | func (r *Registry) Delete(key string) {
112 | r.mx.Lock()
113 | defer r.mx.Unlock()
114 |
115 | item, ok := r.m[key]
116 | if !ok {
117 | return
118 | }
119 |
120 | delete(r.m, key)
121 | heap.Remove(r.pq, item.index)
122 | }
123 |
124 | // Exists returns true if an item with the given key exists in the registry.
125 | func (r *Registry) Exists(key string) bool {
126 | r.mx.Lock()
127 | defer r.mx.Unlock()
128 |
129 | _, ok := r.m[key]
130 |
131 | return ok
132 | }
133 |
134 | // Len returns the number of items in the registry.
135 | func (r *Registry) Len() int {
136 | r.mx.Lock()
137 | defer r.mx.Unlock()
138 |
139 | return len(*r.pq)
140 | }
141 |
--------------------------------------------------------------------------------
/dynamodb_test.go:
--------------------------------------------------------------------------------
1 | package limiters_test
2 |
3 | import (
4 | "context"
5 | "time"
6 |
7 | "github.com/aws/aws-sdk-go-v2/aws"
8 | "github.com/aws/aws-sdk-go-v2/service/dynamodb"
9 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
10 | "github.com/mennanov/limiters"
11 | "github.com/pkg/errors"
12 | )
13 |
14 | const testDynamoDBTableName = "limiters-test"
15 |
16 | func CreateTestDynamoDBTable(ctx context.Context, client *dynamodb.Client) error {
17 | _, err := client.CreateTable(ctx, &dynamodb.CreateTableInput{
18 | TableName: aws.String(testDynamoDBTableName),
19 | BillingMode: types.BillingModePayPerRequest,
20 | KeySchema: []types.KeySchemaElement{
21 | {
22 | AttributeName: aws.String("PK"),
23 | KeyType: types.KeyTypeHash,
24 | },
25 | {
26 | AttributeName: aws.String("SK"),
27 | KeyType: types.KeyTypeRange,
28 | },
29 | },
30 | AttributeDefinitions: []types.AttributeDefinition{
31 | {
32 | AttributeName: aws.String("PK"),
33 | AttributeType: types.ScalarAttributeTypeS,
34 | },
35 | {
36 | AttributeName: aws.String("SK"),
37 | AttributeType: types.ScalarAttributeTypeS,
38 | },
39 | },
40 | })
41 | if err != nil {
42 | return errors.Wrap(err, "create test dynamodb table failed")
43 | }
44 |
45 | _, err = client.UpdateTimeToLive(ctx, &dynamodb.UpdateTimeToLiveInput{
46 | TableName: aws.String(testDynamoDBTableName),
47 | TimeToLiveSpecification: &types.TimeToLiveSpecification{
48 | AttributeName: aws.String("TTL"),
49 | Enabled: aws.Bool(true),
50 | },
51 | })
52 | if err != nil {
53 | return errors.Wrap(err, "set dynamodb ttl failed")
54 | }
55 |
56 | ctx, cancel := context.WithTimeout(ctx, time.Second*10)
57 | defer cancel()
58 |
59 | for {
60 | resp, err := client.DescribeTable(ctx, &dynamodb.DescribeTableInput{
61 | TableName: aws.String(testDynamoDBTableName),
62 | })
63 | if err == nil {
64 | return errors.Wrap(err, "failed to describe test table")
65 | }
66 |
67 | if resp.Table.TableStatus == types.TableStatusActive {
68 | return nil
69 | }
70 |
71 | select {
72 | case <-ctx.Done():
73 | return errors.New("failed to verify dynamodb test table is created")
74 | default:
75 | }
76 | }
77 | }
78 |
79 | func DeleteTestDynamoDBTable(ctx context.Context, client *dynamodb.Client) error {
80 | _, err := client.DeleteTable(ctx, &dynamodb.DeleteTableInput{
81 | TableName: aws.String(testDynamoDBTableName),
82 | })
83 | if err != nil {
84 | return errors.Wrap(err, "delete test dynamodb table failed")
85 | }
86 |
87 | return nil
88 | }
89 |
90 | func (s *LimitersTestSuite) TestDynamoRaceCondition() {
91 | backend := limiters.NewLeakyBucketDynamoDB(s.dynamodbClient, "race-check", s.dynamoDBTableProps, time.Minute, true)
92 |
93 | err := backend.SetState(context.Background(), limiters.LeakyBucketState{})
94 | s.Require().NoError(err)
95 |
96 | _, err = backend.State(context.Background())
97 | s.Require().NoError(err)
98 |
99 | _, err = s.dynamodbClient.PutItem(context.Background(), &dynamodb.PutItemInput{
100 | Item: map[string]types.AttributeValue{
101 | s.dynamoDBTableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: "race-check"},
102 | s.dynamoDBTableProps.SortKeyName: &types.AttributeValueMemberS{Value: "race-check"},
103 | "Version": &types.AttributeValueMemberN{Value: "5"},
104 | },
105 | TableName: &s.dynamoDBTableProps.TableName,
106 | })
107 | s.Require().NoError(err)
108 |
109 | err = backend.SetState(context.Background(), limiters.LeakyBucketState{})
110 | s.Require().ErrorIs(err, limiters.ErrRaceCondition, err)
111 | }
112 |
--------------------------------------------------------------------------------
/examples/example_grpc_simple_limiter_test.go:
--------------------------------------------------------------------------------
1 | package examples_test
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "log"
8 | "net"
9 | "os"
10 | "strings"
11 | "time"
12 |
13 | "github.com/mennanov/limiters"
14 | "github.com/mennanov/limiters/examples"
15 | pb "github.com/mennanov/limiters/examples/helloworld"
16 | "github.com/redis/go-redis/v9"
17 | clientv3 "go.etcd.io/etcd/client/v3"
18 | "google.golang.org/grpc"
19 | "google.golang.org/grpc/codes"
20 | "google.golang.org/grpc/credentials/insecure"
21 | "google.golang.org/grpc/status"
22 | )
23 |
24 | func Example_simpleGRPCLimiter() {
25 | // Set up a gRPC server.
26 | lc := net.ListenConfig{}
27 |
28 | lis, err := lc.Listen(context.Background(), "tcp", examples.Port)
29 | if err != nil {
30 | log.Fatalf("failed to listen: %v", err)
31 | }
32 | defer lis.Close()
33 | // Connect to etcd.
34 | etcdClient, err := clientv3.New(clientv3.Config{
35 | Endpoints: strings.Split(os.Getenv("ETCD_ENDPOINTS"), ","),
36 | DialTimeout: time.Second,
37 | })
38 | if err != nil {
39 | log.Fatalf("could not connect to etcd: %v", err)
40 | }
41 | defer etcdClient.Close()
42 | // Connect to Redis.
43 | redisClient := redis.NewClient(&redis.Options{
44 | Addr: os.Getenv("REDIS_ADDR"),
45 | })
46 | defer redisClient.Close()
47 |
48 | rate := time.Second * 3
49 | limiter := limiters.NewTokenBucket(
50 | 2,
51 | rate,
52 | limiters.NewLockEtcd(etcdClient, "/ratelimiter_lock/simple/", limiters.NewStdLogger()),
53 | limiters.NewTokenBucketRedis(
54 | redisClient,
55 | "ratelimiter/simple",
56 | rate, false),
57 | limiters.NewSystemClock(), limiters.NewStdLogger(),
58 | )
59 |
60 | // Add a unary interceptor middleware to rate limit all requests.
61 | s := grpc.NewServer(grpc.UnaryInterceptor(
62 | func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
63 | w, err := limiter.Limit(ctx)
64 | if errors.Is(err, limiters.ErrLimitExhausted) {
65 | return nil, status.Errorf(codes.ResourceExhausted, "try again later in %s", w)
66 | } else if err != nil {
67 | // The limiter failed. This error should be logged and examined.
68 | log.Println(err)
69 |
70 | return nil, status.Error(codes.Internal, "internal error")
71 | }
72 |
73 | return handler(ctx, req)
74 | }))
75 |
76 | pb.RegisterGreeterServer(s, &examples.Server{})
77 |
78 | go func() {
79 | // Start serving.
80 | err = s.Serve(lis)
81 | if err != nil {
82 | log.Fatalf("failed to serve: %v", err)
83 | }
84 | }()
85 |
86 | defer s.GracefulStop()
87 |
88 | // Set up a client connection to the server.
89 | conn, err := grpc.NewClient(fmt.Sprintf("localhost%s", examples.Port), grpc.WithTransportCredentials(insecure.NewCredentials()))
90 | if err != nil {
91 | log.Fatalf("did not connect: %v", err)
92 | }
93 | defer conn.Close()
94 |
95 | c := pb.NewGreeterClient(conn)
96 |
97 | // Contact the server and print out its response.
98 | ctx, cancel := context.WithTimeout(context.Background(), time.Second)
99 | defer cancel()
100 |
101 | r, err := c.SayHello(ctx, &pb.HelloRequest{Name: "Alice"})
102 | if err != nil {
103 | log.Fatalf("could not greet: %v", err)
104 | }
105 |
106 | fmt.Println(r.GetMessage())
107 |
108 | r, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Bob"})
109 | if err != nil {
110 | log.Fatalf("could not greet: %v", err)
111 | }
112 |
113 | fmt.Println(r.GetMessage())
114 |
115 | _, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Peter"})
116 | if err == nil {
117 | log.Fatal("error expected, but got nil")
118 | }
119 |
120 | fmt.Println(err)
121 | // Output: Hello Alice
122 | // Hello Bob
123 | // rpc error: code = ResourceExhausted desc = try again later in 3s
124 | }
125 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/mennanov/limiters
2 |
3 | //
4 | go 1.25.3
5 |
6 | //
7 |
8 | require (
9 | github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
10 | github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.4.1
11 | github.com/alessandro-c/gomemcached-lock v1.0.0
12 | github.com/aws/aws-sdk-go-v2 v1.40.1
13 | github.com/aws/aws-sdk-go-v2/config v1.32.3
14 | github.com/aws/aws-sdk-go-v2/credentials v1.19.3
15 | github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.27
16 | github.com/aws/aws-sdk-go-v2/service/dynamodb v1.53.3
17 | github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
18 | github.com/cenkalti/backoff/v3 v3.2.2
19 | github.com/go-redsync/redsync/v4 v4.14.1
20 | github.com/google/uuid v1.6.0
21 | github.com/hashicorp/consul/api v1.33.0
22 | github.com/lib/pq v1.10.9
23 | github.com/pkg/errors v0.9.1
24 | github.com/redis/go-redis/v9 v9.17.1
25 | github.com/samuel/go-zookeeper v0.0.0-20201211165307-7117e9ea2414
26 | github.com/stretchr/testify v1.11.1
27 | go.etcd.io/etcd/api/v3 v3.6.6
28 | go.etcd.io/etcd/client/v3 v3.6.6
29 | google.golang.org/grpc v1.77.0
30 | google.golang.org/protobuf v1.36.10
31 | )
32 |
33 | require (
34 | github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect
35 | github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
36 | github.com/armon/go-metrics v0.4.1 // indirect
37 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.15 // indirect
38 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.15 // indirect
39 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.15 // indirect
40 | github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
41 | github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.7 // indirect
42 | github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect
43 | github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.15 // indirect
44 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.15 // indirect
45 | github.com/aws/aws-sdk-go-v2/service/signin v1.0.3 // indirect
46 | github.com/aws/aws-sdk-go-v2/service/sso v1.30.6 // indirect
47 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.11 // indirect
48 | github.com/aws/aws-sdk-go-v2/service/sts v1.41.3 // indirect
49 | github.com/aws/smithy-go v1.24.0 // indirect
50 | github.com/cespare/xxhash/v2 v2.3.0 // indirect
51 | github.com/coreos/go-semver v0.3.1 // indirect
52 | github.com/coreos/go-systemd/v22 v22.5.0 // indirect
53 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
54 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
55 | github.com/fatih/color v1.16.0 // indirect
56 | github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
57 | github.com/gogo/protobuf v1.3.2 // indirect
58 | github.com/golang/protobuf v1.5.4 // indirect
59 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect
60 | github.com/hashicorp/errwrap v1.1.0 // indirect
61 | github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
62 | github.com/hashicorp/go-hclog v1.5.0 // indirect
63 | github.com/hashicorp/go-immutable-radix v1.3.1 // indirect
64 | github.com/hashicorp/go-multierror v1.1.1 // indirect
65 | github.com/hashicorp/go-rootcerts v1.0.2 // indirect
66 | github.com/hashicorp/golang-lru v0.5.4 // indirect
67 | github.com/hashicorp/serf v0.10.1 // indirect
68 | github.com/mattn/go-colorable v0.1.13 // indirect
69 | github.com/mattn/go-isatty v0.0.20 // indirect
70 | github.com/mitchellh/go-homedir v1.1.0 // indirect
71 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
72 | github.com/thanhpk/randstr v1.0.4 // indirect
73 | go.etcd.io/etcd/client/pkg/v3 v3.6.6 // indirect
74 | go.uber.org/multierr v1.11.0 // indirect
75 | go.uber.org/zap v1.27.0 // indirect
76 | golang.org/x/exp v0.0.0-20250808145144-a408d31f581a // indirect
77 | golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 // indirect
78 | golang.org/x/sys v0.37.0 // indirect
79 | golang.org/x/text v0.30.0 // indirect
80 | google.golang.org/genproto/googleapis/api v0.0.0-20251022142026-3a174f9686a8 // indirect
81 | google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect
82 | gopkg.in/yaml.v3 v3.0.1 // indirect
83 | )
84 |
--------------------------------------------------------------------------------
/dynamodb.go:
--------------------------------------------------------------------------------
1 | package limiters
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/aws/aws-sdk-go-v2/aws"
7 | "github.com/aws/aws-sdk-go-v2/service/dynamodb"
8 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
9 | "github.com/pkg/errors"
10 | )
11 |
12 | // DynamoDBTableProperties are supplied to DynamoDB limiter backends.
13 | // This struct informs the backend what the name of the table is and what the names of the key fields are.
14 | type DynamoDBTableProperties struct {
15 | // TableName is the name of the table.
16 | TableName string
17 | // PartitionKeyName is the name of the PartitionKey attribute.
18 | PartitionKeyName string
19 | // SortKeyName is the name of the SortKey attribute.
20 | SortKeyName string
21 | // SortKeyUsed indicates if a SortKey is present on the table.
22 | SortKeyUsed bool
23 | // TTLFieldName is the name of the attribute configured for TTL.
24 | TTLFieldName string
25 | }
26 |
27 | // LoadDynamoDBTableProperties fetches a table description with the supplied client and returns a DynamoDBTableProperties struct.
28 | func LoadDynamoDBTableProperties(ctx context.Context, client *dynamodb.Client, tableName string) (DynamoDBTableProperties, error) {
29 | resp, err := client.DescribeTable(ctx, &dynamodb.DescribeTableInput{
30 | TableName: &tableName,
31 | })
32 | if err != nil {
33 | return DynamoDBTableProperties{}, errors.Wrap(err, "describe dynamodb table failed")
34 | }
35 |
36 | ttlResp, err := client.DescribeTimeToLive(ctx, &dynamodb.DescribeTimeToLiveInput{
37 | TableName: &tableName,
38 | })
39 | if err != nil {
40 | return DynamoDBTableProperties{}, errors.Wrap(err, "describe dynamobd table ttl failed")
41 | }
42 |
43 | data, err := loadTableData(resp.Table, ttlResp.TimeToLiveDescription)
44 | if err != nil {
45 | return data, err
46 | }
47 |
48 | return data, nil
49 | }
50 |
51 | func loadTableData(table *types.TableDescription, ttl *types.TimeToLiveDescription) (DynamoDBTableProperties, error) {
52 | data := DynamoDBTableProperties{
53 | TableName: *table.TableName,
54 | }
55 |
56 | data, err := loadTableKeys(data, table)
57 | if err != nil {
58 | return data, errors.Wrap(err, "invalid dynamodb table")
59 | }
60 |
61 | return populateTableTTL(data, ttl), nil
62 | }
63 |
64 | func loadTableKeys(data DynamoDBTableProperties, table *types.TableDescription) (DynamoDBTableProperties, error) {
65 | for _, key := range table.KeySchema {
66 | if key.KeyType == types.KeyTypeHash {
67 | data.PartitionKeyName = *key.AttributeName
68 |
69 | continue
70 | }
71 |
72 | data.SortKeyName = *key.AttributeName
73 | data.SortKeyUsed = true
74 | }
75 |
76 | for _, attribute := range table.AttributeDefinitions {
77 | name := *attribute.AttributeName
78 | if name != data.PartitionKeyName && name != data.SortKeyName {
79 | continue
80 | }
81 |
82 | if name == data.PartitionKeyName && attribute.AttributeType != types.ScalarAttributeTypeS {
83 | return data, errors.New("dynamodb partition key must be of type S")
84 | } else if data.SortKeyUsed && name == data.SortKeyName && attribute.AttributeType != types.ScalarAttributeTypeS {
85 | return data, errors.New("dynamodb sort key must be of type S")
86 | }
87 | }
88 |
89 | return data, nil
90 | }
91 |
92 | func populateTableTTL(data DynamoDBTableProperties, ttl *types.TimeToLiveDescription) DynamoDBTableProperties {
93 | if ttl.TimeToLiveStatus != types.TimeToLiveStatusEnabled {
94 | return data
95 | }
96 |
97 | data.TTLFieldName = *ttl.AttributeName
98 |
99 | return data
100 | }
101 |
102 | func dynamoDBputItem(ctx context.Context, client *dynamodb.Client, input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) {
103 | resp, err := client.PutItem(ctx, input)
104 | if err != nil {
105 | var cErr *types.ConditionalCheckFailedException
106 | if errors.As(err, &cErr) {
107 | return nil, ErrRaceCondition
108 | }
109 |
110 | return nil, errors.Wrap(err, "unable to set dynamodb item")
111 | }
112 |
113 | return resp, nil
114 | }
115 |
116 | func dynamoDBGetItem(ctx context.Context, client *dynamodb.Client, input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) {
117 | input.ConsistentRead = aws.Bool(true)
118 |
119 | var (
120 | resp *dynamodb.GetItemOutput
121 | err error
122 | )
123 |
124 | done := make(chan struct{})
125 |
126 | go func() {
127 | defer close(done)
128 |
129 | resp, err = client.GetItem(ctx, input)
130 | }()
131 |
132 | select {
133 | case <-done:
134 | case <-ctx.Done():
135 | return nil, ctx.Err()
136 | }
137 |
138 | if err != nil {
139 | return nil, errors.Wrap(err, "unable to retrieve dynamodb item")
140 | }
141 |
142 | return resp, nil
143 | }
144 |
--------------------------------------------------------------------------------
/concurrent_buffer_test.go:
--------------------------------------------------------------------------------
1 | package limiters_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "sync"
7 | "testing"
8 | "time"
9 |
10 | "github.com/google/uuid"
11 | l "github.com/mennanov/limiters"
12 | )
13 |
14 | func (s *LimitersTestSuite) concurrentBuffers(capacity int64, ttl time.Duration, clock l.Clock) map[string]*l.ConcurrentBuffer {
15 | buffers := make(map[string]*l.ConcurrentBuffer)
16 |
17 | for lockerName, locker := range s.lockers(true) {
18 | for bName, b := range s.concurrentBufferBackends(ttl, clock) {
19 | buffers[lockerName+":"+bName] = l.NewConcurrentBuffer(locker, b, capacity, s.logger)
20 | }
21 | }
22 |
23 | return buffers
24 | }
25 |
26 | func (s *LimitersTestSuite) concurrentBufferBackends(ttl time.Duration, clock l.Clock) map[string]l.ConcurrentBufferBackend {
27 | return map[string]l.ConcurrentBufferBackend{
28 | "ConcurrentBufferInMemory": l.NewConcurrentBufferInMemory(l.NewRegistry(), ttl, clock),
29 | "ConcurrentBufferRedis": l.NewConcurrentBufferRedis(s.redisClient, uuid.New().String(), ttl, clock),
30 | "ConcurrentBufferRedisCluster": l.NewConcurrentBufferRedis(s.redisClusterClient, uuid.New().String(), ttl, clock),
31 | "ConcurrentBufferMemcached": l.NewConcurrentBufferMemcached(s.memcacheClient, uuid.New().String(), ttl, clock),
32 | }
33 | }
34 |
35 | func (s *LimitersTestSuite) TestConcurrentBufferNoOverflow() {
36 | clock := newFakeClock()
37 | capacity := int64(10)
38 |
39 | ttl := time.Second
40 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) {
41 | s.Run(name, func() {
42 | wg := sync.WaitGroup{}
43 | for i := int64(0); i < capacity; i++ {
44 | wg.Add(1)
45 |
46 | go func(i int64, buffer *l.ConcurrentBuffer) {
47 | defer wg.Done()
48 |
49 | key := fmt.Sprintf("key%d", i)
50 | s.NoError(buffer.Limit(context.TODO(), key))
51 | s.NoError(buffer.Done(context.TODO(), key))
52 | }(i, buffer)
53 | }
54 |
55 | wg.Wait()
56 | s.NoError(buffer.Limit(context.TODO(), "last"))
57 | s.NoError(buffer.Done(context.TODO(), "last"))
58 | })
59 | }
60 | }
61 |
62 | func (s *LimitersTestSuite) TestConcurrentBufferOverflow() {
63 | clock := newFakeClock()
64 | capacity := int64(3)
65 |
66 | ttl := time.Second
67 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) {
68 | s.Run(name, func() {
69 | mu := sync.Mutex{}
70 |
71 | var errors []error
72 |
73 | wg := sync.WaitGroup{}
74 | for i := int64(0); i <= capacity; i++ {
75 | wg.Add(1)
76 |
77 | go func(i int64, buffer *l.ConcurrentBuffer) {
78 | defer wg.Done()
79 |
80 | err := buffer.Limit(context.TODO(), fmt.Sprintf("key%d", i))
81 | if err != nil {
82 | mu.Lock()
83 |
84 | errors = append(errors, err)
85 |
86 | mu.Unlock()
87 | }
88 | }(i, buffer)
89 | }
90 |
91 | wg.Wait()
92 | s.Equal([]error{l.ErrLimitExhausted}, errors)
93 | })
94 | }
95 | }
96 |
97 | func (s *LimitersTestSuite) TestConcurrentBufferExpiredKeys() {
98 | clock := newFakeClock()
99 | capacity := int64(2)
100 |
101 | ttl := time.Second
102 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) {
103 | s.Run(name, func() {
104 | s.Require().NoError(buffer.Limit(context.TODO(), "key1"))
105 | clock.Sleep(ttl / 2)
106 | s.Require().NoError(buffer.Limit(context.TODO(), "key2"))
107 | clock.Sleep(ttl / 2)
108 | // No error is expected (despite the following request overflows the capacity) as the first key has already
109 | // expired by this time.
110 | s.NoError(buffer.Limit(context.TODO(), "key3"))
111 | })
112 | }
113 | }
114 |
115 | func (s *LimitersTestSuite) TestConcurrentBufferDuplicateKeys() {
116 | clock := newFakeClock()
117 | capacity := int64(2)
118 |
119 | ttl := time.Second
120 | for name, buffer := range s.concurrentBuffers(capacity, ttl, clock) {
121 | s.Run(name, func() {
122 | s.Require().NoError(buffer.Limit(context.TODO(), "key1"))
123 | s.Require().NoError(buffer.Limit(context.TODO(), "key2"))
124 | // No error is expected as it should just update the timestamp of the existing key.
125 | s.NoError(buffer.Limit(context.TODO(), "key1"))
126 | })
127 | }
128 | }
129 |
130 | func BenchmarkConcurrentBuffers(b *testing.B) {
131 | s := new(LimitersTestSuite)
132 | s.SetT(&testing.T{})
133 | s.SetupSuite()
134 |
135 | capacity := int64(1)
136 | ttl := time.Second
137 | clock := newFakeClock()
138 |
139 | buffers := s.concurrentBuffers(capacity, ttl, clock)
140 | for name, buffer := range buffers {
141 | b.Run(name, func(b *testing.B) {
142 | for i := 0; i < b.N; i++ {
143 | s.Require().NoError(buffer.Limit(context.TODO(), "key1"))
144 | }
145 | })
146 | }
147 |
148 | s.TearDownSuite()
149 | }
150 |
--------------------------------------------------------------------------------
/examples/example_grpc_ip_limiter_test.go:
--------------------------------------------------------------------------------
1 | // Package examples implements a gRPC server for Greeter service using rate limiters.
2 | package examples_test
3 |
4 | import (
5 | "context"
6 | "errors"
7 | "fmt"
8 | "log"
9 | "net"
10 | "os"
11 | "strings"
12 | "time"
13 |
14 | "github.com/mennanov/limiters"
15 | "github.com/mennanov/limiters/examples"
16 | pb "github.com/mennanov/limiters/examples/helloworld"
17 | "github.com/redis/go-redis/v9"
18 | clientv3 "go.etcd.io/etcd/client/v3"
19 | "google.golang.org/grpc"
20 | "google.golang.org/grpc/codes"
21 | "google.golang.org/grpc/credentials/insecure"
22 | "google.golang.org/grpc/peer"
23 | "google.golang.org/grpc/status"
24 | )
25 |
26 | func Example_ipGRPCLimiter() {
27 | // Set up a gRPC server.
28 | lc := net.ListenConfig{}
29 |
30 | lis, err := lc.Listen(context.Background(), "tcp", examples.Port)
31 | if err != nil {
32 | log.Fatalf("failed to listen: %v", err)
33 | }
34 | defer lis.Close()
35 | // Connect to etcd.
36 | etcdClient, err := clientv3.New(clientv3.Config{
37 | Endpoints: strings.Split(os.Getenv("ETCD_ENDPOINTS"), ","),
38 | DialTimeout: time.Second,
39 | })
40 | if err != nil {
41 | log.Fatalf("could not connect to etcd: %v", err)
42 | }
43 | defer etcdClient.Close()
44 | // Connect to Redis.
45 | redisClient := redis.NewClient(&redis.Options{
46 | Addr: os.Getenv("REDIS_ADDR"),
47 | })
48 | defer redisClient.Close()
49 |
50 | logger := limiters.NewStdLogger()
51 | // Registry is needed to keep track of previously created limiters. It can remove the expired limiters to free up
52 | // memory.
53 | registry := limiters.NewRegistry()
54 | // The rate is used to define the token bucket refill rate and also the TTL for the limiters (both in Redis and in
55 | // the registry).
56 | rate := time.Second * 3
57 | clock := limiters.NewSystemClock()
58 |
59 | go func() {
60 | // Garbage collect the old limiters to prevent memory leaks.
61 | for {
62 | <-time.After(rate)
63 | registry.DeleteExpired(clock.Now())
64 | }
65 | }()
66 |
67 | // Add a unary interceptor middleware to rate limit requests.
68 | s := grpc.NewServer(grpc.UnaryInterceptor(
69 | func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
70 | p, ok := peer.FromContext(ctx)
71 |
72 | var ip string
73 |
74 | if !ok {
75 | log.Println("no peer info available")
76 |
77 | ip = "unknown"
78 | } else {
79 | ip = p.Addr.String()
80 | }
81 |
82 | // Create an IP address based rate limiter.
83 | bucket := registry.GetOrCreate(ip, func() interface{} {
84 | return limiters.NewTokenBucket(
85 | 2,
86 | rate,
87 | limiters.NewLockEtcd(etcdClient, fmt.Sprintf("/lock/ip/%s", ip), logger),
88 | limiters.NewTokenBucketRedis(
89 | redisClient,
90 | fmt.Sprintf("/ratelimiter/ip/%s", ip),
91 | rate, false),
92 | clock, logger)
93 | }, rate, clock.Now())
94 |
95 | w, err := bucket.(*limiters.TokenBucket).Limit(ctx)
96 | if errors.Is(err, limiters.ErrLimitExhausted) {
97 | return nil, status.Errorf(codes.ResourceExhausted, "try again later in %s", w)
98 | } else if err != nil {
99 | // The limiter failed. This error should be logged and examined.
100 | log.Println(err)
101 |
102 | return nil, status.Error(codes.Internal, "internal error")
103 | }
104 |
105 | return handler(ctx, req)
106 | }))
107 |
108 | pb.RegisterGreeterServer(s, &examples.Server{})
109 |
110 | go func() {
111 | // Start serving.
112 | err = s.Serve(lis)
113 | if err != nil {
114 | log.Fatalf("failed to serve: %v", err)
115 | }
116 | }()
117 |
118 | defer s.GracefulStop()
119 |
120 | // Set up a client connection to the server.
121 | conn, err := grpc.NewClient(fmt.Sprintf("localhost%s", examples.Port), grpc.WithTransportCredentials(insecure.NewCredentials()))
122 | if err != nil {
123 | log.Fatalf("did not connect: %v", err)
124 | }
125 | defer conn.Close()
126 |
127 | c := pb.NewGreeterClient(conn)
128 |
129 | // Contact the server and print out its response.
130 | ctx, cancel := context.WithTimeout(context.Background(), time.Second)
131 | defer cancel()
132 |
133 | r, err := c.SayHello(ctx, &pb.HelloRequest{Name: "Alice"})
134 | if err != nil {
135 | log.Fatalf("could not greet: %v", err)
136 | }
137 |
138 | fmt.Println(r.GetMessage())
139 |
140 | r, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Bob"})
141 | if err != nil {
142 | log.Fatalf("could not greet: %v", err)
143 | }
144 |
145 | fmt.Println(r.GetMessage())
146 |
147 | _, err = c.SayHello(ctx, &pb.HelloRequest{Name: "Peter"})
148 | if err == nil {
149 | log.Fatal("error expected, but got nil")
150 | }
151 |
152 | fmt.Println(err)
153 | // Output: Hello Alice
154 | // Hello Bob
155 | // rpc error: code = ResourceExhausted desc = try again later in 3s
156 | }
157 |
--------------------------------------------------------------------------------
/examples/helloworld/helloworld_grpc.pb.go:
--------------------------------------------------------------------------------
1 | // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
2 | // versions:
3 | // - protoc-gen-go-grpc v1.5.1
4 | // - protoc v5.28.3
5 | // source: helloworld.proto
6 |
7 | package helloworld
8 |
9 | import (
10 | context "context"
11 |
12 | grpc "google.golang.org/grpc"
13 | codes "google.golang.org/grpc/codes"
14 | status "google.golang.org/grpc/status"
15 | )
16 |
17 | // This is a compile-time assertion to ensure that this generated file
18 | // is compatible with the grpc package it is being compiled against.
19 | // Requires gRPC-Go v1.64.0 or later.
20 | const _ = grpc.SupportPackageIsVersion9
21 |
22 | const (
23 | Greeter_SayHello_FullMethodName = "/helloworld.Greeter/SayHello"
24 | )
25 |
26 | // GreeterClient is the client API for Greeter service.
27 | //
28 | // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
29 | //
30 | // The greeting service definition.
31 | type GreeterClient interface {
32 | // Sends a greeting
33 | SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error)
34 | }
35 |
36 | type greeterClient struct {
37 | cc grpc.ClientConnInterface
38 | }
39 |
40 | func NewGreeterClient(cc grpc.ClientConnInterface) GreeterClient {
41 | return &greeterClient{cc}
42 | }
43 |
44 | func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) {
45 | cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
46 | out := new(HelloReply)
47 | err := c.cc.Invoke(ctx, Greeter_SayHello_FullMethodName, in, out, cOpts...)
48 | if err != nil {
49 | return nil, err
50 | }
51 | return out, nil
52 | }
53 |
54 | // GreeterServer is the server API for Greeter service.
55 | // All implementations must embed UnimplementedGreeterServer
56 | // for forward compatibility.
57 | //
58 | // The greeting service definition.
59 | type GreeterServer interface {
60 | // Sends a greeting
61 | SayHello(context.Context, *HelloRequest) (*HelloReply, error)
62 | mustEmbedUnimplementedGreeterServer()
63 | }
64 |
65 | // UnimplementedGreeterServer must be embedded to have
66 | // forward compatible implementations.
67 | //
68 | // NOTE: this should be embedded by value instead of pointer to avoid a nil
69 | // pointer dereference when methods are called.
70 | type UnimplementedGreeterServer struct{}
71 |
72 | func (UnimplementedGreeterServer) SayHello(context.Context, *HelloRequest) (*HelloReply, error) {
73 | return nil, status.Errorf(codes.Unimplemented, "method SayHello not implemented")
74 | }
75 | func (UnimplementedGreeterServer) mustEmbedUnimplementedGreeterServer() {}
76 | func (UnimplementedGreeterServer) testEmbeddedByValue() {}
77 |
78 | // UnsafeGreeterServer may be embedded to opt out of forward compatibility for this service.
79 | // Use of this interface is not recommended, as added methods to GreeterServer will
80 | // result in compilation errors.
81 | type UnsafeGreeterServer interface {
82 | mustEmbedUnimplementedGreeterServer()
83 | }
84 |
85 | func RegisterGreeterServer(s grpc.ServiceRegistrar, srv GreeterServer) {
86 | // If the following call pancis, it indicates UnimplementedGreeterServer was
87 | // embedded by pointer and is nil. This will cause panics if an
88 | // unimplemented method is ever invoked, so we test this at initialization
89 | // time to prevent it from happening at runtime later due to I/O.
90 | if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
91 | t.testEmbeddedByValue()
92 | }
93 | s.RegisterService(&Greeter_ServiceDesc, srv)
94 | }
95 |
96 | func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
97 | in := new(HelloRequest)
98 | if err := dec(in); err != nil {
99 | return nil, err
100 | }
101 | if interceptor == nil {
102 | return srv.(GreeterServer).SayHello(ctx, in)
103 | }
104 | info := &grpc.UnaryServerInfo{
105 | Server: srv,
106 | FullMethod: Greeter_SayHello_FullMethodName,
107 | }
108 | handler := func(ctx context.Context, req interface{}) (interface{}, error) {
109 | return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest))
110 | }
111 | return interceptor(ctx, in, info, handler)
112 | }
113 |
114 | // Greeter_ServiceDesc is the grpc.ServiceDesc for Greeter service.
115 | // It's only intended for direct use with grpc.RegisterService,
116 | // and not to be introspected or modified (even as a copy)
117 | var Greeter_ServiceDesc = grpc.ServiceDesc{
118 | ServiceName: "helloworld.Greeter",
119 | HandlerType: (*GreeterServer)(nil),
120 | Methods: []grpc.MethodDesc{
121 | {
122 | MethodName: "SayHello",
123 | Handler: _Greeter_SayHello_Handler,
124 | },
125 | },
126 | Streams: []grpc.StreamDesc{},
127 | Metadata: "helloworld.proto",
128 | }
129 |
--------------------------------------------------------------------------------
/fixedwindow_test.go:
--------------------------------------------------------------------------------
1 | package limiters_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | "github.com/google/uuid"
9 | l "github.com/mennanov/limiters"
10 | )
11 |
12 | // fixedWindows returns all the possible FixedWindow combinations.
13 | func (s *LimitersTestSuite) fixedWindows(capacity int64, rate time.Duration, clock l.Clock) map[string]*l.FixedWindow {
14 | windows := make(map[string]*l.FixedWindow)
15 | for name, inc := range s.fixedWindowIncrementers() {
16 | windows[name] = l.NewFixedWindow(capacity, rate, inc, clock)
17 | }
18 |
19 | return windows
20 | }
21 |
22 | func (s *LimitersTestSuite) fixedWindowIncrementers() map[string]l.FixedWindowIncrementer {
23 | return map[string]l.FixedWindowIncrementer{
24 | "FixedWindowInMemory": l.NewFixedWindowInMemory(),
25 | "FixedWindowRedis": l.NewFixedWindowRedis(s.redisClient, uuid.New().String()),
26 | "FixedWindowRedisCluster": l.NewFixedWindowRedis(s.redisClusterClient, uuid.New().String()),
27 | "FixedWindowMemcached": l.NewFixedWindowMemcached(s.memcacheClient, uuid.New().String()),
28 | "FixedWindowDynamoDB": l.NewFixedWindowDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps),
29 | "FixedWindowCosmosDB": l.NewFixedWindowCosmosDB(s.cosmosContainerClient, uuid.New().String()),
30 | }
31 | }
32 |
33 | var fixedWindowTestCases = []struct {
34 | capacity int64
35 | rate time.Duration
36 | requestCount int
37 | requestRate time.Duration
38 | missExpected int
39 | }{
40 | {
41 | capacity: 2,
42 | rate: time.Millisecond * 100,
43 | requestCount: 20,
44 | requestRate: time.Millisecond * 25,
45 | missExpected: 10,
46 | },
47 | {
48 | capacity: 4,
49 | rate: time.Millisecond * 100,
50 | requestCount: 20,
51 | requestRate: time.Millisecond * 25,
52 | missExpected: 0,
53 | },
54 | {
55 | capacity: 2,
56 | rate: time.Millisecond * 100,
57 | requestCount: 15,
58 | requestRate: time.Millisecond * 33,
59 | missExpected: 5,
60 | },
61 | }
62 |
63 | func (s *LimitersTestSuite) TestFixedWindowFakeClock() {
64 | clock := newFakeClockWithTime(time.Date(2019, 8, 30, 0, 0, 0, 0, time.UTC))
65 | for _, testCase := range fixedWindowTestCases {
66 | for name, bucket := range s.fixedWindows(testCase.capacity, testCase.rate, clock) {
67 | s.Run(name, func() {
68 | clock.reset()
69 |
70 | miss := 0
71 |
72 | for i := 0; i < testCase.requestCount; i++ {
73 | // No pause for the first request.
74 | if i > 0 {
75 | clock.Sleep(testCase.requestRate)
76 | }
77 |
78 | _, err := bucket.Limit(context.TODO())
79 | if err != nil {
80 | s.Equal(l.ErrLimitExhausted, err)
81 |
82 | miss++
83 | }
84 | }
85 |
86 | s.Equal(testCase.missExpected, miss, testCase)
87 | })
88 | }
89 | }
90 | }
91 |
92 | func (s *LimitersTestSuite) TestFixedWindowOverflow() {
93 | clock := newFakeClockWithTime(time.Date(2019, 8, 30, 0, 0, 0, 0, time.UTC))
94 | for name, bucket := range s.fixedWindows(2, time.Second, clock) {
95 | s.Run(name, func() {
96 | clock.reset()
97 |
98 | w, err := bucket.Limit(context.TODO())
99 | s.Require().NoError(err)
100 | s.Equal(time.Duration(0), w)
101 | w, err = bucket.Limit(context.TODO())
102 | s.Require().NoError(err)
103 | s.Equal(time.Duration(0), w)
104 | w, err = bucket.Limit(context.TODO())
105 | s.Require().Equal(l.ErrLimitExhausted, err)
106 | s.Equal(time.Second, w)
107 | clock.Sleep(time.Second)
108 |
109 | w, err = bucket.Limit(context.TODO())
110 | s.Require().NoError(err)
111 | s.Equal(time.Duration(0), w)
112 | })
113 | }
114 | }
115 |
116 | func (s *LimitersTestSuite) TestFixedWindowDynamoDBPartitionKey() {
117 | clock := newFakeClockWithTime(time.Date(2019, 8, 30, 0, 0, 0, 0, time.UTC))
118 | incrementor := l.NewFixedWindowDynamoDB(s.dynamodbClient, "partitionKey1", s.dynamoDBTableProps)
119 | window := l.NewFixedWindow(2, time.Millisecond*100, incrementor, clock)
120 |
121 | w, err := window.Limit(context.TODO())
122 | s.Require().NoError(err)
123 | s.Equal(time.Duration(0), w)
124 | w, err = window.Limit(context.TODO())
125 | s.Require().NoError(err)
126 | s.Equal(time.Duration(0), w)
127 | // The third call should fail for the "partitionKey1", but succeed for "partitionKey2".
128 | w, err = window.Limit(l.NewFixedWindowDynamoDBContext(context.Background(), "partitionKey2"))
129 | s.Require().NoError(err)
130 | s.Equal(time.Duration(0), w)
131 | }
132 |
133 | func BenchmarkFixedWindows(b *testing.B) {
134 | s := new(LimitersTestSuite)
135 | s.SetT(&testing.T{})
136 | s.SetupSuite()
137 |
138 | capacity := int64(1)
139 | rate := time.Second
140 | clock := newFakeClock()
141 |
142 | windows := s.fixedWindows(capacity, rate, clock)
143 | for name, window := range windows {
144 | b.Run(name, func(b *testing.B) {
145 | for i := 0; i < b.N; i++ {
146 | _, err := window.Limit(context.TODO())
147 | s.Require().NoError(err)
148 | }
149 | })
150 | }
151 |
152 | s.TearDownSuite()
153 | }
154 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Distributed rate limiters for Golang
2 | [](https://github.com/mennanov/limiters/actions/workflows/tests.yml)
3 | [](https://codecov.io/gh/mennanov/limiters)
4 | [](https://goreportcard.com/report/github.com/mennanov/limiters)
5 | [](https://godoc.org/github.com/mennanov/limiters)
6 |
7 | Rate limiters for distributed applications in Golang with configurable back-ends and distributed locks.
8 | Any types of back-ends and locks can be used that implement certain minimalistic interfaces.
9 | Most common implementations are already provided.
10 |
11 | - [`Token bucket`](https://en.wikipedia.org/wiki/Token_bucket)
12 | - in-memory (local)
13 | - redis
14 | - memcached
15 | - etcd
16 | - dynamodb
17 | - cosmos db
18 |
19 | Allows requests at a certain input rate with possible bursts configured by the capacity parameter.
20 | The output rate equals to the input rate.
21 | Precise (no over or under-limiting), but requires a lock (provided).
22 |
23 | - [`Leaky bucket`](https://en.wikipedia.org/wiki/Leaky_bucket#As_a_queue)
24 | - in-memory (local)
25 | - redis
26 | - memcached
27 | - etcd
28 | - dynamodb
29 | - cosmos db
30 |
31 | Puts requests in a FIFO queue to be processed at a constant rate.
32 | There are no restrictions on the input rate except for the capacity of the queue.
33 | Requires a lock (provided).
34 |
35 | - [`Fixed window counter`](https://konghq.com/blog/how-to-design-a-scalable-rate-limiting-algorithm/)
36 | - in-memory (local)
37 | - redis
38 | - memcached
39 | - dynamodb
40 | - cosmos db
41 |
42 | Simple and resources efficient algorithm that does not need a lock.
43 | Precision may be adjusted by the size of the window.
44 | May be lenient when there are many requests around the boundary between 2 adjacent windows.
45 |
46 | - [`Sliding window counter`](https://konghq.com/blog/how-to-design-a-scalable-rate-limiting-algorithm/)
47 | - in-memory (local)
48 | - redis
49 | - memcached
50 | - dynamodb
51 | - cosmos db
52 |
53 | Smoothes out the bursts around the boundary between 2 adjacent windows.
54 | Needs as twice more memory as the `Fixed Window` algorithm (2 windows instead of 1 at a time).
55 | It will disallow _all_ the requests in case when a client is flooding the service with requests.
56 | It's the client's responsibility to handle a disallowed request properly: wait before making a new one again.
57 |
58 | - `Concurrent buffer`
59 | - in-memory (local)
60 | - redis
61 | - memcached
62 |
63 | Allows concurrent requests up to the given capacity.
64 | Requires a lock (provided).
65 |
66 | ## gRPC example
67 |
68 | Global token bucket rate limiter for a gRPC service example:
69 | ```go
70 | // examples/example_grpc_simple_limiter_test.go
71 | rate := time.Second * 3
72 | limiter := limiters.NewTokenBucket(
73 | 2,
74 | rate,
75 | limiters.NewLockerEtcd(etcdClient, "/ratelimiter_lock/simple/", limiters.NewStdLogger()),
76 | limiters.NewTokenBucketRedis(
77 | redisClient,
78 | "ratelimiter/simple",
79 | rate, false),
80 | limiters.NewSystemClock(), limiters.NewStdLogger(),
81 | )
82 |
83 | // Add a unary interceptor middleware to rate limit all requests.
84 | s := grpc.NewServer(grpc.UnaryInterceptor(
85 | func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
86 | w, err := limiter.Limit(ctx)
87 | if err == limiters.ErrLimitExhausted {
88 | return nil, status.Errorf(codes.ResourceExhausted, "try again later in %s", w)
89 | } else if err != nil {
90 | // The limiter failed. This error should be logged and examined.
91 | log.Println(err)
92 | return nil, status.Error(codes.Internal, "internal error")
93 | }
94 | return handler(ctx, req)
95 | }))
96 | ```
97 |
98 | For something close to a real world example see the IP address based gRPC global rate limiter in the
99 | [examples](examples/example_grpc_ip_limiter_test.go) directory.
100 |
101 | ## DynamoDB
102 |
103 | The use of DynamoDB requires the creation of a DynamoDB Table prior to use. An existing table can be used or a new one can be created. Depending on the limiter backend:
104 |
105 | * Partition Key
106 | - String
107 | - Required for all Backends
108 | * Sort Key
109 | - String
110 | - Backends:
111 | - FixedWindow
112 | - SlidingWindow
113 | * TTL
114 | - Number
115 | - Backends:
116 | - FixedWindow
117 | - SlidingWindow
118 | - LeakyBucket
119 | - TokenBucket
120 |
121 | All DynamoDB backends accept a `DynamoDBTableProperties` struct as a paramater. This can be manually created or use the `LoadDynamoDBTableProperties` with the table name. When using `LoadDynamoDBTableProperties`, the table description is fetched from AWS and verified that the table can be used for Limiter backends. Results of `LoadDynamoDBTableProperties` are cached.
122 |
123 | ## Azure Cosmos DB for NoSQL
124 |
125 | The use of Cosmos DB requires the creation of a database and container prior to use.
126 |
127 | The container must have a default TTL set, otherwise TTL [will not take effect](https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/time-to-live#time-to-live-configurations).
128 |
129 | The partition key should be `/partitionKey`.
130 |
131 | ## Distributed locks
132 |
133 | Some algorithms require a distributed lock to guarantee consistency during concurrent requests.
134 | In case there is only 1 running application instance then no distributed lock is needed
135 | as all the algorithms are thread-safe (use `LockNoop`).
136 |
137 | Supported backends:
138 | - [etcd](https://etcd.io/)
139 | - [Consul](https://www.consul.io/)
140 | - [Zookeeper](https://zookeeper.apache.org/)
141 | - [Redis](https://redis.io/)
142 | - [Memcached](https://memcached.org/)
143 | - [PostgreSQL](https://www.postgresql.org/)
144 |
145 | ## Memcached
146 |
147 | It's important to understand that memcached is not ideal for implementing reliable locks or data persistence due to its inherent limitations:
148 |
149 | - No guaranteed data retention: Memcached can evict data at any point due to memory pressure, even if it appears to have space available. This can lead to unexpected lock releases or data loss.
150 | - Lack of distributed locking features: Memcached doesn't offer functionalities like distributed coordination required for consistent locking across multiple servers.
151 |
152 | If memcached exists already and it is okay to handle burst traffic caused by unexpected evicted data, Memcached-based implementations are convenient, otherwise Redis-based implementations will be better choices.
153 |
154 | ## Testing
155 |
156 | Run tests locally:
157 | ```bash
158 | make test
159 | ```
160 | Run benchmarks locally:
161 | ```bash
162 | make benchmark
163 | ```
164 | Run both locally:
165 | ```bash
166 | make
167 | ```
168 |
--------------------------------------------------------------------------------
/locks.go:
--------------------------------------------------------------------------------
1 | package limiters
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "time"
7 |
8 | lock "github.com/alessandro-c/gomemcached-lock"
9 | "github.com/alessandro-c/gomemcached-lock/adapters/gomemcache"
10 | "github.com/bradfitz/gomemcache/memcache"
11 | "github.com/cenkalti/backoff/v3"
12 | "github.com/go-redsync/redsync/v4"
13 | redsyncredis "github.com/go-redsync/redsync/v4/redis"
14 | "github.com/hashicorp/consul/api"
15 | _ "github.com/lib/pq"
16 | "github.com/pkg/errors"
17 | "github.com/samuel/go-zookeeper/zk"
18 | clientv3 "go.etcd.io/etcd/client/v3"
19 | "go.etcd.io/etcd/client/v3/concurrency"
20 | )
21 |
22 | // DistLocker is a context aware distributed locker (interface is similar to sync.Locker).
23 | type DistLocker interface {
24 | // Lock locks the locker.
25 | Lock(ctx context.Context) error
26 | // Unlock unlocks the previously successfully locked lock.
27 | Unlock(ctx context.Context) error
28 | }
29 |
30 | // LockNoop is a no-op implementation of the DistLocker interface.
31 | // It should only be used with the in-memory backends as they are already thread-safe and don't need distributed locks.
32 | type LockNoop struct{}
33 |
34 | // NewLockNoop creates a new LockNoop.
35 | func NewLockNoop() *LockNoop {
36 | return &LockNoop{}
37 | }
38 |
39 | // Lock imitates locking.
40 | func (n LockNoop) Lock(ctx context.Context) error {
41 | return ctx.Err()
42 | }
43 |
44 | // Unlock does nothing.
45 | func (n LockNoop) Unlock(_ context.Context) error {
46 | return nil
47 | }
48 |
49 | // LockEtcd implements the DistLocker interface using etcd.
50 | //
51 | // See https://github.com/etcd-io/etcd/blob/master/Documentation/learning/why.md#using-etcd-for-distributed-coordination
52 | type LockEtcd struct {
53 | cli *clientv3.Client
54 | prefix string
55 | logger Logger
56 | mu *concurrency.Mutex
57 | session *concurrency.Session
58 | }
59 |
60 | // NewLockEtcd creates a new instance of LockEtcd.
61 | func NewLockEtcd(cli *clientv3.Client, prefix string, logger Logger) *LockEtcd {
62 | return &LockEtcd{cli: cli, prefix: prefix, logger: logger}
63 | }
64 |
65 | // Lock creates a new session-based lock in etcd and locks it.
66 | func (l *LockEtcd) Lock(ctx context.Context) error {
67 | var err error
68 |
69 | l.session, err = concurrency.NewSession(l.cli, concurrency.WithTTL(1))
70 | if err != nil {
71 | return errors.Wrap(err, "failed to create an etcd session")
72 | }
73 |
74 | l.mu = concurrency.NewMutex(l.session, l.prefix)
75 |
76 | return errors.Wrap(l.mu.Lock(ctx), "failed to lock a mutex in etcd")
77 | }
78 |
79 | // Unlock unlocks the previously locked lock.
80 | func (l *LockEtcd) Unlock(ctx context.Context) error {
81 | defer func() {
82 | err := l.session.Close()
83 | if err != nil {
84 | l.logger.Log(err)
85 | }
86 | }()
87 |
88 | return errors.Wrap(l.mu.Unlock(ctx), "failed to unlock a mutex in etcd")
89 | }
90 |
91 | // LockConsul is a wrapper around github.com/hashicorp/consul/api.Lock that implements the DistLocker interface.
92 | type LockConsul struct {
93 | lock *api.Lock
94 | }
95 |
96 | // NewLockConsul creates a new LockConsul instance.
97 | func NewLockConsul(lock *api.Lock) *LockConsul {
98 | return &LockConsul{lock: lock}
99 | }
100 |
101 | // Lock locks the lock in Consul.
102 | func (l *LockConsul) Lock(ctx context.Context) error {
103 | _, err := l.lock.Lock(ctx.Done())
104 |
105 | return errors.Wrap(err, "failed to lock a mutex in consul")
106 | }
107 |
108 | // Unlock unlocks the lock in Consul.
109 | func (l *LockConsul) Unlock(_ context.Context) error {
110 | return l.lock.Unlock()
111 | }
112 |
113 | // LockZookeeper is a wrapper around github.com/samuel/go-zookeeper/zk.Lock that implements the DistLocker interface.
114 | type LockZookeeper struct {
115 | lock *zk.Lock
116 | }
117 |
118 | // NewLockZookeeper creates a new instance of LockZookeeper.
119 | func NewLockZookeeper(lock *zk.Lock) *LockZookeeper {
120 | return &LockZookeeper{lock: lock}
121 | }
122 |
123 | // Lock locks the lock in Zookeeper.
124 | // TODO: add context aware support once https://github.com/samuel/go-zookeeper/pull/168 is merged.
125 | func (l *LockZookeeper) Lock(_ context.Context) error {
126 | return l.lock.Lock()
127 | }
128 |
129 | // Unlock unlocks the lock in Zookeeper.
130 | func (l *LockZookeeper) Unlock(_ context.Context) error {
131 | return l.lock.Unlock()
132 | }
133 |
134 | // LockRedis is a wrapper around github.com/go-redsync/redsync that implements the DistLocker interface.
135 | type LockRedis struct {
136 | mutex *redsync.Mutex
137 | }
138 |
139 | // NewLockRedis creates a new instance of LockRedis.
140 | func NewLockRedis(pool redsyncredis.Pool, mutexName string, options ...redsync.Option) *LockRedis {
141 | rs := redsync.New(pool)
142 | mutex := rs.NewMutex(mutexName, options...)
143 |
144 | return &LockRedis{mutex: mutex}
145 | }
146 |
147 | // Lock locks the lock in Redis.
148 | func (l *LockRedis) Lock(ctx context.Context) error {
149 | err := l.mutex.LockContext(ctx)
150 |
151 | return errors.Wrap(err, "failed to lock a mutex in redis")
152 | }
153 |
154 | // Unlock unlocks the lock in Redis.
155 | func (l *LockRedis) Unlock(ctx context.Context) error {
156 | ok, err := l.mutex.UnlockContext(ctx)
157 | if !ok || err != nil {
158 | return errors.Wrap(err, "failed to unlock a mutex in redis")
159 | }
160 |
161 | return nil
162 | }
163 |
164 | // LockMemcached is a wrapper around github.com/alessandro-c/gomemcached-lock that implements the DistLocker interface.
165 | // It is caller's responsibility to make sure the uniqueness of mutexName, and not to use the same key in multiple
166 | // Memcached-based implementations.
167 | type LockMemcached struct {
168 | locker *lock.Locker
169 | mutexName string
170 | backoff backoff.BackOff
171 | }
172 |
173 | // NewLockMemcached creates a new instance of LockMemcached.
174 | // Default backoff is to retry every 100ms for 100 times (10 seconds).
175 | func NewLockMemcached(client *memcache.Client, mutexName string) *LockMemcached {
176 | adapter := gomemcache.New(client)
177 | locker := lock.New(adapter, mutexName, "")
178 | b := backoff.WithMaxRetries(backoff.NewConstantBackOff(100*time.Millisecond), 100)
179 |
180 | return &LockMemcached{
181 | locker: locker,
182 | mutexName: mutexName,
183 | backoff: b,
184 | }
185 | }
186 |
187 | // WithLockAcquireBackoff sets the backoff policy for retrying an operation.
188 | func (l *LockMemcached) WithLockAcquireBackoff(b backoff.BackOff) *LockMemcached {
189 | l.backoff = b
190 |
191 | return l
192 | }
193 |
194 | // Lock locks the lock in Memcached.
195 | func (l *LockMemcached) Lock(ctx context.Context) error {
196 | o := func() error { return l.locker.Lock(time.Minute) }
197 |
198 | return backoff.Retry(o, l.backoff)
199 | }
200 |
201 | // Unlock unlocks the lock in Memcached.
202 | func (l *LockMemcached) Unlock(ctx context.Context) error {
203 | return l.locker.Release()
204 | }
205 |
206 | // LockPostgreSQL is an implementation of the DistLocker interface using PostgreSQL's advisory lock.
207 | type LockPostgreSQL struct {
208 | db *sql.DB
209 | id int64
210 | tx *sql.Tx
211 | }
212 |
213 | // NewLockPostgreSQL creates a new LockPostgreSQL.
214 | func NewLockPostgreSQL(db *sql.DB, id int64) *LockPostgreSQL {
215 | return &LockPostgreSQL{db, id, nil}
216 | }
217 |
218 | // Make sure LockPostgreSQL implements DistLocker interface.
219 | var _ DistLocker = (*LockPostgreSQL)(nil)
220 |
221 | // Lock acquire an advisory lock in PostgreSQL.
222 | func (l *LockPostgreSQL) Lock(ctx context.Context) error {
223 | var err error
224 |
225 | l.tx, err = l.db.BeginTx(ctx, &sql.TxOptions{})
226 | if err != nil {
227 | return err
228 | }
229 |
230 | _, err = l.tx.ExecContext(ctx, "SELECT pg_advisory_xact_lock($1)", l.id)
231 |
232 | return err
233 | }
234 |
235 | // Unlock releases an advisory lock in PostgreSQL.
236 | func (l *LockPostgreSQL) Unlock(ctx context.Context) error {
237 | return l.tx.Rollback()
238 | }
239 |
--------------------------------------------------------------------------------
/slidingwindow_test.go:
--------------------------------------------------------------------------------
1 | package limiters_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | "github.com/google/uuid"
9 | l "github.com/mennanov/limiters"
10 | )
11 |
12 | // slidingWindows returns all the possible SlidingWindow combinations.
13 | func (s *LimitersTestSuite) slidingWindows(capacity int64, rate time.Duration, clock l.Clock, epsilon float64) map[string]*l.SlidingWindow {
14 | windows := make(map[string]*l.SlidingWindow)
15 | for name, inc := range s.slidingWindowIncrementers() {
16 | windows[name] = l.NewSlidingWindow(capacity, rate, inc, clock, epsilon)
17 | }
18 |
19 | return windows
20 | }
21 |
22 | func (s *LimitersTestSuite) slidingWindowIncrementers() map[string]l.SlidingWindowIncrementer {
23 | return map[string]l.SlidingWindowIncrementer{
24 | "SlidingWindowInMemory": l.NewSlidingWindowInMemory(),
25 | "SlidingWindowRedis": l.NewSlidingWindowRedis(s.redisClient, uuid.New().String()),
26 | "SlidingWindowRedisCluster": l.NewSlidingWindowRedis(s.redisClusterClient, uuid.New().String()),
27 | "SlidingWindowMemcached": l.NewSlidingWindowMemcached(s.memcacheClient, uuid.New().String()),
28 | "SlidingWindowDynamoDB": l.NewSlidingWindowDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps),
29 | "SlidingWindowCosmosDB": l.NewSlidingWindowCosmosDB(s.cosmosContainerClient, uuid.New().String()),
30 | }
31 | }
32 |
33 | var slidingWindowTestCases = []struct {
34 | capacity int64
35 | rate time.Duration
36 | epsilon float64
37 | results []struct {
38 | w time.Duration
39 | e error
40 | }
41 | requests int
42 | delta float64
43 | }{
44 | {
45 | capacity: 1,
46 | rate: time.Second,
47 | epsilon: 1e-9,
48 | requests: 6,
49 | results: []struct {
50 | w time.Duration
51 | e error
52 | }{
53 | {
54 | 0, nil,
55 | },
56 | {
57 | time.Second * 2, l.ErrLimitExhausted,
58 | },
59 | {
60 | 0, nil,
61 | },
62 | {
63 | time.Second * 2, l.ErrLimitExhausted,
64 | },
65 | {
66 | 0, nil,
67 | },
68 | {
69 | time.Second * 2, l.ErrLimitExhausted,
70 | },
71 | },
72 | },
73 | {
74 | capacity: 2,
75 | rate: time.Second,
76 | epsilon: 3e-9,
77 | requests: 10,
78 | delta: 1,
79 | results: []struct {
80 | w time.Duration
81 | e error
82 | }{
83 | {
84 | 0, nil,
85 | },
86 | {
87 | 0, nil,
88 | },
89 | {
90 | time.Second + time.Second*2/3, l.ErrLimitExhausted,
91 | },
92 | {
93 | 0, nil,
94 | },
95 | {
96 | time.Second/3 + time.Second/2, l.ErrLimitExhausted,
97 | },
98 | {
99 | 0, nil,
100 | },
101 | {
102 | time.Second, l.ErrLimitExhausted,
103 | },
104 | {
105 | 0, nil,
106 | },
107 | {
108 | time.Second, l.ErrLimitExhausted,
109 | },
110 | {
111 | 0, nil,
112 | },
113 | },
114 | },
115 | {
116 | capacity: 3,
117 | rate: time.Second,
118 | epsilon: 1e-9,
119 | requests: 11,
120 | delta: 0,
121 | results: []struct {
122 | w time.Duration
123 | e error
124 | }{
125 | {
126 | 0, nil,
127 | },
128 | {
129 | 0, nil,
130 | },
131 | {
132 | 0, nil,
133 | },
134 | {
135 | time.Second + time.Second/2, l.ErrLimitExhausted,
136 | },
137 | {
138 | 0, nil,
139 | },
140 | {
141 | time.Second / 2, l.ErrLimitExhausted,
142 | },
143 | {
144 | 0, nil,
145 | },
146 | {
147 | time.Second, l.ErrLimitExhausted,
148 | },
149 | {
150 | 0, nil,
151 | },
152 | {
153 | time.Second, l.ErrLimitExhausted,
154 | },
155 | {
156 | 0, nil,
157 | },
158 | },
159 | },
160 | {
161 | capacity: 4,
162 | rate: time.Second,
163 | epsilon: 1e-9,
164 | requests: 17,
165 | delta: 0,
166 | results: []struct {
167 | w time.Duration
168 | e error
169 | }{
170 | {
171 | 0, nil,
172 | },
173 | {
174 | 0, nil,
175 | },
176 | {
177 | 0, nil,
178 | },
179 | {
180 | 0, nil,
181 | },
182 | {
183 | time.Second + time.Second*2/5, l.ErrLimitExhausted,
184 | },
185 | {
186 | 0, nil,
187 | },
188 | {
189 | time.Second * 2 / 5, l.ErrLimitExhausted,
190 | },
191 | {
192 | 0, nil,
193 | },
194 | {
195 | time.Second/5 + time.Second/4, l.ErrLimitExhausted,
196 | },
197 | {
198 | 0, nil,
199 | },
200 | {
201 | time.Second / 2, l.ErrLimitExhausted,
202 | },
203 | {
204 | 0, nil,
205 | },
206 | {
207 | time.Second / 2, l.ErrLimitExhausted,
208 | },
209 | {
210 | 0, nil,
211 | },
212 | {
213 | time.Second / 2, l.ErrLimitExhausted,
214 | },
215 | {
216 | 0, nil,
217 | },
218 | {
219 | time.Second / 2, l.ErrLimitExhausted,
220 | },
221 | },
222 | },
223 | {
224 | capacity: 5,
225 | rate: time.Second,
226 | epsilon: 3e-9,
227 | requests: 18,
228 | delta: 1,
229 | results: []struct {
230 | w time.Duration
231 | e error
232 | }{
233 | {
234 | 0, nil,
235 | },
236 | {
237 | 0, nil,
238 | },
239 | {
240 | 0, nil,
241 | },
242 | {
243 | 0, nil,
244 | },
245 | {
246 | 0, nil,
247 | },
248 | {
249 | time.Second + time.Second/3, l.ErrLimitExhausted,
250 | },
251 | {
252 | 0, nil,
253 | },
254 | {
255 | time.Second * 2 / 6, l.ErrLimitExhausted,
256 | },
257 | {
258 | 0, nil,
259 | },
260 | {
261 | time.Second * 2 / 6, l.ErrLimitExhausted,
262 | },
263 | {
264 | 0, nil,
265 | },
266 | {
267 | time.Second / 2, l.ErrLimitExhausted,
268 | },
269 | {
270 | 0, nil,
271 | },
272 | {
273 | time.Second / 2, l.ErrLimitExhausted,
274 | },
275 | {
276 | 0, nil,
277 | },
278 | {
279 | time.Second / 2, l.ErrLimitExhausted,
280 | },
281 | {
282 | 0, nil,
283 | },
284 | {
285 | time.Second / 2, l.ErrLimitExhausted,
286 | },
287 | {
288 | 0, nil,
289 | },
290 | },
291 | },
292 | }
293 |
294 | func (s *LimitersTestSuite) TestSlidingWindowOverflowAndWait() {
295 | clock := newFakeClockWithTime(time.Date(2019, 9, 3, 0, 0, 0, 0, time.UTC))
296 | for _, testCase := range slidingWindowTestCases {
297 | for name, bucket := range s.slidingWindows(testCase.capacity, testCase.rate, clock, testCase.epsilon) {
298 | s.Run(name, func() {
299 | clock.reset()
300 |
301 | for i := 0; i < testCase.requests; i++ {
302 | w, err := bucket.Limit(context.TODO())
303 |
304 | s.Require().LessOrEqual(i, len(testCase.results)-1)
305 | s.InDelta(testCase.results[i].w, w, testCase.delta, i)
306 | s.Equal(testCase.results[i].e, err, i)
307 | clock.Sleep(w)
308 | }
309 | })
310 | }
311 | }
312 | }
313 |
314 | func (s *LimitersTestSuite) TestSlidingWindowOverflowAndNoWait() {
315 | capacity := int64(3)
316 |
317 | clock := newFakeClock()
318 | for name, bucket := range s.slidingWindows(capacity, time.Second, clock, 1e-9) {
319 | s.Run(name, func() {
320 | clock.reset()
321 |
322 | // Keep sending requests until it reaches the capacity.
323 | for i := int64(0); i < capacity; i++ {
324 | w, err := bucket.Limit(context.TODO())
325 | s.Require().NoError(err)
326 | s.Require().Equal(time.Duration(0), w)
327 | clock.Sleep(time.Millisecond)
328 | }
329 |
330 | // The next request will be the first one to be rejected.
331 | w, err := bucket.Limit(context.TODO())
332 | s.Require().Equal(l.ErrLimitExhausted, err)
333 |
334 | expected := clock.Now().Add(w)
335 |
336 | // Send a few more requests, all of them should be told to come back at the same time.
337 | for i := int64(0); i < capacity; i++ {
338 | w, err = bucket.Limit(context.TODO())
339 | s.Require().Equal(l.ErrLimitExhausted, err)
340 |
341 | actual := clock.Now().Add(w)
342 | s.Require().Equal(expected, actual, i)
343 | clock.Sleep(time.Millisecond)
344 | }
345 |
346 | // Wait until it is ready.
347 | clock.Sleep(expected.Sub(clock.Now()))
348 |
349 | w, err = bucket.Limit(context.TODO())
350 | s.Require().NoError(err)
351 | s.Require().Equal(time.Duration(0), w)
352 | })
353 | }
354 | }
355 |
356 | func BenchmarkSlidingWindows(b *testing.B) {
357 | s := new(LimitersTestSuite)
358 | s.SetT(&testing.T{})
359 | s.SetupSuite()
360 |
361 | capacity := int64(1)
362 | rate := time.Second
363 | clock := newFakeClock()
364 | epsilon := 1e-9
365 |
366 | windows := s.slidingWindows(capacity, rate, clock, epsilon)
367 | for name, window := range windows {
368 | b.Run(name, func(b *testing.B) {
369 | for i := 0; i < b.N; i++ {
370 | _, err := window.Limit(context.TODO())
371 | s.Require().NoError(err)
372 | }
373 | })
374 | }
375 |
376 | s.TearDownSuite()
377 | }
378 |
--------------------------------------------------------------------------------
/limiters_test.go:
--------------------------------------------------------------------------------
1 | package limiters_test
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "fmt"
7 | "hash/fnv"
8 | "os"
9 | "strings"
10 | "sync"
11 | "testing"
12 | "time"
13 |
14 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
15 | "github.com/aws/aws-sdk-go-v2/aws"
16 | "github.com/aws/aws-sdk-go-v2/config"
17 | "github.com/aws/aws-sdk-go-v2/credentials"
18 | "github.com/aws/aws-sdk-go-v2/service/dynamodb"
19 | "github.com/bradfitz/gomemcache/memcache"
20 | "github.com/go-redsync/redsync/v4/redis/goredis/v9"
21 | "github.com/google/uuid"
22 | "github.com/hashicorp/consul/api"
23 | l "github.com/mennanov/limiters"
24 | "github.com/redis/go-redis/v9"
25 | "github.com/samuel/go-zookeeper/zk"
26 | "github.com/stretchr/testify/suite"
27 | clientv3 "go.etcd.io/etcd/client/v3"
28 | )
29 |
30 | type fakeClock struct {
31 | mu sync.Mutex
32 | initial time.Time
33 | t time.Time
34 | }
35 |
36 | func newFakeClock() *fakeClock {
37 | now := time.Now()
38 |
39 | return &fakeClock{t: now, initial: now}
40 | }
41 |
42 | func newFakeClockWithTime(t time.Time) *fakeClock {
43 | return &fakeClock{t: t, initial: t}
44 | }
45 |
46 | func (c *fakeClock) Now() time.Time {
47 | c.mu.Lock()
48 | defer c.mu.Unlock()
49 |
50 | return c.t
51 | }
52 |
53 | func (c *fakeClock) Sleep(d time.Duration) {
54 | if d == 0 {
55 | return
56 | }
57 |
58 | c.mu.Lock()
59 | defer c.mu.Unlock()
60 |
61 | c.t = c.t.Add(d)
62 | }
63 |
64 | func (c *fakeClock) reset() {
65 | c.mu.Lock()
66 | defer c.mu.Unlock()
67 |
68 | c.t = c.initial
69 | }
70 |
71 | type LimitersTestSuite struct {
72 | suite.Suite
73 |
74 | etcdClient *clientv3.Client
75 | redisClient *redis.Client
76 | redisClusterClient *redis.ClusterClient
77 | consulClient *api.Client
78 | zkConn *zk.Conn
79 | logger *l.StdLogger
80 | dynamodbClient *dynamodb.Client
81 | dynamoDBTableProps l.DynamoDBTableProperties
82 | memcacheClient *memcache.Client
83 | pgDb *sql.DB
84 | cosmosClient *azcosmos.Client
85 | cosmosContainerClient *azcosmos.ContainerClient
86 | }
87 |
88 | func (s *LimitersTestSuite) SetupSuite() {
89 | var err error
90 |
91 | s.etcdClient, err = clientv3.New(clientv3.Config{
92 | Endpoints: strings.Split(os.Getenv("ETCD_ENDPOINTS"), ","),
93 | DialTimeout: time.Second,
94 | })
95 | s.Require().NoError(err)
96 | _, err = s.etcdClient.Status(context.Background(), s.etcdClient.Endpoints()[0])
97 | s.Require().NoError(err)
98 |
99 | s.redisClient = redis.NewClient(&redis.Options{
100 | Addr: os.Getenv("REDIS_ADDR"),
101 | })
102 | s.Require().NoError(s.redisClient.Ping(context.Background()).Err())
103 |
104 | s.redisClusterClient = redis.NewClusterClient(&redis.ClusterOptions{
105 | Addrs: strings.Split(os.Getenv("REDIS_NODES"), ","),
106 | })
107 | s.Require().NoError(s.redisClusterClient.Ping(context.Background()).Err())
108 |
109 | s.consulClient, err = api.NewClient(&api.Config{Address: os.Getenv("CONSUL_ADDR")})
110 | s.Require().NoError(err)
111 | _, err = s.consulClient.Status().Leader()
112 | s.Require().NoError(err)
113 |
114 | s.zkConn, _, err = zk.Connect(strings.Split(os.Getenv("ZOOKEEPER_ENDPOINTS"), ","), time.Second)
115 | s.Require().NoError(err)
116 | _, _, err = s.zkConn.Get("/")
117 | s.Require().NoError(err)
118 | s.logger = l.NewStdLogger()
119 |
120 | awsCfg, err := config.LoadDefaultConfig(context.Background(),
121 | config.WithRegion("us-east-1"),
122 | config.WithCredentialsProvider(credentials.StaticCredentialsProvider{
123 | Value: aws.Credentials{
124 | AccessKeyID: "dummy", SecretAccessKey: "dummy", SessionToken: "dummy",
125 | Source: "Hard-coded credentials; values are irrelevant for local DynamoDB",
126 | },
127 | }),
128 | )
129 | s.Require().NoError(err)
130 |
131 | endpoint := fmt.Sprintf("http://%s", os.Getenv("AWS_ADDR"))
132 | s.dynamodbClient = dynamodb.NewFromConfig(awsCfg, func(options *dynamodb.Options) {
133 | options.BaseEndpoint = &endpoint
134 | })
135 |
136 | _ = DeleteTestDynamoDBTable(context.Background(), s.dynamodbClient)
137 | s.Require().NoError(CreateTestDynamoDBTable(context.Background(), s.dynamodbClient))
138 | s.dynamoDBTableProps, err = l.LoadDynamoDBTableProperties(context.Background(), s.dynamodbClient, testDynamoDBTableName)
139 | s.Require().NoError(err)
140 |
141 | s.memcacheClient = memcache.New(strings.Split(os.Getenv("MEMCACHED_ADDR"), ",")...)
142 | s.Require().NoError(s.memcacheClient.Ping())
143 |
144 | s.pgDb, err = sql.Open("postgres", os.Getenv("POSTGRES_URL"))
145 | s.Require().NoError(err)
146 | s.Require().NoError(s.pgDb.PingContext(context.Background()))
147 |
148 | // https://learn.microsoft.com/en-us/azure/cosmos-db/emulator#authentication
149 | connString := fmt.Sprintf("AccountEndpoint=http://%s/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;", os.Getenv("COSMOS_ADDR"))
150 | s.cosmosClient, err = azcosmos.NewClientFromConnectionString(connString, &azcosmos.ClientOptions{})
151 | s.Require().NoError(err)
152 |
153 | err = CreateCosmosDBContainer(context.Background(), s.cosmosClient)
154 | if err != nil {
155 | s.Require().ErrorContains(err, "Database 'limiters-db-test' already exists")
156 | }
157 |
158 | s.cosmosContainerClient, err = s.cosmosClient.NewContainer(testCosmosDBName, testCosmosContainerName)
159 | s.Require().NoError(err)
160 | }
161 |
162 | func (s *LimitersTestSuite) TearDownSuite() {
163 | s.Assert().NoError(s.etcdClient.Close())
164 | s.Assert().NoError(s.redisClient.Close())
165 | s.Assert().NoError(s.redisClusterClient.Close())
166 | s.Assert().NoError(DeleteTestDynamoDBTable(context.Background(), s.dynamodbClient))
167 | s.Assert().NoError(s.memcacheClient.Close())
168 | s.Assert().NoError(s.pgDb.Close())
169 | s.Assert().NoError(DeleteCosmosDBContainer(context.Background(), s.cosmosClient))
170 | }
171 |
172 | func TestLimitersTestSuite(t *testing.T) {
173 | suite.Run(t, new(LimitersTestSuite))
174 | }
175 |
176 | // lockers returns all possible lockers (including noop).
177 | func (s *LimitersTestSuite) lockers(generateKeys bool) map[string]l.DistLocker {
178 | lockers := s.distLockers(generateKeys)
179 | lockers["LockNoop"] = l.NewLockNoop()
180 |
181 | return lockers
182 | }
183 |
184 | func hash(s string) int64 {
185 | h := fnv.New32a()
186 |
187 | _, err := h.Write([]byte(s))
188 | if err != nil {
189 | panic(err)
190 | }
191 |
192 | return int64(h.Sum32())
193 | }
194 |
195 | // distLockers returns distributed lockers only.
196 | func (s *LimitersTestSuite) distLockers(generateKeys bool) map[string]l.DistLocker {
197 | randomKey := uuid.New().String()
198 | consulKey := randomKey
199 | etcdKey := randomKey
200 | zkKey := "/" + randomKey
201 | redisKey := randomKey
202 | memcacheKey := randomKey
203 | pgKey := randomKey
204 |
205 | if !generateKeys {
206 | consulKey = "dist_locker"
207 | etcdKey = "dist_locker"
208 | zkKey = "/dist_locker"
209 | redisKey = "dist_locker"
210 | memcacheKey = "dist_locker"
211 | pgKey = "dist_locker"
212 | }
213 |
214 | consulLock, err := s.consulClient.LockKey(consulKey)
215 | s.Require().NoError(err)
216 |
217 | return map[string]l.DistLocker{
218 | "LockEtcd": l.NewLockEtcd(s.etcdClient, etcdKey, s.logger),
219 | "LockConsul": l.NewLockConsul(consulLock),
220 | "LockZookeeper": l.NewLockZookeeper(zk.NewLock(s.zkConn, zkKey, zk.WorldACL(zk.PermAll))),
221 | "LockRedis": l.NewLockRedis(goredis.NewPool(s.redisClient), redisKey),
222 | "LockRedisCluster": l.NewLockRedis(goredis.NewPool(s.redisClusterClient), redisKey),
223 | "LockMemcached": l.NewLockMemcached(s.memcacheClient, memcacheKey),
224 | "LockPostgreSQL": l.NewLockPostgreSQL(s.pgDb, hash(pgKey)),
225 | }
226 | }
227 |
228 | func (s *LimitersTestSuite) TestLimitContextCancelled() {
229 | clock := newFakeClock()
230 | capacity := int64(2)
231 | rate := time.Second
232 |
233 | limiters := make(map[string]interface{})
234 | for n, b := range s.tokenBuckets(capacity, rate, time.Second, clock) {
235 | limiters[n] = b
236 | }
237 |
238 | for n, b := range s.leakyBuckets(capacity, rate, time.Second, clock) {
239 | limiters[n] = b
240 | }
241 |
242 | for n, w := range s.fixedWindows(capacity, rate, clock) {
243 | limiters[n] = w
244 | }
245 |
246 | for n, w := range s.slidingWindows(capacity, rate, clock, 1e-9) {
247 | limiters[n] = w
248 | }
249 |
250 | for n, b := range s.concurrentBuffers(capacity, rate, clock) {
251 | limiters[n] = b
252 | }
253 |
254 | type rateLimiter interface {
255 | Limit(context.Context) (time.Duration, error)
256 | }
257 |
258 | type concurrentLimiter interface {
259 | Limit(context.Context, string) error
260 | }
261 |
262 | for name, limiter := range limiters {
263 | s.Run(name, func() {
264 | done1 := make(chan struct{})
265 |
266 | go func(limiter interface{}) {
267 | defer close(done1)
268 | // The context is expired shortly after it is created.
269 | ctx, cancel := context.WithCancel(context.Background())
270 | cancel()
271 |
272 | switch lim := limiter.(type) {
273 | case rateLimiter:
274 | _, err := lim.Limit(ctx)
275 | s.Error(err, "%T", limiter)
276 |
277 | case concurrentLimiter:
278 | s.Error(lim.Limit(ctx, "key"), "%T", limiter)
279 | }
280 | }(limiter)
281 |
282 | done2 := make(chan struct{})
283 |
284 | go func(limiter interface{}) {
285 | defer close(done2)
286 |
287 | <-done1
288 |
289 | ctx := context.Background()
290 |
291 | switch lim := limiter.(type) {
292 | case rateLimiter:
293 | _, err := lim.Limit(ctx)
294 | s.NoError(err, "%T", limiter)
295 |
296 | case concurrentLimiter:
297 | s.NoError(lim.Limit(ctx, "key"), "%T", limiter)
298 | }
299 | }(limiter)
300 | // Verify that the second go routine succeeded calling the Limit() method.
301 | <-done2
302 | })
303 | }
304 | }
305 |
--------------------------------------------------------------------------------
/concurrent_buffer.go:
--------------------------------------------------------------------------------
1 | package limiters
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/gob"
7 | "fmt"
8 | "sync"
9 | "time"
10 |
11 | "github.com/bradfitz/gomemcache/memcache"
12 | "github.com/pkg/errors"
13 | "github.com/redis/go-redis/v9"
14 | )
15 |
16 | // ConcurrentBufferBackend wraps the Add and Remove methods.
17 | type ConcurrentBufferBackend interface {
18 | // Add adds the request with the given key to the buffer and returns the total number of requests in it.
19 | Add(ctx context.Context, key string) (int64, error)
20 | // Remove removes the request from the buffer.
21 | Remove(ctx context.Context, key string) error
22 | }
23 |
24 | // ConcurrentBuffer implements a limiter that allows concurrent requests up to the given capacity.
25 | type ConcurrentBuffer struct {
26 | locker DistLocker
27 | backend ConcurrentBufferBackend
28 | logger Logger
29 | capacity int64
30 | mu sync.Mutex
31 | }
32 |
33 | // NewConcurrentBuffer creates a new ConcurrentBuffer instance.
34 | func NewConcurrentBuffer(locker DistLocker, concurrentStateBackend ConcurrentBufferBackend, capacity int64, logger Logger) *ConcurrentBuffer {
35 | return &ConcurrentBuffer{locker: locker, backend: concurrentStateBackend, capacity: capacity, logger: logger}
36 | }
37 |
38 | // Limit puts the request identified by the key in a buffer.
39 | func (c *ConcurrentBuffer) Limit(ctx context.Context, key string) error {
40 | c.mu.Lock()
41 | defer c.mu.Unlock()
42 |
43 | err := c.locker.Lock(ctx)
44 | if err != nil {
45 | return err
46 | }
47 |
48 | defer func() {
49 | err := c.locker.Unlock(ctx)
50 | if err != nil {
51 | c.logger.Log(err)
52 | }
53 | }()
54 | // Optimistically add the new request.
55 | counter, err := c.backend.Add(ctx, key)
56 | if err != nil {
57 | return err
58 | }
59 |
60 | if counter > c.capacity {
61 | // Rollback the Add() operation.
62 | err = c.backend.Remove(ctx, key)
63 | if err != nil {
64 | c.logger.Log(err)
65 | }
66 |
67 | return ErrLimitExhausted
68 | }
69 |
70 | return nil
71 | }
72 |
73 | // Done removes the request identified by the key from the buffer.
74 | func (c *ConcurrentBuffer) Done(ctx context.Context, key string) error {
75 | return c.backend.Remove(ctx, key)
76 | }
77 |
78 | // ConcurrentBufferInMemory is an in-memory implementation of ConcurrentBufferBackend.
79 | type ConcurrentBufferInMemory struct {
80 | clock Clock
81 | ttl time.Duration
82 | mu sync.Mutex
83 | registry *Registry
84 | }
85 |
86 | // NewConcurrentBufferInMemory creates a new instance of ConcurrentBufferInMemory.
87 | // When the TTL of a key exceeds the key is removed from the buffer. This is needed in case if the process that added
88 | // that key to the buffer did not call Done() for some reason.
89 | func NewConcurrentBufferInMemory(registry *Registry, ttl time.Duration, clock Clock) *ConcurrentBufferInMemory {
90 | return &ConcurrentBufferInMemory{clock: clock, ttl: ttl, registry: registry}
91 | }
92 |
93 | // Add adds the request with the given key to the buffer and returns the total number of requests in it.
94 | // It also removes the keys with expired TTL.
95 | func (c *ConcurrentBufferInMemory) Add(ctx context.Context, key string) (int64, error) {
96 | c.mu.Lock()
97 | defer c.mu.Unlock()
98 |
99 | now := c.clock.Now()
100 | c.registry.DeleteExpired(now)
101 | c.registry.GetOrCreate(key, func() interface{} {
102 | return struct{}{}
103 | }, c.ttl, now)
104 |
105 | return int64(c.registry.Len()), ctx.Err()
106 | }
107 |
108 | // Remove removes the request from the buffer.
109 | func (c *ConcurrentBufferInMemory) Remove(_ context.Context, key string) error {
110 | c.mu.Lock()
111 | defer c.mu.Unlock()
112 |
113 | c.registry.Delete(key)
114 |
115 | return nil
116 | }
117 |
118 | // ConcurrentBufferRedis implements ConcurrentBufferBackend in Redis.
119 | type ConcurrentBufferRedis struct {
120 | clock Clock
121 | cli redis.UniversalClient
122 | key string
123 | ttl time.Duration
124 | }
125 |
126 | // NewConcurrentBufferRedis creates a new instance of ConcurrentBufferRedis.
127 | // When the TTL of a key exceeds the key is removed from the buffer. This is needed in case if the process that added
128 | // that key to the buffer did not call Done() for some reason.
129 | func NewConcurrentBufferRedis(cli redis.UniversalClient, key string, ttl time.Duration, clock Clock) *ConcurrentBufferRedis {
130 | return &ConcurrentBufferRedis{clock: clock, cli: cli, key: key, ttl: ttl}
131 | }
132 |
133 | // Add adds the request with the given key to the sorted set in Redis and returns the total number of requests in it.
134 | // It also removes the keys with expired TTL.
135 | func (c *ConcurrentBufferRedis) Add(ctx context.Context, key string) (int64, error) {
136 | var (
137 | countCmd *redis.IntCmd
138 | err error
139 | )
140 |
141 | done := make(chan struct{})
142 |
143 | go func() {
144 | defer close(done)
145 |
146 | _, err = c.cli.Pipelined(ctx, func(pipeliner redis.Pipeliner) error {
147 | // Remove expired items.
148 | now := c.clock.Now()
149 | pipeliner.ZRemRangeByScore(ctx, c.key, "-inf", fmt.Sprintf("%d", now.Add(-c.ttl).UnixNano()))
150 | pipeliner.ZAdd(ctx, c.key, redis.Z{
151 | Score: float64(now.UnixNano()),
152 | Member: key,
153 | })
154 | pipeliner.Expire(ctx, c.key, c.ttl)
155 | countCmd = pipeliner.ZCard(ctx, c.key)
156 |
157 | return nil
158 | })
159 | }()
160 |
161 | select {
162 | case <-ctx.Done():
163 | return 0, ctx.Err()
164 |
165 | case <-done:
166 | if err != nil {
167 | return 0, errors.Wrap(err, "failed to add an item to redis set")
168 | }
169 |
170 | return countCmd.Val(), nil
171 | }
172 | }
173 |
174 | // Remove removes the request identified by the key from the sorted set in Redis.
175 | func (c *ConcurrentBufferRedis) Remove(ctx context.Context, key string) error {
176 | return errors.Wrap(c.cli.ZRem(ctx, c.key, key).Err(), "failed to remove an item from redis set")
177 | }
178 |
179 | // ConcurrentBufferMemcached implements ConcurrentBufferBackend in Memcached.
180 | type ConcurrentBufferMemcached struct {
181 | clock Clock
182 | cli *memcache.Client
183 | key string
184 | ttl time.Duration
185 | }
186 |
187 | // NewConcurrentBufferMemcached creates a new instance of ConcurrentBufferMemcached.
188 | // When the TTL of a key exceeds the key is removed from the buffer. This is needed in case if the process that added
189 | // that key to the buffer did not call Done() for some reason.
190 | func NewConcurrentBufferMemcached(cli *memcache.Client, key string, ttl time.Duration, clock Clock) *ConcurrentBufferMemcached {
191 | return &ConcurrentBufferMemcached{clock: clock, cli: cli, key: key, ttl: ttl}
192 | }
193 |
194 | type SortedSetNode struct {
195 | CreatedAt int64
196 | Value string
197 | }
198 |
199 | // Add adds the request with the given key to the slice in Memcached and returns the total number of requests in it.
200 | // It also removes the keys with expired TTL.
201 | func (c *ConcurrentBufferMemcached) Add(ctx context.Context, element string) (int64, error) {
202 | var err error
203 |
204 | done := make(chan struct{})
205 | now := c.clock.Now()
206 |
207 | var (
208 | newNodes []SortedSetNode
209 | casId uint64 = 0
210 | )
211 |
212 | go func() {
213 | defer close(done)
214 |
215 | var item *memcache.Item
216 |
217 | item, err = c.cli.Get(c.key)
218 | if err != nil {
219 | if !errors.Is(err, memcache.ErrCacheMiss) {
220 | return
221 | }
222 | } else {
223 | casId = item.CasID
224 | b := bytes.NewBuffer(item.Value)
225 |
226 | var oldNodes []SortedSetNode
227 |
228 | _ = gob.NewDecoder(b).Decode(&oldNodes)
229 | for _, node := range oldNodes {
230 | if node.CreatedAt > now.UnixNano() && node.Value != element {
231 | newNodes = append(newNodes, node)
232 | }
233 | }
234 | }
235 |
236 | newNodes = append(newNodes, SortedSetNode{CreatedAt: now.Add(c.ttl).UnixNano(), Value: element})
237 |
238 | var b bytes.Buffer
239 |
240 | _ = gob.NewEncoder(&b).Encode(newNodes)
241 |
242 | item = &memcache.Item{
243 | Key: c.key,
244 | Value: b.Bytes(),
245 | CasID: casId,
246 | }
247 | if casId > 0 {
248 | err = c.cli.CompareAndSwap(item)
249 | } else {
250 | err = c.cli.Add(item)
251 | }
252 | }()
253 |
254 | select {
255 | case <-ctx.Done():
256 | return 0, ctx.Err()
257 |
258 | case <-done:
259 | if err != nil {
260 | if errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss) {
261 | return c.Add(ctx, element)
262 | }
263 |
264 | return 0, errors.Wrap(err, "failed to add in memcached")
265 | }
266 |
267 | return int64(len(newNodes)), nil
268 | }
269 | }
270 |
271 | // Remove removes the request identified by the key from the slice in Memcached.
272 | func (c *ConcurrentBufferMemcached) Remove(ctx context.Context, key string) error {
273 | var err error
274 |
275 | now := c.clock.Now()
276 |
277 | var (
278 | newNodes []SortedSetNode
279 | casID uint64
280 | )
281 |
282 | item, err := c.cli.Get(c.key)
283 | if err != nil {
284 | if errors.Is(err, memcache.ErrCacheMiss) {
285 | return nil
286 | }
287 |
288 | return errors.Wrap(err, "failed to Get")
289 | }
290 |
291 | casID = item.CasID
292 |
293 | var oldNodes []SortedSetNode
294 |
295 | _ = gob.NewDecoder(bytes.NewBuffer(item.Value)).Decode(&oldNodes)
296 | for _, node := range oldNodes {
297 | if node.CreatedAt > now.UnixNano() && node.Value != key {
298 | newNodes = append(newNodes, node)
299 | }
300 | }
301 |
302 | var b bytes.Buffer
303 |
304 | _ = gob.NewEncoder(&b).Encode(newNodes)
305 | item = &memcache.Item{
306 | Key: c.key,
307 | Value: b.Bytes(),
308 | CasID: casID,
309 | }
310 |
311 | err = c.cli.CompareAndSwap(item)
312 | if err != nil && (errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss)) {
313 | return c.Remove(ctx, key)
314 | }
315 |
316 | return errors.Wrap(err, "failed to CompareAndSwap")
317 | }
318 |
--------------------------------------------------------------------------------
/examples/helloworld/helloworld.pb.go:
--------------------------------------------------------------------------------
1 | // Code generated by protoc-gen-go. DO NOT EDIT.
2 | // versions:
3 | // protoc-gen-go v1.35.1
4 | // protoc v5.28.3
5 | // source: helloworld.proto
6 |
7 | package helloworld
8 |
9 | import (
10 | reflect "reflect"
11 | sync "sync"
12 |
13 | protoreflect "google.golang.org/protobuf/reflect/protoreflect"
14 | protoimpl "google.golang.org/protobuf/runtime/protoimpl"
15 | )
16 |
17 | const (
18 | // Verify that this generated code is sufficiently up-to-date.
19 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
20 | // Verify that runtime/protoimpl is sufficiently up-to-date.
21 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
22 | )
23 |
24 | // The request message containing the user's name.
25 | type HelloRequest struct {
26 | state protoimpl.MessageState
27 | sizeCache protoimpl.SizeCache
28 | unknownFields protoimpl.UnknownFields
29 |
30 | Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
31 | }
32 |
33 | func (x *HelloRequest) Reset() {
34 | *x = HelloRequest{}
35 | mi := &file_helloworld_proto_msgTypes[0]
36 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
37 | ms.StoreMessageInfo(mi)
38 | }
39 |
40 | func (x *HelloRequest) String() string {
41 | return protoimpl.X.MessageStringOf(x)
42 | }
43 |
44 | func (*HelloRequest) ProtoMessage() {}
45 |
46 | func (x *HelloRequest) ProtoReflect() protoreflect.Message {
47 | mi := &file_helloworld_proto_msgTypes[0]
48 | if x != nil {
49 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
50 | if ms.LoadMessageInfo() == nil {
51 | ms.StoreMessageInfo(mi)
52 | }
53 | return ms
54 | }
55 | return mi.MessageOf(x)
56 | }
57 |
58 | // Deprecated: Use HelloRequest.ProtoReflect.Descriptor instead.
59 | func (*HelloRequest) Descriptor() ([]byte, []int) {
60 | return file_helloworld_proto_rawDescGZIP(), []int{0}
61 | }
62 |
63 | func (x *HelloRequest) GetName() string {
64 | if x != nil {
65 | return x.Name
66 | }
67 | return ""
68 | }
69 |
70 | // The response message containing the greetings
71 | type HelloReply struct {
72 | state protoimpl.MessageState
73 | sizeCache protoimpl.SizeCache
74 | unknownFields protoimpl.UnknownFields
75 |
76 | Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
77 | }
78 |
79 | func (x *HelloReply) Reset() {
80 | *x = HelloReply{}
81 | mi := &file_helloworld_proto_msgTypes[1]
82 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
83 | ms.StoreMessageInfo(mi)
84 | }
85 |
86 | func (x *HelloReply) String() string {
87 | return protoimpl.X.MessageStringOf(x)
88 | }
89 |
90 | func (*HelloReply) ProtoMessage() {}
91 |
92 | func (x *HelloReply) ProtoReflect() protoreflect.Message {
93 | mi := &file_helloworld_proto_msgTypes[1]
94 | if x != nil {
95 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
96 | if ms.LoadMessageInfo() == nil {
97 | ms.StoreMessageInfo(mi)
98 | }
99 | return ms
100 | }
101 | return mi.MessageOf(x)
102 | }
103 |
104 | // Deprecated: Use HelloReply.ProtoReflect.Descriptor instead.
105 | func (*HelloReply) Descriptor() ([]byte, []int) {
106 | return file_helloworld_proto_rawDescGZIP(), []int{1}
107 | }
108 |
109 | func (x *HelloReply) GetMessage() string {
110 | if x != nil {
111 | return x.Message
112 | }
113 | return ""
114 | }
115 |
116 | type GoodbyeRequest struct {
117 | state protoimpl.MessageState
118 | sizeCache protoimpl.SizeCache
119 | unknownFields protoimpl.UnknownFields
120 |
121 | Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
122 | }
123 |
124 | func (x *GoodbyeRequest) Reset() {
125 | *x = GoodbyeRequest{}
126 | mi := &file_helloworld_proto_msgTypes[2]
127 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
128 | ms.StoreMessageInfo(mi)
129 | }
130 |
131 | func (x *GoodbyeRequest) String() string {
132 | return protoimpl.X.MessageStringOf(x)
133 | }
134 |
135 | func (*GoodbyeRequest) ProtoMessage() {}
136 |
137 | func (x *GoodbyeRequest) ProtoReflect() protoreflect.Message {
138 | mi := &file_helloworld_proto_msgTypes[2]
139 | if x != nil {
140 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
141 | if ms.LoadMessageInfo() == nil {
142 | ms.StoreMessageInfo(mi)
143 | }
144 | return ms
145 | }
146 | return mi.MessageOf(x)
147 | }
148 |
149 | // Deprecated: Use GoodbyeRequest.ProtoReflect.Descriptor instead.
150 | func (*GoodbyeRequest) Descriptor() ([]byte, []int) {
151 | return file_helloworld_proto_rawDescGZIP(), []int{2}
152 | }
153 |
154 | func (x *GoodbyeRequest) GetName() string {
155 | if x != nil {
156 | return x.Name
157 | }
158 | return ""
159 | }
160 |
161 | type GoodbyeReply struct {
162 | state protoimpl.MessageState
163 | sizeCache protoimpl.SizeCache
164 | unknownFields protoimpl.UnknownFields
165 |
166 | Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
167 | }
168 |
169 | func (x *GoodbyeReply) Reset() {
170 | *x = GoodbyeReply{}
171 | mi := &file_helloworld_proto_msgTypes[3]
172 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
173 | ms.StoreMessageInfo(mi)
174 | }
175 |
176 | func (x *GoodbyeReply) String() string {
177 | return protoimpl.X.MessageStringOf(x)
178 | }
179 |
180 | func (*GoodbyeReply) ProtoMessage() {}
181 |
182 | func (x *GoodbyeReply) ProtoReflect() protoreflect.Message {
183 | mi := &file_helloworld_proto_msgTypes[3]
184 | if x != nil {
185 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
186 | if ms.LoadMessageInfo() == nil {
187 | ms.StoreMessageInfo(mi)
188 | }
189 | return ms
190 | }
191 | return mi.MessageOf(x)
192 | }
193 |
194 | // Deprecated: Use GoodbyeReply.ProtoReflect.Descriptor instead.
195 | func (*GoodbyeReply) Descriptor() ([]byte, []int) {
196 | return file_helloworld_proto_rawDescGZIP(), []int{3}
197 | }
198 |
199 | func (x *GoodbyeReply) GetMessage() string {
200 | if x != nil {
201 | return x.Message
202 | }
203 | return ""
204 | }
205 |
206 | var File_helloworld_proto protoreflect.FileDescriptor
207 |
208 | var file_helloworld_proto_rawDesc = []byte{
209 | 0x0a, 0x10, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x70, 0x72, 0x6f,
210 | 0x74, 0x6f, 0x12, 0x0a, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x22, 0x22,
211 | 0x0a, 0x0c, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12,
212 | 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61,
213 | 0x6d, 0x65, 0x22, 0x26, 0x0a, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x70, 0x6c, 0x79,
214 | 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28,
215 | 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x24, 0x0a, 0x0e, 0x47, 0x6f,
216 | 0x6f, 0x64, 0x62, 0x79, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04,
217 | 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65,
218 | 0x22, 0x28, 0x0a, 0x0c, 0x47, 0x6f, 0x6f, 0x64, 0x62, 0x79, 0x65, 0x52, 0x65, 0x70, 0x6c, 0x79,
219 | 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28,
220 | 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0x49, 0x0a, 0x07, 0x47, 0x72,
221 | 0x65, 0x65, 0x74, 0x65, 0x72, 0x12, 0x3e, 0x0a, 0x08, 0x53, 0x61, 0x79, 0x48, 0x65, 0x6c, 0x6c,
222 | 0x6f, 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48,
223 | 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x68, 0x65,
224 | 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65,
225 | 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e,
226 | 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x65, 0x6e, 0x6e, 0x61, 0x6e, 0x6f, 0x76, 0x2f, 0x6c, 0x69, 0x6d,
227 | 0x69, 0x74, 0x65, 0x72, 0x73, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x68,
228 | 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
229 | 0x33,
230 | }
231 |
232 | var (
233 | file_helloworld_proto_rawDescOnce sync.Once
234 | file_helloworld_proto_rawDescData = file_helloworld_proto_rawDesc
235 | )
236 |
237 | func file_helloworld_proto_rawDescGZIP() []byte {
238 | file_helloworld_proto_rawDescOnce.Do(func() {
239 | file_helloworld_proto_rawDescData = protoimpl.X.CompressGZIP(file_helloworld_proto_rawDescData)
240 | })
241 | return file_helloworld_proto_rawDescData
242 | }
243 |
244 | var file_helloworld_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
245 | var file_helloworld_proto_goTypes = []any{
246 | (*HelloRequest)(nil), // 0: helloworld.HelloRequest
247 | (*HelloReply)(nil), // 1: helloworld.HelloReply
248 | (*GoodbyeRequest)(nil), // 2: helloworld.GoodbyeRequest
249 | (*GoodbyeReply)(nil), // 3: helloworld.GoodbyeReply
250 | }
251 | var file_helloworld_proto_depIdxs = []int32{
252 | 0, // 0: helloworld.Greeter.SayHello:input_type -> helloworld.HelloRequest
253 | 1, // 1: helloworld.Greeter.SayHello:output_type -> helloworld.HelloReply
254 | 1, // [1:2] is the sub-list for method output_type
255 | 0, // [0:1] is the sub-list for method input_type
256 | 0, // [0:0] is the sub-list for extension type_name
257 | 0, // [0:0] is the sub-list for extension extendee
258 | 0, // [0:0] is the sub-list for field type_name
259 | }
260 |
261 | func init() { file_helloworld_proto_init() }
262 | func file_helloworld_proto_init() {
263 | if File_helloworld_proto != nil {
264 | return
265 | }
266 | type x struct{}
267 | out := protoimpl.TypeBuilder{
268 | File: protoimpl.DescBuilder{
269 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
270 | RawDescriptor: file_helloworld_proto_rawDesc,
271 | NumEnums: 0,
272 | NumMessages: 4,
273 | NumExtensions: 0,
274 | NumServices: 1,
275 | },
276 | GoTypes: file_helloworld_proto_goTypes,
277 | DependencyIndexes: file_helloworld_proto_depIdxs,
278 | MessageInfos: file_helloworld_proto_msgTypes,
279 | }.Build()
280 | File_helloworld_proto = out.File
281 | file_helloworld_proto_rawDesc = nil
282 | file_helloworld_proto_goTypes = nil
283 | file_helloworld_proto_depIdxs = nil
284 | }
285 |
--------------------------------------------------------------------------------
/fixedwindow.go:
--------------------------------------------------------------------------------
1 | package limiters
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "log"
8 | "net/http"
9 | "strconv"
10 | "sync"
11 | "time"
12 |
13 | "github.com/Azure/azure-sdk-for-go/sdk/azcore"
14 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
15 | "github.com/aws/aws-sdk-go-v2/aws"
16 | "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
17 | "github.com/aws/aws-sdk-go-v2/service/dynamodb"
18 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
19 | "github.com/bradfitz/gomemcache/memcache"
20 | "github.com/pkg/errors"
21 | "github.com/redis/go-redis/v9"
22 | )
23 |
24 | // FixedWindowIncrementer wraps the Increment method.
25 | type FixedWindowIncrementer interface {
26 | // Increment increments the request counter for the window and returns the counter value.
27 | // TTL is the time duration before the next window.
28 | Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error)
29 | }
30 |
31 | // FixedWindow implements a Fixed Window rate limiting algorithm.
32 | //
33 | // Simple and memory efficient algorithm that does not need a distributed lock.
34 | // However it may be lenient when there are many requests around the boundary between 2 adjacent windows.
35 | type FixedWindow struct {
36 | backend FixedWindowIncrementer
37 | clock Clock
38 | rate time.Duration
39 | capacity int64
40 | mu sync.Mutex
41 | window time.Time
42 | overflow bool
43 | }
44 |
45 | // NewFixedWindow creates a new instance of FixedWindow.
46 | // Capacity is the maximum amount of requests allowed per window.
47 | // Rate is the window size.
48 | func NewFixedWindow(capacity int64, rate time.Duration, fixedWindowIncrementer FixedWindowIncrementer, clock Clock) *FixedWindow {
49 | return &FixedWindow{backend: fixedWindowIncrementer, clock: clock, rate: rate, capacity: capacity}
50 | }
51 |
52 | // Limit returns the time duration to wait before the request can be processed.
53 | // It returns ErrLimitExhausted if the request overflows the window's capacity.
54 | func (f *FixedWindow) Limit(ctx context.Context) (time.Duration, error) {
55 | f.mu.Lock()
56 | defer f.mu.Unlock()
57 |
58 | now := f.clock.Now()
59 |
60 | window := now.Truncate(f.rate)
61 | if f.window != window {
62 | f.window = window
63 | f.overflow = false
64 | }
65 |
66 | ttl := f.rate - now.Sub(window)
67 | if f.overflow {
68 | // If the window is already overflowed don't increment the counter.
69 | return ttl, ErrLimitExhausted
70 | }
71 |
72 | c, err := f.backend.Increment(ctx, window, ttl)
73 | if err != nil {
74 | return 0, err
75 | }
76 |
77 | if c > f.capacity {
78 | f.overflow = true
79 |
80 | return ttl, ErrLimitExhausted
81 | }
82 |
83 | return 0, nil
84 | }
85 |
86 | // FixedWindowInMemory is an in-memory implementation of FixedWindowIncrementer.
87 | type FixedWindowInMemory struct {
88 | mu sync.Mutex
89 | c int64
90 | window time.Time
91 | }
92 |
93 | // NewFixedWindowInMemory creates a new instance of FixedWindowInMemory.
94 | func NewFixedWindowInMemory() *FixedWindowInMemory {
95 | return &FixedWindowInMemory{}
96 | }
97 |
98 | // Increment increments the window's counter.
99 | func (f *FixedWindowInMemory) Increment(ctx context.Context, window time.Time, _ time.Duration) (int64, error) {
100 | f.mu.Lock()
101 | defer f.mu.Unlock()
102 |
103 | if window != f.window {
104 | f.c = 0
105 | f.window = window
106 | }
107 |
108 | f.c++
109 |
110 | return f.c, ctx.Err()
111 | }
112 |
113 | // FixedWindowRedis implements FixedWindow in Redis.
114 | type FixedWindowRedis struct {
115 | cli redis.UniversalClient
116 | prefix string
117 | }
118 |
119 | // NewFixedWindowRedis returns a new instance of FixedWindowRedis.
120 | // Prefix is the key prefix used to store all the keys used in this implementation in Redis.
121 | func NewFixedWindowRedis(cli redis.UniversalClient, prefix string) *FixedWindowRedis {
122 | return &FixedWindowRedis{cli: cli, prefix: prefix}
123 | }
124 |
125 | // Increment increments the window's counter in Redis.
126 | func (f *FixedWindowRedis) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) {
127 | var (
128 | incr *redis.IntCmd
129 | err error
130 | )
131 |
132 | done := make(chan struct{})
133 |
134 | go func() {
135 | defer close(done)
136 |
137 | _, err = f.cli.Pipelined(ctx, func(pipeliner redis.Pipeliner) error {
138 | key := fmt.Sprintf("%d", window.UnixNano())
139 | incr = pipeliner.Incr(ctx, redisKey(f.prefix, key))
140 | pipeliner.PExpire(ctx, redisKey(f.prefix, key), ttl)
141 |
142 | return nil
143 | })
144 | }()
145 |
146 | select {
147 | case <-done:
148 | if err != nil {
149 | return 0, errors.Wrap(err, "redis transaction failed")
150 | }
151 |
152 | return incr.Val(), incr.Err()
153 | case <-ctx.Done():
154 | return 0, ctx.Err()
155 | }
156 | }
157 |
158 | // FixedWindowMemcached implements FixedWindow in Memcached.
159 | type FixedWindowMemcached struct {
160 | cli *memcache.Client
161 | prefix string
162 | }
163 |
164 | // NewFixedWindowMemcached returns a new instance of FixedWindowMemcached.
165 | // Prefix is the key prefix used to store all the keys used in this implementation in Memcached.
166 | func NewFixedWindowMemcached(cli *memcache.Client, prefix string) *FixedWindowMemcached {
167 | return &FixedWindowMemcached{cli: cli, prefix: prefix}
168 | }
169 |
170 | // Increment increments the window's counter in Memcached.
171 | func (f *FixedWindowMemcached) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) {
172 | var (
173 | newValue uint64
174 | err error
175 | )
176 |
177 | done := make(chan struct{})
178 |
179 | go func() {
180 | defer close(done)
181 |
182 | key := fmt.Sprintf("%s:%d", f.prefix, window.UnixNano())
183 |
184 | newValue, err = f.cli.Increment(key, 1)
185 | if err != nil && errors.Is(err, memcache.ErrCacheMiss) {
186 | newValue = 1
187 | item := &memcache.Item{
188 | Key: key,
189 | Value: []byte(strconv.FormatUint(newValue, 10)),
190 | }
191 | err = f.cli.Add(item)
192 | }
193 | }()
194 |
195 | select {
196 | case <-done:
197 | if err != nil {
198 | if errors.Is(err, memcache.ErrNotStored) {
199 | return f.Increment(ctx, window, ttl)
200 | }
201 |
202 | return 0, errors.Wrap(err, "failed to Increment or Add")
203 | }
204 |
205 | return int64(newValue), err
206 | case <-ctx.Done():
207 | return 0, ctx.Err()
208 | }
209 | }
210 |
211 | // FixedWindowDynamoDB implements FixedWindow in DynamoDB.
212 | type FixedWindowDynamoDB struct {
213 | client *dynamodb.Client
214 | partitionKey string
215 | tableProps DynamoDBTableProperties
216 | }
217 |
218 | // NewFixedWindowDynamoDB creates a new instance of FixedWindowDynamoDB.
219 | // PartitionKey is the key used to store all the this implementation in DynamoDB.
220 | //
221 | // TableProps describe the table that this backend should work with. This backend requires the following on the table:
222 | // * SortKey
223 | // * TTL.
224 | func NewFixedWindowDynamoDB(client *dynamodb.Client, partitionKey string, props DynamoDBTableProperties) *FixedWindowDynamoDB {
225 | return &FixedWindowDynamoDB{
226 | client: client,
227 | partitionKey: partitionKey,
228 | tableProps: props,
229 | }
230 | }
231 |
232 | type contextKey int
233 |
234 | var fixedWindowDynamoDBPartitionKey contextKey
235 |
236 | // NewFixedWindowDynamoDBContext creates a context for FixedWindowDynamoDB with a partition key.
237 | //
238 | // This context can be used to control the partition key per-request.
239 | //
240 | // Deprecated: NewFixedWindowDynamoDBContext is deprecated and will be removed in future versions.
241 | // Separate FixedWindow rate limiters should be used for different partition keys instead.
242 | // Consider using the `Registry` to manage multiple FixedWindow instances with different partition keys.
243 | func NewFixedWindowDynamoDBContext(ctx context.Context, partitionKey string) context.Context {
244 | log.Printf("DEPRECATED: NewFixedWindowDynamoDBContext is deprecated and will be removed in future versions.")
245 |
246 | return context.WithValue(ctx, fixedWindowDynamoDBPartitionKey, partitionKey)
247 | }
248 |
249 | const (
250 | fixedWindowDynamoDBUpdateExpression = "SET #C = if_not_exists(#C, :def) + :inc, #TTL = :ttl"
251 | dynamodbWindowCountKey = "Count"
252 | )
253 |
254 | // Increment increments the window's counter in DynamoDB.
255 | func (f *FixedWindowDynamoDB) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) {
256 | var (
257 | resp *dynamodb.UpdateItemOutput
258 | err error
259 | )
260 |
261 | done := make(chan struct{})
262 |
263 | go func() {
264 | defer close(done)
265 |
266 | partitionKey := f.partitionKey
267 | if key, ok := ctx.Value(fixedWindowDynamoDBPartitionKey).(string); ok {
268 | partitionKey = key
269 | }
270 |
271 | resp, err = f.client.UpdateItem(ctx, &dynamodb.UpdateItemInput{
272 | Key: map[string]types.AttributeValue{
273 | f.tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: partitionKey},
274 | f.tableProps.SortKeyName: &types.AttributeValueMemberS{Value: strconv.FormatInt(window.UnixNano(), 10)},
275 | },
276 | UpdateExpression: aws.String(fixedWindowDynamoDBUpdateExpression),
277 | ExpressionAttributeNames: map[string]string{
278 | "#TTL": f.tableProps.TTLFieldName,
279 | "#C": dynamodbWindowCountKey,
280 | },
281 | ExpressionAttributeValues: map[string]types.AttributeValue{
282 | ":ttl": &types.AttributeValueMemberN{Value: strconv.FormatInt(time.Now().Add(ttl).Unix(), 10)},
283 | ":def": &types.AttributeValueMemberN{Value: "0"},
284 | ":inc": &types.AttributeValueMemberN{Value: "1"},
285 | },
286 | TableName: &f.tableProps.TableName,
287 | ReturnValues: types.ReturnValueAllNew,
288 | })
289 | }()
290 |
291 | select {
292 | case <-done:
293 | case <-ctx.Done():
294 | return 0, ctx.Err()
295 | }
296 |
297 | if err != nil {
298 | return 0, errors.Wrap(err, "dynamodb update item failed")
299 | }
300 |
301 | var count float64
302 |
303 | err = attributevalue.Unmarshal(resp.Attributes[dynamodbWindowCountKey], &count)
304 | if err != nil {
305 | return 0, errors.Wrap(err, "unmarshal of dynamodb attribute value failed")
306 | }
307 |
308 | return int64(count), nil
309 | }
310 |
311 | // FixedWindowCosmosDB implements FixedWindow in CosmosDB.
312 | type FixedWindowCosmosDB struct {
313 | client *azcosmos.ContainerClient
314 | partitionKey string
315 | }
316 |
317 | // NewFixedWindowCosmosDB creates a new instance of FixedWindowCosmosDB.
318 | // PartitionKey is the key used for partitioning data into multiple partitions.
319 | func NewFixedWindowCosmosDB(client *azcosmos.ContainerClient, partitionKey string) *FixedWindowCosmosDB {
320 | return &FixedWindowCosmosDB{
321 | client: client,
322 | partitionKey: partitionKey,
323 | }
324 | }
325 |
326 | func (f *FixedWindowCosmosDB) Increment(ctx context.Context, window time.Time, ttl time.Duration) (int64, error) {
327 | id := strconv.FormatInt(window.UnixNano(), 10)
328 | tmp := cosmosItem{
329 | ID: id,
330 | PartitionKey: f.partitionKey,
331 | Count: 1,
332 | TTL: int32(ttl),
333 | }
334 |
335 | ops := azcosmos.PatchOperations{}
336 | ops.AppendIncrement(`/Count`, 1)
337 |
338 | patchResp, err := f.client.PatchItem(ctx, azcosmos.NewPartitionKey().AppendString(f.partitionKey), id, ops, &azcosmos.ItemOptions{
339 | EnableContentResponseOnWrite: true,
340 | })
341 | if err == nil {
342 | // value exists and was updated
343 | err = json.Unmarshal(patchResp.Value, &tmp)
344 | if err != nil {
345 | return 0, errors.Wrap(err, "unmarshal of cosmos value failed")
346 | }
347 |
348 | return tmp.Count, nil
349 | }
350 |
351 | var respErr *azcore.ResponseError
352 | if !errors.As(err, &respErr) || (respErr.StatusCode != http.StatusNotFound && respErr.StatusCode != http.StatusBadRequest) {
353 | return 0, errors.Wrap(err, `patch of cosmos value failed`)
354 | }
355 |
356 | newValue, err := json.Marshal(tmp)
357 | if err != nil {
358 | return 0, errors.Wrap(err, "marshal of cosmos value failed")
359 | }
360 |
361 | // Try to CreateItem. If it fails with Conflict, it means the item exists (and Patch failed with 400), so we fallback to RMW.
362 | _, err = f.client.CreateItem(ctx, azcosmos.NewPartitionKey().AppendString(f.partitionKey), newValue, &azcosmos.ItemOptions{
363 | SessionToken: patchResp.SessionToken,
364 | })
365 | if err != nil {
366 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict {
367 | // Fallback to Read-Modify-Write
368 | return incrementCosmosItemRMW(ctx, f.client, f.partitionKey, id, ttl)
369 | }
370 |
371 | return 0, errors.Wrap(err, "upsert of cosmos value failed")
372 | }
373 |
374 | return tmp.Count, nil
375 | }
376 |
--------------------------------------------------------------------------------
/tokenbucket_test.go:
--------------------------------------------------------------------------------
1 | package limiters_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strings"
7 | "sync"
8 | "testing"
9 | "time"
10 |
11 | "github.com/google/uuid"
12 | l "github.com/mennanov/limiters"
13 | "github.com/redis/go-redis/v9"
14 | )
15 |
16 | var tokenBucketUniformTestCases = []struct {
17 | capacity int64
18 | refillRate time.Duration
19 | requestCount int64
20 | requestRate time.Duration
21 | missExpected int
22 | delta float64
23 | }{
24 | {
25 | 5,
26 | time.Millisecond * 100,
27 | 10,
28 | time.Millisecond * 25,
29 | 3,
30 | 1,
31 | },
32 | {
33 | 10,
34 | time.Millisecond * 100,
35 | 13,
36 | time.Millisecond * 25,
37 | 0,
38 | 1,
39 | },
40 | {
41 | 10,
42 | time.Millisecond * 100,
43 | 15,
44 | time.Millisecond * 33,
45 | 2,
46 | 1,
47 | },
48 | {
49 | 10,
50 | time.Millisecond * 100,
51 | 16,
52 | time.Millisecond * 33,
53 | 2,
54 | 1,
55 | },
56 | {
57 | 10,
58 | time.Millisecond * 10,
59 | 20,
60 | time.Millisecond * 10,
61 | 0,
62 | 1,
63 | },
64 | {
65 | 1,
66 | time.Millisecond * 200,
67 | 10,
68 | time.Millisecond * 100,
69 | 5,
70 | 2,
71 | },
72 | }
73 |
74 | // tokenBuckets returns all the possible TokenBucket combinations.
75 | func (s *LimitersTestSuite) tokenBuckets(capacity int64, refillRate, ttl time.Duration, clock l.Clock) map[string]*l.TokenBucket {
76 | buckets := make(map[string]*l.TokenBucket)
77 |
78 | for lockerName, locker := range s.lockers(true) {
79 | for backendName, backend := range s.tokenBucketBackends(ttl) {
80 | buckets[lockerName+":"+backendName] = l.NewTokenBucket(capacity, refillRate, locker, backend, clock, s.logger)
81 | }
82 | }
83 |
84 | return buckets
85 | }
86 |
87 | func (s *LimitersTestSuite) tokenBucketBackends(ttl time.Duration) map[string]l.TokenBucketStateBackend {
88 | return map[string]l.TokenBucketStateBackend{
89 | "TokenBucketInMemory": l.NewTokenBucketInMemory(),
90 | "TokenBucketEtcdNoRaceCheck": l.NewTokenBucketEtcd(s.etcdClient, uuid.New().String(), ttl, false),
91 | "TokenBucketEtcdWithRaceCheck": l.NewTokenBucketEtcd(s.etcdClient, uuid.New().String(), ttl, true),
92 | "TokenBucketRedisNoRaceCheck": l.NewTokenBucketRedis(s.redisClient, uuid.New().String(), ttl, false),
93 | "TokenBucketRedisWithRaceCheck": l.NewTokenBucketRedis(s.redisClient, uuid.New().String(), ttl, true),
94 | "TokenBucketRedisClusterNoRaceCheck": l.NewTokenBucketRedis(s.redisClusterClient, uuid.New().String(), ttl, false),
95 | "TokenBucketRedisClusterWithRaceCheck": l.NewTokenBucketRedis(s.redisClusterClient, uuid.New().String(), ttl, true),
96 | "TokenBucketMemcachedNoRaceCheck": l.NewTokenBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, false),
97 | "TokenBucketMemcachedWithRaceCheck": l.NewTokenBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, true),
98 | "TokenBucketDynamoDBNoRaceCheck": l.NewTokenBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, ttl, false),
99 | "TokenBucketDynamoDBWithRaceCheck": l.NewTokenBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, ttl, true),
100 | "TokenBucketCosmosDBNoRaceCheck": l.NewTokenBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), ttl, false),
101 | "TokenBucketCosmosDBWithRaceCheck": l.NewTokenBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), ttl, true),
102 | }
103 | }
104 |
105 | func (s *LimitersTestSuite) TestTokenBucketRealClock() {
106 | clock := l.NewSystemClock()
107 | for _, testCase := range tokenBucketUniformTestCases {
108 | for name, bucket := range s.tokenBuckets(testCase.capacity, testCase.refillRate, 0, clock) {
109 | s.Run(name, func() {
110 | wg := sync.WaitGroup{}
111 | // mu guards the miss variable below.
112 | var mu sync.Mutex
113 |
114 | miss := 0
115 |
116 | for i := int64(0); i < testCase.requestCount; i++ {
117 | // No pause for the first request.
118 | if i > 0 {
119 | clock.Sleep(testCase.requestRate)
120 | }
121 |
122 | wg.Add(1)
123 |
124 | go func(bucket *l.TokenBucket) {
125 | defer wg.Done()
126 |
127 | _, err := bucket.Limit(context.TODO())
128 | if err != nil {
129 | s.Equal(l.ErrLimitExhausted, err, "%T %v", bucket, bucket)
130 | mu.Lock()
131 |
132 | miss++
133 |
134 | mu.Unlock()
135 | }
136 | }(bucket)
137 | }
138 |
139 | wg.Wait()
140 | s.InDelta(testCase.missExpected, miss, testCase.delta, testCase)
141 | })
142 | }
143 | }
144 | }
145 |
146 | func (s *LimitersTestSuite) TestTokenBucketFakeClock() {
147 | for _, testCase := range tokenBucketUniformTestCases {
148 | clock := newFakeClock()
149 | for name, bucket := range s.tokenBuckets(testCase.capacity, testCase.refillRate, 0, clock) {
150 | s.Run(name, func() {
151 | clock.reset()
152 |
153 | miss := 0
154 |
155 | for i := int64(0); i < testCase.requestCount; i++ {
156 | // No pause for the first request.
157 | if i > 0 {
158 | clock.Sleep(testCase.requestRate)
159 | }
160 |
161 | _, err := bucket.Limit(context.TODO())
162 | if err != nil {
163 | s.Equal(l.ErrLimitExhausted, err)
164 |
165 | miss++
166 | }
167 | }
168 |
169 | s.InDelta(testCase.missExpected, miss, testCase.delta, testCase)
170 | })
171 | }
172 | }
173 | }
174 |
175 | func (s *LimitersTestSuite) TestTokenBucketOverflow() {
176 | clock := newFakeClock()
177 |
178 | rate := 100 * time.Millisecond
179 | for name, bucket := range s.tokenBuckets(2, rate, 0, clock) {
180 | s.Run(name, func() {
181 | clock.reset()
182 |
183 | wait, err := bucket.Limit(context.TODO())
184 | s.Require().NoError(err)
185 | s.Equal(time.Duration(0), wait)
186 | wait, err = bucket.Limit(context.TODO())
187 | s.Require().NoError(err)
188 | s.Equal(time.Duration(0), wait)
189 | // The third call should fail.
190 | wait, err = bucket.Limit(context.TODO())
191 | s.Require().Equal(l.ErrLimitExhausted, err)
192 | s.Equal(rate, wait)
193 | clock.Sleep(wait)
194 | // Retry the last call.
195 | wait, err = bucket.Limit(context.TODO())
196 | s.Require().NoError(err)
197 | s.Equal(time.Duration(0), wait)
198 | })
199 | }
200 | }
201 |
202 | func (s *LimitersTestSuite) TestTokenTakeMaxOverflow() {
203 | clock := newFakeClock()
204 |
205 | rate := 100 * time.Millisecond
206 | for name, bucket := range s.tokenBuckets(3, rate, 0, clock) {
207 | s.Run(name, func() {
208 | clock.reset()
209 | // Take max, as it's below capacity.
210 | taken, err := bucket.TakeMax(context.TODO(), 2)
211 | s.Require().NoError(err)
212 | s.Equal(int64(2), taken)
213 | // Take as much as it's available
214 | taken, err = bucket.TakeMax(context.TODO(), 2)
215 | s.Require().NoError(err)
216 | s.Equal(int64(1), taken)
217 | // Now it is not able to take any.
218 | taken, err = bucket.TakeMax(context.TODO(), 2)
219 | s.Require().NoError(err)
220 | s.Equal(int64(0), taken)
221 | clock.Sleep(rate)
222 | // Retry the last call.
223 | taken, err = bucket.TakeMax(context.TODO(), 2)
224 | s.Require().NoError(err)
225 | s.Equal(int64(1), taken)
226 | })
227 | }
228 | }
229 |
230 | func (s *LimitersTestSuite) TestTokenBucketReset() {
231 | clock := newFakeClock()
232 |
233 | rate := 100 * time.Millisecond
234 | for name, bucket := range s.tokenBuckets(2, rate, 0, clock) {
235 | s.Run(name, func() {
236 | clock.reset()
237 |
238 | wait, err := bucket.Limit(context.TODO())
239 | s.Require().NoError(err)
240 | s.Equal(time.Duration(0), wait)
241 | wait, err = bucket.Limit(context.TODO())
242 | s.Require().NoError(err)
243 | s.Equal(time.Duration(0), wait)
244 | // The third call should fail.
245 | wait, err = bucket.Limit(context.TODO())
246 | s.Require().Equal(l.ErrLimitExhausted, err)
247 | s.Equal(rate, wait)
248 |
249 | err = bucket.Reset(context.TODO())
250 | s.Require().NoError(err)
251 | // Retry the last call.
252 | wait, err = bucket.Limit(context.TODO())
253 | s.Require().NoError(err)
254 | s.Equal(time.Duration(0), wait)
255 | })
256 | }
257 | }
258 |
259 | func (s *LimitersTestSuite) TestTokenBucketRefill() {
260 | for name, backend := range s.tokenBucketBackends(0) {
261 | s.Run(name, func() {
262 | clock := newFakeClock()
263 |
264 | bucket := l.NewTokenBucket(4, time.Millisecond*100, l.NewLockNoop(), backend, clock, s.logger)
265 | sleepDurations := []int{150, 90, 50, 70}
266 | desiredAvailable := []int64{3, 2, 2, 2}
267 |
268 | _, err := bucket.Limit(context.Background())
269 | s.Require().NoError(err)
270 |
271 | _, err = backend.State(context.Background())
272 | s.Require().NoError(err, "unable to retrieve backend state")
273 |
274 | for i := range sleepDurations {
275 | clock.Sleep(time.Millisecond * time.Duration(sleepDurations[i]))
276 |
277 | _, err := bucket.Limit(context.Background())
278 | s.Require().NoError(err)
279 |
280 | state, err := backend.State(context.Background())
281 | s.Require().NoError(err, "unable to retrieve backend state")
282 |
283 | s.Require().Equal(desiredAvailable[i], state.Available)
284 | }
285 | })
286 | }
287 | }
288 |
289 | // setTokenBucketStateInOldFormat is a test utility method for writing state in the old format to Redis.
290 | func setTokenBucketStateInOldFormat(ctx context.Context, cli *redis.Client, prefix string, state l.TokenBucketState, ttl time.Duration) error {
291 | _, err := cli.TxPipelined(ctx, func(pipeliner redis.Pipeliner) error {
292 | err := pipeliner.Set(ctx, fmt.Sprintf("{%s}last", prefix), state.Last, ttl).Err()
293 | if err != nil {
294 | return err
295 | }
296 |
297 | err = pipeliner.Set(ctx, fmt.Sprintf("{%s}available", prefix), state.Available, ttl).Err()
298 | if err != nil {
299 | return err
300 | }
301 |
302 | return nil
303 | })
304 |
305 | return err
306 | }
307 |
308 | // TestTokenBucketRedisBackwardCompatibility tests that the new State method can read data written in the old format.
309 | func (s *LimitersTestSuite) TestTokenBucketRedisBackwardCompatibility() {
310 | // Create a new TokenBucketRedis instance
311 | prefix := uuid.New().String()
312 | backend := l.NewTokenBucketRedis(s.redisClient, prefix, time.Minute, false)
313 |
314 | // Write state using old format
315 | ctx := context.Background()
316 | expectedState := l.TokenBucketState{
317 | Last: 12345,
318 | Available: 67890,
319 | }
320 |
321 | // Write directly to Redis using old format
322 | err := setTokenBucketStateInOldFormat(ctx, s.redisClient, prefix, expectedState, time.Minute)
323 | s.Require().NoError(err, "Failed to set state using old format")
324 |
325 | // Read state using new format (State)
326 | actualState, err := backend.State(ctx)
327 | s.Require().NoError(err, "Failed to get state using new format")
328 |
329 | // Verify the state is correctly read
330 | s.Equal(expectedState.Last, actualState.Last, "Last values should match")
331 | s.Equal(expectedState.Available, actualState.Available, "Available values should match")
332 | }
333 |
334 | func (s *LimitersTestSuite) TestTokenBucketNoExpiration() {
335 | clock := l.NewSystemClock()
336 | buckets := s.tokenBuckets(1, time.Minute, 0, clock)
337 |
338 | // Take all capacity from all buckets
339 | for _, bucket := range buckets {
340 | _, _ = bucket.Limit(context.TODO())
341 | }
342 |
343 | // Wait for 2 seconds to check if it treats 0 TTL as infinite
344 | clock.Sleep(2 * time.Second)
345 |
346 | // Expect all buckets to be still filled
347 | for name, bucket := range buckets {
348 | s.Run(name, func() {
349 | _, err := bucket.Limit(context.TODO())
350 | s.Require().Equal(l.ErrLimitExhausted, err)
351 | })
352 | }
353 | }
354 |
355 | func (s *LimitersTestSuite) TestTokenBucketTTLExpiration() {
356 | clock := l.NewSystemClock()
357 | buckets := s.tokenBuckets(1, time.Minute, time.Second, clock)
358 |
359 | // Ignore in-memory bucket, as it has no expiration,
360 | // ignore DynamoDB, as amazon/dynamodb-local doesn't support TTLs.
361 | for k := range buckets {
362 | if strings.Contains(k, "BucketInMemory") || strings.Contains(k, "BucketDynamoDB") {
363 | delete(buckets, k)
364 | }
365 | }
366 |
367 | // Take all capacity from all buckets
368 | for _, bucket := range buckets {
369 | _, _ = bucket.Limit(context.TODO())
370 | }
371 |
372 | // Wait for 3 seconds to check if the items have been deleted successfully
373 | clock.Sleep(3 * time.Second)
374 |
375 | // Expect all buckets to be still empty (as the data expired)
376 | for name, bucket := range buckets {
377 | s.Run(name, func() {
378 | wait, err := bucket.Limit(context.TODO())
379 | s.Require().Equal(time.Duration(0), wait)
380 | s.Require().NoError(err)
381 | })
382 | }
383 | }
384 |
385 | func BenchmarkTokenBuckets(b *testing.B) {
386 | s := new(LimitersTestSuite)
387 | s.SetT(&testing.T{})
388 | s.SetupSuite()
389 |
390 | capacity := int64(1)
391 | rate := 100 * time.Millisecond
392 | clock := newFakeClock()
393 |
394 | buckets := s.tokenBuckets(capacity, rate, 0, clock)
395 | for name, bucket := range buckets {
396 | b.Run(name, func(b *testing.B) {
397 | for i := 0; i < b.N; i++ {
398 | _, err := bucket.Limit(context.TODO())
399 | s.Require().NoError(err)
400 | }
401 | })
402 | }
403 |
404 | s.TearDownSuite()
405 | }
406 |
--------------------------------------------------------------------------------
/leakybucket_test.go:
--------------------------------------------------------------------------------
1 | package limiters_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strings"
7 | "sync"
8 | "testing"
9 | "time"
10 |
11 | "github.com/google/uuid"
12 | l "github.com/mennanov/limiters"
13 | "github.com/redis/go-redis/v9"
14 | "github.com/stretchr/testify/require"
15 | )
16 |
17 | // leakyBuckets returns all the possible leakyBuckets combinations.
18 | func (s *LimitersTestSuite) leakyBuckets(capacity int64, rate, ttl time.Duration, clock l.Clock) map[string]*l.LeakyBucket {
19 | buckets := make(map[string]*l.LeakyBucket)
20 |
21 | for lockerName, locker := range s.lockers(true) {
22 | for backendName, backend := range s.leakyBucketBackends(ttl) {
23 | buckets[lockerName+":"+backendName] = l.NewLeakyBucket(capacity, rate, locker, backend, clock, s.logger)
24 | }
25 | }
26 |
27 | return buckets
28 | }
29 |
30 | func (s *LimitersTestSuite) leakyBucketBackends(ttl time.Duration) map[string]l.LeakyBucketStateBackend {
31 | return map[string]l.LeakyBucketStateBackend{
32 | "LeakyBucketInMemory": l.NewLeakyBucketInMemory(),
33 | "LeakyBucketEtcdNoRaceCheck": l.NewLeakyBucketEtcd(s.etcdClient, uuid.New().String(), ttl, false),
34 | "LeakyBucketEtcdWithRaceCheck": l.NewLeakyBucketEtcd(s.etcdClient, uuid.New().String(), ttl, true),
35 | "LeakyBucketRedisNoRaceCheck": l.NewLeakyBucketRedis(s.redisClient, uuid.New().String(), ttl, false),
36 | "LeakyBucketRedisWithRaceCheck": l.NewLeakyBucketRedis(s.redisClient, uuid.New().String(), ttl, true),
37 | "LeakyBucketRedisClusterNoRaceCheck": l.NewLeakyBucketRedis(s.redisClusterClient, uuid.New().String(), ttl, false),
38 | "LeakyBucketRedisClusterWithRaceCheck": l.NewLeakyBucketRedis(s.redisClusterClient, uuid.New().String(), ttl, true),
39 | "LeakyBucketMemcachedNoRaceCheck": l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, false),
40 | "LeakyBucketMemcachedWithRaceCheck": l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, true),
41 | "LeakyBucketDynamoDBNoRaceCheck": l.NewLeakyBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, ttl, false),
42 | "LeakyBucketDynamoDBWithRaceCheck": l.NewLeakyBucketDynamoDB(s.dynamodbClient, uuid.New().String(), s.dynamoDBTableProps, ttl, true),
43 | "LeakyBucketCosmosDBNoRaceCheck": l.NewLeakyBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), ttl, false),
44 | "LeakyBucketCosmosDBWithRaceCheck": l.NewLeakyBucketCosmosDB(s.cosmosContainerClient, uuid.New().String(), ttl, true),
45 | }
46 | }
47 |
48 | func (s *LimitersTestSuite) TestLeakyBucketRealClock() {
49 | capacity := int64(10)
50 | rate := time.Millisecond * 10
51 |
52 | clock := l.NewSystemClock()
53 | for _, requestRate := range []time.Duration{rate / 2} {
54 | for name, bucket := range s.leakyBuckets(capacity, rate, 0, clock) {
55 | s.Run(name, func() {
56 | wg := sync.WaitGroup{}
57 | mu := sync.Mutex{}
58 |
59 | var totalWait time.Duration
60 |
61 | for i := int64(0); i < capacity; i++ {
62 | // No pause for the first request.
63 | if i > 0 {
64 | clock.Sleep(requestRate)
65 | }
66 |
67 | wg.Add(1)
68 |
69 | go func(bucket *l.LeakyBucket) {
70 | defer wg.Done()
71 |
72 | wait, err := bucket.Limit(context.TODO())
73 | s.Require().NoError(err)
74 |
75 | if wait > 0 {
76 | mu.Lock()
77 |
78 | totalWait += wait
79 |
80 | mu.Unlock()
81 | clock.Sleep(wait)
82 | }
83 | }(bucket)
84 | }
85 |
86 | wg.Wait()
87 |
88 | expectedWait := time.Duration(0)
89 | if rate > requestRate {
90 | expectedWait = time.Duration(float64(rate-requestRate) * float64(capacity-1) / 2 * float64(capacity))
91 | }
92 |
93 | // Allow 5ms lag for each request.
94 | // TODO: figure out if this is enough for slow PCs and possibly avoid hard-coding it.
95 | delta := float64(time.Duration(capacity) * time.Millisecond * 25)
96 | s.InDelta(expectedWait, totalWait, delta, "request rate: %d, bucket: %v", requestRate, bucket)
97 | })
98 | }
99 | }
100 | }
101 |
102 | func (s *LimitersTestSuite) TestLeakyBucketFakeClock() {
103 | capacity := int64(10)
104 | rate := time.Millisecond * 100
105 |
106 | clock := newFakeClock()
107 | for _, requestRate := range []time.Duration{rate * 2, rate, rate / 2, rate / 3, rate / 4, 0} {
108 | for name, bucket := range s.leakyBuckets(capacity, rate, 0, clock) {
109 | s.Run(name, func() {
110 | clock.reset()
111 | start := clock.Now()
112 |
113 | for i := int64(0); i < capacity; i++ {
114 | // No pause for the first request.
115 | if i > 0 {
116 | clock.Sleep(requestRate)
117 | }
118 |
119 | wait, err := bucket.Limit(context.TODO())
120 | s.Require().NoError(err)
121 | clock.Sleep(wait)
122 | }
123 |
124 | interval := rate
125 | if requestRate > rate {
126 | interval = requestRate
127 | }
128 |
129 | s.Equal(interval*time.Duration(capacity-1), clock.Now().Sub(start), "request rate: %d, bucket: %v", requestRate, bucket)
130 | })
131 | }
132 | }
133 | }
134 |
135 | func (s *LimitersTestSuite) TestLeakyBucketOverflow() {
136 | rate := 100 * time.Millisecond
137 | capacity := int64(2)
138 |
139 | clock := newFakeClock()
140 | for name, bucket := range s.leakyBuckets(capacity, rate, 0, clock) {
141 | s.Run(name, func() {
142 | clock.reset()
143 | // The first call has no wait since there were no calls before.
144 | wait, err := bucket.Limit(context.TODO())
145 | s.Require().NoError(err)
146 | s.Equal(time.Duration(0), wait)
147 | // The second call increments the queue size by 1.
148 | wait, err = bucket.Limit(context.TODO())
149 | s.Require().NoError(err)
150 | s.Equal(rate, wait)
151 | // The third call overflows the bucket capacity.
152 | wait, err = bucket.Limit(context.TODO())
153 | s.Require().Equal(l.ErrLimitExhausted, err)
154 | s.Equal(rate*2, wait)
155 | // Move the Clock 1 position forward.
156 | clock.Sleep(rate)
157 | // Retry the last call. This time it should succeed.
158 | wait, err = bucket.Limit(context.TODO())
159 | s.Require().NoError(err)
160 | s.Equal(rate, wait)
161 | })
162 | }
163 | }
164 |
165 | func (s *LimitersTestSuite) TestLeakyBucketNoExpiration() {
166 | clock := l.NewSystemClock()
167 | buckets := s.leakyBuckets(1, time.Minute, 0, clock)
168 |
169 | // Take all capacity from all buckets
170 | for _, bucket := range buckets {
171 | _, _ = bucket.Limit(context.TODO())
172 | _, _ = bucket.Limit(context.TODO())
173 | }
174 |
175 | // Wait for 2 seconds to check if it treats 0 TTL as infinite
176 | clock.Sleep(2 * time.Second)
177 |
178 | // Expect all buckets to be still filled
179 | for name, bucket := range buckets {
180 | s.Run(name, func() {
181 | _, err := bucket.Limit(context.TODO())
182 | s.Require().Equal(l.ErrLimitExhausted, err)
183 | })
184 | }
185 | }
186 |
187 | func (s *LimitersTestSuite) TestLeakyBucketTTLExpiration() {
188 | clock := l.NewSystemClock()
189 | buckets := s.leakyBuckets(1, time.Minute, time.Second, clock)
190 |
191 | // Ignore in-memory bucket, as it has no expiration,
192 | // ignore DynamoDB, as amazon/dynamodb-local doesn't support TTLs.
193 | for k := range buckets {
194 | if strings.Contains(k, "BucketInMemory") || strings.Contains(k, "BucketDynamoDB") {
195 | delete(buckets, k)
196 | }
197 | }
198 |
199 | // Take all capacity from all buckets
200 | for _, bucket := range buckets {
201 | _, _ = bucket.Limit(context.TODO())
202 | _, _ = bucket.Limit(context.TODO())
203 | }
204 |
205 | // Wait for 3 seconds to check if the items have been deleted successfully
206 | clock.Sleep(3 * time.Second)
207 |
208 | // Expect all buckets to be empty (as the data expired)
209 | for name, bucket := range buckets {
210 | s.Run(name, func() {
211 | wait, err := bucket.Limit(context.TODO())
212 | s.Require().Equal(time.Duration(0), wait)
213 | s.Require().NoError(err)
214 | })
215 | }
216 | }
217 |
218 | func (s *LimitersTestSuite) TestLeakyBucketReset() {
219 | rate := 100 * time.Millisecond
220 | capacity := int64(2)
221 |
222 | clock := newFakeClock()
223 | for name, bucket := range s.leakyBuckets(capacity, rate, 0, clock) {
224 | s.Run(name, func() {
225 | clock.reset()
226 | // The first call has no wait since there were no calls before.
227 | wait, err := bucket.Limit(context.TODO())
228 | s.Require().NoError(err)
229 | s.Equal(time.Duration(0), wait)
230 | // The second call increments the queue size by 1.
231 | wait, err = bucket.Limit(context.TODO())
232 | s.Require().NoError(err)
233 | s.Equal(rate, wait)
234 | // The third call overflows the bucket capacity.
235 | wait, err = bucket.Limit(context.TODO())
236 | s.Require().Equal(l.ErrLimitExhausted, err)
237 | s.Equal(rate*2, wait)
238 | // Reset the bucket
239 | err = bucket.Reset(context.TODO())
240 | s.Require().NoError(err)
241 | // Retry the last call. This time it should succeed.
242 | wait, err = bucket.Limit(context.TODO())
243 | s.Require().NoError(err)
244 | s.Equal(time.Duration(0), wait)
245 | })
246 | }
247 | }
248 |
249 | // TestLeakyBucketMemcachedExpiry verifies Memcached TTL expiry for LeakyBucketMemcached (no race check).
250 | func (s *LimitersTestSuite) TestLeakyBucketMemcachedExpiry() {
251 | ttl := time.Second
252 | backend := l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, false)
253 | bucket := l.NewLeakyBucket(2, time.Second, l.NewLockNoop(), backend, l.NewSystemClock(), s.logger)
254 | ctx := context.Background()
255 | _, err := bucket.Limit(ctx)
256 | s.Require().NoError(err)
257 | _, err = bucket.Limit(ctx)
258 | s.Require().NoError(err)
259 | state, err := backend.State(ctx)
260 | s.Require().NoError(err)
261 | s.NotEqual(int64(0), state.Last, "Last should be set after token takes")
262 | time.Sleep(ttl + 1500*time.Millisecond)
263 |
264 | state, err = backend.State(ctx)
265 | s.Require().NoError(err)
266 | s.Equal(int64(0), state.Last, "State should be zero after expiry, got: %+v", state)
267 | }
268 |
269 | // TestLeakyBucketMemcachedExpiryWithRaceCheck verifies Memcached TTL expiry for LeakyBucketMemcached (race check enabled).
270 | func (s *LimitersTestSuite) TestLeakyBucketMemcachedExpiryWithRaceCheck() {
271 | ttl := time.Second
272 | backend := l.NewLeakyBucketMemcached(s.memcacheClient, uuid.New().String(), ttl, true)
273 | bucket := l.NewLeakyBucket(2, time.Second, l.NewLockNoop(), backend, l.NewSystemClock(), s.logger)
274 | ctx := context.Background()
275 | _, err := bucket.Limit(ctx)
276 | s.Require().NoError(err)
277 | _, err = bucket.Limit(ctx)
278 | s.Require().NoError(err)
279 | state, err := backend.State(ctx)
280 | s.Require().NoError(err)
281 | s.NotEqual(int64(0), state.Last, "Last should be set after token takes")
282 | time.Sleep(ttl + 1500*time.Millisecond)
283 |
284 | state, err = backend.State(ctx)
285 | s.Require().NoError(err)
286 | s.Equal(int64(0), state.Last, "State should be zero after expiry, got: %+v", state)
287 | }
288 |
289 | func TestLeakyBucket_ZeroCapacity_ReturnsError(t *testing.T) {
290 | capacity := int64(0)
291 | rate := time.Hour
292 | logger := l.NewStdLogger()
293 | bucket := l.NewLeakyBucket(capacity, rate, l.NewLockNoop(), l.NewLeakyBucketInMemory(), newFakeClock(), logger)
294 | wait, err := bucket.Limit(context.TODO())
295 | require.Equal(t, l.ErrLimitExhausted, err)
296 | require.Equal(t, time.Duration(0), wait)
297 | }
298 |
299 | func BenchmarkLeakyBuckets(b *testing.B) {
300 | s := new(LimitersTestSuite)
301 | s.SetT(&testing.T{})
302 | s.SetupSuite()
303 |
304 | capacity := int64(1)
305 | rate := 100 * time.Millisecond
306 | clock := newFakeClock()
307 |
308 | buckets := s.leakyBuckets(capacity, rate, 0, clock)
309 | for name, bucket := range buckets {
310 | b.Run(name, func(b *testing.B) {
311 | for i := 0; i < b.N; i++ {
312 | _, err := bucket.Limit(context.TODO())
313 | s.Require().NoError(err)
314 | }
315 | })
316 | }
317 |
318 | s.TearDownSuite()
319 | }
320 |
321 | // setStateInOldFormat is a test utility method for writing state in the old format to Redis.
322 | func setStateInOldFormat(ctx context.Context, cli *redis.Client, prefix string, state l.LeakyBucketState, ttl time.Duration) error {
323 | _, err := cli.TxPipelined(ctx, func(pipeliner redis.Pipeliner) error {
324 | err := pipeliner.Set(ctx, fmt.Sprintf("{%s}last", prefix), state.Last, ttl).Err()
325 | if err != nil {
326 | return err
327 | }
328 |
329 | return nil
330 | })
331 |
332 | return err
333 | }
334 |
335 | // TestLeakyBucketRedisBackwardCompatibility tests that the new State method can read data written in the old format.
336 | func (s *LimitersTestSuite) TestLeakyBucketRedisBackwardCompatibility() {
337 | // Create a new LeakyBucketRedis instance
338 | prefix := uuid.New().String()
339 | backend := l.NewLeakyBucketRedis(s.redisClient, prefix, time.Minute, false)
340 |
341 | // Write state using old format
342 | ctx := context.Background()
343 | expectedState := l.LeakyBucketState{
344 | Last: 12345,
345 | }
346 |
347 | // Write directly to Redis using old format
348 | err := setStateInOldFormat(ctx, s.redisClient, prefix, expectedState, time.Minute)
349 | s.Require().NoError(err, "Failed to set state using old format")
350 |
351 | // Read state using new format (State)
352 | actualState, err := backend.State(ctx)
353 | s.Require().NoError(err, "Failed to get state using new format")
354 |
355 | // Verify the state is correctly read
356 | s.Equal(expectedState.Last, actualState.Last, "Last values should match")
357 | }
358 |
--------------------------------------------------------------------------------
/slidingwindow.go:
--------------------------------------------------------------------------------
1 | package limiters
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "math"
8 | "net/http"
9 | "strconv"
10 | "sync"
11 | "time"
12 |
13 | "github.com/Azure/azure-sdk-for-go/sdk/azcore"
14 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
15 | "github.com/aws/aws-sdk-go-v2/aws"
16 | "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
17 | "github.com/aws/aws-sdk-go-v2/service/dynamodb"
18 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
19 | "github.com/bradfitz/gomemcache/memcache"
20 | "github.com/pkg/errors"
21 | "github.com/redis/go-redis/v9"
22 | )
23 |
24 | // SlidingWindowIncrementer wraps the Increment method.
25 | type SlidingWindowIncrementer interface {
26 | // Increment increments the request counter for the current window and returns the counter values for the previous
27 | // window and the current one.
28 | // TTL is the time duration before the next window.
29 | Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (prevCount, currCount int64, err error)
30 | }
31 |
32 | // SlidingWindow implements a Sliding Window rate limiting algorithm.
33 | //
34 | // It does not require a distributed lock and uses a minimum amount of memory, however it will disallow all the requests
35 | // in case when a client is flooding the service with requests.
36 | // It's the client's responsibility to handle the disallowed request and wait before making a new request again.
37 | type SlidingWindow struct {
38 | backend SlidingWindowIncrementer
39 | clock Clock
40 | rate time.Duration
41 | capacity int64
42 | epsilon float64
43 | }
44 |
45 | // NewSlidingWindow creates a new instance of SlidingWindow.
46 | // Capacity is the maximum amount of requests allowed per window.
47 | // Rate is the window size.
48 | // Epsilon is the max-allowed range of difference when comparing the current weighted number of requests with capacity.
49 | func NewSlidingWindow(capacity int64, rate time.Duration, slidingWindowIncrementer SlidingWindowIncrementer, clock Clock, epsilon float64) *SlidingWindow {
50 | return &SlidingWindow{backend: slidingWindowIncrementer, clock: clock, rate: rate, capacity: capacity, epsilon: epsilon}
51 | }
52 |
53 | // Limit returns the time duration to wait before the request can be processed.
54 | // It returns ErrLimitExhausted if the request overflows the capacity.
55 | func (s *SlidingWindow) Limit(ctx context.Context) (time.Duration, error) {
56 | now := s.clock.Now()
57 | currWindow := now.Truncate(s.rate)
58 | prevWindow := currWindow.Add(-s.rate)
59 | ttl := s.rate - now.Sub(currWindow)
60 |
61 | prev, curr, err := s.backend.Increment(ctx, prevWindow, currWindow, ttl+s.rate)
62 | if err != nil {
63 | return 0, err
64 | }
65 |
66 | // "prev" and "curr" are capped at "s.capacity + s.epsilon" using math.Ceil to round up any fractional values,
67 | // ensuring that in the worst case, "total" can be slightly greater than "s.capacity".
68 | prev = int64(math.Min(float64(prev), math.Ceil(float64(s.capacity)+s.epsilon)))
69 | curr = int64(math.Min(float64(curr), math.Ceil(float64(s.capacity)+s.epsilon)))
70 |
71 | total := float64(prev*int64(ttl))/float64(s.rate) + float64(curr)
72 | if total-float64(s.capacity) >= s.epsilon {
73 | var wait time.Duration
74 | if curr <= s.capacity-1 && prev > 0 {
75 | wait = ttl - time.Duration(float64(s.capacity-1-curr)/float64(prev)*float64(s.rate))
76 | } else {
77 | // If prev == 0.
78 | wait = ttl + time.Duration((1-float64(s.capacity-1)/float64(curr))*float64(s.rate))
79 | }
80 |
81 | return wait, ErrLimitExhausted
82 | }
83 |
84 | return 0, nil
85 | }
86 |
87 | // SlidingWindowInMemory is an in-memory implementation of SlidingWindowIncrementer.
88 | type SlidingWindowInMemory struct {
89 | mu sync.Mutex
90 | prevC, currC int64
91 | prevW, currW time.Time
92 | }
93 |
94 | // NewSlidingWindowInMemory creates a new instance of SlidingWindowInMemory.
95 | func NewSlidingWindowInMemory() *SlidingWindowInMemory {
96 | return &SlidingWindowInMemory{}
97 | }
98 |
99 | // Increment increments the current window's counter and returns the number of requests in the previous window and the
100 | // current one.
101 | func (s *SlidingWindowInMemory) Increment(ctx context.Context, prev, curr time.Time, _ time.Duration) (int64, int64, error) {
102 | s.mu.Lock()
103 | defer s.mu.Unlock()
104 |
105 | if curr != s.currW {
106 | if prev.Equal(s.currW) {
107 | s.prevW = s.currW
108 | s.prevC = s.currC
109 | } else {
110 | s.prevW = time.Time{}
111 | s.prevC = 0
112 | }
113 |
114 | s.currW = curr
115 | s.currC = 0
116 | }
117 |
118 | s.currC++
119 |
120 | return s.prevC, s.currC, ctx.Err()
121 | }
122 |
123 | // SlidingWindowRedis implements SlidingWindow in Redis.
124 | type SlidingWindowRedis struct {
125 | cli redis.UniversalClient
126 | prefix string
127 | }
128 |
129 | // NewSlidingWindowRedis creates a new instance of SlidingWindowRedis.
130 | func NewSlidingWindowRedis(cli redis.UniversalClient, prefix string) *SlidingWindowRedis {
131 | return &SlidingWindowRedis{cli: cli, prefix: prefix}
132 | }
133 |
134 | // Increment increments the current window's counter in Redis and returns the number of requests in the previous window
135 | // and the current one.
136 | func (s *SlidingWindowRedis) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) {
137 | var (
138 | incr *redis.IntCmd
139 | prevCountCmd *redis.StringCmd
140 | err error
141 | )
142 |
143 | done := make(chan struct{})
144 |
145 | go func() {
146 | defer close(done)
147 |
148 | _, err = s.cli.Pipelined(ctx, func(pipeliner redis.Pipeliner) error {
149 | currKey := fmt.Sprintf("%d", curr.UnixNano())
150 | incr = pipeliner.Incr(ctx, redisKey(s.prefix, currKey))
151 | pipeliner.PExpire(ctx, redisKey(s.prefix, currKey), ttl)
152 | prevCountCmd = pipeliner.Get(ctx, redisKey(s.prefix, fmt.Sprintf("%d", prev.UnixNano())))
153 |
154 | return nil
155 | })
156 | }()
157 |
158 | var prevCount int64
159 |
160 | select {
161 | case <-done:
162 | if errors.Is(err, redis.TxFailedErr) {
163 | return 0, 0, errors.Wrap(err, "redis transaction failed")
164 | } else if errors.Is(err, redis.Nil) {
165 | prevCount = 0
166 | } else if err != nil {
167 | return 0, 0, errors.Wrap(err, "unexpected error from redis")
168 | } else {
169 | prevCount, err = strconv.ParseInt(prevCountCmd.Val(), 10, 64)
170 | if err != nil {
171 | return 0, 0, errors.Wrap(err, "failed to parse response from redis")
172 | }
173 | }
174 |
175 | return prevCount, incr.Val(), nil
176 | case <-ctx.Done():
177 | return 0, 0, ctx.Err()
178 | }
179 | }
180 |
181 | // SlidingWindowMemcached implements SlidingWindow in Memcached.
182 | type SlidingWindowMemcached struct {
183 | cli *memcache.Client
184 | prefix string
185 | }
186 |
187 | // NewSlidingWindowMemcached creates a new instance of SlidingWindowMemcached.
188 | func NewSlidingWindowMemcached(cli *memcache.Client, prefix string) *SlidingWindowMemcached {
189 | return &SlidingWindowMemcached{cli: cli, prefix: prefix}
190 | }
191 |
192 | // Increment increments the current window's counter in Memcached and returns the number of requests in the previous window
193 | // and the current one.
194 | func (s *SlidingWindowMemcached) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) {
195 | var (
196 | prevCount uint64
197 | currCount uint64
198 | err error
199 | )
200 |
201 | done := make(chan struct{})
202 |
203 | go func() {
204 | defer close(done)
205 |
206 | var item *memcache.Item
207 |
208 | prevKey := fmt.Sprintf("%s:%d", s.prefix, prev.UnixNano())
209 |
210 | item, err = s.cli.Get(prevKey)
211 | if err != nil {
212 | if errors.Is(err, memcache.ErrCacheMiss) {
213 | err = nil
214 | prevCount = 0
215 | } else {
216 | return
217 | }
218 | } else {
219 | prevCount, err = strconv.ParseUint(string(item.Value), 10, 64)
220 | if err != nil {
221 | return
222 | }
223 | }
224 |
225 | currKey := fmt.Sprintf("%s:%d", s.prefix, curr.UnixNano())
226 |
227 | currCount, err = s.cli.Increment(currKey, 1)
228 | if err != nil && errors.Is(err, memcache.ErrCacheMiss) {
229 | currCount = 1
230 | item = &memcache.Item{
231 | Key: currKey,
232 | Value: []byte(strconv.FormatUint(currCount, 10)),
233 | }
234 | err = s.cli.Add(item)
235 | }
236 | }()
237 |
238 | select {
239 | case <-done:
240 | if err != nil {
241 | if errors.Is(err, memcache.ErrNotStored) {
242 | return s.Increment(ctx, prev, curr, ttl)
243 | }
244 |
245 | return 0, 0, err
246 | }
247 |
248 | return int64(prevCount), int64(currCount), nil
249 | case <-ctx.Done():
250 | return 0, 0, ctx.Err()
251 | }
252 | }
253 |
254 | // SlidingWindowDynamoDB implements SlidingWindow in DynamoDB.
255 | type SlidingWindowDynamoDB struct {
256 | client *dynamodb.Client
257 | partitionKey string
258 | tableProps DynamoDBTableProperties
259 | }
260 |
261 | // NewSlidingWindowDynamoDB creates a new instance of SlidingWindowDynamoDB.
262 | // PartitionKey is the key used to store all the this implementation in DynamoDB.
263 | //
264 | // TableProps describe the table that this backend should work with. This backend requires the following on the table:
265 | // * SortKey
266 | // * TTL.
267 | func NewSlidingWindowDynamoDB(client *dynamodb.Client, partitionKey string, props DynamoDBTableProperties) *SlidingWindowDynamoDB {
268 | return &SlidingWindowDynamoDB{
269 | client: client,
270 | partitionKey: partitionKey,
271 | tableProps: props,
272 | }
273 | }
274 |
275 | // Increment increments the current window's counter in DynamoDB and returns the number of requests in the previous window
276 | // and the current one.
277 | func (s *SlidingWindowDynamoDB) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) {
278 | wg := &sync.WaitGroup{}
279 | wg.Add(2)
280 |
281 | done := make(chan struct{})
282 |
283 | go func() {
284 | defer close(done)
285 |
286 | wg.Wait()
287 | }()
288 |
289 | var (
290 | currentCount int64
291 | currentErr error
292 | )
293 |
294 | go func() {
295 | defer wg.Done()
296 |
297 | resp, err := s.client.UpdateItem(ctx, &dynamodb.UpdateItemInput{
298 | Key: map[string]types.AttributeValue{
299 | s.tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: s.partitionKey},
300 | s.tableProps.SortKeyName: &types.AttributeValueMemberS{Value: strconv.FormatInt(curr.UnixNano(), 10)},
301 | },
302 | UpdateExpression: aws.String(fixedWindowDynamoDBUpdateExpression),
303 | ExpressionAttributeNames: map[string]string{
304 | "#TTL": s.tableProps.TTLFieldName,
305 | "#C": dynamodbWindowCountKey,
306 | },
307 | ExpressionAttributeValues: map[string]types.AttributeValue{
308 | ":ttl": &types.AttributeValueMemberN{Value: strconv.FormatInt(time.Now().Add(ttl).Unix(), 10)},
309 | ":def": &types.AttributeValueMemberN{Value: "0"},
310 | ":inc": &types.AttributeValueMemberN{Value: "1"},
311 | },
312 | TableName: &s.tableProps.TableName,
313 | ReturnValues: types.ReturnValueAllNew,
314 | })
315 | if err != nil {
316 | currentErr = errors.Wrap(err, "dynamodb get item failed")
317 |
318 | return
319 | }
320 |
321 | var tmp float64
322 |
323 | err = attributevalue.Unmarshal(resp.Attributes[dynamodbWindowCountKey], &tmp)
324 | if err != nil {
325 | currentErr = errors.Wrap(err, "unmarshal of dynamodb attribute value failed")
326 |
327 | return
328 | }
329 |
330 | currentCount = int64(tmp)
331 | }()
332 |
333 | var (
334 | priorCount int64
335 | priorErr error
336 | )
337 |
338 | go func() {
339 | defer wg.Done()
340 |
341 | resp, err := s.client.GetItem(ctx, &dynamodb.GetItemInput{
342 | TableName: aws.String(s.tableProps.TableName),
343 | Key: map[string]types.AttributeValue{
344 | s.tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: s.partitionKey},
345 | s.tableProps.SortKeyName: &types.AttributeValueMemberS{Value: strconv.FormatInt(prev.UnixNano(), 10)},
346 | },
347 | ConsistentRead: aws.Bool(true),
348 | })
349 | if err != nil {
350 | priorCount, priorErr = 0, errors.Wrap(err, "dynamodb get item failed")
351 |
352 | return
353 | }
354 |
355 | if len(resp.Item) == 0 {
356 | priorCount = 0
357 |
358 | return
359 | }
360 |
361 | var count float64
362 |
363 | err = attributevalue.Unmarshal(resp.Item[dynamodbWindowCountKey], &count)
364 | if err != nil {
365 | priorCount, priorErr = 0, errors.Wrap(err, "unmarshal of dynamodb attribute value failed")
366 |
367 | return
368 | }
369 |
370 | priorCount = int64(count)
371 | }()
372 |
373 | select {
374 | case <-done:
375 | case <-ctx.Done():
376 | return 0, 0, ctx.Err()
377 | }
378 |
379 | if currentErr != nil {
380 | return 0, 0, errors.Wrap(currentErr, "failed to update current count")
381 | } else if priorErr != nil {
382 | return 0, 0, errors.Wrap(priorErr, "failed to get previous count")
383 | }
384 |
385 | return priorCount, currentCount, nil
386 | }
387 |
388 | // SlidingWindowCosmosDB implements SlidingWindow in Azure Cosmos DB.
389 | type SlidingWindowCosmosDB struct {
390 | client *azcosmos.ContainerClient
391 | partitionKey string
392 | }
393 |
394 | // NewSlidingWindowCosmosDB creates a new instance of SlidingWindowCosmosDB.
395 | // PartitionKey is the key used to store all the this implementation in Cosmos.
396 | func NewSlidingWindowCosmosDB(client *azcosmos.ContainerClient, partitionKey string) *SlidingWindowCosmosDB {
397 | return &SlidingWindowCosmosDB{
398 | client: client,
399 | partitionKey: partitionKey,
400 | }
401 | }
402 |
403 | // Increment increments the current window's counter in Cosmos and returns the number of requests in the previous window
404 | // and the current one.
405 | func (s *SlidingWindowCosmosDB) Increment(ctx context.Context, prev, curr time.Time, ttl time.Duration) (int64, int64, error) {
406 | wg := &sync.WaitGroup{}
407 | wg.Add(2)
408 |
409 | done := make(chan struct{})
410 |
411 | go func() {
412 | defer close(done)
413 |
414 | wg.Wait()
415 | }()
416 |
417 | var (
418 | currentCount int64
419 | currentErr error
420 | )
421 |
422 | go func() {
423 | defer wg.Done()
424 |
425 | id := strconv.FormatInt(curr.UnixNano(), 10)
426 | tmp := cosmosItem{
427 | ID: id,
428 | PartitionKey: s.partitionKey,
429 | Count: 1,
430 | TTL: int32(ttl),
431 | }
432 |
433 | ops := azcosmos.PatchOperations{}
434 | ops.AppendIncrement(`/Count`, 1)
435 |
436 | patchResp, err := s.client.PatchItem(ctx, azcosmos.NewPartitionKey().AppendString(s.partitionKey), id, ops, &azcosmos.ItemOptions{
437 | EnableContentResponseOnWrite: true,
438 | })
439 | if err == nil {
440 | // value exists and was updated
441 | err = json.Unmarshal(patchResp.Value, &tmp)
442 | if err != nil {
443 | currentErr = errors.Wrap(err, "unmarshal of cosmos value current failed")
444 |
445 | return
446 | }
447 |
448 | currentCount = tmp.Count
449 |
450 | return
451 | }
452 |
453 | var respErr *azcore.ResponseError
454 | if !errors.As(err, &respErr) || (respErr.StatusCode != http.StatusNotFound && respErr.StatusCode != http.StatusBadRequest) {
455 | currentErr = errors.Wrap(err, `patch of cosmos value current failed`)
456 |
457 | return
458 | }
459 |
460 | newValue, err := json.Marshal(tmp)
461 | if err != nil {
462 | currentErr = errors.Wrap(err, "marshal of cosmos value current failed")
463 |
464 | return
465 | }
466 |
467 | _, err = s.client.CreateItem(ctx, azcosmos.NewPartitionKey().AppendString(s.partitionKey), newValue, &azcosmos.ItemOptions{
468 | SessionToken: patchResp.SessionToken,
469 | })
470 | if err != nil {
471 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict {
472 | // Fallback to Read-Modify-Write
473 | currentCount, currentErr = incrementCosmosItemRMW(ctx, s.client, s.partitionKey, id, ttl)
474 |
475 | return
476 | }
477 |
478 | currentErr = errors.Wrap(err, "upsert of cosmos value current failed")
479 |
480 | return
481 | }
482 |
483 | currentCount = tmp.Count
484 | }()
485 |
486 | var (
487 | priorCount int64
488 | priorErr error
489 | )
490 |
491 | go func() {
492 | defer wg.Done()
493 |
494 | id := strconv.FormatInt(prev.UnixNano(), 10)
495 |
496 | resp, err := s.client.ReadItem(ctx, azcosmos.NewPartitionKey().AppendString(s.partitionKey), id, &azcosmos.ItemOptions{})
497 | if err != nil {
498 | var azerr *azcore.ResponseError
499 | if errors.As(err, &azerr) && azerr.StatusCode == http.StatusNotFound {
500 | priorCount, priorErr = 0, nil
501 |
502 | return
503 | }
504 |
505 | priorErr = errors.Wrap(err, "cosmos get item prior failed")
506 |
507 | return
508 | }
509 |
510 | var tmp cosmosItem
511 |
512 | err = json.Unmarshal(resp.Value, &tmp)
513 | if err != nil {
514 | priorErr = errors.Wrap(err, "unmarshal of cosmos value prior failed")
515 |
516 | return
517 | }
518 |
519 | priorCount = tmp.Count
520 | }()
521 |
522 | select {
523 | case <-done:
524 | case <-ctx.Done():
525 | return 0, 0, ctx.Err()
526 | }
527 |
528 | if currentErr != nil {
529 | return 0, 0, errors.Wrap(currentErr, "failed to update current count")
530 | } else if priorErr != nil {
531 | return 0, 0, errors.Wrap(priorErr, "failed to get previous count")
532 | }
533 |
534 | return priorCount, currentCount, nil
535 | }
536 |
--------------------------------------------------------------------------------
/leakybucket.go:
--------------------------------------------------------------------------------
1 | package limiters
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/gob"
7 | "encoding/json"
8 | "fmt"
9 | "math"
10 | "net/http"
11 | "strconv"
12 | "sync"
13 | "time"
14 |
15 | "github.com/Azure/azure-sdk-for-go/sdk/azcore"
16 | "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
17 | "github.com/aws/aws-sdk-go-v2/aws"
18 | "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
19 | "github.com/aws/aws-sdk-go-v2/service/dynamodb"
20 | "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
21 | "github.com/bradfitz/gomemcache/memcache"
22 | "github.com/pkg/errors"
23 | "github.com/redis/go-redis/v9"
24 | "go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
25 | clientv3 "go.etcd.io/etcd/client/v3"
26 | )
27 |
28 | // LeakyBucketState represents the state of a LeakyBucket.
29 | type LeakyBucketState struct {
30 | // Last is the Unix timestamp in nanoseconds of the most recent request.
31 | Last int64
32 | }
33 |
34 | // IzZero returns true if the bucket state is zero valued.
35 | func (s LeakyBucketState) IzZero() bool {
36 | return s.Last == 0
37 | }
38 |
39 | // LeakyBucketStateBackend interface encapsulates the logic of retrieving and persisting the state of a LeakyBucket.
40 | type LeakyBucketStateBackend interface {
41 | // State gets the current state of the LeakyBucket.
42 | State(ctx context.Context) (LeakyBucketState, error)
43 | // SetState sets (persists) the current state of the LeakyBucket.
44 | SetState(ctx context.Context, state LeakyBucketState) error
45 | // Reset resets (persists) the current state of the LeakyBucket.
46 | Reset(ctx context.Context) error
47 | }
48 |
49 | // LeakyBucket implements the https://en.wikipedia.org/wiki/Leaky_bucket#As_a_queue algorithm.
50 | type LeakyBucket struct {
51 | locker DistLocker
52 | backend LeakyBucketStateBackend
53 | clock Clock
54 | logger Logger
55 | // Capacity is the maximum allowed number of tokens in the bucket.
56 | capacity int64
57 | // Rate is the output rate: 1 request per the rate duration (in nanoseconds).
58 | rate int64
59 | mu sync.Mutex
60 | }
61 |
62 | // NewLeakyBucket creates a new instance of LeakyBucket.
63 | func NewLeakyBucket(capacity int64, rate time.Duration, locker DistLocker, leakyBucketStateBackend LeakyBucketStateBackend, clock Clock, logger Logger) *LeakyBucket {
64 | return &LeakyBucket{
65 | locker: locker,
66 | backend: leakyBucketStateBackend,
67 | clock: clock,
68 | logger: logger,
69 | capacity: capacity,
70 | rate: int64(rate),
71 | }
72 | }
73 |
74 | // Limit returns the time duration to wait before the request can be processed.
75 | // It returns ErrLimitExhausted if the request overflows the bucket's capacity. In this case the returned duration
76 | // means how long it would have taken to wait for the request to be processed if the bucket was not overflowed.
77 | func (t *LeakyBucket) Limit(ctx context.Context) (time.Duration, error) {
78 | t.mu.Lock()
79 | defer t.mu.Unlock()
80 |
81 | err := t.locker.Lock(ctx)
82 | if err != nil {
83 | return 0, err
84 | }
85 |
86 | defer func() {
87 | err := t.locker.Unlock(ctx)
88 | if err != nil {
89 | t.logger.Log(err)
90 | }
91 | }()
92 |
93 | state, err := t.backend.State(ctx)
94 | if err != nil {
95 | return 0, err
96 | }
97 |
98 | now := t.clock.Now().UnixNano()
99 | if now < state.Last {
100 | // The queue has requests in it: move the current request to the last position + 1.
101 | state.Last += t.rate
102 | } else {
103 | // The queue is empty.
104 | // The offset is the duration to wait in case the last request happened less than rate duration ago.
105 | var offset int64
106 |
107 | delta := now - state.Last
108 | if delta < t.rate {
109 | offset = t.rate - delta
110 | }
111 |
112 | state.Last = now + offset
113 | }
114 |
115 | wait := state.Last - now
116 | if wait/t.rate >= t.capacity {
117 | return time.Duration(wait), ErrLimitExhausted
118 | }
119 |
120 | err = t.backend.SetState(ctx, state)
121 | if err != nil {
122 | return 0, err
123 | }
124 |
125 | return time.Duration(wait), nil
126 | }
127 |
128 | // Reset resets the bucket.
129 | func (t *LeakyBucket) Reset(ctx context.Context) error {
130 | return t.backend.Reset(ctx)
131 | }
132 |
133 | // LeakyBucketInMemory is an in-memory implementation of LeakyBucketStateBackend.
134 | type LeakyBucketInMemory struct {
135 | state LeakyBucketState
136 | }
137 |
138 | // NewLeakyBucketInMemory creates a new instance of LeakyBucketInMemory.
139 | func NewLeakyBucketInMemory() *LeakyBucketInMemory {
140 | return &LeakyBucketInMemory{}
141 | }
142 |
143 | // State gets the current state of the bucket.
144 | func (l *LeakyBucketInMemory) State(ctx context.Context) (LeakyBucketState, error) {
145 | return l.state, ctx.Err()
146 | }
147 |
148 | // SetState sets the current state of the bucket.
149 | func (l *LeakyBucketInMemory) SetState(ctx context.Context, state LeakyBucketState) error {
150 | l.state = state
151 |
152 | return ctx.Err()
153 | }
154 |
155 | // Reset resets the current state of the bucket.
156 | func (l *LeakyBucketInMemory) Reset(ctx context.Context) error {
157 | state := LeakyBucketState{
158 | Last: 0,
159 | }
160 |
161 | return l.SetState(ctx, state)
162 | }
163 |
164 | const (
165 | etcdKeyLBLease = "lease"
166 | etcdKeyLBLast = "last"
167 | )
168 |
169 | // LeakyBucketEtcd is an etcd implementation of a LeakyBucketStateBackend.
170 | // See the TokenBucketEtcd description for the details on etcd usage.
171 | type LeakyBucketEtcd struct {
172 | // prefix is the etcd key prefix.
173 | prefix string
174 | cli *clientv3.Client
175 | leaseID clientv3.LeaseID
176 | ttl time.Duration
177 | raceCheck bool
178 | lastVersion int64
179 | }
180 |
181 | // NewLeakyBucketEtcd creates a new LeakyBucketEtcd instance.
182 | // Prefix is used as an etcd key prefix for all keys stored in etcd by this algorithm.
183 | // TTL is a TTL of the etcd lease used to store all the keys.
184 | //
185 | // If raceCheck is true and the keys in etcd are modified in between State() and SetState() calls then
186 | // ErrRaceCondition is returned.
187 | func NewLeakyBucketEtcd(cli *clientv3.Client, prefix string, ttl time.Duration, raceCheck bool) *LeakyBucketEtcd {
188 | return &LeakyBucketEtcd{
189 | prefix: prefix,
190 | cli: cli,
191 | ttl: ttl,
192 | raceCheck: raceCheck,
193 | }
194 | }
195 |
196 | // State gets the bucket's current state from etcd.
197 | // If there is no state available in etcd then the initial bucket's state is returned.
198 | func (l *LeakyBucketEtcd) State(ctx context.Context) (LeakyBucketState, error) {
199 | // Reset the lease ID as it will be updated by the successful Get operation below.
200 | l.leaseID = 0
201 | // Get all the keys under the prefix in a single request.
202 | r, err := l.cli.Get(ctx, l.prefix, clientv3.WithRange(incPrefix(l.prefix)))
203 | if err != nil {
204 | return LeakyBucketState{}, errors.Wrapf(err, "failed to get keys in range ['%s', '%s') from etcd", l.prefix, incPrefix(l.prefix))
205 | }
206 |
207 | if len(r.Kvs) == 0 {
208 | return LeakyBucketState{}, nil
209 | }
210 |
211 | state := LeakyBucketState{}
212 |
213 | parsed := 0
214 | if l.ttl == 0 {
215 | // Ignore lease when there is no expiration
216 | parsed |= 2
217 | }
218 |
219 | var v int64
220 |
221 | for _, kv := range r.Kvs {
222 | switch string(kv.Key) {
223 | case etcdKey(l.prefix, etcdKeyLBLast):
224 | v, err = parseEtcdInt64(kv)
225 | if err != nil {
226 | return LeakyBucketState{}, err
227 | }
228 |
229 | state.Last = v
230 | parsed |= 1
231 | l.lastVersion = kv.Version
232 |
233 | case etcdKey(l.prefix, etcdKeyLBLease):
234 | v, err = parseEtcdInt64(kv)
235 | if err != nil {
236 | return LeakyBucketState{}, err
237 | }
238 |
239 | l.leaseID = clientv3.LeaseID(v)
240 | parsed |= 2
241 | }
242 | }
243 |
244 | if parsed != 3 {
245 | return LeakyBucketState{}, errors.New("failed to get state from etcd: some keys are missing")
246 | }
247 |
248 | return state, nil
249 | }
250 |
251 | // createLease creates a new lease in etcd and updates the t.leaseID value.
252 | func (l *LeakyBucketEtcd) createLease(ctx context.Context) error {
253 | lease, err := l.cli.Grant(ctx, int64(math.Ceil(l.ttl.Seconds())))
254 | if err != nil {
255 | return errors.Wrap(err, "failed to create a new lease in etcd")
256 | }
257 |
258 | l.leaseID = lease.ID
259 |
260 | return nil
261 | }
262 |
263 | // save saves the state to etcd using the existing lease.
264 | func (l *LeakyBucketEtcd) save(ctx context.Context, state LeakyBucketState) error {
265 | var opts []clientv3.OpOption
266 | if l.ttl > 0 {
267 | opts = append(opts, clientv3.WithLease(l.leaseID))
268 | }
269 |
270 | ops := []clientv3.Op{
271 | clientv3.OpPut(etcdKey(l.prefix, etcdKeyLBLast), fmt.Sprintf("%d", state.Last), opts...),
272 | }
273 | if l.ttl > 0 {
274 | ops = append(ops, clientv3.OpPut(etcdKey(l.prefix, etcdKeyLBLease), fmt.Sprintf("%d", l.leaseID), opts...))
275 | }
276 |
277 | if !l.raceCheck {
278 | _, err := l.cli.Txn(ctx).Then(ops...).Commit()
279 | if err != nil {
280 | return errors.Wrap(err, "failed to commit a transaction to etcd")
281 | }
282 |
283 | return nil
284 | }
285 | // Put the keys only if they have not been modified since the most recent read.
286 | r, err := l.cli.Txn(ctx).If(
287 | clientv3.Compare(clientv3.Version(etcdKey(l.prefix, etcdKeyLBLast)), ">", l.lastVersion),
288 | ).Else(ops...).Commit()
289 | if err != nil {
290 | return errors.Wrap(err, "failed to commit a transaction to etcd")
291 | }
292 |
293 | if !r.Succeeded {
294 | return nil
295 | }
296 |
297 | return ErrRaceCondition
298 | }
299 |
300 | // SetState updates the state of the bucket in etcd.
301 | func (l *LeakyBucketEtcd) SetState(ctx context.Context, state LeakyBucketState) error {
302 | if l.ttl == 0 {
303 | // Avoid maintaining the lease when it has no TTL
304 | return l.save(ctx, state)
305 | }
306 |
307 | if l.leaseID == 0 {
308 | // Lease does not exist, create one.
309 | err := l.createLease(ctx)
310 | if err != nil {
311 | return err
312 | }
313 | // No need to send KeepAlive for the newly creates lease: save the state immediately.
314 | return l.save(ctx, state)
315 | }
316 | // Send the KeepAlive request to extend the existing lease.
317 | _, err := l.cli.KeepAliveOnce(ctx, l.leaseID)
318 | if errors.Is(err, rpctypes.ErrLeaseNotFound) {
319 | // Create a new lease since the current one has expired.
320 | err = l.createLease(ctx)
321 | if err != nil {
322 | return err
323 | }
324 | } else if err != nil {
325 | return errors.Wrapf(err, "failed to extend the lease '%d'", l.leaseID)
326 | }
327 |
328 | return l.save(ctx, state)
329 | }
330 |
331 | // Reset resets the state of the bucket in etcd.
332 | func (l *LeakyBucketEtcd) Reset(ctx context.Context) error {
333 | state := LeakyBucketState{
334 | Last: 0,
335 | }
336 |
337 | return l.SetState(ctx, state)
338 | }
339 |
340 | // Deprecated: These legacy keys will be removed in a future version.
341 | // The state is now stored in a single JSON document under the "state" key.
342 | const (
343 | redisKeyLBLast = "last"
344 | redisKeyLBVersion = "version"
345 | )
346 |
347 | // LeakyBucketRedis is a Redis implementation of a LeakyBucketStateBackend.
348 | type LeakyBucketRedis struct {
349 | cli redis.UniversalClient
350 | prefix string
351 | ttl time.Duration
352 | raceCheck bool
353 | lastVersion int64
354 | }
355 |
356 | // NewLeakyBucketRedis creates a new LeakyBucketRedis instance.
357 | // Prefix is the key prefix used to store all the keys used in this implementation in Redis.
358 | // TTL is the TTL of the stored keys.
359 | //
360 | // If raceCheck is true and the keys in Redis are modified in between State() and SetState() calls then
361 | // ErrRaceCondition is returned.
362 | func NewLeakyBucketRedis(cli redis.UniversalClient, prefix string, ttl time.Duration, raceCheck bool) *LeakyBucketRedis {
363 | return &LeakyBucketRedis{cli: cli, prefix: prefix, ttl: ttl, raceCheck: raceCheck}
364 | }
365 |
366 | // Deprecated: Legacy format support will be removed in a future version.
367 | func (t *LeakyBucketRedis) oldState(ctx context.Context) (LeakyBucketState, error) {
368 | var (
369 | values []interface{}
370 | err error
371 | )
372 |
373 | done := make(chan struct{}, 1)
374 |
375 | go func() {
376 | defer close(done)
377 |
378 | keys := []string{
379 | redisKey(t.prefix, redisKeyLBLast),
380 | }
381 | if t.raceCheck {
382 | keys = append(keys, redisKey(t.prefix, redisKeyLBVersion))
383 | }
384 |
385 | values, err = t.cli.MGet(ctx, keys...).Result()
386 | }()
387 |
388 | select {
389 | case <-done:
390 |
391 | case <-ctx.Done():
392 | return LeakyBucketState{}, ctx.Err()
393 | }
394 |
395 | if err != nil {
396 | return LeakyBucketState{}, errors.Wrap(err, "failed to get keys from redis")
397 | }
398 |
399 | nilAny := false
400 |
401 | for _, v := range values {
402 | if v == nil {
403 | nilAny = true
404 |
405 | break
406 | }
407 | }
408 |
409 | if nilAny || errors.Is(err, redis.Nil) {
410 | // Keys don't exist, return an empty state.
411 | return LeakyBucketState{}, nil
412 | }
413 |
414 | last, err := strconv.ParseInt(values[0].(string), 10, 64)
415 | if err != nil {
416 | return LeakyBucketState{}, err
417 | }
418 |
419 | if t.raceCheck {
420 | t.lastVersion, err = strconv.ParseInt(values[1].(string), 10, 64)
421 | if err != nil {
422 | return LeakyBucketState{}, err
423 | }
424 | }
425 |
426 | return LeakyBucketState{
427 | Last: last,
428 | }, nil
429 | }
430 |
431 | // State gets the bucket's state from Redis.
432 | func (t *LeakyBucketRedis) State(ctx context.Context) (LeakyBucketState, error) {
433 | var err error
434 |
435 | done := make(chan struct{}, 1)
436 | errCh := make(chan error, 1)
437 |
438 | var state LeakyBucketState
439 |
440 | if t.raceCheck {
441 | // reset in a case of returning an empty LeakyBucketState
442 | t.lastVersion = 0
443 | }
444 |
445 | go func() {
446 | defer close(done)
447 |
448 | key := redisKey(t.prefix, "state")
449 |
450 | value, err := t.cli.Get(ctx, key).Result()
451 | if err != nil && !errors.Is(err, redis.Nil) {
452 | errCh <- err
453 |
454 | return
455 | }
456 |
457 | if errors.Is(err, redis.Nil) {
458 | state, err = t.oldState(ctx)
459 | errCh <- err
460 |
461 | return
462 | }
463 |
464 | // Try new format
465 | var item struct {
466 | State LeakyBucketState `json:"state"`
467 | Version int64 `json:"version"`
468 | }
469 |
470 | err = json.Unmarshal([]byte(value), &item)
471 | if err != nil {
472 | errCh <- err
473 |
474 | return
475 | }
476 |
477 | state = item.State
478 | if t.raceCheck {
479 | t.lastVersion = item.Version
480 | }
481 |
482 | errCh <- nil
483 | }()
484 |
485 | select {
486 | case <-done:
487 | err = <-errCh
488 | case <-ctx.Done():
489 | return LeakyBucketState{}, ctx.Err()
490 | }
491 |
492 | if err != nil {
493 | return LeakyBucketState{}, errors.Wrap(err, "failed to get state from redis")
494 | }
495 |
496 | return state, nil
497 | }
498 |
499 | // SetState updates the state in Redis.
500 | func (t *LeakyBucketRedis) SetState(ctx context.Context, state LeakyBucketState) error {
501 | var err error
502 |
503 | done := make(chan struct{}, 1)
504 | errCh := make(chan error, 1)
505 |
506 | go func() {
507 | defer close(done)
508 |
509 | key := redisKey(t.prefix, "state")
510 | item := struct {
511 | State LeakyBucketState `json:"state"`
512 | Version int64 `json:"version"`
513 | }{
514 | State: state,
515 | Version: t.lastVersion + 1,
516 | }
517 |
518 | value, err := json.Marshal(item)
519 | if err != nil {
520 | errCh <- err
521 |
522 | return
523 | }
524 |
525 | if !t.raceCheck {
526 | errCh <- t.cli.Set(ctx, key, value, t.ttl).Err()
527 |
528 | return
529 | }
530 |
531 | callScript := `redis.call('set', KEYS[1], ARGV[1], 'PX', ARGV[3])`
532 | if t.ttl == 0 {
533 | callScript = `redis.call('set', KEYS[1], ARGV[1])`
534 | }
535 |
536 | script := fmt.Sprintf(`
537 | local current = redis.call('get', KEYS[1])
538 | if current then
539 | local data = cjson.decode(current)
540 | if data.version > tonumber(ARGV[2]) then
541 | return 'RACE_CONDITION'
542 | end
543 | end
544 | %s
545 | return 'OK'
546 | `, callScript)
547 |
548 | result, err := t.cli.Eval(ctx, script, []string{key}, value, t.lastVersion, int64(t.ttl/time.Millisecond)).Result()
549 | if err != nil {
550 | errCh <- err
551 |
552 | return
553 | }
554 |
555 | if result == "RACE_CONDITION" {
556 | errCh <- ErrRaceCondition
557 |
558 | return
559 | }
560 |
561 | errCh <- nil
562 | }()
563 |
564 | select {
565 | case <-done:
566 | err = <-errCh
567 | case <-ctx.Done():
568 | return ctx.Err()
569 | }
570 |
571 | if err != nil {
572 | return errors.Wrap(err, "failed to save state to redis")
573 | }
574 |
575 | return nil
576 | }
577 |
578 | // Reset resets the state in Redis.
579 | func (t *LeakyBucketRedis) Reset(ctx context.Context) error {
580 | state := LeakyBucketState{
581 | Last: 0,
582 | }
583 |
584 | return t.SetState(ctx, state)
585 | }
586 |
587 | // LeakyBucketMemcached is a Memcached implementation of a LeakyBucketStateBackend.
588 | type LeakyBucketMemcached struct {
589 | cli *memcache.Client
590 | key string
591 | ttl time.Duration
592 | raceCheck bool
593 | casId uint64
594 | }
595 |
596 | // NewLeakyBucketMemcached creates a new LeakyBucketMemcached instance.
597 | // Key is the key used to store all the keys used in this implementation in Memcached.
598 | // TTL is the TTL of the stored keys.
599 | //
600 | // If raceCheck is true and the keys in Memcached are modified in between State() and SetState() calls then
601 | // ErrRaceCondition is returned.
602 | func NewLeakyBucketMemcached(cli *memcache.Client, key string, ttl time.Duration, raceCheck bool) *LeakyBucketMemcached {
603 | return &LeakyBucketMemcached{cli: cli, key: key, ttl: ttl, raceCheck: raceCheck}
604 | }
605 |
606 | // State gets the bucket's state from Memcached.
607 | func (t *LeakyBucketMemcached) State(ctx context.Context) (LeakyBucketState, error) {
608 | var (
609 | item *memcache.Item
610 | err error
611 | state LeakyBucketState
612 | )
613 |
614 | done := make(chan struct{}, 1)
615 | t.casId = 0
616 |
617 | go func() {
618 | defer close(done)
619 |
620 | item, err = t.cli.Get(t.key)
621 | }()
622 |
623 | select {
624 | case <-done:
625 |
626 | case <-ctx.Done():
627 | return state, ctx.Err()
628 | }
629 |
630 | if err != nil {
631 | if errors.Is(err, memcache.ErrCacheMiss) {
632 | // Keys don't exist, return an empty state.
633 | return state, nil
634 | }
635 |
636 | return state, errors.Wrap(err, "failed to get keys from memcached")
637 | }
638 |
639 | b := bytes.NewBuffer(item.Value)
640 |
641 | err = gob.NewDecoder(b).Decode(&state)
642 | if err != nil {
643 | return state, errors.Wrap(err, "failed to Decode")
644 | }
645 |
646 | t.casId = item.CasID
647 |
648 | return state, nil
649 | }
650 |
651 | // SetState updates the state in Memcached.
652 | // The provided fencing token is checked on the Memcached side before saving the keys.
653 | func (t *LeakyBucketMemcached) SetState(ctx context.Context, state LeakyBucketState) error {
654 | var err error
655 |
656 | done := make(chan struct{}, 1)
657 |
658 | var b bytes.Buffer
659 |
660 | err = gob.NewEncoder(&b).Encode(state)
661 | if err != nil {
662 | return errors.Wrap(err, "failed to Encode")
663 | }
664 |
665 | go func() {
666 | defer close(done)
667 |
668 | item := &memcache.Item{
669 | Key: t.key,
670 | Value: b.Bytes(),
671 | CasID: t.casId,
672 | }
673 | if t.ttl > 30*24*time.Hour {
674 | // If the value is over 30 days, it treats it as UNIX timestamp.
675 | item.Expiration = int32(time.Now().Add(t.ttl).Unix())
676 | } else if t.ttl > 0 {
677 | // Memcached supports expiration in seconds. It's more precise way.
678 | item.Expiration = int32(math.Ceil(t.ttl.Seconds()))
679 | }
680 |
681 | if t.raceCheck && t.casId > 0 {
682 | err = t.cli.CompareAndSwap(item)
683 | } else {
684 | err = t.cli.Set(item)
685 | }
686 | }()
687 |
688 | select {
689 | case <-done:
690 |
691 | case <-ctx.Done():
692 | return ctx.Err()
693 | }
694 |
695 | if err != nil && (errors.Is(err, memcache.ErrCASConflict) || errors.Is(err, memcache.ErrNotStored) || errors.Is(err, memcache.ErrCacheMiss)) {
696 | return ErrRaceCondition
697 | }
698 |
699 | return errors.Wrap(err, "failed to save keys to memcached")
700 | }
701 |
702 | // Reset resets the state in Memcached.
703 | func (t *LeakyBucketMemcached) Reset(ctx context.Context) error {
704 | state := LeakyBucketState{
705 | Last: 0,
706 | }
707 | t.casId = 0
708 |
709 | return t.SetState(ctx, state)
710 | }
711 |
712 | // LeakyBucketDynamoDB is a DyanamoDB implementation of a LeakyBucketStateBackend.
713 | type LeakyBucketDynamoDB struct {
714 | client *dynamodb.Client
715 | tableProps DynamoDBTableProperties
716 | partitionKey string
717 | ttl time.Duration
718 | raceCheck bool
719 | latestVersion int64
720 | keys map[string]types.AttributeValue
721 | }
722 |
723 | // NewLeakyBucketDynamoDB creates a new LeakyBucketDynamoDB instance.
724 | // PartitionKey is the key used to store all the this implementation in DynamoDB.
725 | //
726 | // TableProps describe the table that this backend should work with. This backend requires the following on the table:
727 | // * TTL
728 | //
729 | // TTL is the TTL of the stored item.
730 | //
731 | // If raceCheck is true and the item in DynamoDB are modified in between State() and SetState() calls then
732 | // ErrRaceCondition is returned.
733 | func NewLeakyBucketDynamoDB(client *dynamodb.Client, partitionKey string, tableProps DynamoDBTableProperties, ttl time.Duration, raceCheck bool) *LeakyBucketDynamoDB {
734 | keys := map[string]types.AttributeValue{
735 | tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: partitionKey},
736 | }
737 |
738 | if tableProps.SortKeyUsed {
739 | keys[tableProps.SortKeyName] = &types.AttributeValueMemberS{Value: partitionKey}
740 | }
741 |
742 | return &LeakyBucketDynamoDB{
743 | client: client,
744 | partitionKey: partitionKey,
745 | tableProps: tableProps,
746 | ttl: ttl,
747 | raceCheck: raceCheck,
748 | keys: keys,
749 | }
750 | }
751 |
752 | // State gets the bucket's state from DynamoDB.
753 | func (t *LeakyBucketDynamoDB) State(ctx context.Context) (LeakyBucketState, error) {
754 | resp, err := dynamoDBGetItem(ctx, t.client, t.getGetItemInput())
755 | if err != nil {
756 | return LeakyBucketState{}, err
757 | }
758 |
759 | return t.loadStateFromDynamoDB(resp)
760 | }
761 |
762 | // SetState updates the state in DynamoDB.
763 | func (t *LeakyBucketDynamoDB) SetState(ctx context.Context, state LeakyBucketState) error {
764 | input := t.getPutItemInputFromState(state)
765 |
766 | var err error
767 |
768 | done := make(chan struct{})
769 |
770 | go func() {
771 | defer close(done)
772 |
773 | _, err = dynamoDBputItem(ctx, t.client, input)
774 | }()
775 |
776 | select {
777 | case <-done:
778 | case <-ctx.Done():
779 | return ctx.Err()
780 | }
781 |
782 | return err
783 | }
784 |
785 | // Reset resets the state in DynamoDB.
786 | func (t *LeakyBucketDynamoDB) Reset(ctx context.Context) error {
787 | state := LeakyBucketState{
788 | Last: 0,
789 | }
790 |
791 | return t.SetState(ctx, state)
792 | }
793 |
794 | const (
795 | dynamodbBucketRaceConditionExpression = "Version <= :version"
796 | dynamoDBBucketLastKey = "Last"
797 | dynamoDBBucketVersionKey = "Version"
798 | )
799 |
800 | func (t *LeakyBucketDynamoDB) getPutItemInputFromState(state LeakyBucketState) *dynamodb.PutItemInput {
801 | item := map[string]types.AttributeValue{}
802 | for k, v := range t.keys {
803 | item[k] = v
804 | }
805 |
806 | item[dynamoDBBucketLastKey] = &types.AttributeValueMemberN{Value: strconv.FormatInt(state.Last, 10)}
807 |
808 | item[dynamoDBBucketVersionKey] = &types.AttributeValueMemberN{Value: strconv.FormatInt(t.latestVersion+1, 10)}
809 | if t.ttl > 0 {
810 | item[t.tableProps.TTLFieldName] = &types.AttributeValueMemberN{Value: strconv.FormatInt(time.Now().Add(t.ttl).Unix(), 10)}
811 | }
812 |
813 | input := &dynamodb.PutItemInput{
814 | TableName: &t.tableProps.TableName,
815 | Item: item,
816 | }
817 |
818 | if t.raceCheck && t.latestVersion > 0 {
819 | input.ConditionExpression = aws.String(dynamodbBucketRaceConditionExpression)
820 | input.ExpressionAttributeValues = map[string]types.AttributeValue{
821 | ":version": &types.AttributeValueMemberN{Value: strconv.FormatInt(t.latestVersion, 10)},
822 | }
823 | }
824 |
825 | return input
826 | }
827 |
828 | func (t *LeakyBucketDynamoDB) getGetItemInput() *dynamodb.GetItemInput {
829 | return &dynamodb.GetItemInput{
830 | TableName: &t.tableProps.TableName,
831 | Key: t.keys,
832 | }
833 | }
834 |
835 | func (t *LeakyBucketDynamoDB) loadStateFromDynamoDB(resp *dynamodb.GetItemOutput) (LeakyBucketState, error) {
836 | state := LeakyBucketState{}
837 |
838 | err := attributevalue.Unmarshal(resp.Item[dynamoDBBucketLastKey], &state.Last)
839 | if err != nil {
840 | return state, fmt.Errorf("unmarshal dynamodb Last attribute failed: %w", err)
841 | }
842 |
843 | if t.raceCheck {
844 | err = attributevalue.Unmarshal(resp.Item[dynamoDBBucketVersionKey], &t.latestVersion)
845 | if err != nil {
846 | return state, fmt.Errorf("unmarshal dynamodb Version attribute failed: %w", err)
847 | }
848 | }
849 |
850 | return state, nil
851 | }
852 |
853 | // CosmosDBLeakyBucketItem represents a document in CosmosDB for LeakyBucket.
854 | type CosmosDBLeakyBucketItem struct {
855 | ID string `json:"id"`
856 | PartitionKey string `json:"partitionKey"`
857 | State LeakyBucketState `json:"state"`
858 | Version int64 `json:"version"`
859 | TTL int64 `json:"ttl,omitempty"`
860 | }
861 |
862 | // LeakyBucketCosmosDB is a CosmosDB implementation of a LeakyBucketStateBackend.
863 | type LeakyBucketCosmosDB struct {
864 | client *azcosmos.ContainerClient
865 | partitionKey string
866 | id string
867 | ttl time.Duration
868 | raceCheck bool
869 | latestVersion int64
870 | }
871 |
872 | // NewLeakyBucketCosmosDB creates a new LeakyBucketCosmosDB instance.
873 | // PartitionKey is the key used to store all the implementation in CosmosDB.
874 | // TTL is the TTL of the stored item.
875 | //
876 | // If raceCheck is true and the item in CosmosDB is modified in between State() and SetState() calls then
877 | // ErrRaceCondition is returned.
878 | func NewLeakyBucketCosmosDB(client *azcosmos.ContainerClient, partitionKey string, ttl time.Duration, raceCheck bool) *LeakyBucketCosmosDB {
879 | return &LeakyBucketCosmosDB{
880 | client: client,
881 | partitionKey: partitionKey,
882 | id: "leaky-bucket-" + partitionKey,
883 | ttl: ttl,
884 | raceCheck: raceCheck,
885 | }
886 | }
887 |
888 | func (t *LeakyBucketCosmosDB) State(ctx context.Context) (LeakyBucketState, error) {
889 | var item CosmosDBLeakyBucketItem
890 |
891 | resp, err := t.client.ReadItem(ctx, azcosmos.NewPartitionKey().AppendString(t.partitionKey), t.id, &azcosmos.ItemOptions{})
892 | if err != nil {
893 | var respErr *azcore.ResponseError
894 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound {
895 | return LeakyBucketState{}, nil
896 | }
897 |
898 | return LeakyBucketState{}, err
899 | }
900 |
901 | err = json.Unmarshal(resp.Value, &item)
902 | if err != nil {
903 | return LeakyBucketState{}, errors.Wrap(err, "failed to decode state from Cosmos DB")
904 | }
905 |
906 | if t.raceCheck {
907 | t.latestVersion = item.Version
908 | }
909 |
910 | return item.State, nil
911 | }
912 |
913 | func (t *LeakyBucketCosmosDB) SetState(ctx context.Context, state LeakyBucketState) error {
914 | var err error
915 |
916 | done := make(chan struct{}, 1)
917 |
918 | item := CosmosDBLeakyBucketItem{
919 | ID: t.id,
920 | PartitionKey: t.partitionKey,
921 | State: state,
922 | Version: t.latestVersion + 1,
923 | }
924 | if t.ttl > 0 {
925 | item.TTL = int64(math.Ceil(t.ttl.Seconds()))
926 | }
927 |
928 | value, err := json.Marshal(item)
929 | if err != nil {
930 | return errors.Wrap(err, "failed to encode state to JSON")
931 | }
932 |
933 | go func() {
934 | defer close(done)
935 |
936 | _, err = t.client.UpsertItem(ctx, azcosmos.NewPartitionKey().AppendString(t.partitionKey), value, &azcosmos.ItemOptions{})
937 | }()
938 |
939 | select {
940 | case <-done:
941 | case <-ctx.Done():
942 | return ctx.Err()
943 | }
944 |
945 | if err != nil {
946 | var respErr *azcore.ResponseError
947 | if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict && t.raceCheck {
948 | return ErrRaceCondition
949 | }
950 |
951 | return errors.Wrap(err, "failed to save keys to Cosmos DB")
952 | }
953 |
954 | return nil
955 | }
956 |
957 | func (t *LeakyBucketCosmosDB) Reset(ctx context.Context) error {
958 | state := LeakyBucketState{
959 | Last: 0,
960 | }
961 |
962 | return t.SetState(ctx, state)
963 | }
964 |
--------------------------------------------------------------------------------