├── doc.go ├── Makefile ├── go.mod ├── .gitignore ├── .travis.yml ├── example_test.go ├── LICENSE ├── ratelimiter.lua ├── benchmark_test.go ├── example └── main.go ├── go.sum ├── README.md ├── memory.go ├── ratelimiter.go ├── memory_test.go └── ratelimiter_test.go /doc.go: -------------------------------------------------------------------------------- 1 | // Package ratelimiter provides the fastest abstract rate limiter, base on redis. 2 | package ratelimiter 3 | 4 | // Version is Ratelimiter's version 5 | const Version = "0.5.3" 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: 2 | go test --race -v 3 | 4 | bench: 5 | go test -bench=. 6 | 7 | cover: 8 | rm -f *.coverprofile 9 | go test -coverprofile=ratelimiter.coverprofile 10 | gover 11 | go tool cover -html=ratelimiter.coverprofile 12 | 13 | .PHONY: test cover 14 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/teambition/ratelimiter-go 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/go-redis/redis v6.15.2+incompatible 7 | github.com/onsi/ginkgo v1.9.0 // indirect 8 | github.com/onsi/gomega v1.6.0 // indirect 9 | github.com/stretchr/testify v1.3.0 10 | ) 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | *.coverprofile 6 | 7 | # Folders 8 | _obj 9 | _test 10 | 11 | # Architecture specific extensions/prefixes 12 | *.[568vq] 13 | [568vq].out 14 | 15 | *.cgo1.go 16 | *.cgo2.c 17 | _cgo_defun.c 18 | _cgo_gotypes.go 19 | _cgo_export.* 20 | 21 | _testmain.go 22 | 23 | *.exe 24 | *.test 25 | *.prof 26 | 27 | vendor 28 | make.ps1 -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: go 3 | go: 4 | - "1.9.5" 5 | - "1.9.6" 6 | - "1.9.7" 7 | - "1.10" 8 | - "1.10.1" 9 | - "1.10.2" 10 | - "1.10.3" 11 | services: 12 | - redis-server 13 | before_install: 14 | - go get -t -v ./... 15 | - go get github.com/modocache/gover 16 | - go get github.com/mattn/goveralls 17 | script: 18 | - go test -coverprofile=ratelimiter.coverprofile 19 | - gover 20 | - goveralls -coverprofile=ratelimiter.coverprofile -service=travis-ci 21 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package ratelimiter_test 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/go-redis/redis" 8 | "github.com/teambition/ratelimiter-go" 9 | ) 10 | 11 | func ExampleRatelimiterGo() { 12 | client := redis.NewClient(&redis.Options{ 13 | Addr: "localhost:6379", 14 | }) 15 | 16 | limiter := ratelimiter.New(ratelimiter.Options{ 17 | Client: &redisClient{client}, 18 | Max: 10, 19 | Duration: time.Second, // limit to 1000 requests in 1 minute. 20 | }) 21 | 22 | userID := "user-123456" 23 | res, err := limiter.Get(userID) 24 | if err != nil { 25 | panic(err) 26 | } 27 | // fmt.Println(res.Reset) Reset time: 2016-10-11 21:17:53.362 +0800 CST 28 | fmt.Println(res.Total) 29 | fmt.Println(res.Remaining) 30 | fmt.Println(res.Duration) 31 | // Output: 32 | // 10 33 | // 9 34 | // 1s 35 | } 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016-2018 Teambition 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ratelimiter.lua: -------------------------------------------------------------------------------- 1 | -- KEYS[1] target hash key 2 | -- KEYS[2] target status hash key 3 | -- ARGV[n >= 3] current timestamp, max count, duration, max count, duration, ... 4 | 5 | -- HASH: KEYS[1] 6 | -- field:ct(count) 7 | -- field:lt(limit) 8 | -- field:dn(duration) 9 | -- field:rt(reset) 10 | 11 | local res = {} 12 | local policyCount = (#ARGV - 1) / 2 13 | local limit = redis.call('hmget', KEYS[1], 'ct', 'lt', 'dn', 'rt') 14 | 15 | if limit[1] then 16 | 17 | res[1] = tonumber(limit[1]) - 1 18 | res[2] = tonumber(limit[2]) 19 | res[3] = tonumber(limit[3]) or ARGV[3] 20 | res[4] = tonumber(limit[4]) 21 | 22 | if policyCount > 1 and res[1] == -1 then 23 | redis.call('incr', KEYS[2]) 24 | redis.call('pexpire', KEYS[2], res[3] * 2) 25 | local index = tonumber(redis.call('get', KEYS[2])) 26 | if index == 1 then 27 | redis.call('incr', KEYS[2]) 28 | end 29 | end 30 | 31 | if res[1] >= -1 then 32 | redis.call('hincrby', KEYS[1], 'ct', -1) 33 | else 34 | res[1] = -1 35 | end 36 | 37 | else 38 | 39 | local index = 1 40 | if policyCount > 1 then 41 | index = tonumber(redis.call('get', KEYS[2])) or 1 42 | if index > policyCount then 43 | index = policyCount 44 | end 45 | end 46 | 47 | local total = tonumber(ARGV[index * 2]) 48 | res[1] = total - 1 49 | res[2] = total 50 | res[3] = tonumber(ARGV[index * 2 + 1]) 51 | res[4] = tonumber(ARGV[1]) + res[3] 52 | 53 | redis.call('hmset', KEYS[1], 'ct', res[1], 'lt', res[2], 'dn', res[3], 'rt', res[4]) 54 | redis.call('pexpire', KEYS[1], res[3]) 55 | 56 | end 57 | 58 | return res 59 | -------------------------------------------------------------------------------- /benchmark_test.go: -------------------------------------------------------------------------------- 1 | package ratelimiter_test 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | "testing" 7 | 8 | ratelimiter "github.com/teambition/ratelimiter-go" 9 | ) 10 | 11 | func BenchmarkGet(b *testing.B) { 12 | 13 | limiter := ratelimiter.New(ratelimiter.Options{}) 14 | policy := []int{1000000, 1000} 15 | id := getUniqueID() 16 | 17 | b.N = 100000 18 | b.ReportAllocs() 19 | b.ResetTimer() 20 | for i := 0; i < b.N; i++ { 21 | limiter.Get(id, policy...) 22 | } 23 | } 24 | func BenchmarkGetAndEexceeding(b *testing.B) { 25 | 26 | limiter := ratelimiter.New(ratelimiter.Options{}) 27 | policy := []int{100, 1000} 28 | id := getUniqueID() 29 | 30 | b.N = 100000 31 | b.ReportAllocs() 32 | b.ResetTimer() 33 | for i := 0; i < b.N; i++ { 34 | limiter.Get(id, policy...) 35 | } 36 | } 37 | func BenchmarkGetAndParallel(b *testing.B) { 38 | limiter := ratelimiter.New(ratelimiter.Options{}) 39 | policy := []int{1000000, 1000} 40 | id := getUniqueID() 41 | 42 | b.N = 100000 43 | b.ReportAllocs() 44 | b.ResetTimer() 45 | 46 | b.RunParallel(func(pb *testing.PB) { 47 | for pb.Next() { 48 | limiter.Get(id, policy...) 49 | } 50 | }) 51 | } 52 | func BenchmarkGetAndClean(b *testing.B) { 53 | limiter := ratelimiter.New(ratelimiter.Options{}) 54 | policy := []int{1000000, 1000} 55 | id := getUniqueID() 56 | 57 | b.N = 100000 58 | b.ReportAllocs() 59 | b.ResetTimer() 60 | 61 | b.RunParallel(func(pb *testing.PB) { 62 | for pb.Next() { 63 | limiter.Get(id, policy...) 64 | } 65 | //limiter.Clean() 66 | }) 67 | } 68 | func BenchmarkGetForDifferentUser(b *testing.B) { 69 | limiter := ratelimiter.New(ratelimiter.Options{}) 70 | policy := []int{1, 10000} 71 | 72 | b.N = 10000 73 | b.ReportAllocs() 74 | b.ResetTimer() 75 | 76 | b.RunParallel(func(pb *testing.PB) { 77 | for pb.Next() { 78 | id := getUniqueID() 79 | limiter.Get(id, policy...) 80 | } 81 | //limiter.Clean() 82 | }) 83 | } 84 | 85 | func getUniqueID() string { 86 | buf := make([]byte, 12) 87 | _, err := rand.Read(buf) 88 | if err != nil { 89 | panic(err) 90 | } 91 | return hex.EncodeToString(buf) 92 | } 93 | -------------------------------------------------------------------------------- /example/main.go: -------------------------------------------------------------------------------- 1 | // The ratelimiter-go HTTP Demo 2 | 3 | package main 4 | 5 | import ( 6 | "fmt" 7 | "html" 8 | "log" 9 | "net/http" 10 | "strconv" 11 | "time" 12 | 13 | redis "github.com/go-redis/redis" 14 | ratelimiter "github.com/teambition/ratelimiter-go" 15 | ) 16 | 17 | // Implements RedisClient for redis.Client 18 | type redisClient struct { 19 | *redis.Client 20 | } 21 | 22 | func (c *redisClient) RateDel(key string) error { 23 | return c.Del(key).Err() 24 | } 25 | 26 | func (c *redisClient) RateEvalSha(sha1 string, keys []string, args ...interface{}) (interface{}, error) { 27 | return c.EvalSha(sha1, keys, args...).Result() 28 | } 29 | 30 | func (c *redisClient) RateScriptLoad(script string) (string, error) { 31 | return c.ScriptLoad(script).Result() 32 | } 33 | 34 | func main() { 35 | // use memory 36 | // limiter := ratelimiter.New(ratelimiter.Options{ 37 | // Max: 10, 38 | // Duration: time.Minute, // limit to 1000 requests in 1 minute. 39 | // }) 40 | 41 | // or use redis 42 | client := redis.NewClient(&redis.Options{ 43 | Addr: "localhost:6379", 44 | }) 45 | limiter := ratelimiter.New(ratelimiter.Options{ 46 | Max: 10, 47 | Duration: time.Minute, // limit to 1000 requests in 1 minute. 48 | Client: &redisClient{client}, 49 | }) 50 | 51 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 52 | res, err := limiter.Get(r.URL.Path) 53 | if err != nil { 54 | http.Error(w, err.Error(), 500) 55 | return 56 | } 57 | 58 | header := w.Header() 59 | header.Set("X-Ratelimit-Limit", strconv.FormatInt(int64(res.Total), 10)) 60 | header.Set("X-Ratelimit-Remaining", strconv.FormatInt(int64(res.Remaining), 10)) 61 | header.Set("X-Ratelimit-Reset", strconv.FormatInt(res.Reset.Unix(), 10)) 62 | 63 | if res.Remaining >= 0 { 64 | w.WriteHeader(200) 65 | fmt.Fprintf(w, "Path: %q\n", html.EscapeString(r.URL.Path)) 66 | fmt.Fprintf(w, "Remaining: %d\n", res.Remaining) 67 | fmt.Fprintf(w, "Total: %d\n", res.Total) 68 | fmt.Fprintf(w, "Duration: %v\n", res.Duration) 69 | fmt.Fprintf(w, "Reset: %v\n", res.Reset) 70 | } else { 71 | after := int64(res.Reset.Sub(time.Now())) / 1e9 72 | header.Set("Retry-After", strconv.FormatInt(after, 10)) 73 | w.WriteHeader(429) 74 | fmt.Fprintf(w, "Rate limit exceeded, retry in %d seconds.\n", after) 75 | } 76 | }) 77 | log.Fatal(http.ListenAndServe(":8080", nil)) 78 | } 79 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= 4 | github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= 5 | github.com/go-redis/redis v6.15.2+incompatible h1:9SpNVG76gr6InJGxoZ6IuuxaCOQwDAhzyXg+Bs+0Sb4= 6 | github.com/go-redis/redis v6.15.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= 7 | github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= 8 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 9 | github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= 10 | github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= 11 | github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= 12 | github.com/onsi/ginkgo v1.9.0 h1:SZjF721BByVj8QH636/8S2DnX4n0Re3SteMmw3N+tzc= 13 | github.com/onsi/ginkgo v1.9.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= 14 | github.com/onsi/gomega v1.6.0 h1:8XTW0fcJZEq9q+Upcyws4JSGua2MFysCL5xkaSgHc+M= 15 | github.com/onsi/gomega v1.6.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= 16 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 17 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 18 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 19 | github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= 20 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 21 | golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= 22 | golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 23 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= 24 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 25 | golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e h1:o3PsSEY8E4eXWkXrIP9YJALUkVZqzHJT5DOasTyn8Vs= 26 | golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 27 | golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= 28 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 29 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 30 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 31 | gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= 32 | gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= 33 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= 34 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= 35 | gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= 36 | gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ratelimiter-go 2 | ========== 3 | The fastest abstract rate limiter, base on memory or redis storage. 4 | 5 | [![Build Status](http://img.shields.io/travis/teambition/ratelimiter-go.svg?style=flat-square)](https://travis-ci.org/teambition/ratelimiter-go) 6 | [![Coverage Status](http://img.shields.io/coveralls/teambition/ratelimiter-go.svg?style=flat-square)](https://coveralls.io/r/teambition/ratelimiter-go) 7 | [![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/teambition/ratelimiter-go/master/LICENSE) 8 | [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](http://godoc.org/github.com/teambition/ratelimiter-go) 9 | 10 | ## Features 11 | 12 | - Distributed 13 | - Atomicity 14 | - High-performance 15 | - Support redis cluster 16 | - Support memory storage 17 | 18 | ## Installation 19 | 20 | ```sh 21 | go get github.com/teambition/ratelimiter-go 22 | ``` 23 | 24 | ## HTTP Server Demo 25 | Try into `github.com/teambition/ratelimiter-go` directory: 26 | 27 | ```sh 28 | go run example/main.go 29 | ``` 30 | Visit: http://127.0.0.1:8080/ 31 | 32 | ```go 33 | package main 34 | 35 | import ( 36 | "fmt" 37 | "html" 38 | "log" 39 | "net/http" 40 | "strconv" 41 | "time" 42 | 43 | ratelimiter "github.com/teambition/ratelimiter-go" 44 | redis "github.com/go-redis/redis" 45 | ) 46 | 47 | // Implements RedisClient for redis.Client 48 | type redisClient struct { 49 | *redis.Client 50 | } 51 | 52 | func (c *redisClient) RateDel(key string) error { 53 | return c.Del(key).Err() 54 | } 55 | func (c *redisClient) RateEvalSha(sha1 string, keys []string, args ...interface{}) (interface{}, error) { 56 | return c.EvalSha(sha1, keys, args...).Result() 57 | } 58 | func (c *redisClient) RateScriptLoad(script string) (string, error) { 59 | return c.ScriptLoad(script).Result() 60 | } 61 | 62 | func main() { 63 | // use memory 64 | // limiter := ratelimiter.New(ratelimiter.Options{ 65 | // Max: 10, 66 | // Duration: time.Minute, // limit to 1000 requests in 1 minute. 67 | // }) 68 | 69 | // or use redis 70 | client := redis.NewClient(&redis.Options{ 71 | Addr: "localhost:6379", 72 | }) 73 | limiter := ratelimiter.New(ratelimiter.Options{ 74 | Max: 10, 75 | Duration: time.Minute, // limit to 1000 requests in 1 minute. 76 | Client: &redisClient{client}, 77 | }) 78 | 79 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 80 | res, err := limiter.Get(r.URL.Path) 81 | if err != nil { 82 | http.Error(w, err.Error(), 500) 83 | return 84 | } 85 | 86 | header := w.Header() 87 | header.Set("X-Ratelimit-Limit", strconv.FormatInt(int64(res.Total), 10)) 88 | header.Set("X-Ratelimit-Remaining", strconv.FormatInt(int64(res.Remaining), 10)) 89 | header.Set("X-Ratelimit-Reset", strconv.FormatInt(res.Reset.Unix(), 10)) 90 | 91 | if res.Remaining >= 0 { 92 | w.WriteHeader(200) 93 | fmt.Fprintf(w, "Path: %q\n", html.EscapeString(r.URL.Path)) 94 | fmt.Fprintf(w, "Remaining: %d\n", res.Remaining) 95 | fmt.Fprintf(w, "Total: %d\n", res.Total) 96 | fmt.Fprintf(w, "Duration: %v\n", res.Duration) 97 | fmt.Fprintf(w, "Reset: %v\n", res.Reset) 98 | } else { 99 | after := int64(res.Reset.Sub(time.Now())) / 1e9 100 | header.Set("Retry-After", strconv.FormatInt(after, 10)) 101 | w.WriteHeader(429) 102 | fmt.Fprintf(w, "Rate limit exceeded, retry in %d seconds.\n", after) 103 | } 104 | }) 105 | log.Fatal(http.ListenAndServe(":8080", nil)) 106 | } 107 | ``` 108 | 109 | ## Node.js version: [thunk-ratelimiter](https://github.com/thunks/thunk-ratelimiter) 110 | 111 | ## Documentation 112 | The docs can be found at [godoc.org](https://godoc.org/github.com/teambition/ratelimiter-go), as usual. 113 | 114 | ## License 115 | `ratelimiter-go` is licensed under the [MIT](https://github.com/teambition/ratelimiter-go/blob/master/LICENSE) license. 116 | Copyright © 2016-2018 [Teambition](https://www.teambition.com). 117 | -------------------------------------------------------------------------------- /memory.go: -------------------------------------------------------------------------------- 1 | package ratelimiter 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // policy status 10 | type statusCacheItem struct { 11 | index int 12 | expire time.Time 13 | } 14 | 15 | // limit status 16 | type limiterCacheItem struct { 17 | total int 18 | remaining int 19 | duration time.Duration 20 | expire time.Time 21 | } 22 | 23 | type memoryLimiter struct { 24 | max int 25 | duration time.Duration 26 | status map[string]*statusCacheItem 27 | store map[string]*limiterCacheItem 28 | ticker *time.Ticker 29 | lock sync.Mutex 30 | } 31 | 32 | func newMemoryLimiter(opts *Options) *Limiter { 33 | m := &memoryLimiter{ 34 | max: opts.Max, 35 | duration: opts.Duration, 36 | store: make(map[string]*limiterCacheItem), 37 | status: make(map[string]*statusCacheItem), 38 | ticker: time.NewTicker(time.Second), 39 | } 40 | go m.cleanCache() 41 | return &Limiter{m, opts.Prefix} 42 | } 43 | 44 | // abstractLimiter interface 45 | func (m *memoryLimiter) getLimit(key string, policy ...int) ([]interface{}, error) { 46 | length := len(policy) 47 | var args []int 48 | if length == 0 { 49 | args = []int{m.max, int(m.duration / time.Millisecond)} 50 | } else { 51 | args = make([]int, length) 52 | for i, val := range policy { 53 | if val <= 0 { 54 | return nil, errors.New("ratelimiter: must be positive integer") 55 | } 56 | args[i] = policy[i] 57 | } 58 | } 59 | 60 | res := m.getItem(key, args...) 61 | m.lock.Lock() 62 | defer m.lock.Unlock() 63 | return []interface{}{res.remaining, res.total, res.duration, res.expire}, nil 64 | } 65 | 66 | // abstractLimiter interface 67 | func (m *memoryLimiter) removeLimit(key string) error { 68 | statusKey := "{" + key + "}:S" 69 | m.lock.Lock() 70 | defer m.lock.Unlock() 71 | delete(m.store, key) 72 | delete(m.status, statusKey) 73 | return nil 74 | } 75 | 76 | func (m *memoryLimiter) clean() { 77 | m.lock.Lock() 78 | defer m.lock.Unlock() 79 | start := time.Now() 80 | expireTime := start.Add(time.Millisecond * 100) 81 | frequency := 24 82 | var expired int 83 | for { 84 | label: 85 | for i := 0; i < frequency; i++ { 86 | for key, value := range m.store { 87 | if value.expire.Add(value.duration).Before(start) { 88 | statusKey := "{" + key + "}:S" 89 | delete(m.store, key) 90 | delete(m.status, statusKey) 91 | expired++ 92 | } 93 | break 94 | } 95 | } 96 | if expireTime.Before(time.Now()) { 97 | return 98 | } 99 | if expired > frequency/4 { 100 | expired = 0 101 | goto label 102 | } 103 | return 104 | } 105 | } 106 | 107 | func (m *memoryLimiter) getItem(key string, args ...int) (res *limiterCacheItem) { 108 | policyCount := len(args) / 2 109 | statusKey := "{" + key + "}:S" 110 | 111 | m.lock.Lock() 112 | defer m.lock.Unlock() 113 | var ok bool 114 | if res, ok = m.store[key]; !ok { 115 | res = &limiterCacheItem{ 116 | total: args[0], 117 | remaining: args[0] - 1, 118 | duration: time.Duration(args[1]) * time.Millisecond, 119 | expire: time.Now().Add(time.Duration(args[1]) * time.Millisecond), 120 | } 121 | m.store[key] = res 122 | return 123 | } 124 | if res.expire.After(time.Now()) { 125 | if policyCount > 1 && res.remaining-1 == -1 { 126 | statusItem, ok := m.status[statusKey] 127 | if ok { 128 | statusItem.expire = time.Now().Add(res.duration * 2) 129 | statusItem.index++ 130 | } else { 131 | statusItem := &statusCacheItem{ 132 | index: 2, 133 | expire: time.Now().Add(time.Duration(args[1]) * time.Millisecond * 2), 134 | } 135 | m.status[statusKey] = statusItem 136 | } 137 | } 138 | if res.remaining >= 0 { 139 | res.remaining-- 140 | } else { 141 | res.remaining = -1 142 | } 143 | } else { 144 | index := 1 145 | if policyCount > 1 { 146 | if statusItem, ok := m.status[statusKey]; ok { 147 | if statusItem.expire.Before(time.Now()) { 148 | index = 1 149 | } else if statusItem.index > policyCount { 150 | index = policyCount 151 | } else { 152 | index = statusItem.index 153 | } 154 | statusItem.index = index 155 | } 156 | } 157 | total := args[(index*2)-2] 158 | duration := args[(index*2)-1] 159 | res.total = total 160 | res.remaining = total - 1 161 | res.duration = time.Duration(duration) * time.Millisecond 162 | res.expire = time.Now().Add(time.Duration(duration) * time.Millisecond) 163 | } 164 | return 165 | } 166 | 167 | func (m *memoryLimiter) cleanCache() { 168 | for range m.ticker.C { 169 | m.clean() 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /ratelimiter.go: -------------------------------------------------------------------------------- 1 | package ratelimiter 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | // RedisClient defines a redis client struct that ratelimiter need. 12 | // Examples: https://github.com/teambition/ratelimiter-go/blob/master/ratelimiter_test.go#L18 13 | /* 14 | Implements RedisClient for a simple redis client: 15 | 16 | import "gopkg.in/redis.v4" 17 | 18 | type redisClient struct { 19 | *redis.Client 20 | } 21 | 22 | func (c *redisClient) RateDel(key string) error { 23 | return c.Del(key).Err() 24 | } 25 | func (c *redisClient) RateEvalSha(sha1 string, keys []string, args ...interface{}) (interface{}, error) { 26 | return c.EvalSha(sha1, keys, args...).Result() 27 | } 28 | func (c *redisClient) RateScriptLoad(script string) (string, error) { 29 | return c.ScriptLoad(lua).Result() 30 | } 31 | 32 | Implements RedisClient for a cluster redis client: 33 | 34 | import "gopkg.in/redis.v4" 35 | 36 | type clusterClient struct { 37 | *redis.ClusterClient 38 | } 39 | 40 | func (c *clusterClient) RateDel(key string) error { 41 | return c.Del(key).Err() 42 | } 43 | func (c *clusterClient) RateEvalSha(sha1 string, keys []string, args ...interface{}) (interface{}, error) { 44 | return c.EvalSha(sha1, keys, args...).Result() 45 | } 46 | func (c *clusterClient) RateScriptLoad(script string) (string, error) { 47 | var sha1 string 48 | err := c.ForEachMaster(func(client *redis.Client) error { 49 | res, err := client.ScriptLoad(script).Result() 50 | if err == nil { 51 | sha1 = res 52 | } 53 | return err 54 | }) 55 | return sha1, err 56 | } 57 | 58 | Uses it: 59 | 60 | client := redis.NewClient(&redis.Options{ 61 | Addr: "localhost:6379", 62 | }) 63 | limiter := ratelimiter.New(ratelimiter.Options{Client: redisClient{client}}) 64 | */ 65 | type RedisClient interface { 66 | RateDel(string) error 67 | RateEvalSha(string, []string, ...interface{}) (interface{}, error) 68 | RateScriptLoad(string) (string, error) 69 | } 70 | 71 | // Limiter struct. 72 | type Limiter struct { 73 | abstractLimiter 74 | prefix string 75 | } 76 | 77 | // Options for Limiter 78 | type Options struct { 79 | Max int // The max count in duration for no policy, default is 100. 80 | Duration time.Duration // Count duration for no policy, default is 1 Minute. 81 | Prefix string // Redis key prefix, default is "LIMIT:". 82 | Client RedisClient // Use a redis client for limiter, if omit, it will use a memory limiter. 83 | } 84 | 85 | // Result of limiter.Get 86 | type Result struct { 87 | Total int // It Equals Options.Max, or policy max 88 | Remaining int // It will always >= -1 89 | Duration time.Duration // It Equals Options.Duration, or policy duration 90 | Reset time.Time // The limit record reset time 91 | } 92 | 93 | // New returns a Limiter instance with given options. 94 | // If options.Client omit, the limiter is a memory limiter 95 | func New(opts Options) *Limiter { 96 | if opts.Prefix == "" { 97 | opts.Prefix = "LIMIT:" 98 | } 99 | if opts.Max <= 0 { 100 | opts.Max = 100 101 | } 102 | if opts.Duration <= 0 { 103 | opts.Duration = time.Minute 104 | } 105 | if opts.Client == nil { 106 | return newMemoryLimiter(&opts) 107 | } 108 | return newRedisLimiter(&opts) 109 | } 110 | 111 | type abstractLimiter interface { 112 | getLimit(key string, policy ...int) ([]interface{}, error) 113 | removeLimit(key string) error 114 | } 115 | 116 | func newRedisLimiter(opts *Options) *Limiter { 117 | sha1, err := opts.Client.RateScriptLoad(lua) 118 | if err != nil { 119 | panic(err) 120 | } 121 | r := &redisLimiter{ 122 | rc: opts.Client, 123 | sha1: sha1, 124 | max: strconv.FormatInt(int64(opts.Max), 10), 125 | duration: strconv.FormatInt(int64(opts.Duration/time.Millisecond), 10), 126 | } 127 | return &Limiter{r, opts.Prefix} 128 | } 129 | 130 | // Get get a limiter result for id. support custom limiter policy. 131 | /* 132 | Get get a limiter result: 133 | 134 | userID := "user-123456" 135 | res, err := limiter.Get(userID) 136 | if err == nil { 137 | fmt.Println(res.Reset) // 2016-10-11 21:17:53.362 +0800 CST 138 | fmt.Println(res.Total) // 100 139 | fmt.Println(res.Remaining) // 100 140 | fmt.Println(res.Duration) // 1m 141 | } 142 | 143 | Get get a limiter result with custom limiter policy: 144 | 145 | id := "id-123456" 146 | policy := []int{100, 60000, 50, 60000, 50, 120000} 147 | res, err := limiter.Get(id, policy...) 148 | */ 149 | func (l *Limiter) Get(id string, policy ...int) (Result, error) { 150 | var result Result 151 | key := l.prefix + id 152 | 153 | if odd := len(policy) % 2; odd == 1 { 154 | return result, errors.New("ratelimiter: must be paired values") 155 | } 156 | 157 | res, err := l.getLimit(key, policy...) 158 | if err != nil { 159 | return result, err 160 | } 161 | 162 | result = Result{} 163 | switch res[3].(type) { 164 | case time.Time: // result from memory limiter 165 | result.Remaining = res[0].(int) 166 | result.Total = res[1].(int) 167 | result.Duration = res[2].(time.Duration) 168 | result.Reset = res[3].(time.Time) 169 | default: // result from redis limiter 170 | result.Remaining = int(res[0].(int64)) 171 | result.Total = int(res[1].(int64)) 172 | result.Duration = time.Duration(res[2].(int64) * 1e6) 173 | 174 | timestamp := res[3].(int64) 175 | sec := timestamp / 1000 176 | result.Reset = time.Unix(sec, (timestamp-(sec*1000))*1e6) 177 | } 178 | return result, nil 179 | } 180 | 181 | // Remove remove limiter record for id 182 | func (l *Limiter) Remove(id string) error { 183 | return l.removeLimit(l.prefix + id) 184 | } 185 | 186 | type redisLimiter struct { 187 | sha1, max, duration string 188 | rc RedisClient 189 | } 190 | 191 | func (r *redisLimiter) removeLimit(key string) error { 192 | return r.rc.RateDel(key) 193 | } 194 | 195 | func (r *redisLimiter) getLimit(key string, policy ...int) ([]interface{}, error) { 196 | keys := []string{key, fmt.Sprintf("{%s}:S", key)} 197 | capacity := 3 198 | length := len(policy) 199 | if length > 2 { 200 | capacity = length + 1 201 | } 202 | 203 | args := make([]interface{}, capacity, capacity) 204 | args[0] = genTimestamp() 205 | if length == 0 { 206 | args[1] = r.max 207 | args[2] = r.duration 208 | } else { 209 | for i, val := range policy { 210 | if val <= 0 { 211 | return nil, errors.New("ratelimiter: must be positive integer") 212 | } 213 | args[i+1] = strconv.FormatInt(int64(val), 10) 214 | } 215 | } 216 | 217 | res, err := r.rc.RateEvalSha(r.sha1, keys, args...) 218 | if err != nil && isNoScriptErr(err) { 219 | // try to load lua for cluster client and ring client for nodes changing. 220 | _, err = r.rc.RateScriptLoad(lua) 221 | if err == nil { 222 | res, err = r.rc.RateEvalSha(r.sha1, keys, args...) 223 | } 224 | } 225 | 226 | if err == nil { 227 | arr, ok := res.([]interface{}) 228 | if ok && len(arr) == 4 { 229 | return arr, nil 230 | } 231 | err = errors.New("Invalid result") 232 | } 233 | return nil, err 234 | } 235 | 236 | func genTimestamp() string { 237 | time := time.Now().UnixNano() / 1e6 238 | return strconv.FormatInt(time, 10) 239 | } 240 | 241 | func isNoScriptErr(err error) bool { 242 | return strings.HasPrefix(err.Error(), "NOSCRIPT ") 243 | } 244 | 245 | // copy from ./ratelimiter.lua 246 | const lua string = ` 247 | -- KEYS[1] target hash key 248 | -- KEYS[2] target status hash key 249 | -- ARGV[n >= 3] current timestamp, max count, duration, max count, duration, ... 250 | 251 | -- HASH: KEYS[1] 252 | -- field:ct(count) 253 | -- field:lt(limit) 254 | -- field:dn(duration) 255 | -- field:rt(reset) 256 | 257 | local res = {} 258 | local policyCount = (#ARGV - 1) / 2 259 | local limit = redis.call('hmget', KEYS[1], 'ct', 'lt', 'dn', 'rt') 260 | 261 | if limit[1] then 262 | 263 | res[1] = tonumber(limit[1]) - 1 264 | res[2] = tonumber(limit[2]) 265 | res[3] = tonumber(limit[3]) or ARGV[3] 266 | res[4] = tonumber(limit[4]) 267 | 268 | if policyCount > 1 and res[1] == -1 then 269 | redis.call('incr', KEYS[2]) 270 | redis.call('pexpire', KEYS[2], res[3] * 2) 271 | local index = tonumber(redis.call('get', KEYS[2])) 272 | if index == 1 then 273 | redis.call('incr', KEYS[2]) 274 | end 275 | end 276 | 277 | if res[1] >= -1 then 278 | redis.call('hincrby', KEYS[1], 'ct', -1) 279 | else 280 | res[1] = -1 281 | end 282 | 283 | else 284 | 285 | local index = 1 286 | if policyCount > 1 then 287 | index = tonumber(redis.call('get', KEYS[2])) or 1 288 | if index > policyCount then 289 | index = policyCount 290 | end 291 | end 292 | 293 | local total = tonumber(ARGV[index * 2]) 294 | res[1] = total - 1 295 | res[2] = total 296 | res[3] = tonumber(ARGV[index * 2 + 1]) 297 | res[4] = tonumber(ARGV[1]) + res[3] 298 | 299 | redis.call('hmset', KEYS[1], 'ct', res[1], 'lt', res[2], 'dn', res[3], 'rt', res[4]) 300 | redis.call('pexpire', KEYS[1], res[3]) 301 | 302 | end 303 | 304 | return res 305 | ` 306 | -------------------------------------------------------------------------------- /memory_test.go: -------------------------------------------------------------------------------- 1 | package ratelimiter 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | "testing" 7 | "time" 8 | 9 | "sync" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestMemoryRateLimiter(t *testing.T) { 15 | t.Run("ratelimiter with default Options should be", func(t *testing.T) { 16 | assert := assert.New(t) 17 | 18 | limiter := New(Options{}) 19 | id := genID() 20 | policy := []int{10, 1000} 21 | 22 | res, err := limiter.Get(id, policy...) 23 | assert.Nil(err) 24 | assert.Equal(10, res.Total) 25 | assert.Equal(9, res.Remaining) 26 | assert.Equal(1000, int(res.Duration/time.Millisecond)) 27 | assert.True(res.Reset.After(time.Now())) 28 | res, err = limiter.Get(id, policy...) 29 | assert.Equal(10, res.Total) 30 | assert.Equal(8, res.Remaining) 31 | }) 32 | 33 | t.Run("ratelimiter with expire should be", func(t *testing.T) { 34 | assert := assert.New(t) 35 | 36 | limiter := New(Options{}) 37 | id := genID() 38 | policy := []int{10, 100} 39 | 40 | res, err := limiter.Get(id, policy...) 41 | assert.Nil(err) 42 | assert.Equal(10, res.Total) 43 | assert.Equal(9, res.Remaining) 44 | res, err = limiter.Get(id, policy...) 45 | assert.Equal(8, res.Remaining) 46 | 47 | time.Sleep(100 * time.Millisecond) 48 | res, err = limiter.Get(id, policy...) 49 | assert.Nil(err) 50 | assert.Equal(10, res.Total) 51 | assert.Equal(9, res.Remaining) 52 | }) 53 | 54 | t.Run("ratelimiter with goroutine should be", func(t *testing.T) { 55 | assert := assert.New(t) 56 | 57 | limiter := New(Options{}) 58 | policy := []int{10, 500} 59 | id := genID() 60 | res, err := limiter.Get(id, policy...) 61 | assert.Nil(err) 62 | assert.Equal(10, res.Total) 63 | assert.Equal(9, res.Remaining) 64 | var wait sync.WaitGroup 65 | wait.Add(100) 66 | for i := 0; i < 100; i++ { 67 | go func() { 68 | limiter.Get(id, policy...) 69 | wait.Done() 70 | }() 71 | } 72 | wait.Wait() 73 | time.Sleep(200 * time.Millisecond) 74 | res, err = limiter.Get(id, policy...) 75 | assert.Nil(err) 76 | assert.Equal(10, res.Total) 77 | assert.Equal(-1, res.Remaining) 78 | }) 79 | 80 | t.Run("ratelimiter with multi-policy should be", func(t *testing.T) { 81 | assert := assert.New(t) 82 | 83 | limiter := New(Options{}) 84 | id := genID() 85 | policy := []int{3, 100, 2, 200} 86 | 87 | res, err := limiter.Get(id, policy...) 88 | assert.Nil(err) 89 | assert.Equal(3, res.Total) 90 | assert.Equal(2, res.Remaining) 91 | res, err = limiter.Get(id, policy...) 92 | assert.Equal(1, res.Remaining) 93 | res, err = limiter.Get(id, policy...) 94 | assert.Equal(0, res.Remaining) 95 | res, err = limiter.Get(id, policy...) 96 | assert.Equal(-1, res.Remaining) 97 | res, err = limiter.Get(id, policy...) 98 | assert.Equal(-1, res.Remaining) 99 | assert.True(res.Reset.After(time.Now())) 100 | 101 | time.Sleep(res.Duration + time.Millisecond) 102 | res, err = limiter.Get(id, policy...) 103 | assert.Equal(2, res.Total) 104 | assert.Equal(1, res.Remaining) 105 | assert.Equal(time.Millisecond*200, res.Duration) 106 | 107 | res, err = limiter.Get(id, policy...) 108 | assert.Equal(0, res.Remaining) 109 | res, err = limiter.Get(id, policy...) 110 | assert.Equal(-1, res.Remaining) 111 | 112 | time.Sleep(res.Duration + time.Millisecond) 113 | res, err = limiter.Get(id, policy...) 114 | assert.Equal(2, res.Total) 115 | assert.Equal(1, res.Remaining) 116 | assert.Equal(time.Millisecond*200, res.Duration) 117 | }) 118 | 119 | t.Run("ratelimiter with Remove id should be", func(t *testing.T) { 120 | assert := assert.New(t) 121 | 122 | limiter := New(Options{}) 123 | id := genID() 124 | policy := []int{10, 1000} 125 | 126 | res, err := limiter.Get(id, policy...) 127 | assert.Nil(err) 128 | assert.Equal(10, res.Total) 129 | assert.Equal(9, res.Remaining) 130 | limiter.Remove(id) 131 | res, err = limiter.Get(id, policy...) 132 | assert.Equal(10, res.Total) 133 | assert.Equal(9, res.Remaining) 134 | }) 135 | 136 | t.Run("ratelimiter with wrong policy id should be", func(t *testing.T) { 137 | assert := assert.New(t) 138 | 139 | limiter := New(Options{}) 140 | id := genID() 141 | policy := []int{10, 1000, 1} 142 | 143 | res, err := limiter.Get(id, policy...) 144 | assert.Error(err) 145 | assert.Equal(0, res.Total) 146 | assert.Equal(0, res.Remaining) 147 | }) 148 | 149 | t.Run("ratelimiter with empty policy id should be", func(t *testing.T) { 150 | assert := assert.New(t) 151 | 152 | limiter := New(Options{}) 153 | id := genID() 154 | policy := []int{} 155 | 156 | res, _ := limiter.Get(id, policy...) 157 | assert.Equal(100, res.Total) 158 | assert.Equal(99, res.Remaining) 159 | assert.Equal(time.Minute, res.Duration) 160 | }) 161 | 162 | t.Run("limiter.Get with invalid args", func(t *testing.T) { 163 | assert := assert.New(t) 164 | 165 | limiter := New(Options{}) 166 | id := genID() 167 | _, err := limiter.Get(id, 10) 168 | assert.Equal("ratelimiter: must be paired values", err.Error()) 169 | 170 | _, err2 := limiter.Get(id, -1, 10) 171 | assert.Equal("ratelimiter: must be positive integer", err2.Error()) 172 | 173 | _, err3 := limiter.Get(id, 10, 0) 174 | assert.Equal("ratelimiter: must be positive integer", err3.Error()) 175 | }) 176 | 177 | t.Run("ratelimiter with Clean cache should be", func(t *testing.T) { 178 | assert := assert.New(t) 179 | 180 | opts := Options{} 181 | limiter := &memoryLimiter{ 182 | max: opts.Max, 183 | duration: opts.Duration, 184 | store: make(map[string]*limiterCacheItem), 185 | status: make(map[string]*statusCacheItem), 186 | ticker: time.NewTicker(time.Minute), 187 | } 188 | 189 | id := genID() 190 | policy := []int{10, 100} 191 | 192 | res, _ := limiter.getLimit(id, policy...) 193 | 194 | assert.Equal(10, res[1].(int)) 195 | assert.Equal(9, res[0].(int)) 196 | 197 | time.Sleep(res[2].(time.Duration) + time.Millisecond) 198 | limiter.clean() 199 | res, _ = limiter.getLimit(id, policy...) 200 | assert.Equal(10, res[1].(int)) 201 | assert.Equal(9, res[0].(int)) 202 | 203 | time.Sleep(res[2].(time.Duration)*2 + time.Millisecond) 204 | limiter.clean() 205 | res, _ = limiter.getLimit(id, policy...) 206 | assert.Equal(10, res[1].(int)) 207 | assert.Equal(9, res[0].(int)) 208 | limiter.ticker = time.NewTicker(time.Millisecond) 209 | go limiter.cleanCache() 210 | time.Sleep(2 * time.Millisecond) 211 | res, _ = limiter.getLimit(id, policy...) 212 | assert.Equal(10, res[1].(int)) 213 | assert.Equal(8, res[0].(int)) 214 | }) 215 | 216 | t.Run("ratelimiter with big goroutine should be", func(t *testing.T) { 217 | assert := assert.New(t) 218 | 219 | limiter := New(Options{}) 220 | policy := []int{1000, 1000} 221 | id := genID() 222 | 223 | var wg sync.WaitGroup 224 | wg.Add(1000) 225 | for i := 0; i < 1000; i++ { 226 | go func() { 227 | newid := genID() 228 | limiter.Get(newid, policy...) 229 | limiter.Get(id, policy...) 230 | wg.Done() 231 | }() 232 | } 233 | wg.Wait() 234 | res, err := limiter.Get(id, policy...) 235 | assert.Nil(err) 236 | assert.Equal(1000, res.Total) 237 | assert.Equal(-1, res.Remaining) 238 | }) 239 | 240 | t.Run("limiter.Get with multi-policy for expired", func(t *testing.T) { 241 | assert := assert.New(t) 242 | limiter := New(Options{}) 243 | 244 | id := genID() 245 | policy := []int{2, 100, 2, 200, 3, 300, 3, 400} 246 | 247 | //First policy 248 | res, err := limiter.Get(id, policy...) 249 | assert.Nil(err) 250 | assert.Equal(2, res.Total) 251 | assert.Equal(1, res.Remaining) 252 | assert.Equal(time.Millisecond*100, res.Duration) 253 | 254 | res, err = limiter.Get(id, policy...) 255 | assert.Equal(0, res.Remaining) 256 | 257 | res, err = limiter.Get(id, policy...) 258 | assert.Equal(-1, res.Remaining) 259 | assert.Equal(time.Millisecond*100, res.Duration) 260 | 261 | //Second policy 262 | time.Sleep(res.Duration + time.Millisecond) 263 | res, err = limiter.Get(id, policy...) 264 | assert.Equal(2, res.Total) 265 | assert.Equal(1, res.Remaining) 266 | assert.Equal(time.Millisecond*200, res.Duration) 267 | 268 | res, err = limiter.Get(id, policy...) 269 | assert.Equal(0, res.Remaining) 270 | 271 | res, err = limiter.Get(id, policy...) 272 | assert.Equal(-1, res.Remaining) 273 | 274 | //Third policy 275 | time.Sleep(res.Duration + time.Millisecond) 276 | res, err = limiter.Get(id, policy...) 277 | assert.Equal(3, res.Total) 278 | assert.Equal(2, res.Remaining) 279 | assert.Equal(time.Millisecond*300, res.Duration) 280 | 281 | res, err = limiter.Get(id, policy...) 282 | res, err = limiter.Get(id, policy...) 283 | res, err = limiter.Get(id, policy...) 284 | assert.Equal(-1, res.Remaining) 285 | 286 | // restore to First policy after Third policy*2 Duration 287 | time.Sleep(res.Duration*2 + time.Millisecond) 288 | res, err = limiter.Get(id, policy...) 289 | assert.Nil(err) 290 | assert.Equal(2, res.Total) 291 | assert.Equal(1, res.Remaining) 292 | assert.Equal(time.Millisecond*100, res.Duration) 293 | res, err = limiter.Get(id, policy...) 294 | res, err = limiter.Get(id, policy...) 295 | assert.Equal(-1, res.Remaining) 296 | 297 | //Second policy 298 | time.Sleep(res.Duration + time.Millisecond) 299 | res, err = limiter.Get(id, policy...) 300 | assert.Equal(2, res.Total) 301 | assert.Equal(1, res.Remaining) 302 | 303 | res, err = limiter.Get(id, policy...) 304 | assert.Equal(0, res.Remaining) 305 | 306 | res, err = limiter.Get(id, policy...) 307 | assert.Equal(-1, res.Remaining) 308 | assert.Equal(time.Millisecond*200, res.Duration) 309 | 310 | //Third policy 311 | time.Sleep(res.Duration + time.Millisecond) 312 | res, err = limiter.Get(id, policy...) 313 | assert.Equal(3, res.Total) 314 | assert.Equal(2, res.Remaining) 315 | assert.Equal(time.Millisecond*300, res.Duration) 316 | 317 | res, err = limiter.Get(id, policy...) 318 | assert.Equal(1, res.Remaining) 319 | assert.Equal(time.Millisecond*300, res.Duration) 320 | 321 | res, err = limiter.Get(id, policy...) 322 | res, err = limiter.Get(id, policy...) 323 | assert.Equal(-1, res.Remaining) 324 | 325 | //Fourth policy 326 | time.Sleep(res.Duration + time.Millisecond) 327 | res, err = limiter.Get(id, policy...) 328 | assert.Equal(3, res.Total) 329 | assert.Equal(2, res.Remaining) 330 | assert.Equal(time.Millisecond*400, res.Duration) 331 | 332 | res, err = limiter.Get(id, policy...) 333 | assert.Equal(3, res.Total) 334 | assert.Equal(1, res.Remaining) 335 | 336 | res, err = limiter.Get(id, policy...) 337 | assert.Equal(3, res.Total) 338 | assert.Equal(0, res.Remaining) 339 | 340 | res, err = limiter.Get(id, policy...) 341 | assert.Equal(3, res.Total) 342 | assert.Equal(-1, res.Remaining) 343 | 344 | // restore to First policy after Fourth policy*2 Duration 345 | time.Sleep(res.Duration*2 + time.Millisecond) 346 | res, err = limiter.Get(id, policy...) 347 | assert.Nil(err) 348 | assert.Equal(2, res.Total) 349 | assert.Equal(1, res.Remaining) 350 | assert.Equal(time.Millisecond*100, res.Duration) 351 | }) 352 | t.Run("limiter.Get with multi-policy situation for expired", func(t *testing.T) { 353 | assert := assert.New(t) 354 | 355 | var id = genID() 356 | limiter := New(Options{}) 357 | policy := []int{2, 150, 2, 200, 3, 300, 3, 400} 358 | 359 | //用户访问数在第一个策略限制内 360 | res, err := limiter.Get(id, policy...) 361 | assert.Nil(err) 362 | assert.Equal(2, res.Total) 363 | assert.Equal(1, res.Remaining) 364 | assert.Equal(time.Millisecond*150, res.Duration) 365 | 366 | //第一个策略正常过期,第二次会继续走第一个 367 | time.Sleep(res.Duration + time.Millisecond) 368 | res, err = limiter.Get(id, policy...) 369 | assert.Equal(2, res.Total) 370 | assert.Equal(1, res.Remaining) 371 | assert.Equal(time.Millisecond*150, res.Duration) 372 | 373 | //第一个策略超出 374 | res, err = limiter.Get(id, policy...) 375 | assert.Equal(2, res.Total) 376 | assert.Equal(0, res.Remaining) 377 | assert.Equal(time.Millisecond*150, res.Duration) 378 | 379 | res, err = limiter.Get(id, policy...) 380 | assert.Equal(2, res.Total) 381 | assert.Equal(-1, res.Remaining) 382 | assert.Equal(time.Millisecond*150, res.Duration) 383 | 384 | // 超出后,等待第一个策略过期。 385 | time.Sleep(res.Duration + time.Millisecond) 386 | // 如果在第一个策略2倍时间内访问,走第二个策略。 如果不在恢复到第一个策略 387 | res, err = limiter.Get(id, policy...) 388 | assert.Equal(2, res.Total) 389 | assert.Equal(1, res.Remaining) 390 | assert.Equal(time.Millisecond*200, res.Duration) 391 | 392 | // 在第二个策略正常过期后,恢复到第一个策略 393 | time.Sleep(res.Duration + time.Millisecond) 394 | res, err = limiter.Get(id, policy...) 395 | assert.Equal(2, res.Total) 396 | assert.Equal(1, res.Remaining) 397 | assert.Equal(time.Millisecond*150, res.Duration) 398 | 399 | //第一个策略又超出 400 | res, err = limiter.Get(id, policy...) 401 | assert.Equal(2, res.Total) 402 | assert.Equal(0, res.Remaining) 403 | assert.Equal(time.Millisecond*150, res.Duration) 404 | res, err = limiter.Get(id, policy...) 405 | assert.Equal(2, res.Total) 406 | assert.Equal(-1, res.Remaining) 407 | assert.Equal(time.Millisecond*150, res.Duration) 408 | 409 | //等待第一个策略过期,然后走第二个策略 410 | time.Sleep(res.Duration + time.Millisecond) 411 | res, err = limiter.Get(id, policy...) 412 | assert.Equal(2, res.Total) 413 | assert.Equal(1, res.Remaining) 414 | assert.Equal(time.Millisecond*200, res.Duration) 415 | 416 | //第二个策略页超出 417 | res, err = limiter.Get(id, policy...) 418 | assert.Equal(2, res.Total) 419 | assert.Equal(0, res.Remaining) 420 | assert.Equal(time.Millisecond*200, res.Duration) 421 | res, err = limiter.Get(id, policy...) 422 | assert.Equal(-1, res.Remaining) 423 | //等待第二个过期,走第三个,然后第三个超出 424 | time.Sleep(res.Duration + time.Millisecond) 425 | res, err = limiter.Get(id, policy...) 426 | assert.Equal(3, res.Total) 427 | assert.Equal(2, res.Remaining) 428 | assert.Equal(time.Millisecond*300, res.Duration) 429 | 430 | res, err = limiter.Get(id, policy...) 431 | assert.Equal(3, res.Total) 432 | assert.Equal(1, res.Remaining) 433 | assert.Equal(time.Millisecond*300, res.Duration) 434 | res, err = limiter.Get(id, policy...) 435 | 436 | assert.Equal(3, res.Total) 437 | assert.Equal(0, res.Remaining) 438 | assert.Equal(time.Millisecond*300, res.Duration) 439 | res, err = limiter.Get(id, policy...) 440 | assert.Equal(-1, res.Remaining) 441 | 442 | //等待第三个过期,走第四个,然后第四个也过期 443 | time.Sleep(res.Duration + time.Millisecond) 444 | res, err = limiter.Get(id, policy...) 445 | assert.Equal(3, res.Total) 446 | assert.Equal(2, res.Remaining) 447 | assert.Equal(time.Millisecond*400, res.Duration) 448 | 449 | res, err = limiter.Get(id, policy...) 450 | assert.Equal(3, res.Total) 451 | assert.Equal(1, res.Remaining) 452 | 453 | res, err = limiter.Get(id, policy...) 454 | assert.Equal(3, res.Total) 455 | assert.Equal(0, res.Remaining) 456 | 457 | res, err = limiter.Get(id, policy...) 458 | assert.Equal(3, res.Total) 459 | assert.Equal(-1, res.Remaining) 460 | 461 | //等待第四个策略过期,还是走第四个策略,因为还在第三个策略2倍时间内 462 | time.Sleep(res.Duration + time.Millisecond) 463 | res, err = limiter.Get(id, policy...) 464 | assert.Equal(3, res.Total) 465 | assert.Equal(2, res.Remaining) 466 | assert.Equal(time.Millisecond*400, res.Duration) 467 | 468 | //第四个策略第二次过期,恢复走第一个。 469 | time.Sleep(res.Duration + time.Millisecond) 470 | res, err = limiter.Get(id, policy...) 471 | assert.Equal(2, res.Total) 472 | assert.Equal(1, res.Remaining) 473 | assert.Equal(time.Millisecond*150, res.Duration) 474 | 475 | }) 476 | t.Run("limiter.Get with different policy time situation for expired", func(t *testing.T) { 477 | assert := assert.New(t) 478 | 479 | var id = genID() 480 | limiter := New(Options{}) 481 | policy := []int{2, 300, 3, 100} 482 | 483 | //默认走第一个策略 484 | res, err := limiter.Get(id, policy...) 485 | assert.Nil(err) 486 | assert.Equal(2, res.Total) 487 | assert.Equal(1, res.Remaining) 488 | assert.Equal(time.Millisecond*300, res.Duration) 489 | 490 | //第一个策略超出 491 | res, err = limiter.Get(id, policy...) 492 | res, err = limiter.Get(id, policy...) 493 | assert.Equal(-1, res.Remaining) 494 | assert.Equal(time.Millisecond*300, res.Duration) 495 | 496 | //等待第一个策略过期,然后走第二个策略 497 | time.Sleep(res.Duration + time.Millisecond) 498 | res, err = limiter.Get(id, policy...) 499 | assert.Equal(3, res.Total) 500 | assert.Equal(2, res.Remaining) 501 | assert.Equal(time.Millisecond*100, res.Duration) 502 | 503 | //第一次正常过期 504 | time.Sleep(res.Duration + time.Millisecond) 505 | res, err = limiter.Get(id, policy...) 506 | assert.Equal(3, res.Total) 507 | assert.Equal(2, res.Remaining) 508 | assert.Equal(time.Millisecond*100, res.Duration) 509 | 510 | ///第二次正常过期 511 | time.Sleep(res.Duration + time.Millisecond) 512 | res, err = limiter.Get(id, policy...) 513 | assert.Equal(3, res.Total) 514 | assert.Equal(2, res.Remaining) 515 | assert.Equal(time.Millisecond*100, res.Duration) 516 | 517 | ///第三次正常过期,恢复到第一个 518 | time.Sleep(res.Duration + time.Millisecond) 519 | res, err = limiter.Get(id, policy...) 520 | assert.Nil(err) 521 | assert.Equal(2, res.Total) 522 | assert.Equal(1, res.Remaining) 523 | assert.Equal(time.Millisecond*300, res.Duration) 524 | 525 | //==========然后第一个策略又超出了 526 | res, err = limiter.Get(id, policy...) 527 | res, err = limiter.Get(id, policy...) 528 | assert.Equal(-1, res.Remaining) 529 | assert.Equal(time.Millisecond*300, res.Duration) 530 | 531 | //等待第一个策略过期, 532 | time.Sleep(res.Duration + time.Millisecond) 533 | //走第二个策略(第一次), 534 | res, err = limiter.Get(id, policy...) 535 | assert.Equal(3, res.Total) 536 | assert.Equal(2, res.Remaining) 537 | assert.Equal(time.Millisecond*100, res.Duration) 538 | 539 | // 第二个策略超过, 540 | res, err = limiter.Get(id, policy...) 541 | assert.Equal(3, res.Total) 542 | assert.Equal(1, res.Remaining) 543 | assert.Equal(time.Millisecond*100, res.Duration) 544 | res, err = limiter.Get(id, policy...) 545 | assert.Equal(3, res.Total) 546 | assert.Equal(0, res.Remaining) 547 | assert.Equal(time.Millisecond*100, res.Duration) 548 | 549 | //等待过期 550 | time.Sleep(res.Duration + time.Millisecond) 551 | 552 | //走第二个策略(第二次),在第二个策略二倍时间内 553 | res, err = limiter.Get(id, policy...) 554 | assert.Equal(3, res.Total) 555 | assert.Equal(2, res.Remaining) 556 | assert.Equal(time.Millisecond*100, res.Duration) 557 | 558 | //第二个策略继续超出,延长2倍时间 559 | res, err = limiter.Get(id, policy...) 560 | assert.Equal(3, res.Total) 561 | assert.Equal(1, res.Remaining) 562 | assert.Equal(time.Millisecond*100, res.Duration) 563 | res, err = limiter.Get(id, policy...) 564 | assert.Equal(3, res.Total) 565 | assert.Equal(0, res.Remaining) 566 | assert.Equal(time.Millisecond*100, res.Duration) 567 | 568 | //等待过期 569 | time.Sleep(res.Duration + time.Millisecond) 570 | //然后走第二个策略,在第二个策略二倍时间内(被延长过)。 如果一直超出被停留在第二次 571 | res, err = limiter.Get(id, policy...) 572 | assert.Equal(3, res.Total) 573 | assert.Equal(2, res.Remaining) 574 | assert.Equal(time.Millisecond*100, res.Duration) 575 | 576 | //第二个策略第二次过期了,没有被延长 577 | time.Sleep(res.Duration + time.Millisecond) 578 | //恢复到第一个 579 | time.Sleep(res.Duration + time.Millisecond) 580 | res, err = limiter.Get(id, policy...) 581 | assert.Nil(err) 582 | assert.Equal(2, res.Total) 583 | assert.Equal(1, res.Remaining) 584 | assert.Equal(time.Millisecond*300, res.Duration) 585 | }) 586 | t.Run("limiter.Get with normal situation for expired", func(t *testing.T) { 587 | assert := assert.New(t) 588 | 589 | var id = genID() 590 | limiter := New(Options{}) 591 | policy := []int{3, 300, 2, 200} 592 | 593 | res, err := limiter.Get(id, policy...) 594 | assert.Nil(err) 595 | assert.Equal(3, res.Total) 596 | assert.Equal(2, res.Remaining) 597 | assert.Equal(time.Millisecond*300, res.Duration) 598 | 599 | res, err = limiter.Get(id, policy...) 600 | assert.Equal(3, res.Total) 601 | assert.Equal(1, res.Remaining) 602 | assert.Equal(time.Millisecond*300, res.Duration) 603 | 604 | res, err = limiter.Get(id, policy...) 605 | assert.Equal(3, res.Total) 606 | assert.Equal(0, res.Remaining) 607 | assert.Equal(time.Millisecond*300, res.Duration) 608 | res, err = limiter.Get(id, policy...) 609 | assert.Equal(-1, res.Remaining) 610 | 611 | //等待过期,然后走第二个 612 | time.Sleep(res.Duration + time.Millisecond) 613 | res, err = limiter.Get(id, policy...) 614 | assert.Equal(2, res.Total) 615 | assert.Equal(1, res.Remaining) 616 | assert.Equal(time.Millisecond*200, res.Duration) 617 | 618 | //第二策略正常过期 619 | time.Sleep(res.Duration + time.Millisecond) 620 | res, err = limiter.Get(id, policy...) 621 | assert.Equal(2, res.Total) 622 | assert.Equal(1, res.Remaining) 623 | assert.Equal(time.Millisecond*200, res.Duration) 624 | 625 | //第二策略第二次正常过期,恢复到第一个 626 | time.Sleep(res.Duration + time.Millisecond) 627 | res, err = limiter.Get(id, policy...) 628 | assert.Equal(3, res.Total) 629 | assert.Equal(2, res.Remaining) 630 | assert.Equal(time.Millisecond*300, res.Duration) 631 | 632 | }) 633 | } 634 | func genID() string { 635 | buf := make([]byte, 12) 636 | _, err := rand.Read(buf) 637 | if err != nil { 638 | panic(err) 639 | } 640 | return hex.EncodeToString(buf) 641 | } 642 | -------------------------------------------------------------------------------- /ratelimiter_test.go: -------------------------------------------------------------------------------- 1 | package ratelimiter_test 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | "errors" 7 | "sort" 8 | "sync" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/teambition/ratelimiter-go" 14 | "github.com/go-redis/redis" 15 | ) 16 | 17 | // Implements RedisClient for redis.Client 18 | type redisFailedClient struct { 19 | *redis.Client 20 | } 21 | 22 | func (c *redisFailedClient) RateDel(key string) error { 23 | return c.Del(key).Err() 24 | } 25 | 26 | func (c *redisFailedClient) RateEvalSha(sha1 string, keys []string, args ...interface{}) (interface{}, error) { 27 | return nil, errors.New("NOSCRIPT mock error") 28 | } 29 | 30 | func (c *redisFailedClient) RateScriptLoad(script string) (string, error) { 31 | return c.ScriptLoad(script).Result() 32 | } 33 | 34 | // Implements RedisClient for redis.Client 35 | type redisClient struct { 36 | *redis.Client 37 | } 38 | 39 | func (c *redisClient) RateDel(key string) error { 40 | return c.Del(key).Err() 41 | } 42 | 43 | func (c *redisClient) RateEvalSha(sha1 string, keys []string, args ...interface{}) (interface{}, error) { 44 | return c.EvalSha(sha1, keys, args...).Result() 45 | } 46 | 47 | func (c *redisClient) RateScriptLoad(script string) (string, error) { 48 | return c.ScriptLoad(script).Result() 49 | } 50 | 51 | // Implements RedisClient for redis.ClusterClient 52 | type clusterClient struct { 53 | *redis.ClusterClient 54 | } 55 | 56 | func (c *clusterClient) RateDel(key string) error { 57 | return c.Del(key).Err() 58 | } 59 | 60 | func (c *clusterClient) RateEvalSha(sha1 string, keys []string, args ...interface{}) (interface{}, error) { 61 | return c.EvalSha(sha1, keys, args...).Result() 62 | } 63 | 64 | func (c *clusterClient) RateScriptLoad(script string) (string, error) { 65 | var sha1 string 66 | err := c.ForEachMaster(func(client *redis.Client) error { 67 | res, err := client.ScriptLoad(script).Result() 68 | if err == nil { 69 | sha1 = res 70 | } 71 | return err 72 | }) 73 | return sha1, err 74 | } 75 | 76 | // Implements RedisClient for redis.Ring 77 | type ringClient struct { 78 | *redis.Ring 79 | } 80 | 81 | func (c *ringClient) RateDel(key string) error { 82 | return c.Del(key).Err() 83 | } 84 | 85 | func (c *ringClient) RateEvalSha(sha1 string, keys []string, args ...interface{}) (interface{}, error) { 86 | return c.EvalSha(sha1, keys, args...).Result() 87 | } 88 | 89 | func (c *ringClient) RateScriptLoad(script string) (string, error) { 90 | var sha1 string 91 | err := c.ForEachShard(func(client *redis.Client) error { 92 | res, err := client.ScriptLoad(script).Result() 93 | if err == nil { 94 | sha1 = res 95 | } 96 | return err 97 | }) 98 | return sha1, err 99 | } 100 | 101 | func TestRedisRatelimiter(t *testing.T) { 102 | var client = redis.NewClient(&redis.Options{ 103 | Addr: "localhost:6379", 104 | }) 105 | pong, err := client.Ping().Result() 106 | assert.Nil(t, err) 107 | assert.Equal(t, "PONG", pong) 108 | defer client.Close() 109 | 110 | t.Run("ratelimiter.New, With default options", func(t *testing.T) { 111 | assert := assert.New(t) 112 | 113 | var limiter *ratelimiter.Limiter 114 | var id = genID() 115 | t.Run("ratelimiter.New", func(t *testing.T) { 116 | limiter = ratelimiter.New(ratelimiter.Options{Client: &redisClient{client}}) 117 | }) 118 | 119 | t.Run("limiter.Get", func(t *testing.T) { 120 | res, err := limiter.Get(id) 121 | assert.Nil(err) 122 | assert.Equal(res.Total, 100) 123 | assert.Equal(res.Remaining, 99) 124 | assert.Equal(res.Duration, time.Duration(60*1e9)) 125 | assert.True(res.Reset.UnixNano() > time.Now().UnixNano()) 126 | 127 | res, err = limiter.Get(id) 128 | assert.Nil(err) 129 | assert.Equal(res.Total, 100) 130 | assert.Equal(res.Remaining, 98) 131 | }) 132 | 133 | t.Run("limiter.Remove", func(t *testing.T) { 134 | err := limiter.Remove(id) 135 | assert.Nil(err) 136 | 137 | err = limiter.Remove(id) 138 | assert.Nil(err) 139 | 140 | res, err := limiter.Get(id) 141 | assert.Nil(err) 142 | assert.Equal(res.Total, 100) 143 | assert.Equal(res.Remaining, 99) 144 | }) 145 | 146 | t.Run("limiter.Get with invalid args", func(t *testing.T) { 147 | _, err := limiter.Get(id, 10) 148 | assert.Equal("ratelimiter: must be paired values", err.Error()) 149 | 150 | _, err2 := limiter.Get(id, -1, 10) 151 | assert.Equal("ratelimiter: must be positive integer", err2.Error()) 152 | 153 | _, err3 := limiter.Get(id, 10, 0) 154 | assert.Equal("ratelimiter: must be positive integer", err3.Error()) 155 | }) 156 | }) 157 | 158 | t.Run("ratelimiter.New, With options", func(t *testing.T) { 159 | assert := assert.New(t) 160 | 161 | var limiter *ratelimiter.Limiter 162 | var id = genID() 163 | t.Run("ratelimiter.New", func(t *testing.T) { 164 | limiter = ratelimiter.New(ratelimiter.Options{ 165 | Client: &redisClient{client}, 166 | Max: 3, 167 | Duration: time.Second, 168 | }) 169 | }) 170 | 171 | t.Run("limiter.Get", func(t *testing.T) { 172 | res, err := limiter.Get(id) 173 | assert.Nil(err) 174 | assert.Equal(3, res.Total) 175 | assert.Equal(2, res.Remaining) 176 | assert.Equal(time.Second, res.Duration) 177 | assert.True(res.Reset.UnixNano() > time.Now().UnixNano()) 178 | assert.True(res.Reset.UnixNano() <= time.Now().Add(time.Second).UnixNano()) 179 | 180 | res, err = limiter.Get(id) 181 | assert.Equal(res.Remaining, 1) 182 | res, err = limiter.Get(id) 183 | assert.Equal(res.Remaining, 0) 184 | res, err = limiter.Get(id) 185 | assert.Equal(res.Remaining, -1) 186 | res, err = limiter.Get(id) 187 | assert.Equal(res.Remaining, -1) 188 | }) 189 | 190 | t.Run("limiter.Remove", func(t *testing.T) { 191 | err := limiter.Remove(id) 192 | assert.Nil(err) 193 | 194 | res2, err2 := limiter.Get(id) 195 | assert.Nil(err2) 196 | assert.Equal(res2.Remaining, 2) 197 | }) 198 | 199 | t.Run("limiter.Get with multi-policy", func(t *testing.T) { 200 | id := genID() 201 | policy := []int{2, 100, 2, 200, 1, 300} 202 | 203 | res, err := limiter.Get(id, policy...) 204 | assert.Nil(err) 205 | assert.Equal(2, res.Total) 206 | assert.Equal(1, res.Remaining) 207 | assert.Equal(time.Millisecond*100, res.Duration) 208 | 209 | res, err = limiter.Get(id, policy...) 210 | assert.Nil(err) 211 | assert.Equal(0, res.Remaining) 212 | 213 | res, err = limiter.Get(id, policy...) 214 | assert.Nil(err) 215 | assert.Equal(-1, res.Remaining) 216 | 217 | time.Sleep(res.Duration + time.Millisecond) 218 | res, err = limiter.Get(id, policy...) 219 | assert.Nil(err) 220 | assert.Equal(2, res.Total) 221 | assert.Equal(1, res.Remaining) 222 | assert.Equal(time.Millisecond*200, res.Duration) 223 | 224 | res, err = limiter.Get(id, policy...) 225 | assert.Nil(err) 226 | assert.Equal(0, res.Remaining) 227 | 228 | res, err = limiter.Get(id, policy...) 229 | assert.Nil(err) 230 | assert.Equal(-1, res.Remaining) 231 | 232 | time.Sleep(res.Duration + time.Millisecond) 233 | res, err = limiter.Get(id, policy...) 234 | assert.Nil(err) 235 | assert.Equal(1, res.Total) 236 | assert.Equal(0, res.Remaining) 237 | assert.Equal(time.Millisecond*300, res.Duration) 238 | 239 | res, err = limiter.Get(id, policy...) 240 | assert.Nil(err) 241 | assert.Equal(res.Remaining, -1) 242 | 243 | // restore after double Duration 244 | time.Sleep(res.Duration*2 + time.Millisecond) 245 | res, err = limiter.Get(id, policy...) 246 | assert.Nil(err) 247 | assert.Equal(2, res.Total) 248 | assert.Equal(1, res.Remaining) 249 | assert.Equal(time.Millisecond*100, res.Duration) 250 | }) 251 | 252 | t.Run("limiter.Get with multi-policy for expired", func(t *testing.T) { 253 | id := genID() 254 | policy := []int{2, 100, 2, 200, 3, 300, 3, 400} 255 | 256 | //First policy 257 | res, err := limiter.Get(id, policy...) 258 | assert.Nil(err) 259 | assert.Equal(2, res.Total) 260 | assert.Equal(1, res.Remaining) 261 | assert.Equal(time.Millisecond*100, res.Duration) 262 | 263 | res, err = limiter.Get(id, policy...) 264 | assert.Equal(0, res.Remaining) 265 | 266 | res, err = limiter.Get(id, policy...) 267 | assert.Equal(-1, res.Remaining) 268 | 269 | //Second policy 270 | time.Sleep(res.Duration + 5*time.Millisecond) 271 | res, err = limiter.Get(id, policy...) 272 | assert.Equal(2, res.Total) 273 | assert.Equal(1, res.Remaining) 274 | assert.Equal(time.Millisecond*200, res.Duration) 275 | 276 | res, err = limiter.Get(id, policy...) 277 | assert.Equal(0, res.Remaining) 278 | 279 | res, err = limiter.Get(id, policy...) 280 | assert.Equal(-1, res.Remaining) 281 | 282 | //Third policy 283 | time.Sleep(res.Duration + time.Millisecond) 284 | res, err = limiter.Get(id, policy...) 285 | assert.Equal(3, res.Total) 286 | assert.Equal(2, res.Remaining) 287 | assert.Equal(time.Millisecond*300, res.Duration) 288 | 289 | res, err = limiter.Get(id, policy...) 290 | res, err = limiter.Get(id, policy...) 291 | res, err = limiter.Get(id, policy...) 292 | assert.Equal(-1, res.Remaining) 293 | 294 | // restore to First policy after Third policy*2 Duration 295 | time.Sleep(res.Duration*2 + time.Millisecond) 296 | res, err = limiter.Get(id, policy...) 297 | assert.Nil(err) 298 | assert.Equal(2, res.Total) 299 | assert.Equal(1, res.Remaining) 300 | assert.Equal(time.Millisecond*100, res.Duration) 301 | res, err = limiter.Get(id, policy...) 302 | res, err = limiter.Get(id, policy...) 303 | assert.Equal(-1, res.Remaining) 304 | 305 | //Second policy 306 | time.Sleep(res.Duration + time.Millisecond) 307 | res, err = limiter.Get(id, policy...) 308 | assert.Equal(2, res.Total) 309 | assert.Equal(1, res.Remaining) 310 | 311 | res, err = limiter.Get(id, policy...) 312 | assert.Equal(0, res.Remaining) 313 | 314 | res, err = limiter.Get(id, policy...) 315 | assert.Equal(-1, res.Remaining) 316 | assert.Equal(time.Millisecond*200, res.Duration) 317 | 318 | //Third policy 319 | time.Sleep(res.Duration + time.Millisecond) 320 | res, err = limiter.Get(id, policy...) 321 | assert.Equal(3, res.Total) 322 | assert.Equal(2, res.Remaining) 323 | assert.Equal(time.Millisecond*300, res.Duration) 324 | 325 | res, err = limiter.Get(id, policy...) 326 | assert.Equal(1, res.Remaining) 327 | assert.Equal(time.Millisecond*300, res.Duration) 328 | 329 | res, err = limiter.Get(id, policy...) 330 | res, err = limiter.Get(id, policy...) 331 | assert.Equal(-1, res.Remaining) 332 | 333 | //Fourth policy 334 | time.Sleep(res.Duration + time.Millisecond) 335 | res, err = limiter.Get(id, policy...) 336 | assert.Equal(3, res.Total) 337 | assert.Equal(2, res.Remaining) 338 | assert.Equal(time.Millisecond*400, res.Duration) 339 | 340 | res, err = limiter.Get(id, policy...) 341 | assert.Equal(3, res.Total) 342 | assert.Equal(1, res.Remaining) 343 | 344 | res, err = limiter.Get(id, policy...) 345 | assert.Equal(3, res.Total) 346 | assert.Equal(0, res.Remaining) 347 | 348 | res, err = limiter.Get(id, policy...) 349 | assert.Equal(3, res.Total) 350 | assert.Equal(-1, res.Remaining) 351 | 352 | // restore to First policy after Fourth policy*2 Duration 353 | time.Sleep(res.Duration*2 + time.Millisecond) 354 | res, err = limiter.Get(id, policy...) 355 | assert.Nil(err) 356 | assert.Equal(2, res.Total) 357 | assert.Equal(1, res.Remaining) 358 | assert.Equal(time.Millisecond*100, res.Duration) 359 | }) 360 | }) 361 | t.Run("limiter.Get with multi-policy situation for expired", func(t *testing.T) { 362 | assert := assert.New(t) 363 | 364 | var id = genID() 365 | 366 | limiter := ratelimiter.New(ratelimiter.Options{ 367 | Client: &redisClient{client}, 368 | }) 369 | 370 | policy := []int{2, 150, 2, 200, 3, 300, 3, 400} 371 | 372 | //用户访问数在第一个策略限制内 373 | res, err := limiter.Get(id, policy...) 374 | assert.Nil(err) 375 | assert.Equal(2, res.Total) 376 | assert.Equal(1, res.Remaining) 377 | assert.Equal(time.Millisecond*150, res.Duration) 378 | 379 | //第一个策略正常过期,第二次会继续走第一个 380 | time.Sleep(res.Duration + time.Millisecond) 381 | res, err = limiter.Get(id, policy...) 382 | assert.Equal(2, res.Total) 383 | assert.Equal(1, res.Remaining) 384 | assert.Equal(time.Millisecond*150, res.Duration) 385 | 386 | //第一个策略超出 387 | res, err = limiter.Get(id, policy...) 388 | assert.Equal(2, res.Total) 389 | assert.Equal(0, res.Remaining) 390 | assert.Equal(time.Millisecond*150, res.Duration) 391 | 392 | res, err = limiter.Get(id, policy...) 393 | assert.Equal(2, res.Total) 394 | assert.Equal(-1, res.Remaining) 395 | assert.Equal(time.Millisecond*150, res.Duration) 396 | 397 | // 超出后,等待第一个策略过期。 398 | time.Sleep(res.Duration + time.Millisecond) 399 | // 如果在第一个策略2倍时间内访问,走第二个策略。 如果不在恢复到第一个策略 400 | res, err = limiter.Get(id, policy...) 401 | assert.Equal(2, res.Total) 402 | assert.Equal(1, res.Remaining) 403 | assert.Equal(time.Millisecond*200, res.Duration) 404 | 405 | // 在第二个策略正常过期后,恢复到第一个策略 406 | time.Sleep(res.Duration + time.Millisecond) 407 | res, err = limiter.Get(id, policy...) 408 | assert.Equal(2, res.Total) 409 | assert.Equal(1, res.Remaining) 410 | assert.Equal(time.Millisecond*150, res.Duration) 411 | 412 | //第一个策略又超出 413 | res, err = limiter.Get(id, policy...) 414 | assert.Equal(2, res.Total) 415 | assert.Equal(0, res.Remaining) 416 | assert.Equal(time.Millisecond*150, res.Duration) 417 | res, err = limiter.Get(id, policy...) 418 | assert.Equal(2, res.Total) 419 | assert.Equal(-1, res.Remaining) 420 | assert.Equal(time.Millisecond*150, res.Duration) 421 | 422 | //等待第一个策略过期,然后走第二个策略 423 | time.Sleep(res.Duration + time.Millisecond) 424 | res, err = limiter.Get(id, policy...) 425 | assert.Equal(2, res.Total) 426 | assert.Equal(1, res.Remaining) 427 | assert.Equal(time.Millisecond*200, res.Duration) 428 | 429 | //第二个策略页超出 430 | res, err = limiter.Get(id, policy...) 431 | assert.Equal(2, res.Total) 432 | assert.Equal(0, res.Remaining) 433 | assert.Equal(time.Millisecond*200, res.Duration) 434 | res, err = limiter.Get(id, policy...) 435 | assert.Equal(-1, res.Remaining) 436 | 437 | //等待第二个过期,走第三个,然后第三个超出 438 | time.Sleep(res.Duration + time.Millisecond) 439 | res, err = limiter.Get(id, policy...) 440 | assert.Equal(3, res.Total) 441 | assert.Equal(2, res.Remaining) 442 | assert.Equal(time.Millisecond*300, res.Duration) 443 | 444 | res, err = limiter.Get(id, policy...) 445 | assert.Equal(3, res.Total) 446 | assert.Equal(1, res.Remaining) 447 | assert.Equal(time.Millisecond*300, res.Duration) 448 | res, err = limiter.Get(id, policy...) 449 | 450 | assert.Equal(3, res.Total) 451 | assert.Equal(0, res.Remaining) 452 | assert.Equal(time.Millisecond*300, res.Duration) 453 | res, err = limiter.Get(id, policy...) 454 | assert.Equal(-1, res.Remaining) 455 | 456 | //等待第三个过期,走第四个,然后第四个也过期 457 | time.Sleep(res.Duration + time.Millisecond) 458 | res, err = limiter.Get(id, policy...) 459 | assert.Equal(3, res.Total) 460 | assert.Equal(2, res.Remaining) 461 | assert.Equal(time.Millisecond*400, res.Duration) 462 | 463 | res, err = limiter.Get(id, policy...) 464 | assert.Equal(3, res.Total) 465 | assert.Equal(1, res.Remaining) 466 | 467 | res, err = limiter.Get(id, policy...) 468 | assert.Equal(3, res.Total) 469 | assert.Equal(0, res.Remaining) 470 | 471 | res, err = limiter.Get(id, policy...) 472 | assert.Equal(3, res.Total) 473 | assert.Equal(-1, res.Remaining) 474 | 475 | //等待第四个策略过期,还是走第四个策略,因为还在第三个策略2倍时间内 476 | time.Sleep(res.Duration + time.Millisecond) 477 | res, err = limiter.Get(id, policy...) 478 | assert.Equal(3, res.Total) 479 | assert.Equal(2, res.Remaining) 480 | assert.Equal(time.Millisecond*400, res.Duration) 481 | 482 | //第四个策略第二次过期,恢复走第一个。 483 | time.Sleep(res.Duration + time.Millisecond) 484 | res, err = limiter.Get(id, policy...) 485 | assert.Equal(2, res.Total) 486 | assert.Equal(1, res.Remaining) 487 | assert.Equal(time.Millisecond*150, res.Duration) 488 | 489 | }) 490 | t.Run("limiter.Get with different policy time situation for expired", func(t *testing.T) { 491 | assert := assert.New(t) 492 | 493 | var id = genID() 494 | 495 | limiter := ratelimiter.New(ratelimiter.Options{ 496 | Client: &redisClient{client}, 497 | }) 498 | 499 | policy := []int{2, 300, 3, 100} 500 | 501 | //默认走第一个策略 502 | res, err := limiter.Get(id, policy...) 503 | assert.Nil(err) 504 | assert.Equal(2, res.Total) 505 | assert.Equal(1, res.Remaining) 506 | assert.Equal(time.Millisecond*300, res.Duration) 507 | 508 | //第一个策略超出 509 | res, err = limiter.Get(id, policy...) 510 | res, err = limiter.Get(id, policy...) 511 | assert.Equal(-1, res.Remaining) 512 | assert.Equal(time.Millisecond*300, res.Duration) 513 | 514 | //等待第一个策略过期, 然后走第二个策略, 515 | time.Sleep(res.Duration + time.Millisecond) 516 | res, err = limiter.Get(id, policy...) 517 | assert.Equal(3, res.Total) 518 | assert.Equal(2, res.Remaining) 519 | assert.Equal(time.Millisecond*100, res.Duration) 520 | 521 | //第一次正常过期, 522 | time.Sleep(res.Duration + time.Millisecond) 523 | res, err = limiter.Get(id, policy...) 524 | assert.Equal(3, res.Total) 525 | assert.Equal(2, res.Remaining) 526 | assert.Equal(time.Millisecond*100, res.Duration) 527 | 528 | ///第二次正常过期 529 | time.Sleep(res.Duration + time.Millisecond) 530 | res, err = limiter.Get(id, policy...) 531 | assert.Equal(3, res.Total) 532 | assert.Equal(2, res.Remaining) 533 | assert.Equal(time.Millisecond*100, res.Duration) 534 | 535 | ///第三次正常过期,恢复到第一个 536 | time.Sleep(res.Duration + time.Millisecond) 537 | res, err = limiter.Get(id, policy...) 538 | assert.Nil(err) 539 | assert.Equal(2, res.Total) 540 | assert.Equal(1, res.Remaining) 541 | assert.Equal(time.Millisecond*300, res.Duration) 542 | 543 | //==========然后第一个策略又超出了 544 | res, err = limiter.Get(id, policy...) 545 | res, err = limiter.Get(id, policy...) 546 | assert.Equal(-1, res.Remaining) 547 | assert.Equal(time.Millisecond*300, res.Duration) 548 | 549 | //等待第一个策略过期, 550 | time.Sleep(res.Duration + time.Millisecond) 551 | //走第二个策略(第一次), 552 | res, err = limiter.Get(id, policy...) 553 | assert.Equal(3, res.Total) 554 | assert.Equal(2, res.Remaining) 555 | assert.Equal(time.Millisecond*100, res.Duration) 556 | 557 | // 第二个策略超过, 558 | res, err = limiter.Get(id, policy...) 559 | assert.Equal(3, res.Total) 560 | assert.Equal(1, res.Remaining) 561 | assert.Equal(time.Millisecond*100, res.Duration) 562 | res, err = limiter.Get(id, policy...) 563 | assert.Equal(3, res.Total) 564 | assert.Equal(0, res.Remaining) 565 | assert.Equal(time.Millisecond*100, res.Duration) 566 | 567 | //等待过期 568 | time.Sleep(res.Duration + time.Millisecond) 569 | 570 | //走第二个策略(第二次),在第二个策略二倍时间内 571 | res, err = limiter.Get(id, policy...) 572 | assert.Equal(3, res.Total) 573 | assert.Equal(2, res.Remaining) 574 | assert.Equal(time.Millisecond*100, res.Duration) 575 | 576 | //第二个策略继续超出,延长2倍时间 577 | res, err = limiter.Get(id, policy...) 578 | assert.Equal(3, res.Total) 579 | assert.Equal(1, res.Remaining) 580 | assert.Equal(time.Millisecond*100, res.Duration) 581 | res, err = limiter.Get(id, policy...) 582 | assert.Equal(3, res.Total) 583 | assert.Equal(0, res.Remaining) 584 | assert.Equal(time.Millisecond*100, res.Duration) 585 | 586 | //等待过期 587 | time.Sleep(res.Duration + time.Millisecond) 588 | //然后走第二个策略,在第二个策略二倍时间内(被延长过)。 如果一直超出被停留在第二次 589 | res, err = limiter.Get(id, policy...) 590 | assert.Equal(3, res.Total) 591 | assert.Equal(2, res.Remaining) 592 | assert.Equal(time.Millisecond*100, res.Duration) 593 | 594 | //第二个策略第二次过期了,没有被延长 595 | time.Sleep(res.Duration + time.Millisecond) 596 | //恢复到第一个 597 | time.Sleep(res.Duration + time.Millisecond) 598 | res, err = limiter.Get(id, policy...) 599 | assert.Nil(err) 600 | assert.Equal(2, res.Total) 601 | assert.Equal(1, res.Remaining) 602 | assert.Equal(time.Millisecond*300, res.Duration) 603 | }) 604 | t.Run("limiter.Get with normal situation for expired", func(t *testing.T) { 605 | assert := assert.New(t) 606 | 607 | var id = genID() 608 | limiter := ratelimiter.New(ratelimiter.Options{ 609 | Client: &redisClient{client}, 610 | }) 611 | 612 | policy := []int{3, 300, 2, 200} 613 | 614 | res, err := limiter.Get(id, policy...) 615 | assert.Nil(err) 616 | assert.Equal(3, res.Total) 617 | assert.Equal(2, res.Remaining) 618 | assert.Equal(time.Millisecond*300, res.Duration) 619 | 620 | res, err = limiter.Get(id, policy...) 621 | assert.Equal(3, res.Total) 622 | assert.Equal(1, res.Remaining) 623 | assert.Equal(time.Millisecond*300, res.Duration) 624 | 625 | res, err = limiter.Get(id, policy...) 626 | assert.Equal(3, res.Total) 627 | assert.Equal(0, res.Remaining) 628 | assert.Equal(time.Millisecond*300, res.Duration) 629 | res, err = limiter.Get(id, policy...) 630 | assert.Equal(-1, res.Remaining) 631 | 632 | //等待过期,然后走第二个 633 | time.Sleep(res.Duration + time.Millisecond) 634 | res, err = limiter.Get(id, policy...) 635 | assert.Equal(2, res.Total) 636 | assert.Equal(1, res.Remaining) 637 | assert.Equal(time.Millisecond*200, res.Duration) 638 | 639 | //第二策略正常过期 640 | time.Sleep(res.Duration + time.Millisecond) 641 | res, err = limiter.Get(id, policy...) 642 | assert.Equal(2, res.Total) 643 | assert.Equal(1, res.Remaining) 644 | assert.Equal(time.Millisecond*200, res.Duration) 645 | 646 | //第二策略第二次正常过期,恢复到第一个 647 | time.Sleep(res.Duration + time.Millisecond) 648 | res, err = limiter.Get(id, policy...) 649 | assert.Equal(3, res.Total) 650 | assert.Equal(2, res.Remaining) 651 | assert.Equal(time.Millisecond*300, res.Duration) 652 | 653 | }) 654 | t.Run("ratelimiter.New, Chaos", func(t *testing.T) { 655 | t.Run("10 limiters work for one id", func(t *testing.T) { 656 | assert := assert.New(t) 657 | 658 | var wg sync.WaitGroup 659 | var id = genID() 660 | var result = NewResult(make([]int, 10000)) 661 | var redisOptions = redis.Options{Addr: "localhost:6379"} 662 | 663 | var worker = func(c *redis.Client, l *ratelimiter.Limiter) { 664 | defer wg.Done() 665 | defer c.Close() 666 | 667 | for i := 0; i < 1000; i++ { 668 | res, err := l.Get(id) 669 | assert.Nil(err) 670 | result.Push(res.Remaining) 671 | } 672 | } 673 | 674 | wg.Add(10) 675 | for i := 0; i < 10; i++ { 676 | client := redis.NewClient(&redisOptions) 677 | limiter := ratelimiter.New(ratelimiter.Options{Client: &redisClient{client}, Max: 9998}) 678 | go worker(client, limiter) 679 | } 680 | 681 | wg.Wait() 682 | s := result.Value() 683 | sort.Ints(s) // [-1 -1 0 1 2 3 ... 9997 9997] 684 | assert.Equal(s[0], -1) 685 | for i := 1; i < 10000; i++ { 686 | assert.Equal(s[i], i-2) 687 | } 688 | }) 689 | }) 690 | 691 | t.Run("ratelimiter.New with redis ring, Chaos", func(t *testing.T) { 692 | t.Run("10 limiters work for one id", func(t *testing.T) { 693 | t.Skip("Can't run in travis") 694 | assert := assert.New(t) 695 | 696 | var wg sync.WaitGroup 697 | var id = genID() 698 | var result = NewResult(make([]int, 10000)) 699 | var redisOptions = redis.RingOptions{Addrs: map[string]string{ 700 | "a": "localhost:6379", 701 | "b": "localhost:6380", 702 | }} 703 | 704 | var worker = func(c *redis.Ring, l *ratelimiter.Limiter) { 705 | defer wg.Done() 706 | defer c.Close() 707 | 708 | for i := 0; i < 1000; i++ { 709 | res, err := l.Get(id) 710 | assert.Nil(err) 711 | result.Push(res.Remaining) 712 | } 713 | } 714 | 715 | wg.Add(10) 716 | for i := 0; i < 10; i++ { 717 | client := redis.NewRing(&redisOptions) 718 | limiter := ratelimiter.New(ratelimiter.Options{Client: &ringClient{client}, Max: 9998}) 719 | go worker(client, limiter) 720 | } 721 | 722 | wg.Wait() 723 | s := result.Value() 724 | sort.Ints(s) // [-1 -1 0 1 2 3 ... 9997 9997] 725 | assert.Equal(s[0], -1) 726 | for i := 1; i < 10000; i++ { 727 | assert.Equal(s[i], i-2) 728 | } 729 | }) 730 | }) 731 | 732 | t.Run("ratelimiter.New with redis cluster, Chaos", func(t *testing.T) { 733 | t.Run("10 limiters work for one id", func(t *testing.T) { 734 | t.Skip("Can't run in travis") 735 | assert := assert.New(t) 736 | 737 | var wg sync.WaitGroup 738 | var id = genID() 739 | var result = NewResult(make([]int, 10000)) 740 | var redisOptions = redis.ClusterOptions{Addrs: []string{ 741 | "localhost:7000", 742 | "localhost:7001", 743 | "localhost:7002", 744 | "localhost:7003", 745 | "localhost:7004", 746 | "localhost:7005", 747 | }} 748 | 749 | var worker = func(c *redis.ClusterClient, l *ratelimiter.Limiter) { 750 | defer wg.Done() 751 | defer c.Close() 752 | 753 | for i := 0; i < 1000; i++ { 754 | res, err := l.Get(id) 755 | assert.Nil(err) 756 | result.Push(res.Remaining) 757 | } 758 | } 759 | 760 | wg.Add(10) 761 | for i := 0; i < 10; i++ { 762 | client := redis.NewClusterClient(&redisOptions) 763 | limiter := ratelimiter.New(ratelimiter.Options{Client: &clusterClient{client}, Max: 9998}) 764 | go worker(client, limiter) 765 | } 766 | 767 | wg.Wait() 768 | s := result.Value() 769 | sort.Ints(s) // [-1 -1 0 1 2 3 ... 9997 9997] 770 | assert.Equal(s[0], -1) 771 | for i := 1; i < 10000; i++ { 772 | assert.Equal(s[i], i-2) 773 | } 774 | }) 775 | }) 776 | t.Run("ratelimiter with no redis machine should be", func(t *testing.T) { 777 | assert := assert.New(t) 778 | var client = redis.NewClient(&redis.Options{ 779 | Addr: "localhost:6399", 780 | }) 781 | assert.Panics(func() { 782 | ratelimiter.New(ratelimiter.Options{Client: &redisClient{client}}) 783 | }) 784 | }) 785 | t.Run("ratelimiter with redisFailedClient should be", func(t *testing.T) { 786 | assert := assert.New(t) 787 | 788 | var limiter *ratelimiter.Limiter 789 | var client = redis.NewClient(&redis.Options{ 790 | Addr: "localhost:6379", 791 | }) 792 | 793 | t.Run("ratelimiter.New", func(t *testing.T) { 794 | limiter = ratelimiter.New(ratelimiter.Options{Client: &redisFailedClient{client}}) 795 | }) 796 | policy := []int{2, 100, 2, 200, 1, 300} 797 | id := genID() 798 | res, err := limiter.Get(id, policy...) 799 | assert.Equal("NOSCRIPT mock error", err.Error()) 800 | 801 | assert.Equal(0, res.Total) 802 | assert.Equal(0, res.Remaining) 803 | assert.Equal(time.Duration(0), res.Duration) 804 | 805 | }) 806 | } 807 | 808 | func genID() string { 809 | buf := make([]byte, 12) 810 | _, err := rand.Read(buf) 811 | if err != nil { 812 | panic(err) 813 | } 814 | return hex.EncodeToString(buf) 815 | } 816 | 817 | type Result struct { 818 | i int 819 | len int 820 | s []int 821 | m sync.Mutex 822 | } 823 | 824 | func NewResult(s []int) Result { 825 | return Result{s: s, len: len(s)} 826 | } 827 | 828 | func (r *Result) Push(val int) { 829 | r.m.Lock() 830 | if r.i == r.len { 831 | panic(errors.New("Result overflow")) 832 | } 833 | r.s[r.i] = val 834 | r.i++ 835 | r.m.Unlock() 836 | } 837 | 838 | func (r *Result) Value() []int { 839 | return r.s 840 | } 841 | --------------------------------------------------------------------------------