├── go.mod ├── Makefile ├── README.md ├── .github ├── CONTRIBUTING.md └── workflows │ └── test.yml ├── example_test.go ├── .golangci.yml ├── go.sum ├── lua.go ├── store_test.go ├── store.go └── LICENSE /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/sethvargo/go-redisstore 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/gomodule/redigo v1.8.9 7 | github.com/sethvargo/go-limiter v0.7.2 8 | ) 9 | 10 | require github.com/stretchr/testify v1.8.1 // indirect 11 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: 2 | @go test \ 3 | -count=1 \ 4 | -short \ 5 | -shuffle=on \ 6 | -timeout=5m \ 7 | ./... 8 | .PHONY: test 9 | 10 | test-acc: 11 | @go test \ 12 | -count=1 \ 13 | -race \ 14 | -shuffle=on \ 15 | -timeout=10m \ 16 | ./... 17 | .PHONY: test-acc 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Go Rate Limiter 2 | 3 | This package provides a rate limiting interface in Go (Golang), using Redis. See 4 | [sethvargo/go-limiter][limiter] for more information. 5 | 6 | For an instrumented client, see the [sethvargo/go-redisstore-opencensus][opencensus]. 7 | 8 | [limiter]: https://github.com/sethvargo/go-limiter 9 | [opencensus]: https://github.com/sethvargo/go-redisstore-opencensus 10 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | 7 | ## Code reviews 8 | 9 | All submissions, including submissions by project members, require review. We 10 | use GitHub pull requests for this purpose. Consult 11 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 12 | information on using pull requests. 13 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package redisstore_test 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "time" 7 | 8 | "github.com/gomodule/redigo/redis" 9 | "github.com/sethvargo/go-redisstore" 10 | ) 11 | 12 | func ExampleNew() { 13 | ctx := context.Background() 14 | 15 | store, err := redisstore.New(&redisstore.Config{ 16 | Tokens: 15, 17 | Interval: time.Minute, 18 | Dial: func() (redis.Conn, error) { 19 | return redis.Dial("tcp", "127.0.0.1:6379", 20 | redis.DialPassword("my-password")) 21 | }, 22 | }) 23 | if err != nil { 24 | log.Fatal(err) 25 | } 26 | defer store.Close(ctx) 27 | 28 | limit, remaining, reset, ok, err := store.Take(ctx, "my-key") 29 | if err != nil { 30 | log.Fatal(err) 31 | } 32 | _, _, _, _ = limit, remaining, reset, ok 33 | } 34 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | # default: '' 3 | modules-download-mode: 'readonly' 4 | 5 | # default: false 6 | allow-parallel-runners: true 7 | 8 | linters: 9 | enable: 10 | - 'asasalint' 11 | - 'asciicheck' 12 | - 'bidichk' 13 | - 'bodyclose' 14 | - 'containedctx' 15 | - 'depguard' 16 | - 'durationcheck' 17 | - 'errcheck' 18 | - 'errchkjson' 19 | - 'errname' 20 | - 'errorlint' 21 | - 'errorlint' 22 | - 'exportloopref' 23 | - 'gofmt' 24 | - 'gofumpt' 25 | - 'goheader' 26 | - 'goimports' 27 | - 'gomodguard' 28 | - 'goprintffuncname' 29 | - 'gosec' 30 | - 'gosimple' 31 | - 'govet' 32 | - 'importas' 33 | - 'ineffassign' 34 | - 'ireturn' 35 | - 'makezero' 36 | - 'misspell' 37 | - 'noctx' 38 | - 'noctx' 39 | - 'paralleltest' 40 | - 'prealloc' 41 | - 'predeclared' 42 | - 'revive' 43 | - 'sqlclosecheck' 44 | - 'staticcheck' 45 | - 'stylecheck' 46 | - 'thelper' 47 | - 'typecheck' 48 | - 'unconvert' 49 | - 'unused' 50 | - 'whitespace' 51 | 52 | issues: 53 | # default: [] 54 | exclude: 55 | 56 | # default: 50 57 | max-issues-per-linter: 0 58 | 59 | # default: 3 60 | max-same-issues: 0 61 | 62 | severity: 63 | # default: '' 64 | default-severity: 'error' 65 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: 'test' 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | - 'opencensus' 8 | pull_request: 9 | branches: 10 | - 'main' 11 | - 'opencensus' 12 | 13 | jobs: 14 | test: 15 | strategy: 16 | fail-fast: false 17 | 18 | matrix: 19 | redis: 20 | - '6.0' 21 | - '5.0' 22 | 23 | runs-on: 'ubuntu-latest' 24 | 25 | services: 26 | redis: 27 | image: 'bitnami/redis:${{ matrix.redis }}' 28 | env: 29 | REDIS_PASSWORD: testing123 30 | options: >- 31 | --health-cmd "redis-cli ping" 32 | --health-interval 10s 33 | --health-timeout 5s 34 | --health-retries 5 35 | ports: 36 | - '6379:6379' 37 | 38 | steps: 39 | - uses: 'actions/checkout@v3' 40 | 41 | - uses: 'actions/setup-go@v3' 42 | with: 43 | go-version: '1.17' 44 | 45 | - run: 'go mod download' 46 | 47 | - uses: 'golangci/golangci-lint-action@v3' 48 | with: 49 | only-new-issues: 'true' 50 | skip-cache: 'true' 51 | 52 | - run: 'make test-acc' 53 | env: 54 | REDIS_HOST: '127.0.0.1' 55 | REDIS_PORT: '6379' 56 | REDIS_PASS: 'testing123' 57 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= 5 | github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= 6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 8 | github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8= 9 | github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU= 10 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 11 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 12 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 13 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 14 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 15 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 16 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 17 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 18 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 19 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 20 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 21 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 22 | -------------------------------------------------------------------------------- /lua.go: -------------------------------------------------------------------------------- 1 | package redisstore 2 | 3 | const luaTemplate = ` 4 | local C_EXPIRE = 'EXPIRE' 5 | local C_HGETALL = 'HGETALL' 6 | local C_HSET = 'HSET' 7 | local F_START = 's' 8 | local F_TICK = 't' 9 | local F_INTERVAL = 'i' 10 | local F_TOKENS = 'k' 11 | local F_MAX = 'm' 12 | 13 | -- speed up access to next 14 | local next = next 15 | 16 | -- get arguments 17 | local key = KEYS[1] 18 | local now = tonumber(ARGV[1]) -- current unix time in nanoseconds 19 | local defmaxtokens = tonumber(ARGV[2]) -- default tokens per interval, only used if no value already exists for the key 20 | local definterval = tonumber(ARGV[3]) -- interval in nanoseconds, only used if no value already exists for the key 21 | 22 | -- hgetall gets all the fields as a lua table. 23 | local hgetall = function (key) 24 | local data = redis.call(C_HGETALL, key) 25 | local result = {} 26 | for i = 1, #data, 2 do 27 | result[data[i]] = data[i+1] 28 | end 29 | return result 30 | end 31 | 32 | -- availabletokens returns the number of available tokens given the last tick, 33 | -- current tick, max, and fill rate. 34 | local availabletokens = function (last, curr, max, fillrate) 35 | local delta = curr - last 36 | local available = delta * fillrate 37 | if available > max then 38 | available = max 39 | end 40 | return available 41 | end 42 | 43 | -- present returns true if the given value is not nil and is not the empty 44 | -- string. 45 | local present = function (val) 46 | return val ~= nil and val ~= '' 47 | end 48 | 49 | -- tick returns the total number of times the interval has occurred between 50 | -- start and current. 51 | local tick = function (start, curr, interval) 52 | local val = math.floor((curr - start) / interval) 53 | if val > 0 then 54 | return val 55 | end 56 | return 0 57 | end 58 | 59 | -- ttl returns the appropriate ttl in seconds for the given interval, 3x the 60 | -- interval. 61 | local ttl = function (interval) 62 | return 3 * math.floor(interval / 1000000000) 63 | end 64 | 65 | 66 | -- 67 | -- begin exec 68 | -- 69 | 70 | local data = hgetall(key) 71 | 72 | local start = now 73 | if present(data[F_START]) then 74 | start = tonumber(data[F_START]) 75 | else 76 | redis.call(C_HSET, key, F_START, now) 77 | redis.call(C_EXPIRE, key, 30) 78 | end 79 | 80 | local lasttick = 0 81 | if present(data[F_TICK]) then 82 | lasttick = tonumber(data[F_TICK]) 83 | else 84 | redis.call(C_HSET, key, F_TICK, 0) 85 | redis.call(C_EXPIRE, key, 30) 86 | end 87 | 88 | local maxtokens = defmaxtokens 89 | if present(data[F_MAX]) then 90 | maxtokens = tonumber(data[F_MAX]) 91 | else 92 | redis.call(C_HSET, key, F_MAX, defmaxtokens) 93 | redis.call(C_EXPIRE, key, 30) 94 | end 95 | 96 | local tokens = maxtokens 97 | if present(data[F_TOKENS]) then 98 | tokens = tonumber(data[F_TOKENS]) 99 | end 100 | 101 | local interval = definterval 102 | if present(data[F_INTERVAL]) then 103 | interval = tonumber(data[F_INTERVAL]) 104 | end 105 | 106 | local currtick = tick(start, now, interval) 107 | local nexttime = start + ((currtick+1) * interval) 108 | 109 | if lasttick < currtick then 110 | local rate = interval / tokens 111 | tokens = availabletokens(lasttick, currtick, maxtokens, rate) 112 | lasttick = currtick 113 | redis.call(C_HSET, key, 114 | F_START, start, 115 | F_TICK, lasttick, 116 | F_INTERVAL, interval, 117 | F_TOKENS, tokens) 118 | redis.call(C_EXPIRE, key, ttl(interval)) 119 | end 120 | 121 | if tokens > 0 then 122 | tokens = tokens - 1 123 | redis.call(C_HSET, key, F_TOKENS, tokens) 124 | redis.call(C_EXPIRE, key, ttl(interval)) 125 | return {maxtokens, tokens, nexttime, true} 126 | end 127 | 128 | return {maxtokens, tokens, nexttime, false} 129 | ` 130 | -------------------------------------------------------------------------------- /store_test.go: -------------------------------------------------------------------------------- 1 | package redisstore 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | "crypto/sha256" 7 | "fmt" 8 | "os" 9 | "sort" 10 | "testing" 11 | "time" 12 | 13 | "github.com/gomodule/redigo/redis" 14 | ) 15 | 16 | func testKey(tb testing.TB) string { 17 | tb.Helper() 18 | 19 | var b [512]byte 20 | if _, err := rand.Read(b[:]); err != nil { 21 | tb.Fatalf("failed to generate random string: %v", err) 22 | } 23 | digest := fmt.Sprintf("%x", sha256.Sum256(b[:])) 24 | return digest[:32] 25 | } 26 | 27 | func TestStore_Exercise(t *testing.T) { 28 | t.Parallel() 29 | 30 | ctx := context.Background() 31 | 32 | host := os.Getenv("REDIS_HOST") 33 | if host == "" { 34 | t.Fatal("missing REDIS_HOST") 35 | } 36 | 37 | port := os.Getenv("REDIS_PORT") 38 | if port == "" { 39 | port = "6379" 40 | } 41 | 42 | pass := os.Getenv("REDIS_PASS") 43 | 44 | s, err := New(&Config{ 45 | Tokens: 5, 46 | Interval: 3 * time.Second, 47 | Dial: func() (redis.Conn, error) { 48 | return redis.Dial("tcp", host+":"+port, 49 | redis.DialPassword(pass)) 50 | }, 51 | }) 52 | if err != nil { 53 | t.Fatal(err) 54 | } 55 | t.Cleanup(func() { 56 | if err := s.Close(ctx); err != nil { 57 | t.Fatal(err) 58 | } 59 | }) 60 | 61 | key := testKey(t) 62 | 63 | // Get when no config exists 64 | { 65 | limit, remaining, err := s.Get(ctx, key) 66 | if err != nil { 67 | t.Fatal(err) 68 | } 69 | 70 | if got, want := limit, uint64(0); got != want { 71 | t.Errorf("expected %v to be %v", got, want) 72 | } 73 | if got, want := remaining, uint64(0); got != want { 74 | t.Errorf("expected %v to be %v", got, want) 75 | } 76 | } 77 | 78 | // Take with no key configuration - this should use the default values 79 | { 80 | limit, remaining, reset, ok, err := s.Take(ctx, key) 81 | if err != nil { 82 | t.Fatal(err) 83 | } 84 | if !ok { 85 | t.Errorf("expected ok") 86 | } 87 | if got, want := limit, uint64(5); got != want { 88 | t.Errorf("expected %v to be %v", got, want) 89 | } 90 | if got, want := remaining, uint64(4); got != want { 91 | t.Errorf("expected %v to be %v", got, want) 92 | } 93 | if got, want := time.Until(time.Unix(0, int64(reset))), 3*time.Second; got > want { 94 | t.Errorf("expected %v to less than %v", got, want) 95 | } 96 | } 97 | 98 | // Get the value 99 | { 100 | limit, remaining, err := s.Get(ctx, key) 101 | if err != nil { 102 | t.Fatal(err) 103 | } 104 | if got, want := limit, uint64(5); got != want { 105 | t.Errorf("expected %v to be %v", got, want) 106 | } 107 | if got, want := remaining, uint64(4); got != want { 108 | t.Errorf("expected %v to be %v", got, want) 109 | } 110 | } 111 | 112 | // Now set a value 113 | { 114 | if err := s.Set(ctx, key, 11, 5*time.Second); err != nil { 115 | t.Fatal(err) 116 | } 117 | } 118 | 119 | // Get the value again 120 | { 121 | limit, remaining, err := s.Get(ctx, key) 122 | if err != nil { 123 | t.Fatal(err) 124 | } 125 | if got, want := limit, uint64(11); got != want { 126 | t.Errorf("expected %v to be %v", got, want) 127 | } 128 | if got, want := remaining, uint64(11); got != want { 129 | t.Errorf("expected %v to be %v", got, want) 130 | } 131 | } 132 | 133 | // Take again, this should use the new values 134 | { 135 | limit, remaining, reset, ok, err := s.Take(ctx, key) 136 | if err != nil { 137 | t.Fatal(err) 138 | } 139 | if !ok { 140 | t.Errorf("expected ok") 141 | } 142 | if got, want := limit, uint64(11); got != want { 143 | t.Errorf("expected %v to be %v", got, want) 144 | } 145 | if got, want := remaining, uint64(10); got != want { 146 | t.Errorf("expected %v to be %v", got, want) 147 | } 148 | if got, want := time.Until(time.Unix(0, int64(reset))), 5*time.Second; got > want { 149 | t.Errorf("expected %v to less than %v", got, want) 150 | } 151 | } 152 | 153 | // Get the value again 154 | { 155 | limit, remaining, err := s.Get(ctx, key) 156 | if err != nil { 157 | t.Fatal(err) 158 | } 159 | if got, want := limit, uint64(11); got != want { 160 | t.Errorf("expected %v to be %v", got, want) 161 | } 162 | if got, want := remaining, uint64(10); got != want { 163 | t.Errorf("expected %v to be %v", got, want) 164 | } 165 | } 166 | 167 | // Burst and take 168 | { 169 | if err := s.Burst(ctx, key, 5); err != nil { 170 | t.Fatal(err) 171 | } 172 | 173 | limit, remaining, reset, ok, err := s.Take(ctx, key) 174 | if err != nil { 175 | t.Fatal(err) 176 | } 177 | if !ok { 178 | t.Errorf("expected ok") 179 | } 180 | if got, want := limit, uint64(11); got != want { 181 | t.Errorf("expected %v to be %v", got, want) 182 | } 183 | if got, want := remaining, uint64(14); got != want { 184 | t.Errorf("expected %v to be %v", got, want) 185 | } 186 | if got, want := time.Until(time.Unix(0, int64(reset))), 5*time.Second; got > want { 187 | t.Errorf("expected %v to less than %v", got, want) 188 | } 189 | } 190 | 191 | // Get the value one final time 192 | { 193 | limit, remaining, err := s.Get(ctx, key) 194 | if err != nil { 195 | t.Fatal(err) 196 | } 197 | if got, want := limit, uint64(11); got != want { 198 | t.Errorf("expected %v to be %v", got, want) 199 | } 200 | if got, want := remaining, uint64(14); got != want { 201 | t.Errorf("expected %v to be %v", got, want) 202 | } 203 | } 204 | } 205 | 206 | func TestStore_Take(t *testing.T) { 207 | t.Parallel() 208 | 209 | if testing.Short() { 210 | t.Skipf("skipping (short)") 211 | } 212 | 213 | ctx := context.Background() 214 | 215 | host := os.Getenv("REDIS_HOST") 216 | if host == "" { 217 | t.Fatal("missing REDIS_HOST") 218 | } 219 | 220 | port := os.Getenv("REDIS_PORT") 221 | if port == "" { 222 | port = "6379" 223 | } 224 | 225 | pass := os.Getenv("REDIS_PASS") 226 | 227 | cases := []struct { 228 | name string 229 | tokens uint64 230 | interval time.Duration 231 | }{ 232 | { 233 | name: "second", 234 | tokens: 10, 235 | interval: 1 * time.Second, 236 | }, 237 | } 238 | 239 | for _, tc := range cases { 240 | tc := tc 241 | 242 | t.Run(tc.name, func(t *testing.T) { 243 | t.Parallel() 244 | 245 | key := testKey(t) 246 | 247 | s, err := New(&Config{ 248 | Interval: tc.interval, 249 | Tokens: tc.tokens, 250 | Dial: func() (redis.Conn, error) { 251 | return redis.Dial("tcp", host+":"+port, 252 | redis.DialPassword(pass)) 253 | }, 254 | }) 255 | if err != nil { 256 | t.Fatal(err) 257 | } 258 | t.Cleanup(func() { 259 | if err := s.Close(ctx); err != nil { 260 | t.Fatal(err) 261 | } 262 | }) 263 | 264 | type result struct { 265 | limit, remaining uint64 266 | reset time.Duration 267 | ok bool 268 | err error 269 | } 270 | 271 | // Take twice everything 272 | takeCh := make(chan *result, 2*tc.tokens) 273 | for i := uint64(1); i <= 2*tc.tokens; i++ { 274 | go func() { 275 | limit, remaining, reset, ok, err := s.Take(ctx, key) 276 | takeCh <- &result{limit, remaining, time.Until(time.Unix(0, int64(reset))), ok, err} 277 | }() 278 | } 279 | 280 | // Accumulate and sort results, since they could come in any order 281 | var results []*result 282 | for i := uint64(1); i <= 2*tc.tokens; i++ { 283 | select { 284 | case result := <-takeCh: 285 | results = append(results, result) 286 | case <-time.After(5 * time.Second): 287 | t.Fatal("timeout") 288 | } 289 | } 290 | sort.Slice(results, func(i, j int) bool { 291 | if results[i].remaining == results[j].remaining { 292 | return !results[j].ok 293 | } 294 | return results[i].remaining > results[j].remaining 295 | }) 296 | 297 | for i, result := range results { 298 | if err := result.err; err != nil { 299 | t.Fatal(err) 300 | } 301 | 302 | if got, want := result.limit, tc.tokens; got != want { 303 | t.Errorf("limit: expected %d to be %d", got, want) 304 | } 305 | if got, want := result.reset, tc.interval; got > want { 306 | t.Errorf("reset: expected %d to be less than %d", got, want) 307 | } 308 | 309 | // first half should pass, second half should fail 310 | if uint64(i) < tc.tokens { 311 | if got, want := result.remaining, tc.tokens-uint64(i)-1; got != want { 312 | t.Errorf("remaining: expected %d to be %d", got, want) 313 | } 314 | if got, want := result.ok, true; got != want { 315 | t.Errorf("ok: expected %t to be %t", got, want) 316 | } 317 | } else { 318 | if got, want := result.remaining, uint64(0); got != want { 319 | t.Errorf("remaining: expected %d to be %d", got, want) 320 | } 321 | if got, want := result.ok, false; got != want { 322 | t.Errorf("ok: expected %t to be %t", got, want) 323 | } 324 | } 325 | } 326 | 327 | // Wait for entries again 328 | time.Sleep(tc.interval) 329 | 330 | // Verify we can take once more 331 | _, _, _, ok, err := s.Take(ctx, key) 332 | if err != nil { 333 | t.Fatal(err) 334 | } 335 | if !ok { 336 | t.Errorf("expected %t to be %t", ok, true) 337 | } 338 | }) 339 | } 340 | } 341 | -------------------------------------------------------------------------------- /store.go: -------------------------------------------------------------------------------- 1 | // Package redisstore defines a redis-backed storage system for limiting. 2 | package redisstore 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "strconv" 8 | "sync/atomic" 9 | "time" 10 | 11 | "github.com/gomodule/redigo/redis" 12 | "github.com/sethvargo/go-limiter" 13 | ) 14 | 15 | const ( 16 | // hash field keys shared by the Lua script. 17 | fieldInterval = "i" 18 | fieldMaxTokens = "m" 19 | fieldTokens = "k" 20 | 21 | // weekSeconds is the number of seconds in a week. 22 | weekSeconds = 60 * 60 * 24 * 7 23 | 24 | // Common Redis commands 25 | cmdEXPIRE = "EXPIRE" 26 | cmdHINCRBY = "HINCRBY" 27 | cmdHMGET = "HMGET" 28 | cmdHSET = "HSET" 29 | cmdPING = "PING" 30 | ) 31 | 32 | var _ limiter.Store = (*Store)(nil) 33 | 34 | type Store struct { 35 | tokens uint64 36 | interval time.Duration 37 | pool *redis.Pool 38 | luaScript *redis.Script 39 | 40 | stopped uint32 41 | } 42 | 43 | // Config is used as input to New. It defines the behavior of the storage 44 | // system. 45 | type Config struct { 46 | // Tokens is the number of tokens to allow per interval. The default value is 47 | // 1. 48 | Tokens uint64 49 | 50 | // Interval is the time interval upon which to enforce rate limiting. The 51 | // default value is 1 second. 52 | Interval time.Duration 53 | 54 | // Dial is the function to use as the dialer. This is ignored when used with 55 | // NewWithPool. 56 | Dial func() (redis.Conn, error) 57 | } 58 | 59 | // New uses a Redis instance to back a rate limiter that to limit the number of 60 | // permitted events over an interval. 61 | func New(c *Config) (*Store, error) { 62 | return NewWithPool(c, &redis.Pool{ 63 | MaxActive: 100, 64 | IdleTimeout: 5 * time.Minute, 65 | Dial: c.Dial, 66 | TestOnBorrow: func(c redis.Conn, _ time.Time) error { 67 | _, err := c.Do(cmdPING) 68 | return fmt.Errorf("failed to borrow: %w", err) 69 | }, 70 | }) 71 | } 72 | 73 | // NewWithPool creates a new limiter using the given redis pool. Use this to 74 | // customize lower-level details about the pool. 75 | func NewWithPool(c *Config, pool *redis.Pool) (*Store, error) { 76 | if c == nil { 77 | c = new(Config) 78 | } 79 | 80 | tokens := uint64(1) 81 | if c.Tokens > 0 { 82 | tokens = c.Tokens 83 | } 84 | 85 | interval := 1 * time.Second 86 | if c.Interval > 0 { 87 | interval = c.Interval 88 | } 89 | 90 | luaScript := redis.NewScript(1, luaTemplate) 91 | 92 | s := &Store{ 93 | tokens: tokens, 94 | interval: interval, 95 | pool: pool, 96 | luaScript: luaScript, 97 | } 98 | return s, nil 99 | } 100 | 101 | // Take attempts to remove a token from the named key. If the take is 102 | // successful, it returns true, otherwise false. It also returns the configured 103 | // limit, remaining tokens, and reset time, if one was found. Any errors 104 | // connecting to the store or parsing the return value are considered failures 105 | // and fail the take. 106 | func (s *Store) Take(ctx context.Context, key string) (limit uint64, remaining uint64, next uint64, ok bool, retErr error) { 107 | // If the store is stopped, all requests are rejected. 108 | if atomic.LoadUint32(&s.stopped) == 1 { 109 | retErr = limiter.ErrStopped 110 | return 111 | } 112 | 113 | // Get the current time, since this is when the function was called, and we 114 | // want to limit from call time, not invoke time. 115 | now := uint64(time.Now().UTC().UnixNano()) 116 | 117 | // Get a client from the pool. 118 | conn, err := s.pool.GetContext(ctx) 119 | if err != nil { 120 | retErr = fmt.Errorf("failed to get connection from pool: %w", err) 121 | return 122 | } 123 | if err := conn.Err(); err != nil { 124 | retErr = fmt.Errorf("connection is not usable: %w", err) 125 | return 126 | } 127 | defer closeConnection(conn, &retErr) 128 | 129 | nowStr := strconv.FormatUint(now, 10) 130 | tokensStr := strconv.FormatUint(s.tokens, 10) 131 | intervalStr := strconv.FormatInt(s.interval.Nanoseconds(), 10) 132 | a, err := redis.Int64s(s.luaScript.Do(conn, key, nowStr, tokensStr, intervalStr)) 133 | if err != nil { 134 | retErr = fmt.Errorf("failed to run script: %w", err) 135 | return 136 | } 137 | 138 | if len(a) < 4 { 139 | retErr = fmt.Errorf("response has less than 4 values: %#v", a) 140 | return 141 | } 142 | 143 | limit, remaining, next, ok = uint64(a[0]), uint64(a[1]), uint64(a[2]), a[3] == 1 144 | return 145 | } 146 | 147 | // Get gets the current limit and remaining tokens for the key. It does not 148 | // reduce or reset any counters. 149 | func (s *Store) Get(ctx context.Context, key string) (limit, remaining uint64, retErr error) { 150 | // If the store is stopped, all requests are rejected. 151 | if atomic.LoadUint32(&s.stopped) == 1 { 152 | retErr = limiter.ErrStopped 153 | return 154 | } 155 | 156 | // Get a client from the pool. 157 | conn, err := s.pool.GetContext(ctx) 158 | if err != nil { 159 | retErr = fmt.Errorf("failed to get connection from pool: %w", err) 160 | return 161 | } 162 | if err := conn.Err(); err != nil { 163 | retErr = fmt.Errorf("connection is not usable: %w", err) 164 | return 165 | } 166 | defer closeConnection(conn, &retErr) 167 | 168 | result, err := redis.Int64s(conn.Do(cmdHMGET, key, fieldMaxTokens, fieldTokens)) 169 | if err != nil { 170 | retErr = fmt.Errorf("failed to get key: %w", err) 171 | return 172 | } 173 | 174 | if got, want := len(result), 2; got != want { 175 | retErr = fmt.Errorf("not enough keys returned, expected %d got %d", want, got) 176 | return 177 | } 178 | 179 | limit = uint64(result[0]) 180 | remaining = uint64(result[1]) 181 | return 182 | } 183 | 184 | // Set sets the key's limit to the provided value and interval. 185 | func (s *Store) Set(ctx context.Context, key string, tokens uint64, interval time.Duration) (retErr error) { 186 | // If the store is stopped, all requests are rejected. 187 | if atomic.LoadUint32(&s.stopped) == 1 { 188 | retErr = limiter.ErrStopped 189 | return 190 | } 191 | 192 | // Get a client from the pool. 193 | conn, err := s.pool.GetContext(ctx) 194 | if err != nil { 195 | retErr = fmt.Errorf("failed to get connection from pool: %w", err) 196 | return 197 | } 198 | if err := conn.Err(); err != nil { 199 | retErr = fmt.Errorf("connection is not usable: %w", err) 200 | return 201 | } 202 | defer closeConnection(conn, &retErr) 203 | 204 | // Set configuration on the key. 205 | tokensStr := strconv.FormatUint(tokens, 10) 206 | intervalStr := strconv.FormatInt(interval.Nanoseconds(), 10) 207 | if err := conn.Send(cmdHSET, key, 208 | fieldTokens, tokensStr, 209 | fieldMaxTokens, tokensStr, 210 | fieldInterval, intervalStr, 211 | ); err != nil { 212 | retErr = fmt.Errorf("failed to set key: %w", err) 213 | return 214 | } 215 | 216 | // Set the key to expire. This will prevent a leak when a key's configuration 217 | // is set, but nothing is ever taken from the bucket. 218 | if err := conn.Send(cmdEXPIRE, key, weekSeconds); err != nil { 219 | retErr = fmt.Errorf("failed to set expire on key: %w", err) 220 | return 221 | } 222 | 223 | return 224 | } 225 | 226 | // Burst adds the given tokens to the key's bucket. 227 | func (s *Store) Burst(ctx context.Context, key string, tokens uint64) (retErr error) { 228 | // If the store is stopped, all requests are rejected. 229 | if atomic.LoadUint32(&s.stopped) == 1 { 230 | retErr = limiter.ErrStopped 231 | return 232 | } 233 | 234 | // Get a client from the pool. 235 | conn, err := s.pool.GetContext(ctx) 236 | if err != nil { 237 | retErr = fmt.Errorf("failed to get connection from pool: %w", err) 238 | return 239 | } 240 | if err := conn.Err(); err != nil { 241 | retErr = fmt.Errorf("connection is not usable: %w", err) 242 | return 243 | } 244 | defer closeConnection(conn, &retErr) 245 | 246 | // Set configuration on the key. 247 | tokensStr := strconv.FormatUint(tokens, 10) 248 | if err := conn.Send(cmdHINCRBY, key, fieldTokens, tokensStr); err != nil { 249 | retErr = fmt.Errorf("failed to set key: %w", err) 250 | return 251 | } 252 | 253 | // Set the key to expire. This will prevent a leak when a key's configuration 254 | // is set, but nothing is ever taken from the bucket. 255 | if err := conn.Send(cmdEXPIRE, key, weekSeconds); err != nil { 256 | retErr = fmt.Errorf("failed to set expire on key: %w", err) 257 | return 258 | } 259 | 260 | return 261 | } 262 | 263 | // Close stops the memory limiter and cleans up any outstanding sessions. You 264 | // should always call CloseWithContext() as it releases any open network 265 | // connections. 266 | func (s *Store) Close(_ context.Context) error { 267 | if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) { 268 | return nil 269 | } 270 | 271 | // Close the connection pool. 272 | if err := s.pool.Close(); err != nil { 273 | return fmt.Errorf("failed to close pool: %w", err) 274 | } 275 | return nil 276 | } 277 | 278 | // closeConnection is a helper for closing the connection object. It is used in 279 | // defer statements to alter the provided error pointer before the final result 280 | // is bubbled up the stack. 281 | func closeConnection(c redis.Conn, err *error) { 282 | nerr := c.Close() 283 | if *err == nil { 284 | *err = nerr 285 | } 286 | } 287 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | --------------------------------------------------------------------------------