├── .github └── FUNDING.yml ├── .prettierrc ├── .travis.yml ├── LICENSE ├── Makefile ├── README.md ├── example_test.go ├── go.mod ├── go.sum ├── lua.go ├── rate.go ├── rate_test.go └── renovate.json /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | custom: ['https://uptrace.dev'] 2 | -------------------------------------------------------------------------------- /.prettierrc: -------------------------------------------------------------------------------- 1 | semi: false 2 | singleQuote: true 3 | proseWrap: always 4 | printWidth: 100 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: xenial 2 | language: go 3 | 4 | services: 5 | - redis-server 6 | 7 | go: 8 | - 1.12.x 9 | - 1.13.x 10 | - 1.14.x 11 | - tip 12 | 13 | matrix: 14 | allow_failures: 15 | - go: tip 16 | 17 | env: 18 | - GO111MODULE=on 19 | 20 | go_import_path: github.com/go-redis/redis_rate 21 | 22 | before_install: 23 | - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go 24 | env GOPATH)/bin v1.21.0 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013 The github.com/go-redis/redis_rate Authors. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above 11 | copyright notice, this list of conditions and the following disclaimer 12 | in the documentation and/or other materials provided with the 13 | distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 19 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 20 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 21 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 22 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 23 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | go test ./... 3 | go test ./... -short -race 4 | go test ./... -run=NONE -bench=. -benchmem 5 | env GOOS=linux GOARCH=386 go test ./... 6 | golangci-lint run 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rate limiting for go-redis 2 | 3 | [![Build Status](https://travis-ci.org/go-redis/redis_rate.svg?branch=master)](https://travis-ci.org/go-redis/redis_rate) 4 | [![PkgGoDev](https://pkg.go.dev/badge/github.com/redis/go-redis/v8)](https://pkg.go.dev/github.com/go-redis/redis_rate/v9) 5 | 6 | > :heart: [**Uptrace.dev** - distributed traces, logs, and errors in one place](https://uptrace.dev) 7 | 8 | This package is based on [rwz/redis-gcra](https://github.com/rwz/redis-gcra) and implements 9 | [GCRA](https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm) (aka leaky bucket) for rate 10 | limiting based on Redis. The code requires Redis version 3.2 or newer since it relies on 11 | [replicate_commands](https://redis.io/commands/eval#replicating-commands-instead-of-scripts) 12 | feature. 13 | 14 | ## Installation 15 | 16 | redis_rate supports 2 last Go versions and requires a Go version with 17 | [modules](https://github.com/golang/go/wiki/Modules) support. So make sure to initialize a Go 18 | module: 19 | 20 | ```shell 21 | go mod init github.com/my/repo 22 | ``` 23 | 24 | And then install redis*rate/v10 (note \*\*\_v10*\*\* in the import; omitting it is a popular 25 | mistake): 26 | 27 | ```shell 28 | go get github.com/go-redis/redis_rate/v10 29 | ``` 30 | 31 | ## Example 32 | 33 | ```go 34 | package redis_rate_test 35 | 36 | import ( 37 | "context" 38 | "fmt" 39 | 40 | "github.com/redis/go-redis/v9" 41 | "github.com/go-redis/redis_rate/v10" 42 | ) 43 | 44 | func ExampleNewLimiter() { 45 | ctx := context.Background() 46 | rdb := redis.NewClient(&redis.Options{ 47 | Addr: "localhost:6379", 48 | }) 49 | _ = rdb.FlushDB(ctx).Err() 50 | 51 | limiter := redis_rate.NewLimiter(rdb) 52 | res, err := limiter.Allow(ctx, "project:123", redis_rate.PerSecond(10)) 53 | if err != nil { 54 | panic(err) 55 | } 56 | fmt.Println("allowed", res.Allowed, "remaining", res.Remaining) 57 | // Output: allowed 1 remaining 9 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package redis_rate_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/redis/go-redis/v9" 8 | 9 | "github.com/go-redis/redis_rate/v10" 10 | ) 11 | 12 | func ExampleNewLimiter() { 13 | ctx := context.Background() 14 | rdb := redis.NewClient(&redis.Options{ 15 | Addr: "localhost:6379", 16 | }) 17 | _ = rdb.FlushDB(ctx).Err() 18 | 19 | limiter := redis_rate.NewLimiter(rdb) 20 | res, err := limiter.Allow(ctx, "project:123", redis_rate.PerSecond(10)) 21 | if err != nil { 22 | panic(err) 23 | } 24 | fmt.Println("allowed", res.Allowed, "remaining", res.Remaining) 25 | // Output: allowed 1 remaining 9 26 | } 27 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-redis/redis_rate/v10 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/redis/go-redis/v9 v9.0.2 7 | github.com/stretchr/testify v1.8.1 8 | ) 9 | 10 | require ( 11 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 12 | github.com/davecgh/go-spew v1.1.1 // indirect 13 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 14 | github.com/pmezard/go-difflib v1.0.0 // indirect 15 | gopkg.in/yaml.v3 v3.0.1 // indirect 16 | ) 17 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= 2 | github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= 7 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= 8 | github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= 9 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 10 | github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= 11 | github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= 12 | github.com/onsi/gomega v1.25.0 h1:Vw7br2PCDYijJHSfBOWhov+8cAnUf8MfMaIOV323l6Y= 13 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 14 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 15 | github.com/redis/go-redis/v9 v9.0.0-rc.4 h1:JUhsiZMTZknz3vn50zSVlkwcSeTGPd51lMO3IKUrWpY= 16 | github.com/redis/go-redis/v9 v9.0.0-rc.4/go.mod h1:Vo3EsyWnicKnSKCA7HhgnvnyA74wOA69Cd2Meli5mmA= 17 | github.com/redis/go-redis/v9 v9.0.2 h1:BA426Zqe/7r56kCcvxYLWe1mkaz71LKF77GwgFzSxfE= 18 | github.com/redis/go-redis/v9 v9.0.2/go.mod h1:/xDTe9EF1LM61hek62Poq2nzQSGj0xSrEtEHbBQevps= 19 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 20 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 21 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 22 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 23 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 24 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 25 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 26 | golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= 27 | golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= 28 | golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= 29 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 30 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 31 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= 32 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 33 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 34 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 35 | -------------------------------------------------------------------------------- /lua.go: -------------------------------------------------------------------------------- 1 | package redis_rate 2 | 3 | import "github.com/redis/go-redis/v9" 4 | 5 | // Copyright (c) 2017 Pavel Pravosud 6 | // https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua 7 | var allowN = redis.NewScript(` 8 | -- this script has side-effects, so it requires replicate commands mode 9 | redis.replicate_commands() 10 | 11 | local rate_limit_key = KEYS[1] 12 | local burst = ARGV[1] 13 | local rate = ARGV[2] 14 | local period = ARGV[3] 15 | local cost = tonumber(ARGV[4]) 16 | 17 | local emission_interval = period / rate 18 | local increment = emission_interval * cost 19 | local burst_offset = emission_interval * burst 20 | 21 | -- redis returns time as an array containing two integers: seconds of the epoch 22 | -- time (10 digits) and microseconds (6 digits). for convenience we need to 23 | -- convert them to a floating point number. the resulting number is 16 digits, 24 | -- bordering on the limits of a 64-bit double-precision floating point number. 25 | -- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating 26 | -- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09 27 | -- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits. 28 | local jan_1_2017 = 1483228800 29 | local now = redis.call("TIME") 30 | now = (now[1] - jan_1_2017) + (now[2] / 1000000) 31 | 32 | local tat = redis.call("GET", rate_limit_key) 33 | 34 | if not tat then 35 | tat = now 36 | else 37 | tat = tonumber(tat) 38 | end 39 | 40 | tat = math.max(tat, now) 41 | 42 | local new_tat = tat + increment 43 | local allow_at = new_tat - burst_offset 44 | 45 | local diff = now - allow_at 46 | local remaining = diff / emission_interval 47 | 48 | if remaining < 0 then 49 | local reset_after = tat - now 50 | local retry_after = diff * -1 51 | return { 52 | 0, -- allowed 53 | 0, -- remaining 54 | tostring(retry_after), 55 | tostring(reset_after), 56 | } 57 | end 58 | 59 | local reset_after = new_tat - now 60 | if reset_after > 0 then 61 | redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after)) 62 | end 63 | local retry_after = -1 64 | return {cost, remaining, tostring(retry_after), tostring(reset_after)} 65 | `) 66 | 67 | var allowAtMost = redis.NewScript(` 68 | -- this script has side-effects, so it requires replicate commands mode 69 | redis.replicate_commands() 70 | 71 | local rate_limit_key = KEYS[1] 72 | local burst = ARGV[1] 73 | local rate = ARGV[2] 74 | local period = ARGV[3] 75 | local cost = tonumber(ARGV[4]) 76 | 77 | local emission_interval = period / rate 78 | local burst_offset = emission_interval * burst 79 | 80 | -- redis returns time as an array containing two integers: seconds of the epoch 81 | -- time (10 digits) and microseconds (6 digits). for convenience we need to 82 | -- convert them to a floating point number. the resulting number is 16 digits, 83 | -- bordering on the limits of a 64-bit double-precision floating point number. 84 | -- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating 85 | -- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09 86 | -- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits. 87 | local jan_1_2017 = 1483228800 88 | local now = redis.call("TIME") 89 | now = (now[1] - jan_1_2017) + (now[2] / 1000000) 90 | 91 | local tat = redis.call("GET", rate_limit_key) 92 | 93 | if not tat then 94 | tat = now 95 | else 96 | tat = tonumber(tat) 97 | end 98 | 99 | tat = math.max(tat, now) 100 | 101 | local diff = now - (tat - burst_offset) 102 | local remaining = diff / emission_interval 103 | 104 | if remaining < 1 then 105 | local reset_after = tat - now 106 | local retry_after = emission_interval - diff 107 | return { 108 | 0, -- allowed 109 | 0, -- remaining 110 | tostring(retry_after), 111 | tostring(reset_after), 112 | } 113 | end 114 | 115 | if remaining < cost then 116 | cost = remaining 117 | remaining = 0 118 | else 119 | remaining = remaining - cost 120 | end 121 | 122 | local increment = emission_interval * cost 123 | local new_tat = tat + increment 124 | 125 | local reset_after = new_tat - now 126 | if reset_after > 0 then 127 | redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after)) 128 | end 129 | 130 | return { 131 | cost, 132 | remaining, 133 | tostring(-1), 134 | tostring(reset_after), 135 | } 136 | `) 137 | -------------------------------------------------------------------------------- /rate.go: -------------------------------------------------------------------------------- 1 | package redis_rate 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strconv" 7 | "time" 8 | 9 | "github.com/redis/go-redis/v9" 10 | ) 11 | 12 | const redisPrefix = "rate:" 13 | 14 | type rediser interface { 15 | Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd 16 | EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd 17 | ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd 18 | ScriptLoad(ctx context.Context, script string) *redis.StringCmd 19 | Del(ctx context.Context, keys ...string) *redis.IntCmd 20 | 21 | EvalRO(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd 22 | EvalShaRO(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd 23 | } 24 | 25 | type Limit struct { 26 | Rate int 27 | Burst int 28 | Period time.Duration 29 | } 30 | 31 | func (l Limit) String() string { 32 | return fmt.Sprintf("%d req/%s (burst %d)", l.Rate, fmtDur(l.Period), l.Burst) 33 | } 34 | 35 | func (l Limit) IsZero() bool { 36 | return l == Limit{} 37 | } 38 | 39 | func fmtDur(d time.Duration) string { 40 | switch d { 41 | case time.Second: 42 | return "s" 43 | case time.Minute: 44 | return "m" 45 | case time.Hour: 46 | return "h" 47 | } 48 | return d.String() 49 | } 50 | 51 | func PerSecond(rate int) Limit { 52 | return Limit{ 53 | Rate: rate, 54 | Period: time.Second, 55 | Burst: rate, 56 | } 57 | } 58 | 59 | func PerMinute(rate int) Limit { 60 | return Limit{ 61 | Rate: rate, 62 | Period: time.Minute, 63 | Burst: rate, 64 | } 65 | } 66 | 67 | func PerHour(rate int) Limit { 68 | return Limit{ 69 | Rate: rate, 70 | Period: time.Hour, 71 | Burst: rate, 72 | } 73 | } 74 | 75 | // ------------------------------------------------------------------------------ 76 | 77 | // Limiter controls how frequently events are allowed to happen. 78 | type Limiter struct { 79 | rdb rediser 80 | } 81 | 82 | // NewLimiter returns a new Limiter. 83 | func NewLimiter(rdb rediser) *Limiter { 84 | return &Limiter{ 85 | rdb: rdb, 86 | } 87 | } 88 | 89 | // Allow is a shortcut for AllowN(ctx, key, limit, 1). 90 | func (l Limiter) Allow(ctx context.Context, key string, limit Limit) (*Result, error) { 91 | return l.AllowN(ctx, key, limit, 1) 92 | } 93 | 94 | // AllowN reports whether n events may happen at time now. 95 | func (l Limiter) AllowN( 96 | ctx context.Context, 97 | key string, 98 | limit Limit, 99 | n int, 100 | ) (*Result, error) { 101 | values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n} 102 | v, err := allowN.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result() 103 | if err != nil { 104 | return nil, err 105 | } 106 | 107 | values = v.([]interface{}) 108 | 109 | retryAfter, err := strconv.ParseFloat(values[2].(string), 64) 110 | if err != nil { 111 | return nil, err 112 | } 113 | 114 | resetAfter, err := strconv.ParseFloat(values[3].(string), 64) 115 | if err != nil { 116 | return nil, err 117 | } 118 | 119 | res := &Result{ 120 | Limit: limit, 121 | Allowed: int(values[0].(int64)), 122 | Remaining: int(values[1].(int64)), 123 | RetryAfter: dur(retryAfter), 124 | ResetAfter: dur(resetAfter), 125 | } 126 | return res, nil 127 | } 128 | 129 | // AllowAtMost reports whether at most n events may happen at time now. 130 | // It returns number of allowed events that is less than or equal to n. 131 | func (l Limiter) AllowAtMost( 132 | ctx context.Context, 133 | key string, 134 | limit Limit, 135 | n int, 136 | ) (*Result, error) { 137 | values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n} 138 | v, err := allowAtMost.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result() 139 | if err != nil { 140 | return nil, err 141 | } 142 | 143 | values = v.([]interface{}) 144 | 145 | retryAfter, err := strconv.ParseFloat(values[2].(string), 64) 146 | if err != nil { 147 | return nil, err 148 | } 149 | 150 | resetAfter, err := strconv.ParseFloat(values[3].(string), 64) 151 | if err != nil { 152 | return nil, err 153 | } 154 | 155 | res := &Result{ 156 | Limit: limit, 157 | Allowed: int(values[0].(int64)), 158 | Remaining: int(values[1].(int64)), 159 | RetryAfter: dur(retryAfter), 160 | ResetAfter: dur(resetAfter), 161 | } 162 | return res, nil 163 | } 164 | 165 | // Reset gets a key and reset all limitations and previous usages 166 | func (l *Limiter) Reset(ctx context.Context, key string) error { 167 | return l.rdb.Del(ctx, redisPrefix+key).Err() 168 | } 169 | 170 | func dur(f float64) time.Duration { 171 | if f == -1 { 172 | return -1 173 | } 174 | return time.Duration(f * float64(time.Second)) 175 | } 176 | 177 | type Result struct { 178 | // Limit is the limit that was used to obtain this result. 179 | Limit Limit 180 | 181 | // Allowed is the number of events that may happen at time now. 182 | Allowed int 183 | 184 | // Remaining is the maximum number of requests that could be 185 | // permitted instantaneously for this key given the current 186 | // state. For example, if a rate limiter allows 10 requests per 187 | // second and has already received 6 requests for this key this 188 | // second, Remaining would be 4. 189 | Remaining int 190 | 191 | // RetryAfter is the time until the next request will be permitted. 192 | // It should be -1 unless the rate limit has been exceeded. 193 | RetryAfter time.Duration 194 | 195 | // ResetAfter is the time until the RateLimiter returns to its 196 | // initial state for a given key. For example, if a rate limiter 197 | // manages requests per second and received one request 200ms ago, 198 | // Reset would return 800ms. You can also think of this as the time 199 | // until Limit and Remaining will be equal. 200 | ResetAfter time.Duration 201 | } 202 | -------------------------------------------------------------------------------- /rate_test.go: -------------------------------------------------------------------------------- 1 | package redis_rate_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/redis/go-redis/v9" 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/go-redis/redis_rate/v10" 12 | ) 13 | 14 | func rateLimiter() *redis_rate.Limiter { 15 | ring := redis.NewRing(&redis.RingOptions{ 16 | Addrs: map[string]string{"server0": ":6379"}, 17 | }) 18 | if err := ring.FlushDB(context.TODO()).Err(); err != nil { 19 | panic(err) 20 | } 21 | return redis_rate.NewLimiter(ring) 22 | } 23 | 24 | func TestAllow(t *testing.T) { 25 | ctx := context.Background() 26 | 27 | l := rateLimiter() 28 | 29 | limit := redis_rate.PerSecond(10) 30 | require.Equal(t, limit.String(), "10 req/s (burst 10)") 31 | require.False(t, limit.IsZero()) 32 | 33 | res, err := l.Allow(ctx, "test_id", limit) 34 | require.Nil(t, err) 35 | require.Equal(t, res.Allowed, 1) 36 | require.Equal(t, res.Remaining, 9) 37 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 38 | require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond)) 39 | 40 | err = l.Reset(ctx, "test_id") 41 | require.Nil(t, err) 42 | res, err = l.Allow(ctx, "test_id", limit) 43 | require.Nil(t, err) 44 | require.Equal(t, res.Allowed, 1) 45 | require.Equal(t, res.Remaining, 9) 46 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 47 | require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond)) 48 | 49 | res, err = l.AllowN(ctx, "test_id", limit, 2) 50 | require.Nil(t, err) 51 | require.Equal(t, res.Allowed, 2) 52 | require.Equal(t, res.Remaining, 7) 53 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 54 | require.InDelta(t, res.ResetAfter, 300*time.Millisecond, float64(10*time.Millisecond)) 55 | 56 | res, err = l.AllowN(ctx, "test_id", limit, 7) 57 | require.Nil(t, err) 58 | require.Equal(t, res.Allowed, 7) 59 | require.Equal(t, res.Remaining, 0) 60 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 61 | require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond)) 62 | 63 | res, err = l.AllowN(ctx, "test_id", limit, 1000) 64 | require.Nil(t, err) 65 | require.Equal(t, res.Allowed, 0) 66 | require.Equal(t, res.Remaining, 0) 67 | require.InDelta(t, res.RetryAfter, 99*time.Second, float64(time.Second)) 68 | require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond)) 69 | } 70 | 71 | func TestAllowN_IncrementZero(t *testing.T) { 72 | ctx := context.Background() 73 | l := rateLimiter() 74 | limit := redis_rate.PerSecond(10) 75 | 76 | // Check for a row that's not there 77 | res, err := l.AllowN(ctx, "test_id", limit, 0) 78 | require.Nil(t, err) 79 | require.Equal(t, res.Allowed, 0) 80 | require.Equal(t, res.Remaining, 10) 81 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 82 | require.Equal(t, res.ResetAfter, time.Duration(0)) 83 | 84 | // Now increment it 85 | res, err = l.Allow(ctx, "test_id", limit) 86 | require.Nil(t, err) 87 | require.Equal(t, res.Allowed, 1) 88 | require.Equal(t, res.Remaining, 9) 89 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 90 | require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond)) 91 | 92 | // Peek again 93 | res, err = l.AllowN(ctx, "test_id", limit, 0) 94 | require.Nil(t, err) 95 | require.Equal(t, res.Allowed, 0) 96 | require.Equal(t, res.Remaining, 9) 97 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 98 | require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond)) 99 | } 100 | 101 | func TestRetryAfter(t *testing.T) { 102 | limit := redis_rate.Limit{ 103 | Rate: 1, 104 | Period: time.Millisecond, 105 | Burst: 1, 106 | } 107 | 108 | ctx := context.Background() 109 | l := rateLimiter() 110 | 111 | for i := 0; i < 1000; i++ { 112 | res, err := l.Allow(ctx, "test_id", limit) 113 | require.Nil(t, err) 114 | 115 | if res.Allowed > 0 { 116 | continue 117 | } 118 | 119 | require.LessOrEqual(t, int64(res.RetryAfter), int64(time.Millisecond)) 120 | } 121 | } 122 | 123 | func TestAllowAtMost(t *testing.T) { 124 | ctx := context.Background() 125 | 126 | l := rateLimiter() 127 | limit := redis_rate.PerSecond(10) 128 | 129 | res, err := l.Allow(ctx, "test_id", limit) 130 | require.Nil(t, err) 131 | require.Equal(t, res.Allowed, 1) 132 | require.Equal(t, res.Remaining, 9) 133 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 134 | require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond)) 135 | 136 | res, err = l.AllowAtMost(ctx, "test_id", limit, 2) 137 | require.Nil(t, err) 138 | require.Equal(t, res.Allowed, 2) 139 | require.Equal(t, res.Remaining, 7) 140 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 141 | require.InDelta(t, res.ResetAfter, 300*time.Millisecond, float64(10*time.Millisecond)) 142 | 143 | res, err = l.AllowN(ctx, "test_id", limit, 0) 144 | require.Nil(t, err) 145 | require.Equal(t, res.Allowed, 0) 146 | require.Equal(t, res.Remaining, 7) 147 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 148 | require.InDelta(t, res.ResetAfter, 300*time.Millisecond, float64(10*time.Millisecond)) 149 | 150 | res, err = l.AllowAtMost(ctx, "test_id", limit, 10) 151 | require.Nil(t, err) 152 | require.Equal(t, res.Allowed, 7) 153 | require.Equal(t, res.Remaining, 0) 154 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 155 | require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond)) 156 | 157 | res, err = l.AllowN(ctx, "test_id", limit, 0) 158 | require.Nil(t, err) 159 | require.Equal(t, res.Allowed, 0) 160 | require.Equal(t, res.Remaining, 0) 161 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 162 | require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond)) 163 | 164 | res, err = l.AllowAtMost(ctx, "test_id", limit, 1000) 165 | require.Nil(t, err) 166 | require.Equal(t, res.Allowed, 0) 167 | require.Equal(t, res.Remaining, 0) 168 | require.InDelta(t, res.RetryAfter, 99*time.Millisecond, float64(10*time.Millisecond)) 169 | require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond)) 170 | 171 | res, err = l.AllowN(ctx, "test_id", limit, 1000) 172 | require.Nil(t, err) 173 | require.Equal(t, res.Allowed, 0) 174 | require.Equal(t, res.Remaining, 0) 175 | require.InDelta(t, res.RetryAfter, 99*time.Second, float64(time.Second)) 176 | require.InDelta(t, res.ResetAfter, 999*time.Millisecond, float64(10*time.Millisecond)) 177 | } 178 | 179 | func TestAllowAtMost_IncrementZero(t *testing.T) { 180 | ctx := context.Background() 181 | l := rateLimiter() 182 | limit := redis_rate.PerSecond(10) 183 | 184 | // Check for a row that isn't there 185 | res, err := l.AllowAtMost(ctx, "test_id", limit, 0) 186 | require.Nil(t, err) 187 | require.Equal(t, res.Allowed, 0) 188 | require.Equal(t, res.Remaining, 10) 189 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 190 | require.Equal(t, res.ResetAfter, time.Duration(0)) 191 | 192 | // Now increment it 193 | res, err = l.Allow(ctx, "test_id", limit) 194 | require.Nil(t, err) 195 | require.Equal(t, res.Allowed, 1) 196 | require.Equal(t, res.Remaining, 9) 197 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 198 | require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond)) 199 | 200 | // Peek again 201 | res, err = l.AllowAtMost(ctx, "test_id", limit, 0) 202 | require.Nil(t, err) 203 | require.Equal(t, res.Allowed, 0) 204 | require.Equal(t, res.Remaining, 9) 205 | require.Equal(t, res.RetryAfter, time.Duration(-1)) 206 | require.InDelta(t, res.ResetAfter, 100*time.Millisecond, float64(10*time.Millisecond)) 207 | } 208 | 209 | func BenchmarkAllow(b *testing.B) { 210 | ctx := context.Background() 211 | l := rateLimiter() 212 | limit := redis_rate.PerSecond(1e6) 213 | 214 | b.ResetTimer() 215 | 216 | b.RunParallel(func(pb *testing.PB) { 217 | for pb.Next() { 218 | res, err := l.Allow(ctx, "foo", limit) 219 | if err != nil { 220 | b.Fatal(err) 221 | } 222 | if res.Allowed == 0 { 223 | panic("not reached") 224 | } 225 | } 226 | }) 227 | } 228 | 229 | func BenchmarkAllowAtMost(b *testing.B) { 230 | ctx := context.Background() 231 | l := rateLimiter() 232 | limit := redis_rate.PerSecond(1e6) 233 | 234 | b.ResetTimer() 235 | 236 | b.RunParallel(func(pb *testing.PB) { 237 | for pb.Next() { 238 | res, err := l.AllowAtMost(ctx, "foo", limit, 1) 239 | if err != nil { 240 | b.Fatal(err) 241 | } 242 | if res.Allowed == 0 { 243 | panic("not reached") 244 | } 245 | } 246 | }) 247 | } 248 | -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": [ 3 | "config:base" 4 | ] 5 | } 6 | --------------------------------------------------------------------------------