├── .circleci └── config.yml ├── .dockerignore ├── .editorconfig ├── .github └── dependabot.yml ├── .gitignore ├── .golangci.yml ├── AUTHORS ├── LICENSE ├── Makefile ├── README.md ├── defaults.go ├── drivers ├── middleware │ ├── fasthttp │ │ ├── middleware.go │ │ ├── middleware_test.go │ │ └── options.go │ ├── gin │ │ ├── middleware.go │ │ ├── middleware_test.go │ │ └── options.go │ └── stdlib │ │ ├── middleware.go │ │ ├── middleware_test.go │ │ └── options.go └── store │ ├── common │ └── context.go │ ├── memory │ ├── cache.go │ ├── cache_test.go │ ├── store.go │ └── store_test.go │ ├── redis │ ├── store.go │ └── store_test.go │ └── tests │ └── tests.go ├── examples └── README.md ├── go.mod ├── go.sum ├── internal ├── bytebuffer │ └── pool.go └── fasttime │ ├── fasttime.go │ └── fasttime_windows.go ├── limiter.go ├── limiter_test.go ├── network.go ├── network_test.go ├── options.go ├── rate.go ├── rate_test.go ├── scripts ├── conf │ └── go │ │ └── Dockerfile ├── go-wrapper ├── lint ├── redis └── test └── store.go /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | jobs: 4 | build: 5 | machine: 6 | image: circleci/classic:edge 7 | docker_layer_caching: true 8 | steps: 9 | - checkout 10 | - run: 11 | name: Checkout submodules 12 | command: | 13 | git submodule sync 14 | git submodule update --init 15 | 16 | - run: 17 | name: Start docker container for redis 18 | command: scripts/redis 19 | 20 | - run: 21 | name: Run tests 22 | command: scripts/go-wrapper scripts/test 23 | environment: 24 | GO111MODULE: on 25 | REDIS_DISABLE_BOOTSTRAP: false 26 | REDIS_URI: redis://localhost:26379/1 27 | 28 | - run: 29 | name: Run linters 30 | command: scripts/go-wrapper scripts/lint 31 | environment: 32 | GO111MODULE: on 33 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Circle CI directory 2 | .circleci 3 | 4 | # Example directory 5 | examples 6 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | indent_size = 4 6 | indent_style = space 7 | insert_final_newline = true 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | 12 | [*.{yml,yaml}] 13 | indent_size = 2 14 | 15 | [*.go] 16 | indent_size = 8 17 | indent_style = tab 18 | 19 | [*.json] 20 | indent_size = 4 21 | indent_style = space 22 | 23 | [Makefile] 24 | indent_style = tab 25 | indent_size = 4 26 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | - package-ecosystem: "gomod" 5 | directory: "/" 6 | schedule: 7 | interval: "daily" 8 | commit-message: 9 | prefix: "chore(go.mod):" 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /vendor 2 | .idea 3 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | concurrency: 4 3 | deadline: 1m 4 | issues-exit-code: 1 5 | tests: true 6 | 7 | 8 | output: 9 | format: colored-line-number 10 | print-issued-lines: true 11 | print-linter-name: true 12 | 13 | 14 | linters-settings: 15 | errcheck: 16 | check-type-assertions: false 17 | check-blank: false 18 | govet: 19 | check-shadowing: false 20 | use-installed-packages: false 21 | golint: 22 | min-confidence: 0.8 23 | gofmt: 24 | simplify: true 25 | gocyclo: 26 | min-complexity: 10 27 | maligned: 28 | suggest-new: true 29 | dupl: 30 | threshold: 80 31 | goconst: 32 | min-len: 3 33 | min-occurrences: 3 34 | misspell: 35 | locale: US 36 | lll: 37 | line-length: 140 38 | unused: 39 | check-exported: false 40 | unparam: 41 | algo: cha 42 | check-exported: false 43 | nakedret: 44 | max-func-lines: 30 45 | 46 | linters: 47 | enable: 48 | - megacheck 49 | - govet 50 | - errcheck 51 | - gas 52 | - structcheck 53 | - varcheck 54 | - ineffassign 55 | - deadcode 56 | - typecheck 57 | - unconvert 58 | - gocyclo 59 | - gofmt 60 | - misspell 61 | - lll 62 | - nakedret 63 | enable-all: false 64 | disable: 65 | - depguard 66 | - prealloc 67 | - dupl 68 | - maligned 69 | disable-all: false 70 | 71 | 72 | issues: 73 | exclude-use-default: false 74 | max-per-linter: 1024 75 | max-same: 1024 76 | exclude: 77 | - "G304" 78 | - "G101" 79 | - "G104" 80 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Primary contributors: 2 | 3 | Gilles FABIO 4 | Florent MESSA 5 | Thomas LE ROUX 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015-2018 Ulule 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test lint 2 | 3 | test: 4 | @(scripts/test) 5 | 6 | lint: 7 | @(scripts/lint) 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Limiter 2 | 3 | [![Documentation][godoc-img]][godoc-url] 4 | ![License][license-img] 5 | [![Build Status][circle-img]][circle-url] 6 | [![Go Report Card][goreport-img]][goreport-url] 7 | 8 | _Dead simple rate limit middleware for Go._ 9 | 10 | - Simple API 11 | - "Store" approach for backend 12 | - Redis support (but not tied too) 13 | - Middlewares: HTTP, [FastHTTP][6] and [Gin][4] 14 | 15 | ## Installation 16 | 17 | Using [Go Modules](https://github.com/golang/go/wiki/Modules) 18 | 19 | ```bash 20 | $ go get github.com/ulule/limiter/v3@v3.11.2 21 | ``` 22 | 23 | ## Usage 24 | 25 | In five steps: 26 | 27 | - Create a `limiter.Rate` instance _(the number of requests per period)_ 28 | - Create a `limiter.Store` instance _(see [Redis](https://github.com/ulule/limiter/blob/master/drivers/store/redis/store.go) or [In-Memory](https://github.com/ulule/limiter/blob/master/drivers/store/memory/store.go))_ 29 | - Create a `limiter.Limiter` instance that takes store and rate instances as arguments 30 | - Create a middleware instance using the middleware of your choice 31 | - Give the limiter instance to your middleware initializer 32 | 33 | **Example:** 34 | 35 | ```go 36 | // Create a rate with the given limit (number of requests) for the given 37 | // period (a time.Duration of your choice). 38 | import "github.com/ulule/limiter/v3" 39 | 40 | rate := limiter.Rate{ 41 | Period: 1 * time.Hour, 42 | Limit: 1000, 43 | } 44 | 45 | // You can also use the simplified format "-"", with the given 46 | // periods: 47 | // 48 | // * "S": second 49 | // * "M": minute 50 | // * "H": hour 51 | // * "D": day 52 | // 53 | // Examples: 54 | // 55 | // * 5 reqs/second: "5-S" 56 | // * 10 reqs/minute: "10-M" 57 | // * 1000 reqs/hour: "1000-H" 58 | // * 2000 reqs/day: "2000-D" 59 | // 60 | rate, err := limiter.NewRateFromFormatted("1000-H") 61 | if err != nil { 62 | panic(err) 63 | } 64 | 65 | // Then, create a store. Here, we use the bundled Redis store. Any store 66 | // compliant to limiter.Store interface will do the job. The defaults are 67 | // "limiter" as Redis key prefix and a maximum of 3 retries for the key under 68 | // race condition. 69 | import "github.com/ulule/limiter/v3/drivers/store/redis" 70 | 71 | store, err := redis.NewStore(client) 72 | if err != nil { 73 | panic(err) 74 | } 75 | 76 | // Alternatively, you can pass options to the store with the "WithOptions" 77 | // function. For example, for Redis store: 78 | import "github.com/ulule/limiter/v3/drivers/store/redis" 79 | 80 | store, err := redis.NewStoreWithOptions(pool, limiter.StoreOptions{ 81 | Prefix: "your_own_prefix", 82 | }) 83 | if err != nil { 84 | panic(err) 85 | } 86 | 87 | // Or use a in-memory store with a goroutine which clears expired keys. 88 | import "github.com/ulule/limiter/v3/drivers/store/memory" 89 | 90 | store := memory.NewStore() 91 | 92 | // Then, create the limiter instance which takes the store and the rate as arguments. 93 | // Now, you can give this instance to any supported middleware. 94 | instance := limiter.New(store, rate) 95 | 96 | // Alternatively, you can pass options to the limiter instance with several options. 97 | instance := limiter.New(store, rate, limiter.WithClientIPHeader("True-Client-IP"), limiter.WithIPv6Mask(mask)) 98 | 99 | // Finally, give the limiter instance to your middleware initializer. 100 | import "github.com/ulule/limiter/v3/drivers/middleware/stdlib" 101 | 102 | middleware := stdlib.NewMiddleware(instance) 103 | ``` 104 | 105 | See middleware examples: 106 | 107 | - [HTTP](https://github.com/ulule/limiter-examples/tree/master/http/main.go) 108 | - [Gin](https://github.com/ulule/limiter-examples/tree/master/gin/main.go) 109 | - [Beego](https://github.com/ulule/limiter-examples/blob/master//beego/main.go) 110 | - [Chi](https://github.com/ulule/limiter-examples/tree/master/chi/main.go) 111 | - [Echo](https://github.com/ulule/limiter-examples/tree/master/echo/main.go) 112 | - [Fasthttp](https://github.com/ulule/limiter-examples/tree/master/fasthttp/main.go) 113 | 114 | ## How it works 115 | 116 | The ip address of the request is used as a key in the store. 117 | 118 | If the key does not exist in the store we set a default 119 | value with an expiration period. 120 | 121 | You will find two stores: 122 | 123 | - Redis: rely on [TTL](http://redis.io/commands/ttl) and incrementing the rate limit on each request. 124 | - In-Memory: rely on a fork of [go-cache](https://github.com/patrickmn/go-cache) with a goroutine to clear expired keys using a default interval. 125 | 126 | When the limit is reached, a `429` HTTP status code is sent. 127 | 128 | ## Limiter behind a reverse proxy 129 | 130 | ### Introduction 131 | 132 | If your limiter is behind a reverse proxy, it could be difficult to obtain the "real" client IP. 133 | 134 | Some reverse proxies, like AWS ALB, lets all header values through that it doesn't set itself. 135 | Like for example, `True-Client-IP` and `X-Real-IP`. 136 | Similarly, `X-Forwarded-For` is a list of comma-separated IPs that gets appended to by each traversed proxy. 137 | The idea is that the first IP _(added by the first proxy)_ is the true client IP. Each subsequent IP is another proxy along the path. 138 | 139 | An attacker can spoof either of those headers, which could be reported as a client IP. 140 | 141 | By default, limiter doesn't trust any of those headers: you have to explicitly enable them in order to use them. 142 | If you enable them, **you must always be aware** that any header added by any _(reverse)_ proxy not controlled 143 | by you **are completely unreliable.** 144 | 145 | ### X-Forwarded-For 146 | 147 | For example, if you make this request to your load balancer: 148 | ```bash 149 | curl -X POST https://example.com/login -H "X-Forwarded-For: 1.2.3.4, 11.22.33.44" 150 | ``` 151 | 152 | And your server behind the load balancer obtain this: 153 | ``` 154 | X-Forwarded-For: 1.2.3.4, 11.22.33.44, 155 | ``` 156 | 157 | That's mean you can't use `X-Forwarded-For` header, because it's **unreliable** and **untrustworthy**. 158 | So keep `TrustForwardHeader` disabled in your limiter option. 159 | 160 | However, if you have configured your reverse proxy to always remove/overwrite `X-Forwarded-For` and/or `X-Real-IP` headers 161 | so that if you execute this _(same)_ request: 162 | ```bash 163 | curl -X POST https://example.com/login -H "X-Forwarded-For: 1.2.3.4, 11.22.33.44" 164 | ``` 165 | 166 | And your server behind the load balancer obtain this: 167 | ``` 168 | X-Forwarded-For: 169 | ``` 170 | 171 | Then, you can enable `TrustForwardHeader` in your limiter option. 172 | 173 | ### Custom header 174 | 175 | Many CDN and Cloud providers add a custom header to define the client IP. Like for example, this non exhaustive list: 176 | 177 | * `Fastly-Client-IP` from Fastly 178 | * `CF-Connecting-IP` from Cloudflare 179 | * `X-Azure-ClientIP` from Azure 180 | 181 | You can use these headers using `ClientIPHeader` in your limiter option. 182 | 183 | ### None of the above 184 | 185 | If none of the above solution are working, please use a custom `KeyGetter` in your middleware. 186 | 187 | You can use this excellent article to help you define the best strategy depending on your network topology and your security need: 188 | https://adam-p.ca/blog/2022/03/x-forwarded-for/ 189 | 190 | If you have any idea/suggestions on how we could simplify this steps, don't hesitate to raise an issue. 191 | We would like some feedback on how we could implement this steps in the Limiter API. 192 | 193 | Thank you. 194 | 195 | ## Why Yet Another Package 196 | 197 | You could ask us: why yet another rate limit package? 198 | 199 | Because existing packages did not suit our needs. 200 | 201 | We tried a lot of alternatives: 202 | 203 | 1. [Throttled][1]. This package uses the generic cell-rate algorithm. To cite the 204 | documentation: _"The algorithm has been slightly modified from its usual form to 205 | support limiting with an additional quantity parameter, such as for limiting the 206 | number of bytes uploaded"_. It is brilliant in term of algorithm but 207 | documentation is quite unclear at the moment, we don't need _burst_ feature for 208 | now, impossible to get a correct `After-Retry` (when limit exceeds, we can still 209 | make a few requests, because of the max burst) and it only supports `http.Handler` 210 | middleware (we use [Gin][4]). Currently, we only need to return `429` 211 | and `X-Ratelimit-*` headers for `n reqs/duration`. 212 | 213 | 2. [Speedbump][3]. Good package but maybe too lightweight. No `Reset` support, 214 | only one middleware for [Gin][4] framework and too Redis-coupled. We rather 215 | prefer to use a "store" approach. 216 | 217 | 3. [Tollbooth][5]. Good one too but does both too much and too little. It limits by 218 | remote IP, path, methods, custom headers and basic auth usernames... but does not 219 | provide any Redis support (only _in-memory_) and a ready-to-go middleware that sets 220 | `X-Ratelimit-*` headers. `tollbooth.LimitByRequest(limiter, r)` only returns an HTTP 221 | code. 222 | 223 | 4. [ratelimit][2]. Probably the closer to our needs but, once again, too 224 | lightweight, no middleware available and not active (last commit was in August 225 | 2014). Some parts of code (Redis) comes from this project. It should deserve much 226 | more love. 227 | 228 | There are other many packages on GitHub but most are either too lightweight, too 229 | old (only support old Go versions) or unmaintained. So that's why we decided to 230 | create yet another one. 231 | 232 | ## Contributing 233 | 234 | - Ping us on twitter: 235 | - [@oibafsellig](https://twitter.com/oibafsellig) 236 | - [@thoas](https://twitter.com/thoas) 237 | - [@novln\_](https://twitter.com/novln_) 238 | - Fork the [project](https://github.com/ulule/limiter) 239 | - Fix [bugs](https://github.com/ulule/limiter/issues) 240 | 241 | Don't hesitate ;) 242 | 243 | [1]: https://github.com/throttled/throttled 244 | [2]: https://github.com/r8k/ratelimit 245 | [3]: https://github.com/etcinit/speedbump 246 | [4]: https://github.com/gin-gonic/gin 247 | [5]: https://github.com/didip/tollbooth 248 | [6]: https://github.com/valyala/fasthttp 249 | [godoc-url]: https://pkg.go.dev/github.com/ulule/limiter/v3 250 | [godoc-img]: https://pkg.go.dev/badge/github.com/ulule/limiter/v3 251 | [license-img]: https://img.shields.io/badge/license-MIT-blue.svg 252 | [goreport-url]: https://goreportcard.com/report/github.com/ulule/limiter 253 | [goreport-img]: https://goreportcard.com/badge/github.com/ulule/limiter 254 | [circle-url]: https://circleci.com/gh/ulule/limiter/tree/master 255 | [circle-img]: https://circleci.com/gh/ulule/limiter.svg?style=shield&circle-token=baf62ec320dd871b3a4a7e67fa99530fbc877c99 256 | -------------------------------------------------------------------------------- /defaults.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import "time" 4 | 5 | const ( 6 | // DefaultPrefix is the default prefix to use for the key in the store. 7 | DefaultPrefix = "limiter" 8 | 9 | // DefaultMaxRetry is the default maximum number of key retries under 10 | // race condition (mainly used with database-based stores). 11 | DefaultMaxRetry = 3 12 | 13 | // DefaultCleanUpInterval is the default time duration for cleanup. 14 | DefaultCleanUpInterval = 30 * time.Second 15 | ) 16 | -------------------------------------------------------------------------------- /drivers/middleware/fasthttp/middleware.go: -------------------------------------------------------------------------------- 1 | package fasthttp 2 | 3 | import ( 4 | "github.com/ulule/limiter/v3" 5 | "github.com/valyala/fasthttp" 6 | "strconv" 7 | ) 8 | 9 | // Middleware is the middleware for fasthttp. 10 | type Middleware struct { 11 | Limiter *limiter.Limiter 12 | OnError ErrorHandler 13 | OnLimitReached LimitReachedHandler 14 | KeyGetter KeyGetter 15 | ExcludedKey func(string) bool 16 | } 17 | 18 | // NewMiddleware return a new instance of a fasthttp middleware. 19 | func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware { 20 | middleware := &Middleware{ 21 | Limiter: limiter, 22 | OnError: DefaultErrorHandler, 23 | OnLimitReached: DefaultLimitReachedHandler, 24 | KeyGetter: DefaultKeyGetter, 25 | ExcludedKey: nil, 26 | } 27 | 28 | for _, option := range options { 29 | option.apply(middleware) 30 | } 31 | 32 | return middleware 33 | } 34 | 35 | // Handle fasthttp request. 36 | func (middleware *Middleware) Handle(next fasthttp.RequestHandler) fasthttp.RequestHandler { 37 | return func(ctx *fasthttp.RequestCtx) { 38 | key := middleware.KeyGetter(ctx) 39 | if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) { 40 | next(ctx) 41 | return 42 | } 43 | 44 | context, err := middleware.Limiter.Get(ctx, key) 45 | if err != nil { 46 | middleware.OnError(ctx, err) 47 | return 48 | } 49 | 50 | ctx.Response.Header.Set("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10)) 51 | ctx.Response.Header.Set("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10)) 52 | ctx.Response.Header.Set("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10)) 53 | 54 | if context.Reached { 55 | middleware.OnLimitReached(ctx) 56 | return 57 | } 58 | 59 | next(ctx) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /drivers/middleware/fasthttp/middleware_test.go: -------------------------------------------------------------------------------- 1 | package fasthttp_test 2 | 3 | import ( 4 | "net" 5 | "strconv" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | libfasthttp "github.com/valyala/fasthttp" 12 | "github.com/valyala/fasthttp/fasthttputil" 13 | 14 | "github.com/ulule/limiter/v3" 15 | "github.com/ulule/limiter/v3/drivers/middleware/fasthttp" 16 | "github.com/ulule/limiter/v3/drivers/store/memory" 17 | ) 18 | 19 | // nolint: gocyclo 20 | func TestFasthttpMiddleware(t *testing.T) { 21 | is := require.New(t) 22 | 23 | store := memory.NewStore() 24 | is.NotZero(store) 25 | 26 | rate, err := limiter.NewRateFromFormatted("10-M") 27 | is.NoError(err) 28 | is.NotZero(rate) 29 | 30 | middleware := fasthttp.NewMiddleware(limiter.New(store, rate)) 31 | 32 | requestHandler := func(ctx *libfasthttp.RequestCtx) { 33 | switch string(ctx.Path()) { 34 | case "/": 35 | ctx.SetStatusCode(libfasthttp.StatusOK) 36 | ctx.SetBodyString("hello") 37 | } 38 | } 39 | 40 | success := int64(10) 41 | clients := int64(100) 42 | 43 | // 44 | // Sequential 45 | // 46 | 47 | for i := int64(1); i <= clients; i++ { 48 | resp := libfasthttp.AcquireResponse() 49 | req := libfasthttp.AcquireRequest() 50 | req.Header.SetHost("localhost:8081") 51 | req.Header.SetRequestURI("/") 52 | err := serve(middleware.Handle(requestHandler), req, resp) 53 | is.NoError(err) 54 | 55 | if i <= success { 56 | is.Equal(resp.StatusCode(), libfasthttp.StatusOK) 57 | } else { 58 | is.Equal(resp.StatusCode(), libfasthttp.StatusTooManyRequests) 59 | } 60 | } 61 | 62 | // 63 | // Concurrent 64 | // 65 | 66 | store = memory.NewStore() 67 | is.NotZero(store) 68 | 69 | middleware = fasthttp.NewMiddleware(limiter.New(store, rate)) 70 | 71 | requestHandler = func(ctx *libfasthttp.RequestCtx) { 72 | switch string(ctx.Path()) { 73 | case "/": 74 | ctx.SetStatusCode(libfasthttp.StatusOK) 75 | ctx.SetBodyString("hello") 76 | } 77 | } 78 | 79 | wg := &sync.WaitGroup{} 80 | counter := int64(0) 81 | 82 | for i := int64(1); i <= clients; i++ { 83 | wg.Add(1) 84 | 85 | go func() { 86 | resp := libfasthttp.AcquireResponse() 87 | req := libfasthttp.AcquireRequest() 88 | req.Header.SetHost("localhost:8081") 89 | req.Header.SetRequestURI("/") 90 | err := serve(middleware.Handle(requestHandler), req, resp) 91 | is.NoError(err) 92 | 93 | if resp.StatusCode() == libfasthttp.StatusOK { 94 | atomic.AddInt64(&counter, 1) 95 | } 96 | 97 | wg.Done() 98 | }() 99 | } 100 | 101 | wg.Wait() 102 | is.Equal(success, atomic.LoadInt64(&counter)) 103 | 104 | // 105 | // Custom KeyGetter 106 | // 107 | 108 | store = memory.NewStore() 109 | is.NotZero(store) 110 | 111 | counter = int64(0) 112 | keyGetter := func(c *libfasthttp.RequestCtx) string { 113 | v := atomic.AddInt64(&counter, 1) 114 | return strconv.FormatInt(v, 10) 115 | } 116 | 117 | middleware = fasthttp.NewMiddleware(limiter.New(store, rate), fasthttp.WithKeyGetter(keyGetter)) 118 | is.NotZero(middleware) 119 | 120 | requestHandler = func(ctx *libfasthttp.RequestCtx) { 121 | switch string(ctx.Path()) { 122 | case "/": 123 | ctx.SetStatusCode(libfasthttp.StatusOK) 124 | ctx.SetBodyString("hello") 125 | } 126 | } 127 | 128 | for i := int64(1); i <= clients; i++ { 129 | resp := libfasthttp.AcquireResponse() 130 | req := libfasthttp.AcquireRequest() 131 | req.Header.SetHost("localhost:8081") 132 | req.Header.SetRequestURI("/") 133 | err := serve(middleware.Handle(requestHandler), req, resp) 134 | is.NoError(err) 135 | is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10)) 136 | } 137 | 138 | // 139 | // Test ExcludedKey 140 | // 141 | 142 | store = memory.NewStore() 143 | is.NotZero(store) 144 | 145 | counter = int64(0) 146 | keyGetterHandler := func(c *libfasthttp.RequestCtx) string { 147 | v := atomic.AddInt64(&counter, 1) 148 | return strconv.FormatInt(v%2, 10) 149 | } 150 | excludedKeyHandler := func(key string) bool { 151 | return key == "1" 152 | } 153 | 154 | middleware = fasthttp.NewMiddleware(limiter.New(store, rate), 155 | fasthttp.WithKeyGetter(keyGetterHandler), fasthttp.WithExcludedKey(excludedKeyHandler)) 156 | is.NotZero(middleware) 157 | 158 | requestHandler = func(ctx *libfasthttp.RequestCtx) { 159 | switch string(ctx.Path()) { 160 | case "/": 161 | ctx.SetStatusCode(libfasthttp.StatusOK) 162 | ctx.SetBodyString("hello") 163 | } 164 | } 165 | 166 | success = 20 167 | for i := int64(1); i <= clients; i++ { 168 | resp := libfasthttp.AcquireResponse() 169 | req := libfasthttp.AcquireRequest() 170 | req.Header.SetHost("localhost:8081") 171 | req.Header.SetRequestURI("/") 172 | err := serve(middleware.Handle(requestHandler), req, resp) 173 | is.NoError(err) 174 | if i <= success || i%2 == 1 { 175 | is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10)) 176 | } else { 177 | is.Equal(libfasthttp.StatusTooManyRequests, resp.StatusCode(), strconv.FormatInt(i, 10)) 178 | } 179 | } 180 | } 181 | 182 | func serve(handler libfasthttp.RequestHandler, req *libfasthttp.Request, res *libfasthttp.Response) error { 183 | ln := fasthttputil.NewInmemoryListener() 184 | defer func() { 185 | err := ln.Close() 186 | if err != nil { 187 | panic(err) 188 | } 189 | }() 190 | 191 | go func() { 192 | err := libfasthttp.Serve(ln, handler) 193 | if err != nil { 194 | panic(err) 195 | } 196 | }() 197 | 198 | client := libfasthttp.Client{ 199 | Dial: func(addr string) (net.Conn, error) { 200 | return ln.Dial() 201 | }, 202 | } 203 | 204 | return client.Do(req, res) 205 | } 206 | -------------------------------------------------------------------------------- /drivers/middleware/fasthttp/options.go: -------------------------------------------------------------------------------- 1 | package fasthttp 2 | 3 | import ( 4 | "github.com/valyala/fasthttp" 5 | ) 6 | 7 | // Option is used to define Middleware configuration. 8 | type Option interface { 9 | apply(middleware *Middleware) 10 | } 11 | 12 | type option func(*Middleware) 13 | 14 | func (o option) apply(middleware *Middleware) { 15 | o(middleware) 16 | } 17 | 18 | // ErrorHandler is an handler used to inform when an error has occurred. 19 | type ErrorHandler func(ctx *fasthttp.RequestCtx, err error) 20 | 21 | // WithErrorHandler will configure the Middleware to use the given ErrorHandler. 22 | func WithErrorHandler(handler ErrorHandler) Option { 23 | return option(func(middleware *Middleware) { 24 | middleware.OnError = handler 25 | }) 26 | } 27 | 28 | // DefaultErrorHandler is the default ErrorHandler used by a new Middleware. 29 | func DefaultErrorHandler(ctx *fasthttp.RequestCtx, err error) { 30 | panic(err) 31 | } 32 | 33 | // LimitReachedHandler is an handler used to inform when the limit has exceeded. 34 | type LimitReachedHandler func(ctx *fasthttp.RequestCtx) 35 | 36 | // WithLimitReachedHandler will configure the Middleware to use the given LimitReachedHandler. 37 | func WithLimitReachedHandler(handler LimitReachedHandler) Option { 38 | return option(func(middleware *Middleware) { 39 | middleware.OnLimitReached = handler 40 | }) 41 | } 42 | 43 | // DefaultLimitReachedHandler is the default LimitReachedHandler used by a new Middleware. 44 | func DefaultLimitReachedHandler(ctx *fasthttp.RequestCtx) { 45 | ctx.SetStatusCode(fasthttp.StatusTooManyRequests) 46 | ctx.Response.SetBodyString("Limit exceeded") 47 | } 48 | 49 | // KeyGetter will define the rate limiter key given the fasthttp Context. 50 | type KeyGetter func(ctx *fasthttp.RequestCtx) string 51 | 52 | // WithKeyGetter will configure the Middleware to use the given KeyGetter. 53 | func WithKeyGetter(KeyGetter KeyGetter) Option { 54 | return option(func(middleware *Middleware) { 55 | middleware.KeyGetter = KeyGetter 56 | }) 57 | } 58 | 59 | // DefaultKeyGetter is the default KeyGetter used by a new Middleware. 60 | // It returns the Client IP address. 61 | func DefaultKeyGetter(ctx *fasthttp.RequestCtx) string { 62 | return ctx.RemoteIP().String() 63 | } 64 | 65 | // WithExcludedKey will configure the Middleware to ignore key(s) using the given function. 66 | func WithExcludedKey(handler func(string) bool) Option { 67 | return option(func(middleware *Middleware) { 68 | middleware.ExcludedKey = handler 69 | }) 70 | } 71 | -------------------------------------------------------------------------------- /drivers/middleware/gin/middleware.go: -------------------------------------------------------------------------------- 1 | package gin 2 | 3 | import ( 4 | "strconv" 5 | 6 | "github.com/gin-gonic/gin" 7 | 8 | "github.com/ulule/limiter/v3" 9 | ) 10 | 11 | // Middleware is the middleware for gin. 12 | type Middleware struct { 13 | Limiter *limiter.Limiter 14 | OnError ErrorHandler 15 | OnLimitReached LimitReachedHandler 16 | KeyGetter KeyGetter 17 | ExcludedKey func(string) bool 18 | } 19 | 20 | // NewMiddleware return a new instance of a gin middleware. 21 | func NewMiddleware(limiter *limiter.Limiter, options ...Option) gin.HandlerFunc { 22 | middleware := &Middleware{ 23 | Limiter: limiter, 24 | OnError: DefaultErrorHandler, 25 | OnLimitReached: DefaultLimitReachedHandler, 26 | KeyGetter: DefaultKeyGetter, 27 | ExcludedKey: nil, 28 | } 29 | 30 | for _, option := range options { 31 | option.apply(middleware) 32 | } 33 | 34 | return func(ctx *gin.Context) { 35 | middleware.Handle(ctx) 36 | } 37 | } 38 | 39 | // Handle gin request. 40 | func (middleware *Middleware) Handle(c *gin.Context) { 41 | key := middleware.KeyGetter(c) 42 | if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) { 43 | c.Next() 44 | return 45 | } 46 | 47 | context, err := middleware.Limiter.Get(c, key) 48 | if err != nil { 49 | middleware.OnError(c, err) 50 | c.Abort() 51 | return 52 | } 53 | 54 | c.Header("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10)) 55 | c.Header("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10)) 56 | c.Header("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10)) 57 | 58 | if context.Reached { 59 | middleware.OnLimitReached(c) 60 | c.Abort() 61 | return 62 | } 63 | 64 | c.Next() 65 | } 66 | -------------------------------------------------------------------------------- /drivers/middleware/gin/middleware_test.go: -------------------------------------------------------------------------------- 1 | package gin_test 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strconv" 7 | "sync" 8 | "sync/atomic" 9 | "testing" 10 | 11 | libgin "github.com/gin-gonic/gin" 12 | "github.com/stretchr/testify/require" 13 | 14 | "github.com/ulule/limiter/v3" 15 | "github.com/ulule/limiter/v3/drivers/middleware/gin" 16 | "github.com/ulule/limiter/v3/drivers/store/memory" 17 | ) 18 | 19 | func TestHTTPMiddleware(t *testing.T) { 20 | is := require.New(t) 21 | libgin.SetMode(libgin.TestMode) 22 | 23 | request, err := http.NewRequest("GET", "/", nil) 24 | is.NoError(err) 25 | is.NotNil(request) 26 | 27 | store := memory.NewStore() 28 | is.NotZero(store) 29 | 30 | rate, err := limiter.NewRateFromFormatted("10-M") 31 | is.NoError(err) 32 | is.NotZero(rate) 33 | 34 | middleware := gin.NewMiddleware(limiter.New(store, rate)) 35 | is.NotZero(middleware) 36 | 37 | router := libgin.New() 38 | router.Use(middleware) 39 | router.GET("/", func(c *libgin.Context) { 40 | c.String(http.StatusOK, "hello") 41 | }) 42 | 43 | success := int64(10) 44 | clients := int64(100) 45 | 46 | // 47 | // Sequential 48 | // 49 | 50 | for i := int64(1); i <= clients; i++ { 51 | 52 | resp := httptest.NewRecorder() 53 | router.ServeHTTP(resp, request) 54 | 55 | if i <= success { 56 | is.Equal(resp.Code, http.StatusOK) 57 | } else { 58 | is.Equal(resp.Code, http.StatusTooManyRequests) 59 | } 60 | } 61 | 62 | // 63 | // Concurrent 64 | // 65 | 66 | store = memory.NewStore() 67 | is.NotZero(store) 68 | 69 | middleware = gin.NewMiddleware(limiter.New(store, rate)) 70 | is.NotZero(middleware) 71 | 72 | router = libgin.New() 73 | router.Use(middleware) 74 | router.GET("/", func(c *libgin.Context) { 75 | c.String(http.StatusOK, "hello") 76 | }) 77 | 78 | wg := &sync.WaitGroup{} 79 | counter := int64(0) 80 | 81 | for i := int64(1); i <= clients; i++ { 82 | wg.Add(1) 83 | go func() { 84 | 85 | resp := httptest.NewRecorder() 86 | router.ServeHTTP(resp, request) 87 | 88 | if resp.Code == http.StatusOK { 89 | atomic.AddInt64(&counter, 1) 90 | } 91 | 92 | wg.Done() 93 | }() 94 | } 95 | 96 | wg.Wait() 97 | is.Equal(success, atomic.LoadInt64(&counter)) 98 | 99 | // 100 | // Custom KeyGetter 101 | // 102 | 103 | store = memory.NewStore() 104 | is.NotZero(store) 105 | 106 | counter = int64(0) 107 | keyGetter := func(c *libgin.Context) string { 108 | v := atomic.AddInt64(&counter, 1) 109 | return strconv.FormatInt(v, 10) 110 | } 111 | 112 | middleware = gin.NewMiddleware(limiter.New(store, rate), gin.WithKeyGetter(keyGetter)) 113 | is.NotZero(middleware) 114 | 115 | router = libgin.New() 116 | router.Use(middleware) 117 | router.GET("/", func(c *libgin.Context) { 118 | c.String(http.StatusOK, "hello") 119 | }) 120 | 121 | for i := int64(1); i <= clients; i++ { 122 | resp := httptest.NewRecorder() 123 | router.ServeHTTP(resp, request) 124 | // We should always be ok as the key changes for each request 125 | is.Equal(http.StatusOK, resp.Code, strconv.FormatInt(i, 10)) 126 | } 127 | 128 | // 129 | // Test ExcludedKey 130 | // 131 | store = memory.NewStore() 132 | is.NotZero(store) 133 | counter = int64(0) 134 | excludedKeyFn := func(key string) bool { 135 | return key == "1" 136 | } 137 | middleware = gin.NewMiddleware(limiter.New(store, rate), 138 | gin.WithKeyGetter(func(c *libgin.Context) string { 139 | v := atomic.AddInt64(&counter, 1) 140 | return strconv.FormatInt(v%2, 10) 141 | }), 142 | gin.WithExcludedKey(excludedKeyFn), 143 | ) 144 | is.NotZero(middleware) 145 | 146 | router = libgin.New() 147 | router.Use(middleware) 148 | router.GET("/", func(c *libgin.Context) { 149 | c.String(http.StatusOK, "hello") 150 | }) 151 | success = 20 152 | for i := int64(1); i < clients; i++ { 153 | resp := httptest.NewRecorder() 154 | router.ServeHTTP(resp, request) 155 | if i <= success || i%2 == 1 { 156 | is.Equal(http.StatusOK, resp.Code, strconv.FormatInt(i, 10)) 157 | } else { 158 | is.Equal(resp.Code, http.StatusTooManyRequests) 159 | } 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /drivers/middleware/gin/options.go: -------------------------------------------------------------------------------- 1 | package gin 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/gin-gonic/gin" 7 | ) 8 | 9 | // Option is used to define Middleware configuration. 10 | type Option interface { 11 | apply(*Middleware) 12 | } 13 | 14 | type option func(*Middleware) 15 | 16 | func (o option) apply(middleware *Middleware) { 17 | o(middleware) 18 | } 19 | 20 | // ErrorHandler is an handler used to inform when an error has occurred. 21 | type ErrorHandler func(c *gin.Context, err error) 22 | 23 | // WithErrorHandler will configure the Middleware to use the given ErrorHandler. 24 | func WithErrorHandler(handler ErrorHandler) Option { 25 | return option(func(middleware *Middleware) { 26 | middleware.OnError = handler 27 | }) 28 | } 29 | 30 | // DefaultErrorHandler is the default ErrorHandler used by a new Middleware. 31 | func DefaultErrorHandler(c *gin.Context, err error) { 32 | panic(err) 33 | } 34 | 35 | // LimitReachedHandler is an handler used to inform when the limit has exceeded. 36 | type LimitReachedHandler func(c *gin.Context) 37 | 38 | // WithLimitReachedHandler will configure the Middleware to use the given LimitReachedHandler. 39 | func WithLimitReachedHandler(handler LimitReachedHandler) Option { 40 | return option(func(middleware *Middleware) { 41 | middleware.OnLimitReached = handler 42 | }) 43 | } 44 | 45 | // DefaultLimitReachedHandler is the default LimitReachedHandler used by a new Middleware. 46 | func DefaultLimitReachedHandler(c *gin.Context) { 47 | c.String(http.StatusTooManyRequests, "Limit exceeded") 48 | } 49 | 50 | // KeyGetter will define the rate limiter key given the gin Context. 51 | type KeyGetter func(c *gin.Context) string 52 | 53 | // WithKeyGetter will configure the Middleware to use the given KeyGetter. 54 | func WithKeyGetter(handler KeyGetter) Option { 55 | return option(func(middleware *Middleware) { 56 | middleware.KeyGetter = handler 57 | }) 58 | } 59 | 60 | // DefaultKeyGetter is the default KeyGetter used by a new Middleware. 61 | // It returns the Client IP address. 62 | func DefaultKeyGetter(c *gin.Context) string { 63 | return c.ClientIP() 64 | } 65 | 66 | // WithExcludedKey will configure the Middleware to ignore key(s) using the given function. 67 | func WithExcludedKey(handler func(string) bool) Option { 68 | return option(func(middleware *Middleware) { 69 | middleware.ExcludedKey = handler 70 | }) 71 | } 72 | -------------------------------------------------------------------------------- /drivers/middleware/stdlib/middleware.go: -------------------------------------------------------------------------------- 1 | package stdlib 2 | 3 | import ( 4 | "net/http" 5 | "strconv" 6 | 7 | "github.com/ulule/limiter/v3" 8 | ) 9 | 10 | // Middleware is the middleware for basic http.Handler. 11 | type Middleware struct { 12 | Limiter *limiter.Limiter 13 | OnError ErrorHandler 14 | OnLimitReached LimitReachedHandler 15 | KeyGetter KeyGetter 16 | ExcludedKey func(string) bool 17 | } 18 | 19 | // NewMiddleware return a new instance of a basic HTTP middleware. 20 | func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware { 21 | middleware := &Middleware{ 22 | Limiter: limiter, 23 | OnError: DefaultErrorHandler, 24 | OnLimitReached: DefaultLimitReachedHandler, 25 | KeyGetter: DefaultKeyGetter(limiter), 26 | ExcludedKey: nil, 27 | } 28 | 29 | for _, option := range options { 30 | option.apply(middleware) 31 | } 32 | 33 | return middleware 34 | } 35 | 36 | // Handler handles a HTTP request. 37 | func (middleware *Middleware) Handler(h http.Handler) http.Handler { 38 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 39 | key := middleware.KeyGetter(r) 40 | if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) { 41 | h.ServeHTTP(w, r) 42 | return 43 | } 44 | 45 | context, err := middleware.Limiter.Get(r.Context(), key) 46 | if err != nil { 47 | middleware.OnError(w, r, err) 48 | return 49 | } 50 | 51 | w.Header().Add("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10)) 52 | w.Header().Add("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10)) 53 | w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10)) 54 | 55 | if context.Reached { 56 | middleware.OnLimitReached(w, r) 57 | return 58 | } 59 | 60 | h.ServeHTTP(w, r) 61 | }) 62 | } 63 | -------------------------------------------------------------------------------- /drivers/middleware/stdlib/middleware_test.go: -------------------------------------------------------------------------------- 1 | package stdlib_test 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | 12 | "github.com/ulule/limiter/v3" 13 | "github.com/ulule/limiter/v3/drivers/middleware/stdlib" 14 | "github.com/ulule/limiter/v3/drivers/store/memory" 15 | ) 16 | 17 | func TestHTTPMiddleware(t *testing.T) { 18 | is := require.New(t) 19 | 20 | request, err := http.NewRequest("GET", "/", nil) 21 | is.NoError(err) 22 | is.NotNil(request) 23 | 24 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 25 | _, thr := w.Write([]byte("hello")) 26 | if thr != nil { 27 | panic(thr) 28 | } 29 | }) 30 | 31 | store := memory.NewStore() 32 | is.NotZero(store) 33 | 34 | rate, err := limiter.NewRateFromFormatted("10-M") 35 | is.NoError(err) 36 | is.NotZero(rate) 37 | 38 | middleware := stdlib.NewMiddleware(limiter.New(store, rate)).Handler(handler) 39 | is.NotZero(middleware) 40 | 41 | success := int64(10) 42 | clients := int64(100) 43 | 44 | // 45 | // Sequential 46 | // 47 | 48 | for i := int64(1); i <= clients; i++ { 49 | 50 | resp := httptest.NewRecorder() 51 | middleware.ServeHTTP(resp, request) 52 | 53 | if i <= success { 54 | is.Equal(resp.Code, http.StatusOK) 55 | } else { 56 | is.Equal(resp.Code, http.StatusTooManyRequests) 57 | } 58 | } 59 | 60 | // 61 | // Concurrent 62 | // 63 | 64 | store = memory.NewStore() 65 | is.NotZero(store) 66 | 67 | middleware = stdlib.NewMiddleware(limiter.New(store, rate)).Handler(handler) 68 | is.NotZero(middleware) 69 | 70 | wg := &sync.WaitGroup{} 71 | counter := int64(0) 72 | 73 | for i := int64(1); i <= clients; i++ { 74 | wg.Add(1) 75 | go func() { 76 | 77 | resp := httptest.NewRecorder() 78 | middleware.ServeHTTP(resp, request) 79 | 80 | if resp.Code == http.StatusOK { 81 | atomic.AddInt64(&counter, 1) 82 | } 83 | 84 | wg.Done() 85 | }() 86 | } 87 | 88 | wg.Wait() 89 | is.Equal(success, atomic.LoadInt64(&counter)) 90 | 91 | } 92 | -------------------------------------------------------------------------------- /drivers/middleware/stdlib/options.go: -------------------------------------------------------------------------------- 1 | package stdlib 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/ulule/limiter/v3" 7 | ) 8 | 9 | // Option is used to define Middleware configuration. 10 | type Option interface { 11 | apply(*Middleware) 12 | } 13 | 14 | type option func(*Middleware) 15 | 16 | func (o option) apply(middleware *Middleware) { 17 | o(middleware) 18 | } 19 | 20 | // ErrorHandler is an handler used to inform when an error has occurred. 21 | type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error) 22 | 23 | // WithErrorHandler will configure the Middleware to use the given ErrorHandler. 24 | func WithErrorHandler(handler ErrorHandler) Option { 25 | return option(func(middleware *Middleware) { 26 | middleware.OnError = handler 27 | }) 28 | } 29 | 30 | // DefaultErrorHandler is the default ErrorHandler used by a new Middleware. 31 | func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) { 32 | panic(err) 33 | } 34 | 35 | // LimitReachedHandler is an handler used to inform when the limit has exceeded. 36 | type LimitReachedHandler func(w http.ResponseWriter, r *http.Request) 37 | 38 | // WithLimitReachedHandler will configure the Middleware to use the given LimitReachedHandler. 39 | func WithLimitReachedHandler(handler LimitReachedHandler) Option { 40 | return option(func(middleware *Middleware) { 41 | middleware.OnLimitReached = handler 42 | }) 43 | } 44 | 45 | // DefaultLimitReachedHandler is the default LimitReachedHandler used by a new Middleware. 46 | func DefaultLimitReachedHandler(w http.ResponseWriter, r *http.Request) { 47 | http.Error(w, "Limit exceeded", http.StatusTooManyRequests) 48 | } 49 | 50 | // KeyGetter will define the rate limiter key given the gin Context. 51 | type KeyGetter func(r *http.Request) string 52 | 53 | // WithKeyGetter will configure the Middleware to use the given KeyGetter. 54 | func WithKeyGetter(handler KeyGetter) Option { 55 | return option(func(middleware *Middleware) { 56 | middleware.KeyGetter = handler 57 | }) 58 | } 59 | 60 | // DefaultKeyGetter is the default KeyGetter used by a new Middleware. 61 | // It returns the Client IP address. 62 | func DefaultKeyGetter(limiter *limiter.Limiter) func(r *http.Request) string { 63 | return func(r *http.Request) string { 64 | return limiter.GetIPKey(r) 65 | } 66 | } 67 | 68 | // WithExcludedKey will configure the Middleware to ignore key(s) using the given function. 69 | func WithExcludedKey(handler func(string) bool) Option { 70 | return option(func(middleware *Middleware) { 71 | middleware.ExcludedKey = handler 72 | }) 73 | } 74 | -------------------------------------------------------------------------------- /drivers/store/common/context.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/ulule/limiter/v3" 7 | ) 8 | 9 | // GetContextFromState generate a new limiter.Context from given state. 10 | func GetContextFromState(now time.Time, rate limiter.Rate, expiration time.Time, count int64) limiter.Context { 11 | limit := rate.Limit 12 | remaining := int64(0) 13 | reached := true 14 | 15 | if count <= limit { 16 | remaining = limit - count 17 | reached = false 18 | } 19 | 20 | reset := expiration.Unix() 21 | 22 | return limiter.Context{ 23 | Limit: limit, 24 | Remaining: remaining, 25 | Reset: reset, 26 | Reached: reached, 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /drivers/store/memory/cache.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import ( 4 | "runtime" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // Forked from https://github.com/patrickmn/go-cache 10 | 11 | // CacheWrapper is used to ensure that the underlying cleaner goroutine used to clean expired keys will not prevent 12 | // Cache from being garbage collected. 13 | type CacheWrapper struct { 14 | *Cache 15 | } 16 | 17 | // A cleaner will periodically delete expired keys from cache. 18 | type cleaner struct { 19 | interval time.Duration 20 | stop chan bool 21 | } 22 | 23 | // Run will periodically delete expired keys from given cache until GC notify that it should stop. 24 | func (cleaner *cleaner) Run(cache *Cache) { 25 | ticker := time.NewTicker(cleaner.interval) 26 | for { 27 | select { 28 | case <-ticker.C: 29 | cache.Clean() 30 | case <-cleaner.stop: 31 | ticker.Stop() 32 | return 33 | } 34 | } 35 | } 36 | 37 | // stopCleaner is a callback from GC used to stop cleaner goroutine. 38 | func stopCleaner(wrapper *CacheWrapper) { 39 | wrapper.cleaner.stop <- true 40 | wrapper.cleaner = nil 41 | } 42 | 43 | // startCleaner will start a cleaner goroutine for given cache. 44 | func startCleaner(cache *Cache, interval time.Duration) { 45 | cleaner := &cleaner{ 46 | interval: interval, 47 | stop: make(chan bool), 48 | } 49 | 50 | cache.cleaner = cleaner 51 | go cleaner.Run(cache) 52 | } 53 | 54 | // Counter is a simple counter with an expiration. 55 | type Counter struct { 56 | mutex sync.RWMutex 57 | value int64 58 | expiration int64 59 | } 60 | 61 | // Value returns the counter current value. 62 | func (counter *Counter) Value() int64 { 63 | counter.mutex.RLock() 64 | defer counter.mutex.RUnlock() 65 | return counter.value 66 | } 67 | 68 | // Expiration returns the counter expiration. 69 | func (counter *Counter) Expiration() int64 { 70 | counter.mutex.RLock() 71 | defer counter.mutex.RUnlock() 72 | return counter.expiration 73 | } 74 | 75 | // Expired returns true if the counter has expired. 76 | func (counter *Counter) Expired() bool { 77 | counter.mutex.RLock() 78 | defer counter.mutex.RUnlock() 79 | 80 | return counter.expiration == 0 || time.Now().UnixNano() > counter.expiration 81 | } 82 | 83 | // Load returns the value and the expiration of this counter. 84 | // If the counter is expired, it will use the given expiration. 85 | func (counter *Counter) Load(expiration int64) (int64, int64) { 86 | counter.mutex.RLock() 87 | defer counter.mutex.RUnlock() 88 | 89 | if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration { 90 | return 0, expiration 91 | } 92 | 93 | return counter.value, counter.expiration 94 | } 95 | 96 | // Increment increments given value on this counter. 97 | // If the counter is expired, it will use the given expiration. 98 | // It returns its current value and expiration. 99 | func (counter *Counter) Increment(value int64, expiration int64) (int64, int64) { 100 | counter.mutex.Lock() 101 | defer counter.mutex.Unlock() 102 | 103 | if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration { 104 | counter.value = value 105 | counter.expiration = expiration 106 | return counter.value, counter.expiration 107 | } 108 | 109 | counter.value += value 110 | return counter.value, counter.expiration 111 | } 112 | 113 | // Cache contains a collection of counters. 114 | type Cache struct { 115 | counters sync.Map 116 | cleaner *cleaner 117 | } 118 | 119 | // NewCache returns a new cache. 120 | func NewCache(cleanInterval time.Duration) *CacheWrapper { 121 | 122 | cache := &Cache{} 123 | wrapper := &CacheWrapper{Cache: cache} 124 | 125 | if cleanInterval > 0 { 126 | startCleaner(cache, cleanInterval) 127 | runtime.SetFinalizer(wrapper, stopCleaner) 128 | } 129 | 130 | return wrapper 131 | } 132 | 133 | // LoadOrStore returns the existing counter for the key if present. 134 | // Otherwise, it stores and returns the given counter. 135 | // The loaded result is true if the counter was loaded, false if stored. 136 | func (cache *Cache) LoadOrStore(key string, counter *Counter) (*Counter, bool) { 137 | val, loaded := cache.counters.LoadOrStore(key, counter) 138 | if val == nil { 139 | return counter, false 140 | } 141 | 142 | actual := val.(*Counter) 143 | return actual, loaded 144 | } 145 | 146 | // Load returns the counter stored in the map for a key, or nil if no counter is present. 147 | // The ok result indicates whether counter was found in the map. 148 | func (cache *Cache) Load(key string) (*Counter, bool) { 149 | val, ok := cache.counters.Load(key) 150 | if val == nil || !ok { 151 | return nil, false 152 | } 153 | actual := val.(*Counter) 154 | return actual, true 155 | } 156 | 157 | // Store sets the counter for a key. 158 | func (cache *Cache) Store(key string, counter *Counter) { 159 | cache.counters.Store(key, counter) 160 | } 161 | 162 | // Delete deletes the value for a key. 163 | func (cache *Cache) Delete(key string) { 164 | cache.counters.Delete(key) 165 | } 166 | 167 | // Range calls handler sequentially for each key and value present in the cache. 168 | // If handler returns false, range stops the iteration. 169 | func (cache *Cache) Range(handler func(key string, counter *Counter)) { 170 | cache.counters.Range(func(k interface{}, v interface{}) bool { 171 | if v == nil { 172 | return true 173 | } 174 | 175 | key := k.(string) 176 | counter := v.(*Counter) 177 | 178 | handler(key, counter) 179 | 180 | return true 181 | }) 182 | } 183 | 184 | // Increment increments given value on key. 185 | // If key is undefined or expired, it will create it. 186 | func (cache *Cache) Increment(key string, value int64, duration time.Duration) (int64, time.Time) { 187 | expiration := time.Now().Add(duration).UnixNano() 188 | 189 | // If counter is in cache, try to load it first. 190 | counter, loaded := cache.Load(key) 191 | if loaded { 192 | value, expiration = counter.Increment(value, expiration) 193 | return value, time.Unix(0, expiration) 194 | } 195 | 196 | // If it's not in cache, try to atomically create it. 197 | // We do that in two step to reduce memory allocation. 198 | counter, loaded = cache.LoadOrStore(key, &Counter{ 199 | mutex: sync.RWMutex{}, 200 | value: value, 201 | expiration: expiration, 202 | }) 203 | if loaded { 204 | value, expiration = counter.Increment(value, expiration) 205 | return value, time.Unix(0, expiration) 206 | } 207 | 208 | // Otherwise, it has been created, return given value. 209 | return value, time.Unix(0, expiration) 210 | } 211 | 212 | // Get returns key's value and expiration. 213 | func (cache *Cache) Get(key string, duration time.Duration) (int64, time.Time) { 214 | expiration := time.Now().Add(duration).UnixNano() 215 | 216 | counter, ok := cache.Load(key) 217 | if !ok { 218 | return 0, time.Unix(0, expiration) 219 | } 220 | 221 | value, expiration := counter.Load(expiration) 222 | return value, time.Unix(0, expiration) 223 | } 224 | 225 | // Clean will deleted any expired keys. 226 | func (cache *Cache) Clean() { 227 | cache.Range(func(key string, counter *Counter) { 228 | if counter.Expired() { 229 | cache.Delete(key) 230 | } 231 | }) 232 | } 233 | 234 | // Reset changes the key's value and resets the expiration. 235 | func (cache *Cache) Reset(key string, duration time.Duration) (int64, time.Time) { 236 | cache.Delete(key) 237 | 238 | expiration := time.Now().Add(duration).UnixNano() 239 | return 0, time.Unix(0, expiration) 240 | } 241 | -------------------------------------------------------------------------------- /drivers/store/memory/cache_test.go: -------------------------------------------------------------------------------- 1 | package memory_test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/ulule/limiter/v3/drivers/store/memory" 11 | ) 12 | 13 | func TestCacheIncrementSequential(t *testing.T) { 14 | is := require.New(t) 15 | 16 | key := "foobar" 17 | cache := memory.NewCache(10 * time.Nanosecond) 18 | duration := 50 * time.Millisecond 19 | deleted := time.Now().Add(duration).UnixNano() 20 | epsilon := 0.001 21 | 22 | x, expire := cache.Increment(key, 1, duration) 23 | is.Equal(int64(1), x) 24 | is.InEpsilon(deleted, expire.UnixNano(), epsilon) 25 | 26 | x, expire = cache.Increment(key, 2, duration) 27 | is.Equal(int64(3), x) 28 | is.InEpsilon(deleted, expire.UnixNano(), epsilon) 29 | 30 | time.Sleep(duration) 31 | 32 | deleted = time.Now().Add(duration).UnixNano() 33 | x, expire = cache.Increment(key, 1, duration) 34 | is.Equal(int64(1), x) 35 | is.InEpsilon(deleted, expire.UnixNano(), epsilon) 36 | } 37 | 38 | func TestCacheIncrementConcurrent(t *testing.T) { 39 | is := require.New(t) 40 | 41 | goroutines := 200 42 | ops := 500 43 | 44 | expected := int64(0) 45 | for i := 0; i < goroutines; i++ { 46 | if (i % 3) == 0 { 47 | for j := 0; j < ops; j++ { 48 | expected += int64(i + j) 49 | } 50 | } 51 | } 52 | 53 | key := "foobar" 54 | cache := memory.NewCache(10 * time.Nanosecond) 55 | 56 | wg := &sync.WaitGroup{} 57 | wg.Add(goroutines) 58 | 59 | for i := 0; i < goroutines; i++ { 60 | go func(i int) { 61 | if (i % 3) == 0 { 62 | time.Sleep(1 * time.Second) 63 | for j := 0; j < ops; j++ { 64 | cache.Increment(key, int64(i+j), (1 * time.Second)) 65 | } 66 | } else { 67 | time.Sleep(50 * time.Millisecond) 68 | stopAt := time.Now().Add(500 * time.Millisecond) 69 | for time.Now().Before(stopAt) { 70 | cache.Increment(key, int64(i), (75 * time.Millisecond)) 71 | } 72 | } 73 | wg.Done() 74 | }(i) 75 | } 76 | wg.Wait() 77 | 78 | value, expire := cache.Get(key, (100 * time.Millisecond)) 79 | is.Equal(expected, value) 80 | is.True(time.Now().Before(expire)) 81 | } 82 | 83 | func TestCacheGet(t *testing.T) { 84 | is := require.New(t) 85 | 86 | key := "foobar" 87 | cache := memory.NewCache(10 * time.Nanosecond) 88 | duration := 50 * time.Millisecond 89 | deleted := time.Now().Add(duration).UnixNano() 90 | epsilon := 0.001 91 | 92 | x, expire := cache.Get(key, duration) 93 | is.Equal(int64(0), x) 94 | is.InEpsilon(deleted, expire.UnixNano(), epsilon) 95 | } 96 | 97 | func TestCacheReset(t *testing.T) { 98 | is := require.New(t) 99 | 100 | key := "foobar" 101 | cache := memory.NewCache(10 * time.Nanosecond) 102 | duration := 50 * time.Millisecond 103 | deleted := time.Now().Add(duration).UnixNano() 104 | epsilon := 0.001 105 | 106 | x, expire := cache.Get(key, duration) 107 | is.Equal(int64(0), x) 108 | is.InEpsilon(deleted, expire.UnixNano(), epsilon) 109 | 110 | x, expire = cache.Increment(key, 1, duration) 111 | is.Equal(int64(1), x) 112 | is.InEpsilon(deleted, expire.UnixNano(), epsilon) 113 | 114 | x, expire = cache.Increment(key, 1, duration) 115 | is.Equal(int64(2), x) 116 | is.InEpsilon(deleted, expire.UnixNano(), epsilon) 117 | 118 | x, expire = cache.Reset(key, duration) 119 | is.Equal(int64(0), x) 120 | is.InEpsilon(deleted, expire.UnixNano(), epsilon) 121 | 122 | x, expire = cache.Increment(key, 1, duration) 123 | is.Equal(int64(1), x) 124 | is.InEpsilon(deleted, expire.UnixNano(), epsilon) 125 | 126 | x, expire = cache.Increment(key, 1, duration) 127 | is.Equal(int64(2), x) 128 | is.InEpsilon(deleted, expire.UnixNano(), epsilon) 129 | } 130 | -------------------------------------------------------------------------------- /drivers/store/memory/store.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | "time" 7 | 8 | "github.com/ulule/limiter/v3" 9 | "github.com/ulule/limiter/v3/drivers/store/common" 10 | ) 11 | 12 | // Store is the in-memory store. 13 | type Store struct { 14 | // Prefix used for the key. 15 | Prefix string 16 | // cache used to store values in-memory. 17 | cache *CacheWrapper 18 | } 19 | 20 | // NewStore creates a new instance of memory store with defaults. 21 | func NewStore() limiter.Store { 22 | return NewStoreWithOptions(limiter.StoreOptions{ 23 | Prefix: limiter.DefaultPrefix, 24 | CleanUpInterval: limiter.DefaultCleanUpInterval, 25 | }) 26 | } 27 | 28 | // NewStoreWithOptions creates a new instance of memory store with options. 29 | func NewStoreWithOptions(options limiter.StoreOptions) limiter.Store { 30 | return &Store{ 31 | Prefix: options.Prefix, 32 | cache: NewCache(options.CleanUpInterval), 33 | } 34 | } 35 | 36 | // Get returns the limit for given identifier. 37 | func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { 38 | count, expiration := store.cache.Increment(store.getCacheKey(key), 1, rate.Period) 39 | 40 | lctx := common.GetContextFromState(time.Now(), rate, expiration, count) 41 | return lctx, nil 42 | } 43 | 44 | // Increment increments the limit by given count & returns the new limit value for given identifier. 45 | func (store *Store) Increment(ctx context.Context, key string, count int64, rate limiter.Rate) (limiter.Context, error) { 46 | newCount, expiration := store.cache.Increment(store.getCacheKey(key), count, rate.Period) 47 | 48 | lctx := common.GetContextFromState(time.Now(), rate, expiration, newCount) 49 | return lctx, nil 50 | } 51 | 52 | // Peek returns the limit for given identifier, without modification on current values. 53 | func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { 54 | count, expiration := store.cache.Get(store.getCacheKey(key), rate.Period) 55 | 56 | lctx := common.GetContextFromState(time.Now(), rate, expiration, count) 57 | return lctx, nil 58 | } 59 | 60 | // Reset returns the limit for given identifier. 61 | func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { 62 | count, expiration := store.cache.Reset(store.getCacheKey(key), rate.Period) 63 | 64 | lctx := common.GetContextFromState(time.Now(), rate, expiration, count) 65 | return lctx, nil 66 | } 67 | 68 | // getCacheKey returns the full path for an identifier. 69 | func (store *Store) getCacheKey(key string) string { 70 | buffer := strings.Builder{} 71 | buffer.WriteString(store.Prefix) 72 | buffer.WriteString(":") 73 | buffer.WriteString(key) 74 | return buffer.String() 75 | } 76 | -------------------------------------------------------------------------------- /drivers/store/memory/store_test.go: -------------------------------------------------------------------------------- 1 | package memory_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/ulule/limiter/v3" 8 | "github.com/ulule/limiter/v3/drivers/store/memory" 9 | "github.com/ulule/limiter/v3/drivers/store/tests" 10 | ) 11 | 12 | func TestMemoryStoreSequentialAccess(t *testing.T) { 13 | tests.TestStoreSequentialAccess(t, memory.NewStoreWithOptions(limiter.StoreOptions{ 14 | Prefix: "limiter:memory:sequential-test", 15 | CleanUpInterval: 30 * time.Second, 16 | })) 17 | } 18 | 19 | func TestMemoryStoreConcurrentAccess(t *testing.T) { 20 | tests.TestStoreConcurrentAccess(t, memory.NewStoreWithOptions(limiter.StoreOptions{ 21 | Prefix: "limiter:memory:concurrent-test", 22 | CleanUpInterval: 1 * time.Nanosecond, 23 | })) 24 | } 25 | 26 | func BenchmarkMemoryStoreSequentialAccess(b *testing.B) { 27 | tests.BenchmarkStoreSequentialAccess(b, memory.NewStoreWithOptions(limiter.StoreOptions{ 28 | Prefix: "limiter:memory:sequential-benchmark", 29 | CleanUpInterval: 1 * time.Hour, 30 | })) 31 | } 32 | 33 | func BenchmarkMemoryStoreConcurrentAccess(b *testing.B) { 34 | tests.BenchmarkStoreConcurrentAccess(b, memory.NewStoreWithOptions(limiter.StoreOptions{ 35 | Prefix: "limiter:memory:concurrent-benchmark", 36 | CleanUpInterval: 1 * time.Hour, 37 | })) 38 | } 39 | -------------------------------------------------------------------------------- /drivers/store/redis/store.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | "sync" 7 | "sync/atomic" 8 | "time" 9 | 10 | "github.com/pkg/errors" 11 | libredis "github.com/redis/go-redis/v9" 12 | 13 | "github.com/ulule/limiter/v3" 14 | "github.com/ulule/limiter/v3/drivers/store/common" 15 | ) 16 | 17 | const ( 18 | luaIncrScript = ` 19 | local key = KEYS[1] 20 | local count = tonumber(ARGV[1]) 21 | local ttl = tonumber(ARGV[2]) 22 | local ret = redis.call("incrby", key, ARGV[1]) 23 | if ret == count then 24 | if ttl > 0 then 25 | redis.call("pexpire", key, ARGV[2]) 26 | end 27 | return {ret, ttl} 28 | end 29 | ttl = redis.call("pttl", key) 30 | return {ret, ttl} 31 | ` 32 | luaPeekScript = ` 33 | local key = KEYS[1] 34 | local v = redis.call("get", key) 35 | if v == false then 36 | return {0, 0} 37 | end 38 | local ttl = redis.call("pttl", key) 39 | return {tonumber(v), ttl} 40 | ` 41 | ) 42 | 43 | // Client is an interface thats allows to use a redis cluster or a redis single client seamlessly. 44 | type Client interface { 45 | Get(ctx context.Context, key string) *libredis.StringCmd 46 | Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *libredis.StatusCmd 47 | Watch(ctx context.Context, handler func(*libredis.Tx) error, keys ...string) error 48 | Del(ctx context.Context, keys ...string) *libredis.IntCmd 49 | SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *libredis.BoolCmd 50 | EvalSha(ctx context.Context, sha string, keys []string, args ...interface{}) *libredis.Cmd 51 | ScriptLoad(ctx context.Context, script string) *libredis.StringCmd 52 | } 53 | 54 | // Store is the redis store. 55 | type Store struct { 56 | // Prefix used for the key. 57 | Prefix string 58 | // MaxRetry is the maximum number of retry under race conditions. 59 | // Deprecated: this option is no longer required since all operations are atomic now. 60 | MaxRetry int 61 | // client used to communicate with redis server. 62 | client Client 63 | // luaMutex is a mutex used to avoid concurrent access on luaIncrSHA and luaPeekSHA. 64 | luaMutex sync.RWMutex 65 | // luaLoaded is used for CAS and reduce pressure on luaMutex. 66 | luaLoaded uint32 67 | // luaIncrSHA is the SHA of increase and expire key script. 68 | luaIncrSHA string 69 | // luaPeekSHA is the SHA of peek and expire key script. 70 | luaPeekSHA string 71 | } 72 | 73 | // NewStore returns an instance of redis store with defaults. 74 | func NewStore(client Client) (limiter.Store, error) { 75 | return NewStoreWithOptions(client, limiter.StoreOptions{ 76 | Prefix: limiter.DefaultPrefix, 77 | CleanUpInterval: limiter.DefaultCleanUpInterval, 78 | MaxRetry: limiter.DefaultMaxRetry, 79 | }) 80 | } 81 | 82 | // NewStoreWithOptions returns an instance of redis store with options. 83 | func NewStoreWithOptions(client Client, options limiter.StoreOptions) (limiter.Store, error) { 84 | store := &Store{ 85 | client: client, 86 | Prefix: options.Prefix, 87 | MaxRetry: options.MaxRetry, 88 | } 89 | 90 | err := store.preloadLuaScripts(context.Background()) 91 | if err != nil { 92 | return nil, err 93 | } 94 | 95 | return store, nil 96 | } 97 | 98 | // Increment increments the limit by given count & gives back the new limit for given identifier 99 | func (store *Store) Increment(ctx context.Context, key string, count int64, rate limiter.Rate) (limiter.Context, error) { 100 | cmd := store.evalSHA(ctx, store.getLuaIncrSHA, []string{store.getCacheKey(key)}, count, rate.Period.Milliseconds()) 101 | return currentContext(cmd, rate) 102 | } 103 | 104 | // Get returns the limit for given identifier. 105 | func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { 106 | cmd := store.evalSHA(ctx, store.getLuaIncrSHA, []string{store.getCacheKey(key)}, 1, rate.Period.Milliseconds()) 107 | return currentContext(cmd, rate) 108 | } 109 | 110 | // Peek returns the limit for given identifier, without modification on current values. 111 | func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { 112 | cmd := store.evalSHA(ctx, store.getLuaPeekSHA, []string{store.getCacheKey(key)}) 113 | count, ttl, err := parseCountAndTTL(cmd) 114 | if err != nil { 115 | return limiter.Context{}, err 116 | } 117 | 118 | now := time.Now() 119 | expiration := now.Add(rate.Period) 120 | if ttl > 0 { 121 | expiration = now.Add(time.Duration(ttl) * time.Millisecond) 122 | } 123 | 124 | return common.GetContextFromState(now, rate, expiration, count), nil 125 | } 126 | 127 | // Reset returns the limit for given identifier which is set to zero. 128 | func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { 129 | _, err := store.client.Del(ctx, store.getCacheKey(key)).Result() 130 | if err != nil { 131 | return limiter.Context{}, err 132 | } 133 | 134 | count := int64(0) 135 | now := time.Now() 136 | expiration := now.Add(rate.Period) 137 | 138 | return common.GetContextFromState(now, rate, expiration, count), nil 139 | } 140 | 141 | // getCacheKey returns the full path for an identifier. 142 | func (store *Store) getCacheKey(key string) string { 143 | buffer := strings.Builder{} 144 | buffer.WriteString(store.Prefix) 145 | buffer.WriteString(":") 146 | buffer.WriteString(key) 147 | return buffer.String() 148 | } 149 | 150 | // preloadLuaScripts preloads the "incr" and "peek" lua scripts. 151 | func (store *Store) preloadLuaScripts(ctx context.Context) error { 152 | // Verify if we need to load lua scripts. 153 | // Inspired by sync.Once. 154 | if atomic.LoadUint32(&store.luaLoaded) == 0 { 155 | return store.loadLuaScripts(ctx) 156 | } 157 | return nil 158 | } 159 | 160 | // reloadLuaScripts forces a reload of "incr" and "peek" lua scripts. 161 | func (store *Store) reloadLuaScripts(ctx context.Context) error { 162 | // Reset lua scripts loaded state. 163 | // Inspired by sync.Once. 164 | atomic.StoreUint32(&store.luaLoaded, 0) 165 | return store.loadLuaScripts(ctx) 166 | } 167 | 168 | // loadLuaScripts load "incr" and "peek" lua scripts. 169 | // WARNING: Please use preloadLuaScripts or reloadLuaScripts, instead of this one. 170 | func (store *Store) loadLuaScripts(ctx context.Context) error { 171 | store.luaMutex.Lock() 172 | defer store.luaMutex.Unlock() 173 | 174 | // Check if scripts are already loaded. 175 | if atomic.LoadUint32(&store.luaLoaded) != 0 { 176 | return nil 177 | } 178 | 179 | luaIncrSHA, err := store.client.ScriptLoad(ctx, luaIncrScript).Result() 180 | if err != nil { 181 | return errors.Wrap(err, `failed to load "incr" lua script`) 182 | } 183 | 184 | luaPeekSHA, err := store.client.ScriptLoad(ctx, luaPeekScript).Result() 185 | if err != nil { 186 | return errors.Wrap(err, `failed to load "peek" lua script`) 187 | } 188 | 189 | store.luaIncrSHA = luaIncrSHA 190 | store.luaPeekSHA = luaPeekSHA 191 | 192 | atomic.StoreUint32(&store.luaLoaded, 1) 193 | 194 | return nil 195 | } 196 | 197 | // getLuaIncrSHA returns a "thread-safe" value for luaIncrSHA. 198 | func (store *Store) getLuaIncrSHA() string { 199 | store.luaMutex.RLock() 200 | defer store.luaMutex.RUnlock() 201 | return store.luaIncrSHA 202 | } 203 | 204 | // getLuaPeekSHA returns a "thread-safe" value for luaPeekSHA. 205 | func (store *Store) getLuaPeekSHA() string { 206 | store.luaMutex.RLock() 207 | defer store.luaMutex.RUnlock() 208 | return store.luaPeekSHA 209 | } 210 | 211 | // evalSHA eval the redis lua sha and load the scripts if missing. 212 | func (store *Store) evalSHA(ctx context.Context, getSha func() string, 213 | keys []string, args ...interface{}) *libredis.Cmd { 214 | 215 | cmd := store.client.EvalSha(ctx, getSha(), keys, args...) 216 | err := cmd.Err() 217 | if err == nil || !isLuaScriptGone(err) { 218 | return cmd 219 | } 220 | 221 | err = store.reloadLuaScripts(ctx) 222 | if err != nil { 223 | cmd = libredis.NewCmd(ctx) 224 | cmd.SetErr(err) 225 | return cmd 226 | } 227 | 228 | return store.client.EvalSha(ctx, getSha(), keys, args...) 229 | } 230 | 231 | // isLuaScriptGone returns if the error is a missing lua script from redis server. 232 | func isLuaScriptGone(err error) bool { 233 | return strings.HasPrefix(err.Error(), "NOSCRIPT") 234 | } 235 | 236 | // parseCountAndTTL parse count and ttl from lua script output. 237 | func parseCountAndTTL(cmd *libredis.Cmd) (int64, int64, error) { 238 | result, err := cmd.Result() 239 | if err != nil { 240 | return 0, 0, errors.Wrap(err, "an error has occurred with redis command") 241 | } 242 | 243 | fields, ok := result.([]interface{}) 244 | if !ok || len(fields) != 2 { 245 | return 0, 0, errors.New("two elements in result were expected") 246 | } 247 | 248 | count, ok1 := fields[0].(int64) 249 | ttl, ok2 := fields[1].(int64) 250 | if !ok1 || !ok2 { 251 | return 0, 0, errors.New("type of the count and/or ttl should be number") 252 | } 253 | 254 | return count, ttl, nil 255 | } 256 | 257 | func currentContext(cmd *libredis.Cmd, rate limiter.Rate) (limiter.Context, error) { 258 | count, ttl, err := parseCountAndTTL(cmd) 259 | if err != nil { 260 | return limiter.Context{}, err 261 | } 262 | 263 | now := time.Now() 264 | expiration := now.Add(rate.Period) 265 | if ttl > 0 { 266 | expiration = now.Add(time.Duration(ttl) * time.Millisecond) 267 | } 268 | 269 | return common.GetContextFromState(now, rate, expiration, count), nil 270 | } 271 | -------------------------------------------------------------------------------- /drivers/store/redis/store_test.go: -------------------------------------------------------------------------------- 1 | package redis_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | "time" 8 | 9 | libredis "github.com/redis/go-redis/v9" 10 | "github.com/stretchr/testify/require" 11 | 12 | "github.com/ulule/limiter/v3" 13 | "github.com/ulule/limiter/v3/drivers/store/redis" 14 | "github.com/ulule/limiter/v3/drivers/store/tests" 15 | ) 16 | 17 | func TestRedisStoreSequentialAccess(t *testing.T) { 18 | is := require.New(t) 19 | 20 | client, err := newRedisClient() 21 | is.NoError(err) 22 | is.NotNil(client) 23 | 24 | store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{ 25 | Prefix: "limiter:redis:sequential-test", 26 | }) 27 | is.NoError(err) 28 | is.NotNil(store) 29 | 30 | tests.TestStoreSequentialAccess(t, store) 31 | } 32 | 33 | func TestRedisStoreConcurrentAccess(t *testing.T) { 34 | is := require.New(t) 35 | 36 | client, err := newRedisClient() 37 | is.NoError(err) 38 | is.NotNil(client) 39 | 40 | store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{ 41 | Prefix: "limiter:redis:concurrent-test", 42 | }) 43 | is.NoError(err) 44 | is.NotNil(store) 45 | 46 | tests.TestStoreConcurrentAccess(t, store) 47 | } 48 | 49 | func TestRedisClientExpiration(t *testing.T) { 50 | is := require.New(t) 51 | 52 | client, err := newRedisClient() 53 | is.NoError(err) 54 | is.NotNil(client) 55 | 56 | key := "foobar" 57 | value := 642 58 | keyNoExpiration := -1 * time.Nanosecond 59 | keyNotExist := -2 * time.Nanosecond 60 | 61 | ctx := context.Background() 62 | delCmd := client.Del(ctx, key) 63 | _, err = delCmd.Result() 64 | is.NoError(err) 65 | 66 | expCmd := client.PTTL(ctx, key) 67 | ttl, err := expCmd.Result() 68 | is.NoError(err) 69 | is.Equal(keyNotExist, ttl) 70 | 71 | setCmd := client.Set(ctx, key, value, 0) 72 | _, err = setCmd.Result() 73 | is.NoError(err) 74 | 75 | expCmd = client.PTTL(ctx, key) 76 | ttl, err = expCmd.Result() 77 | is.NoError(err) 78 | is.Equal(keyNoExpiration, ttl) 79 | 80 | setCmd = client.Set(ctx, key, value, time.Second) 81 | _, err = setCmd.Result() 82 | is.NoError(err) 83 | 84 | time.Sleep(100 * time.Millisecond) 85 | 86 | expCmd = client.PTTL(ctx, key) 87 | ttl, err = expCmd.Result() 88 | is.NoError(err) 89 | 90 | expected := int64(0) 91 | actual := int64(ttl) 92 | is.Greater(actual, expected) 93 | } 94 | 95 | func BenchmarkRedisStoreSequentialAccess(b *testing.B) { 96 | is := require.New(b) 97 | 98 | client, err := newRedisClient() 99 | is.NoError(err) 100 | is.NotNil(client) 101 | 102 | store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{ 103 | Prefix: "limiter:redis:sequential-benchmark", 104 | }) 105 | is.NoError(err) 106 | is.NotNil(store) 107 | 108 | tests.BenchmarkStoreSequentialAccess(b, store) 109 | } 110 | 111 | func BenchmarkRedisStoreConcurrentAccess(b *testing.B) { 112 | is := require.New(b) 113 | 114 | client, err := newRedisClient() 115 | is.NoError(err) 116 | is.NotNil(client) 117 | 118 | store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{ 119 | Prefix: "limiter:redis:concurrent-benchmark", 120 | }) 121 | is.NoError(err) 122 | is.NotNil(store) 123 | 124 | tests.BenchmarkStoreConcurrentAccess(b, store) 125 | } 126 | 127 | func newRedisClient() (*libredis.Client, error) { 128 | uri := "redis://localhost:6379/0" 129 | if os.Getenv("REDIS_URI") != "" { 130 | uri = os.Getenv("REDIS_URI") 131 | } 132 | 133 | opt, err := libredis.ParseURL(uri) 134 | if err != nil { 135 | return nil, err 136 | } 137 | 138 | client := libredis.NewClient(opt) 139 | return client, nil 140 | } 141 | -------------------------------------------------------------------------------- /drivers/store/tests/tests.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/ulule/limiter/v3" 12 | ) 13 | 14 | // TestStoreSequentialAccess verify that store works as expected with a sequential access. 15 | func TestStoreSequentialAccess(t *testing.T, store limiter.Store) { 16 | is := require.New(t) 17 | ctx := context.Background() 18 | 19 | limiter := limiter.New(store, limiter.Rate{ 20 | Limit: 3, 21 | Period: time.Minute, 22 | }) 23 | 24 | // Check counter increment. 25 | { 26 | for i := 1; i <= 6; i++ { 27 | 28 | if i <= 3 { 29 | 30 | lctx, err := limiter.Peek(ctx, "foo") 31 | is.NoError(err) 32 | is.NotZero(lctx) 33 | is.Equal(int64(3-(i-1)), lctx.Remaining) 34 | is.False(lctx.Reached) 35 | 36 | } 37 | 38 | lctx, err := limiter.Get(ctx, "foo") 39 | is.NoError(err) 40 | is.NotZero(lctx) 41 | 42 | if i <= 3 { 43 | 44 | is.Equal(int64(3), lctx.Limit) 45 | is.Equal(int64(3-i), lctx.Remaining) 46 | is.True((lctx.Reset - time.Now().Unix()) <= 60) 47 | is.False(lctx.Reached) 48 | 49 | lctx, err = limiter.Peek(ctx, "foo") 50 | is.NoError(err) 51 | is.Equal(int64(3-i), lctx.Remaining) 52 | is.False(lctx.Reached) 53 | 54 | } else { 55 | 56 | is.Equal(int64(3), lctx.Limit) 57 | is.Equal(int64(0), lctx.Remaining) 58 | is.True((lctx.Reset - time.Now().Unix()) <= 60) 59 | is.True(lctx.Reached) 60 | 61 | } 62 | } 63 | } 64 | 65 | // Check counter reset. 66 | { 67 | lctx, err := limiter.Peek(ctx, "foo") 68 | is.NoError(err) 69 | is.NotZero(lctx) 70 | 71 | is.Equal(int64(3), lctx.Limit) 72 | is.Equal(int64(0), lctx.Remaining) 73 | is.True((lctx.Reset - time.Now().Unix()) <= 60) 74 | is.True(lctx.Reached) 75 | 76 | lctx, err = limiter.Reset(ctx, "foo") 77 | is.NoError(err) 78 | is.NotZero(lctx) 79 | 80 | is.Equal(int64(3), lctx.Limit) 81 | is.Equal(int64(3), lctx.Remaining) 82 | is.True((lctx.Reset - time.Now().Unix()) <= 60) 83 | is.False(lctx.Reached) 84 | 85 | lctx, err = limiter.Peek(ctx, "foo") 86 | is.NoError(err) 87 | is.NotZero(lctx) 88 | 89 | is.Equal(int64(3), lctx.Limit) 90 | is.Equal(int64(3), lctx.Remaining) 91 | is.True((lctx.Reset - time.Now().Unix()) <= 60) 92 | is.False(lctx.Reached) 93 | 94 | lctx, err = limiter.Get(ctx, "foo") 95 | is.NoError(err) 96 | is.NotZero(lctx) 97 | 98 | lctx, err = limiter.Reset(ctx, "foo") 99 | is.NoError(err) 100 | is.NotZero(lctx) 101 | 102 | is.Equal(int64(3), lctx.Limit) 103 | is.Equal(int64(3), lctx.Remaining) 104 | is.True((lctx.Reset - time.Now().Unix()) <= 60) 105 | is.False(lctx.Reached) 106 | 107 | lctx, err = limiter.Reset(ctx, "foo") 108 | is.NoError(err) 109 | is.NotZero(lctx) 110 | 111 | is.Equal(int64(3), lctx.Limit) 112 | is.Equal(int64(3), lctx.Remaining) 113 | is.True((lctx.Reset - time.Now().Unix()) <= 60) 114 | is.False(lctx.Reached) 115 | } 116 | } 117 | 118 | // TestStoreConcurrentAccess verify that store works as expected with a concurrent access. 119 | func TestStoreConcurrentAccess(t *testing.T, store limiter.Store) { 120 | is := require.New(t) 121 | ctx := context.Background() 122 | 123 | limiter := limiter.New(store, limiter.Rate{ 124 | Limit: 100000, 125 | Period: 10 * time.Second, 126 | }) 127 | 128 | goroutines := 500 129 | ops := 500 130 | 131 | wg := &sync.WaitGroup{} 132 | wg.Add(goroutines) 133 | for i := 0; i < goroutines; i++ { 134 | go func(i int) { 135 | for j := 0; j < ops; j++ { 136 | lctx, err := limiter.Get(ctx, "foo") 137 | is.NoError(err) 138 | is.NotZero(lctx) 139 | } 140 | wg.Done() 141 | }(i) 142 | } 143 | wg.Wait() 144 | } 145 | 146 | // BenchmarkStoreSequentialAccess executes a benchmark against a store without parallel setting. 147 | func BenchmarkStoreSequentialAccess(b *testing.B, store limiter.Store) { 148 | ctx := context.Background() 149 | 150 | instance := limiter.New(store, limiter.Rate{ 151 | Limit: 100000, 152 | Period: 10 * time.Second, 153 | }) 154 | 155 | b.ResetTimer() 156 | for i := 0; i < b.N; i++ { 157 | _, _ = instance.Get(ctx, "foo") 158 | } 159 | } 160 | 161 | // BenchmarkStoreConcurrentAccess executes a benchmark against a store with parallel setting. 162 | func BenchmarkStoreConcurrentAccess(b *testing.B, store limiter.Store) { 163 | ctx := context.Background() 164 | 165 | instance := limiter.New(store, limiter.Rate{ 166 | Limit: 100000, 167 | Period: 10 * time.Second, 168 | }) 169 | 170 | b.ResetTimer() 171 | b.RunParallel(func(pb *testing.PB) { 172 | for pb.Next() { 173 | _, _ = instance.Get(ctx, "foo") 174 | } 175 | }) 176 | } 177 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Limiter examples 2 | 3 | The examples has been moved here: https://github.com/ulule/limiter-examples 4 | 5 | Nonetheless, this is list of middleware examples with the new location: 6 | 7 | * [HTTP](https://github.com/ulule/limiter-examples/tree/master/http/main.go) 8 | * [Gin](https://github.com/ulule/limiter-examples/tree/master/gin/main.go) 9 | * [Beego](https://github.com/ulule/limiter-examples/blob/master/beego/main.go) 10 | * [Chi](https://github.com/ulule/limiter-examples/tree/master/chi/main.go) 11 | * [Echo](https://github.com/ulule/limiter-examples/tree/master/echo/main.go) 12 | * [fasthttp](https://github.com/ulule/limiter-examples/tree/master/fasthttp/main.go) 13 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/ulule/limiter/v3 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/gin-gonic/gin v1.9.1 7 | github.com/pkg/errors v0.9.1 8 | github.com/redis/go-redis/v9 v9.6.2 9 | github.com/stretchr/testify v1.8.4 10 | github.com/valyala/fasthttp v1.50.0 11 | ) 12 | 13 | require ( 14 | github.com/andybalholm/brotli v1.0.5 // indirect 15 | github.com/bytedance/sonic v1.9.1 // indirect 16 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 17 | github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect 18 | github.com/davecgh/go-spew v1.1.1 // indirect 19 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 20 | github.com/gabriel-vasile/mimetype v1.4.2 // indirect 21 | github.com/gin-contrib/sse v0.1.0 // indirect 22 | github.com/go-playground/locales v0.14.1 // indirect 23 | github.com/go-playground/universal-translator v0.18.1 // indirect 24 | github.com/go-playground/validator/v10 v10.14.0 // indirect 25 | github.com/goccy/go-json v0.10.2 // indirect 26 | github.com/google/go-cmp v0.5.6 // indirect 27 | github.com/json-iterator/go v1.1.12 // indirect 28 | github.com/klauspost/compress v1.16.3 // indirect 29 | github.com/klauspost/cpuid/v2 v2.2.4 // indirect 30 | github.com/leodido/go-urn v1.2.4 // indirect 31 | github.com/mattn/go-isatty v0.0.19 // indirect 32 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 33 | github.com/modern-go/reflect2 v1.0.2 // indirect 34 | github.com/pelletier/go-toml/v2 v2.0.8 // indirect 35 | github.com/pmezard/go-difflib v1.0.0 // indirect 36 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 37 | github.com/ugorji/go/codec v1.2.11 // indirect 38 | github.com/valyala/bytebufferpool v1.0.0 // indirect 39 | golang.org/x/arch v0.3.0 // indirect 40 | golang.org/x/crypto v0.21.0 // indirect 41 | golang.org/x/net v0.23.0 // indirect 42 | golang.org/x/sys v0.18.0 // indirect 43 | golang.org/x/text v0.14.0 // indirect 44 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect 45 | google.golang.org/protobuf v1.33.0 // indirect 46 | gopkg.in/yaml.v3 v3.0.1 // indirect 47 | ) 48 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= 2 | github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= 3 | github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= 4 | github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= 5 | github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= 6 | github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= 7 | github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= 8 | github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= 9 | github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= 10 | github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= 11 | github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 12 | github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= 13 | github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= 14 | github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= 15 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 16 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 17 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 18 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= 19 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= 20 | github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= 21 | github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= 22 | github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= 23 | github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= 24 | github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= 25 | github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= 26 | github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= 27 | github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= 28 | github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= 29 | github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= 30 | github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= 31 | github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= 32 | github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= 33 | github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= 34 | github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= 35 | github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= 36 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 37 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 38 | github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= 39 | github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 40 | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 41 | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= 42 | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= 43 | github.com/klauspost/compress v1.16.3 h1:XuJt9zzcnaz6a16/OU53ZjWp/v7/42WcR5t2a0PcNQY= 44 | github.com/klauspost/compress v1.16.3/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= 45 | github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= 46 | github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= 47 | github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= 48 | github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= 49 | github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= 50 | github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= 51 | github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 52 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 53 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= 54 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 55 | github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= 56 | github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= 57 | github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= 58 | github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= 59 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 60 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 61 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 62 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 63 | github.com/redis/go-redis/v9 v9.6.2 h1:w0uvkRbc9KpgD98zcvo5IrVUsn0lXpRMuhNgiHDJzdk= 64 | github.com/redis/go-redis/v9 v9.6.2/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJCQLgE9+RA= 65 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 66 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 67 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 68 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 69 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 70 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 71 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 72 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 73 | github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 74 | github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 75 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 76 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 77 | github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= 78 | github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= 79 | github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= 80 | github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= 81 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= 82 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= 83 | github.com/valyala/fasthttp v1.50.0 h1:H7fweIlBm0rXLs2q0XbalvJ6r0CUPFWK3/bB4N13e9M= 84 | github.com/valyala/fasthttp v1.50.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA= 85 | github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= 86 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= 87 | golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= 88 | golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= 89 | golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= 90 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 91 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 92 | golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= 93 | golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= 94 | golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= 95 | golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= 96 | golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= 97 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= 98 | golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 99 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 100 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 101 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= 102 | golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= 103 | golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= 104 | golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= 105 | golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= 106 | golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= 107 | golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= 108 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 109 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 110 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 111 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 112 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 113 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 114 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 115 | golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 116 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 117 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 118 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 119 | golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 120 | golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 121 | golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= 122 | golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 123 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 124 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 125 | golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= 126 | golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= 127 | golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= 128 | golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= 129 | golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= 130 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 131 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 132 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 133 | golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= 134 | golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 135 | golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 136 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 137 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 138 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 139 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 140 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 141 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= 142 | golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= 143 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 144 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 145 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= 146 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 147 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 148 | google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= 149 | google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= 150 | google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= 151 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 152 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 153 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 154 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 155 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 156 | rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= 157 | -------------------------------------------------------------------------------- /internal/bytebuffer/pool.go: -------------------------------------------------------------------------------- 1 | package bytebuffer 2 | 3 | import ( 4 | "sync" 5 | "unsafe" 6 | ) 7 | 8 | // ByteBuffer is a wrapper around a slice to reduce memory allocation while handling blob of data. 9 | type ByteBuffer struct { 10 | blob []byte 11 | } 12 | 13 | // New creates a new ByteBuffer instance. 14 | func New() *ByteBuffer { 15 | entry := bufferPool.Get().(*ByteBuffer) 16 | entry.blob = entry.blob[:0] 17 | return entry 18 | } 19 | 20 | // Bytes returns the content buffer. 21 | func (buffer *ByteBuffer) Bytes() []byte { 22 | return buffer.blob 23 | } 24 | 25 | // String returns the content buffer. 26 | func (buffer *ByteBuffer) String() string { 27 | // Copied from strings.(*Builder).String 28 | return *(*string)(unsafe.Pointer(&buffer.blob)) // nolint: gosec 29 | } 30 | 31 | // Concat appends given arguments to blob content 32 | func (buffer *ByteBuffer) Concat(args ...string) { 33 | for i := range args { 34 | buffer.blob = append(buffer.blob, args[i]...) 35 | } 36 | } 37 | 38 | // Close recycles underlying resources of encoder. 39 | func (buffer *ByteBuffer) Close() { 40 | // Proper usage of a sync.Pool requires each entry to have approximately 41 | // the same memory cost. To obtain this property when the stored type 42 | // contains a variably-sized buffer, we add a hard limit on the maximum buffer 43 | // to place back in the pool. 44 | // 45 | // See https://golang.org/issue/23199 46 | if buffer != nil && cap(buffer.blob) < (1<<16) { 47 | bufferPool.Put(buffer) 48 | } 49 | } 50 | 51 | // A byte buffer pool to reduce memory allocation pressure. 52 | var bufferPool = &sync.Pool{ 53 | New: func() interface{} { 54 | return &ByteBuffer{ 55 | blob: make([]byte, 0, 1024), 56 | } 57 | }, 58 | } 59 | -------------------------------------------------------------------------------- /internal/fasttime/fasttime.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | // Package fasttime gets wallclock time, but super fast. 4 | package fasttime 5 | 6 | import ( 7 | _ "unsafe" // import unsafe because we use go:linkname directive. 8 | ) 9 | 10 | // Forked from https://github.com/sethvargo/go-limiter 11 | 12 | //go:noescape 13 | //go:linkname now time.now 14 | func now() (sec int64, nsec int32, mono int64) 15 | 16 | // Now returns a monotonic clock value. The actual value will differ across 17 | // systems, but that's okay because we generally only care about the deltas. 18 | func Now() uint64 { 19 | sec, nsec, _ := now() 20 | return uint64(sec)*1e9 + uint64(nsec) 21 | } 22 | -------------------------------------------------------------------------------- /internal/fasttime/fasttime_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package fasttime 4 | 5 | import "time" 6 | 7 | // Forked from https://github.com/sethvargo/go-limiter 8 | 9 | // Now returns a monotonic clock value. On Windows, no such clock exists, so we 10 | // fallback to time.Now(). 11 | func Now() uint64 { 12 | return uint64(time.Now().UnixNano()) 13 | } 14 | -------------------------------------------------------------------------------- /limiter.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // ----------------------------------------------------------------- 8 | // Context 9 | // ----------------------------------------------------------------- 10 | 11 | // Context is the limit context. 12 | type Context struct { 13 | Limit int64 14 | Remaining int64 15 | Reset int64 16 | Reached bool 17 | } 18 | 19 | // ----------------------------------------------------------------- 20 | // Limiter 21 | // ----------------------------------------------------------------- 22 | 23 | // Limiter is the limiter instance. 24 | type Limiter struct { 25 | Store Store 26 | Rate Rate 27 | Options Options 28 | } 29 | 30 | // New returns an instance of Limiter. 31 | func New(store Store, rate Rate, options ...Option) *Limiter { 32 | opt := Options{ 33 | IPv4Mask: DefaultIPv4Mask, 34 | IPv6Mask: DefaultIPv6Mask, 35 | TrustForwardHeader: false, 36 | } 37 | for _, o := range options { 38 | o(&opt) 39 | } 40 | return &Limiter{ 41 | Store: store, 42 | Rate: rate, 43 | Options: opt, 44 | } 45 | } 46 | 47 | // Get returns the limit for given identifier. 48 | func (limiter *Limiter) Get(ctx context.Context, key string) (Context, error) { 49 | return limiter.Store.Get(ctx, key, limiter.Rate) 50 | } 51 | 52 | // Peek returns the limit for given identifier, without modification on current values. 53 | func (limiter *Limiter) Peek(ctx context.Context, key string) (Context, error) { 54 | return limiter.Store.Peek(ctx, key, limiter.Rate) 55 | } 56 | 57 | // Reset sets the limit for given identifier to zero. 58 | func (limiter *Limiter) Reset(ctx context.Context, key string) (Context, error) { 59 | return limiter.Store.Reset(ctx, key, limiter.Rate) 60 | } 61 | 62 | // Increment increments the limit by given count & gives back the new limit for given identifier 63 | func (limiter *Limiter) Increment(ctx context.Context, key string, count int64) (Context, error) { 64 | return limiter.Store.Increment(ctx, key, count, limiter.Rate) 65 | } 66 | -------------------------------------------------------------------------------- /limiter_test.go: -------------------------------------------------------------------------------- 1 | package limiter_test 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/ulule/limiter/v3" 7 | "github.com/ulule/limiter/v3/drivers/store/memory" 8 | ) 9 | 10 | func New(options ...limiter.Option) *limiter.Limiter { 11 | store := memory.NewStore() 12 | rate := limiter.Rate{ 13 | Period: 1 * time.Second, 14 | Limit: int64(10), 15 | } 16 | return limiter.New(store, rate, options...) 17 | } 18 | -------------------------------------------------------------------------------- /network.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | var ( 10 | // DefaultIPv4Mask defines the default IPv4 mask used to obtain user IP. 11 | DefaultIPv4Mask = net.CIDRMask(32, 32) 12 | // DefaultIPv6Mask defines the default IPv6 mask used to obtain user IP. 13 | DefaultIPv6Mask = net.CIDRMask(128, 128) 14 | ) 15 | 16 | // GetIP returns IP address from request. 17 | // If options is defined and either TrustForwardHeader is true or ClientIPHeader is defined, 18 | // it will lookup IP in HTTP headers. 19 | // Please be advised that using this option could be insecure (ie: spoofed) if your reverse 20 | // proxy is not configured properly to forward a trustworthy client IP. 21 | // Please read the section "Limiter behind a reverse proxy" in the README for further information. 22 | func (limiter *Limiter) GetIP(r *http.Request) net.IP { 23 | return GetIP(r, limiter.Options) 24 | } 25 | 26 | // GetIPWithMask returns IP address from request by applying a mask. 27 | // If options is defined and either TrustForwardHeader is true or ClientIPHeader is defined, 28 | // it will lookup IP in HTTP headers. 29 | // Please be advised that using this option could be insecure (ie: spoofed) if your reverse 30 | // proxy is not configured properly to forward a trustworthy client IP. 31 | // Please read the section "Limiter behind a reverse proxy" in the README for further information. 32 | func (limiter *Limiter) GetIPWithMask(r *http.Request) net.IP { 33 | return GetIPWithMask(r, limiter.Options) 34 | } 35 | 36 | // GetIPKey extracts IP from request and returns hashed IP to use as store key. 37 | // If options is defined and either TrustForwardHeader is true or ClientIPHeader is defined, 38 | // it will lookup IP in HTTP headers. 39 | // Please be advised that using this option could be insecure (ie: spoofed) if your reverse 40 | // proxy is not configured properly to forward a trustworthy client IP. 41 | // Please read the section "Limiter behind a reverse proxy" in the README for further information. 42 | func (limiter *Limiter) GetIPKey(r *http.Request) string { 43 | return limiter.GetIPWithMask(r).String() 44 | } 45 | 46 | // GetIP returns IP address from request. 47 | // If options is defined and either TrustForwardHeader is true or ClientIPHeader is defined, 48 | // it will lookup IP in HTTP headers. 49 | // Please be advised that using this option could be insecure (ie: spoofed) if your reverse 50 | // proxy is not configured properly to forward a trustworthy client IP. 51 | // Please read the section "Limiter behind a reverse proxy" in the README for further information. 52 | func GetIP(r *http.Request, options ...Options) net.IP { 53 | if len(options) >= 1 { 54 | if options[0].ClientIPHeader != "" { 55 | ip := getIPFromHeader(r, options[0].ClientIPHeader) 56 | if ip != nil { 57 | return ip 58 | } 59 | } 60 | if options[0].TrustForwardHeader { 61 | ip := getIPFromXFFHeader(r) 62 | if ip != nil { 63 | return ip 64 | } 65 | 66 | ip = getIPFromHeader(r, "X-Real-IP") 67 | if ip != nil { 68 | return ip 69 | } 70 | } 71 | } 72 | 73 | remoteAddr := strings.TrimSpace(r.RemoteAddr) 74 | host, _, err := net.SplitHostPort(remoteAddr) 75 | if err != nil { 76 | return net.ParseIP(remoteAddr) 77 | } 78 | 79 | return net.ParseIP(host) 80 | } 81 | 82 | // GetIPWithMask returns IP address from request by applying a mask. 83 | // If options is defined and either TrustForwardHeader is true or ClientIPHeader is defined, 84 | // it will lookup IP in HTTP headers. 85 | // Please be advised that using this option could be insecure (ie: spoofed) if your reverse 86 | // proxy is not configured properly to forward a trustworthy client IP. 87 | // Please read the section "Limiter behind a reverse proxy" in the README for further information. 88 | func GetIPWithMask(r *http.Request, options ...Options) net.IP { 89 | if len(options) == 0 { 90 | return GetIP(r) 91 | } 92 | 93 | ip := GetIP(r, options[0]) 94 | if ip.To4() != nil { 95 | return ip.Mask(options[0].IPv4Mask) 96 | } 97 | if ip.To16() != nil { 98 | return ip.Mask(options[0].IPv6Mask) 99 | } 100 | return ip 101 | } 102 | 103 | func getIPFromXFFHeader(r *http.Request) net.IP { 104 | headers := r.Header.Values("X-Forwarded-For") 105 | if len(headers) == 0 { 106 | return nil 107 | } 108 | 109 | parts := []string{} 110 | for _, header := range headers { 111 | parts = append(parts, strings.Split(header, ",")...) 112 | } 113 | 114 | for i := range parts { 115 | part := strings.TrimSpace(parts[i]) 116 | ip := net.ParseIP(part) 117 | if ip != nil { 118 | return ip 119 | } 120 | } 121 | 122 | return nil 123 | } 124 | 125 | func getIPFromHeader(r *http.Request, name string) net.IP { 126 | header := strings.TrimSpace(r.Header.Get(name)) 127 | if header == "" { 128 | return nil 129 | } 130 | 131 | ip := net.ParseIP(header) 132 | if ip != nil { 133 | return ip 134 | } 135 | 136 | return nil 137 | } 138 | -------------------------------------------------------------------------------- /network_test.go: -------------------------------------------------------------------------------- 1 | package limiter_test 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | "net/url" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | 12 | "github.com/ulule/limiter/v3" 13 | ) 14 | 15 | func TestGetIP(t *testing.T) { 16 | is := require.New(t) 17 | 18 | limiter1 := New(limiter.WithTrustForwardHeader(false)) 19 | limiter2 := New(limiter.WithTrustForwardHeader(true)) 20 | limiter3 := New(limiter.WithIPv4Mask(net.CIDRMask(24, 32))) 21 | limiter4 := New(limiter.WithIPv6Mask(net.CIDRMask(48, 128))) 22 | 23 | request1 := &http.Request{ 24 | URL: &url.URL{Path: "/"}, 25 | Header: http.Header{}, 26 | RemoteAddr: "8.8.8.8:8888", 27 | } 28 | 29 | request2 := &http.Request{ 30 | URL: &url.URL{Path: "/foo"}, 31 | Header: http.Header{}, 32 | RemoteAddr: "8.8.8.8:8888", 33 | } 34 | request2.Header.Add("X-Forwarded-For", "9.9.9.9, 7.7.7.7, 6.6.6.6") 35 | 36 | request3 := &http.Request{ 37 | URL: &url.URL{Path: "/bar"}, 38 | Header: http.Header{}, 39 | RemoteAddr: "8.8.8.8:8888", 40 | } 41 | request3.Header.Add("X-Real-IP", "6.6.6.6") 42 | 43 | request4 := &http.Request{ 44 | URL: &url.URL{Path: "/"}, 45 | Header: http.Header{}, 46 | RemoteAddr: "[2001:db8:cafe:1234:beef::fafa]:8888", 47 | } 48 | 49 | scenarios := []struct { 50 | request *http.Request 51 | limiter *limiter.Limiter 52 | expected net.IP 53 | }{ 54 | { 55 | // 56 | // Scenario #1 : RemoteAddr without proxy. 57 | // 58 | request: request1, 59 | limiter: limiter1, 60 | expected: net.ParseIP("8.8.8.8").To4(), 61 | }, 62 | { 63 | // 64 | // Scenario #2 : X-Forwarded-For without proxy. 65 | // 66 | request: request2, 67 | limiter: limiter1, 68 | expected: net.ParseIP("8.8.8.8").To4(), 69 | }, 70 | { 71 | // 72 | // Scenario #3 : X-Real-IP without proxy. 73 | // 74 | request: request3, 75 | limiter: limiter1, 76 | expected: net.ParseIP("8.8.8.8").To4(), 77 | }, 78 | { 79 | // 80 | // Scenario #4 : RemoteAddr with proxy. 81 | // 82 | request: request1, 83 | limiter: limiter2, 84 | expected: net.ParseIP("8.8.8.8").To4(), 85 | }, 86 | { 87 | // 88 | // Scenario #5 : X-Forwarded-For with proxy. 89 | // 90 | request: request2, 91 | limiter: limiter2, 92 | expected: net.ParseIP("9.9.9.9").To4(), 93 | }, 94 | { 95 | // 96 | // Scenario #6 : X-Real-IP with proxy. 97 | // 98 | request: request3, 99 | limiter: limiter2, 100 | expected: net.ParseIP("6.6.6.6").To4(), 101 | }, 102 | { 103 | // 104 | // Scenario #7 : IPv4 with mask. 105 | // 106 | request: request1, 107 | limiter: limiter3, 108 | expected: net.ParseIP("8.8.8.0").To4(), 109 | }, 110 | { 111 | // 112 | // Scenario #8 : IPv6 with mask. 113 | // 114 | request: request4, 115 | limiter: limiter4, 116 | expected: net.ParseIP("2001:db8:cafe::").To16(), 117 | }, 118 | } 119 | 120 | for i, scenario := range scenarios { 121 | message := fmt.Sprintf("Scenario #%d", (i + 1)) 122 | ip := scenario.limiter.GetIPWithMask(scenario.request) 123 | is.Equal(scenario.expected, ip, message) 124 | } 125 | } 126 | 127 | func TestGetIPKey(t *testing.T) { 128 | is := require.New(t) 129 | 130 | limiter1 := New(limiter.WithTrustForwardHeader(false)) 131 | limiter2 := New(limiter.WithTrustForwardHeader(true)) 132 | limiter3 := New(limiter.WithIPv4Mask(net.CIDRMask(24, 32))) 133 | limiter4 := New(limiter.WithIPv6Mask(net.CIDRMask(48, 128))) 134 | 135 | request1 := &http.Request{ 136 | URL: &url.URL{Path: "/"}, 137 | Header: http.Header{}, 138 | RemoteAddr: "8.8.8.8:8888", 139 | } 140 | 141 | request2 := &http.Request{ 142 | URL: &url.URL{Path: "/foo"}, 143 | Header: http.Header{}, 144 | RemoteAddr: "8.8.8.8:8888", 145 | } 146 | request2.Header.Add("X-Forwarded-For", "9.9.9.9, 7.7.7.7, 6.6.6.6") 147 | 148 | request3 := &http.Request{ 149 | URL: &url.URL{Path: "/bar"}, 150 | Header: http.Header{}, 151 | RemoteAddr: "8.8.8.8:8888", 152 | } 153 | request3.Header.Add("X-Real-IP", "6.6.6.6") 154 | 155 | request4 := &http.Request{ 156 | URL: &url.URL{Path: "/"}, 157 | Header: http.Header{}, 158 | RemoteAddr: "[2001:db8:cafe:1234:beef::fafa]:8888", 159 | } 160 | 161 | scenarios := []struct { 162 | request *http.Request 163 | limiter *limiter.Limiter 164 | expected string 165 | }{ 166 | { 167 | // 168 | // Scenario #1 : RemoteAddr without proxy. 169 | // 170 | request: request1, 171 | limiter: limiter1, 172 | expected: "8.8.8.8", 173 | }, 174 | { 175 | // 176 | // Scenario #2 : X-Forwarded-For without proxy. 177 | // 178 | request: request2, 179 | limiter: limiter1, 180 | expected: "8.8.8.8", 181 | }, 182 | { 183 | // 184 | // Scenario #3 : X-Real-IP without proxy. 185 | // 186 | request: request3, 187 | limiter: limiter1, 188 | expected: "8.8.8.8", 189 | }, 190 | { 191 | // 192 | // Scenario #4 : RemoteAddr without proxy. 193 | // 194 | request: request1, 195 | limiter: limiter2, 196 | expected: "8.8.8.8", 197 | }, 198 | { 199 | // 200 | // Scenario #5 : X-Forwarded-For without proxy. 201 | // 202 | request: request2, 203 | limiter: limiter2, 204 | expected: "9.9.9.9", 205 | }, 206 | { 207 | // 208 | // Scenario #6 : X-Real-IP without proxy. 209 | // 210 | request: request3, 211 | limiter: limiter2, 212 | expected: "6.6.6.6", 213 | }, 214 | { 215 | // 216 | // Scenario #7 : IPv4 with mask. 217 | // 218 | request: request1, 219 | limiter: limiter3, 220 | expected: "8.8.8.0", 221 | }, 222 | { 223 | // 224 | // Scenario #8 : IPv6 with mask. 225 | // 226 | request: request4, 227 | limiter: limiter4, 228 | expected: "2001:db8:cafe::", 229 | }, 230 | } 231 | 232 | for i, scenario := range scenarios { 233 | message := fmt.Sprintf("Scenario #%d", (i + 1)) 234 | key := scenario.limiter.GetIPKey(scenario.request) 235 | is.Equal(scenario.expected, key, message) 236 | } 237 | } 238 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "net" 5 | ) 6 | 7 | // Option is a functional option. 8 | type Option func(*Options) 9 | 10 | // Options are limiter options. 11 | type Options struct { 12 | // IPv4Mask defines the mask used to obtain a IPv4 address. 13 | IPv4Mask net.IPMask 14 | // IPv6Mask defines the mask used to obtain a IPv6 address. 15 | IPv6Mask net.IPMask 16 | // TrustForwardHeader enable parsing of X-Real-IP and X-Forwarded-For headers to obtain user IP. 17 | // Please be advised that using this option could be insecure (ie: spoofed) if your reverse 18 | // proxy is not configured properly to forward a trustworthy client IP. 19 | // Please read the section "Limiter behind a reverse proxy" in the README for further information. 20 | TrustForwardHeader bool 21 | // ClientIPHeader defines a custom header (likely defined by your CDN or Cloud provider) to obtain user IP. 22 | // If configured, this option will override "TrustForwardHeader" option. 23 | // Please be advised that using this option could be insecure (ie: spoofed) if your reverse 24 | // proxy is not configured properly to forward a trustworthy client IP. 25 | // Please read the section "Limiter behind a reverse proxy" in the README for further information. 26 | ClientIPHeader string 27 | } 28 | 29 | // WithIPv4Mask will configure the limiter to use given mask for IPv4 address. 30 | func WithIPv4Mask(mask net.IPMask) Option { 31 | return func(o *Options) { 32 | o.IPv4Mask = mask 33 | } 34 | } 35 | 36 | // WithIPv6Mask will configure the limiter to use given mask for IPv6 address. 37 | func WithIPv6Mask(mask net.IPMask) Option { 38 | return func(o *Options) { 39 | o.IPv6Mask = mask 40 | } 41 | } 42 | 43 | // WithTrustForwardHeader will configure the limiter to trust X-Real-IP and X-Forwarded-For headers. 44 | // Please be advised that using this option could be insecure (ie: spoofed) if your reverse 45 | // proxy is not configured properly to forward a trustworthy client IP. 46 | // Please read the section "Limiter behind a reverse proxy" in the README for further information. 47 | func WithTrustForwardHeader(enable bool) Option { 48 | return func(o *Options) { 49 | o.TrustForwardHeader = enable 50 | } 51 | } 52 | 53 | // WithClientIPHeader will configure the limiter to use a custom header to obtain user IP. 54 | // Please be advised that using this option could be insecure (ie: spoofed) if your reverse 55 | // proxy is not configured properly to forward a trustworthy client IP. 56 | // Please read the section "Limiter behind a reverse proxy" in the README for further information. 57 | func WithClientIPHeader(header string) Option { 58 | return func(o *Options) { 59 | o.ClientIPHeader = header 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /rate.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | "time" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | // Rate is the rate. 12 | type Rate struct { 13 | Formatted string 14 | Period time.Duration 15 | Limit int64 16 | } 17 | 18 | // NewRateFromFormatted returns the rate from the formatted version. 19 | func NewRateFromFormatted(formatted string) (Rate, error) { 20 | rate := Rate{} 21 | 22 | values := strings.Split(formatted, "-") 23 | if len(values) != 2 { 24 | return rate, errors.Errorf("incorrect format '%s'", formatted) 25 | } 26 | 27 | periods := map[string]time.Duration{ 28 | "S": time.Second, // Second 29 | "M": time.Minute, // Minute 30 | "H": time.Hour, // Hour 31 | "D": time.Hour * 24, // Day 32 | } 33 | 34 | limit, period := values[0], strings.ToUpper(values[1]) 35 | 36 | p, ok := periods[period] 37 | if !ok { 38 | return rate, errors.Errorf("incorrect period '%s'", period) 39 | } 40 | 41 | l, err := strconv.ParseInt(limit, 10, 64) 42 | if err != nil { 43 | return rate, errors.Errorf("incorrect limit '%s'", limit) 44 | } 45 | 46 | rate = Rate{ 47 | Formatted: formatted, 48 | Period: p, 49 | Limit: l, 50 | } 51 | 52 | return rate, nil 53 | } 54 | -------------------------------------------------------------------------------- /rate_test.go: -------------------------------------------------------------------------------- 1 | package limiter_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/ulule/limiter/v3" 11 | ) 12 | 13 | // TestRate tests Rate methods. 14 | func TestRate(t *testing.T) { 15 | is := require.New(t) 16 | 17 | expected := map[string]limiter.Rate{ 18 | "10-S": { 19 | Formatted: "10-S", 20 | Period: 1 * time.Second, 21 | Limit: int64(10), 22 | }, 23 | "356-M": { 24 | Formatted: "356-M", 25 | Period: 1 * time.Minute, 26 | Limit: int64(356), 27 | }, 28 | "3-H": { 29 | Formatted: "3-H", 30 | Period: 1 * time.Hour, 31 | Limit: int64(3), 32 | }, 33 | "2000-D": { 34 | Formatted: "2000-D", 35 | Period: 24 * time.Hour, 36 | Limit: int64(2000), 37 | }, 38 | } 39 | 40 | for k, v := range expected { 41 | r, err := limiter.NewRateFromFormatted(k) 42 | is.NoError(err) 43 | is.True(reflect.DeepEqual(v, r)) 44 | } 45 | 46 | wrongs := []string{ 47 | "10 S", 48 | "10:S", 49 | "AZERTY", 50 | "na wak", 51 | "H-10", 52 | } 53 | 54 | for _, w := range wrongs { 55 | _, err := limiter.NewRateFromFormatted(w) 56 | is.Error(err) 57 | } 58 | 59 | } 60 | -------------------------------------------------------------------------------- /scripts/conf/go/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1-bullseye 2 | 3 | MAINTAINER thomas@leroux.io 4 | 5 | ENV DEBIAN_FRONTEND noninteractive 6 | ENV LANG C.UTF-8 7 | ENV LC_ALL C.UTF-8 8 | 9 | RUN apt-get -y update \ 10 | && apt-get upgrade -y \ 11 | && apt-get -y install git \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* \ 14 | && useradd -ms /bin/bash gopher 15 | 16 | COPY go.mod go.sum /media/ulule/limiter/ 17 | RUN chown -R gopher:gopher /media/ulule/limiter 18 | ENV GOPATH /home/gopher/go 19 | ENV PATH $GOPATH/bin:$PATH 20 | USER gopher 21 | 22 | RUN go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest 23 | 24 | WORKDIR /media/ulule/limiter 25 | RUN go mod download 26 | COPY --chown=gopher:gopher . /media/ulule/limiter 27 | 28 | CMD [ "/bin/bash" ] 29 | -------------------------------------------------------------------------------- /scripts/go-wrapper: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eo pipefail 4 | 5 | SOURCE_DIRECTORY=$(dirname "${BASH_SOURCE[0]}") 6 | cd "${SOURCE_DIRECTORY}/.." 7 | 8 | ROOT_DIRECTORY=`pwd` 9 | IMAGE_NAME="limiter-go" 10 | DOCKERFILE="scripts/conf/go/Dockerfile" 11 | CONTAINER_IMAGE="golang:1-bullseye" 12 | 13 | if [[ -n "$REDIS_DISABLE_BOOTSTRAP" ]]; then 14 | REDIS_DISABLE_BOOTSTRAP_OPTS="-e REDIS_DISABLE_BOOTSTRAP=$REDIS_DISABLE_BOOTSTRAP" 15 | fi 16 | 17 | if [[ -n "$REDIS_URI" ]]; then 18 | REDIS_URI_OPTS="-e REDIS_URI=$REDIS_URI" 19 | fi 20 | 21 | create_docker_image() { 22 | declare tag="$1" dockerfile="$2" path="$3" 23 | 24 | echo "[go-wrapper] update golang image" 25 | docker pull ${CONTAINER_IMAGE} || true 26 | 27 | echo "[go-wrapper] build docker image" 28 | docker build -f "${dockerfile}" -t "${tag}" "${path}" 29 | } 30 | 31 | do_command() { 32 | declare command="$@" 33 | 34 | echo "[go-wrapper] run '${command}' in docker container" 35 | docker run --rm --net=host ${REDIS_DISABLE_BOOTSTRAP_OPTS} ${REDIS_URI_OPTS} \ 36 | "${IMAGE_NAME}" ${command} 37 | } 38 | 39 | do_usage() { 40 | 41 | echo >&2 "Usage: $0 command" 42 | exit 255 43 | 44 | } 45 | 46 | if [ -z "$1" ]; then 47 | do_usage 48 | fi 49 | 50 | create_docker_image "${IMAGE_NAME}" "${DOCKERFILE}" "${ROOT_DIRECTORY}" 51 | do_command "$@" 52 | 53 | exit 0 54 | -------------------------------------------------------------------------------- /scripts/lint: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eo pipefail 4 | 5 | if [[ ! -x "$(command -v go)" ]]; then 6 | echo >&2 "go runtime is required: https://golang.org/doc/install" 7 | echo >&2 "You can use scripts/go-wrapper $0 to use go in a docker container." 8 | exit 1 9 | fi 10 | 11 | golinter_path="${GOPATH}/bin/golangci-lint" 12 | 13 | if [[ ! -x "${golinter_path}" ]]; then 14 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest 15 | fi 16 | 17 | SOURCE_DIRECTORY=$(dirname "${BASH_SOURCE[0]}") 18 | cd "${SOURCE_DIRECTORY}/.." 19 | 20 | if [[ -n $1 ]]; then 21 | golangci-lint run "$1" 22 | else 23 | golangci-lint run ./... 24 | fi 25 | -------------------------------------------------------------------------------- /scripts/redis: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eo pipefail 4 | 5 | DOCKER_REDIS_PORT=${DOCKER_REDIS_PORT:-26379} 6 | 7 | CONTAINER_NAME="limiter-redis" 8 | CONTAINER_IMAGE="redis:6.0" 9 | 10 | do_start() { 11 | 12 | if [[ -n "$(docker ps -q -f name="${CONTAINER_NAME}" 2> /dev/null)" ]]; then 13 | echo "[redis] ${CONTAINER_NAME} already started. (use --restart otherwise)" 14 | return 0 15 | fi 16 | 17 | if [[ -n "$(docker ps -a -q -f name="${CONTAINER_NAME}" 2> /dev/null)" ]]; then 18 | echo "[redis] erase previous configuration" 19 | docker stop "${CONTAINER_NAME}" >/dev/null 2>&1 || true 20 | docker rm "${CONTAINER_NAME}" >/dev/null 2>&1 || true 21 | fi 22 | 23 | echo "[redis] update redis images" 24 | docker pull ${CONTAINER_IMAGE} || true 25 | 26 | echo "[redis] start new ${CONTAINER_NAME} container" 27 | docker run --name "${CONTAINER_NAME}" \ 28 | -p ${DOCKER_REDIS_PORT}:6379 \ 29 | -d ${CONTAINER_IMAGE} >/dev/null 30 | 31 | } 32 | 33 | do_stop() { 34 | 35 | echo "[redis] stop ${CONTAINER_NAME} container" 36 | docker stop "${CONTAINER_NAME}" >/dev/null 2>&1 || true 37 | docker rm "${CONTAINER_NAME}" >/dev/null 2>&1 || true 38 | 39 | } 40 | 41 | do_client() { 42 | 43 | echo "[redis] use redis-cli on ${CONTAINER_NAME}" 44 | docker run --rm -it \ 45 | --link "${CONTAINER_NAME}":redis \ 46 | ${CONTAINER_IMAGE} redis-cli -h redis -p 6379 -n 1 47 | 48 | } 49 | 50 | case "$1" in 51 | --stop) 52 | do_stop 53 | ;; 54 | --restart) 55 | do_stop 56 | do_start 57 | ;; 58 | --client) 59 | do_client 60 | ;; 61 | --start | *) 62 | do_start 63 | ;; 64 | esac 65 | exit 0 66 | -------------------------------------------------------------------------------- /scripts/test: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ ! -x "$(command -v go)" ]]; then 4 | echo >&2 "go runtime is required: https://golang.org/doc/install" 5 | echo >&2 "You can use scripts/go-wrapper $0 to use go in a docker container." 6 | exit 1 7 | fi 8 | 9 | SOURCE_DIRECTORY=$(dirname "${BASH_SOURCE[0]}") 10 | cd "${SOURCE_DIRECTORY}/.." 11 | 12 | if [ -z "$REDIS_DISABLE_BOOTSTRAP" ]; then 13 | export REDIS_URI="redis://localhost:26379/1" 14 | scripts/redis --restart 15 | fi 16 | 17 | go test -count=1 -race -v $(go list ./... | grep -v -E '\/(vendor|examples)\/') 18 | -------------------------------------------------------------------------------- /store.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | // Store is the common interface for limiter stores. 9 | type Store interface { 10 | // Get returns the limit for given identifier. 11 | Get(ctx context.Context, key string, rate Rate) (Context, error) 12 | // Peek returns the limit for given identifier, without modification on current values. 13 | Peek(ctx context.Context, key string, rate Rate) (Context, error) 14 | // Reset resets the limit to zero for given identifier. 15 | Reset(ctx context.Context, key string, rate Rate) (Context, error) 16 | // Increment increments the limit by given count & gives back the new limit for given identifier 17 | Increment(ctx context.Context, key string, count int64, rate Rate) (Context, error) 18 | } 19 | 20 | // StoreOptions are options for store. 21 | type StoreOptions struct { 22 | // Prefix is the prefix to use for the key. 23 | Prefix string 24 | 25 | // MaxRetry is the maximum number of retry under race conditions on redis store. 26 | // Deprecated: this option is no longer required since all operations are atomic now. 27 | MaxRetry int 28 | 29 | // CleanUpInterval is the interval for cleanup (run garbage collection) on stale entries on memory store. 30 | // Setting this to a low value will optimize memory consumption, but will likely 31 | // reduce performance and increase lock contention. 32 | // Setting this to a high value will maximum throughput, but will increase the memory footprint. 33 | CleanUpInterval time.Duration 34 | } 35 | --------------------------------------------------------------------------------