├── go.mod ├── limit_key.go ├── _example ├── go.mod ├── go.sum └── main.go ├── .github └── workflows │ ├── ci.yml │ └── benchmark.yml ├── context.go ├── go.sum ├── LICENSE ├── httprate_test.go ├── local_counter.go ├── local_counter_test.go ├── httprate.go ├── README.md ├── limiter.go └── limiter_test.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-chi/httprate 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.1 6 | 7 | require github.com/zeebo/xxh3 v1.0.2 8 | 9 | require golang.org/x/sys v0.30.0 // indirect 10 | 11 | require ( 12 | github.com/klauspost/cpuid/v2 v2.2.10 // indirect 13 | golang.org/x/sync v0.12.0 14 | ) 15 | -------------------------------------------------------------------------------- /limit_key.go: -------------------------------------------------------------------------------- 1 | package httprate 2 | 3 | import ( 4 | "strconv" 5 | "time" 6 | 7 | "github.com/zeebo/xxh3" 8 | ) 9 | 10 | func LimitCounterKey(key string, window time.Time) uint64 { 11 | h := xxh3.New() 12 | h.WriteString(key) 13 | h.WriteString(strconv.FormatInt(window.Unix(), 10)) 14 | return h.Sum64() 15 | } 16 | -------------------------------------------------------------------------------- /_example/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-chi/httprate/_example 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.1 6 | 7 | replace github.com/go-chi/httprate => ../ 8 | 9 | require ( 10 | github.com/go-chi/chi/v5 v5.1.0 11 | github.com/go-chi/httprate v0.0.0-00010101000000-000000000000 12 | ) 13 | 14 | require ( 15 | github.com/klauspost/cpuid/v2 v2.2.10 // indirect 16 | github.com/zeebo/xxh3 v1.0.2 // indirect 17 | golang.org/x/sys v0.30.0 // indirect 18 | ) 19 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | tests: 9 | name: Tests 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Set up Go 13 | uses: actions/setup-go@v5 14 | with: 15 | go-version: ^1.17 16 | 17 | - name: Check out code into the Go module directory 18 | uses: actions/checkout@v4 19 | 20 | - name: Get dependencies 21 | run: go get -v -t -d ./... 22 | 23 | - name: Build 24 | run: go build -v ./ 25 | 26 | - name: Build example 27 | run: cd ./_example && go build -v ./ 28 | 29 | - name: Test 30 | run: go test -v ./... 31 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package httprate 2 | 3 | import "context" 4 | 5 | type ctxKey int 6 | 7 | const ( 8 | incrementKey ctxKey = iota 9 | requestLimitKey 10 | ) 11 | 12 | func WithIncrement(ctx context.Context, value int) context.Context { 13 | return context.WithValue(ctx, incrementKey, value) 14 | } 15 | 16 | func getIncrement(ctx context.Context) int { 17 | if value, ok := ctx.Value(incrementKey).(int); ok { 18 | return value 19 | } 20 | return 1 21 | } 22 | 23 | func WithRequestLimit(ctx context.Context, value int) context.Context { 24 | return context.WithValue(ctx, requestLimitKey, value) 25 | } 26 | 27 | func getRequestLimit(ctx context.Context) int { 28 | if value, ok := ctx.Value(requestLimitKey).(int); ok { 29 | return value 30 | } 31 | return 0 32 | } 33 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= 2 | github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= 3 | github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= 4 | github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= 5 | github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= 6 | github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= 7 | golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= 8 | golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 9 | golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= 10 | golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 11 | -------------------------------------------------------------------------------- /_example/go.sum: -------------------------------------------------------------------------------- 1 | github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= 2 | github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= 3 | github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= 4 | github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= 5 | github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= 6 | github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= 7 | github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= 8 | github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= 9 | golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= 10 | golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 11 | golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= 12 | golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015-present Peter Kieltyka (https://github.com/pkieltyka). 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /httprate_test.go: -------------------------------------------------------------------------------- 1 | package httprate 2 | 3 | import "testing" 4 | 5 | func Test_canonicalizeIP(t *testing.T) { 6 | tests := []struct { 7 | name string 8 | ip string 9 | want string 10 | }{ 11 | { 12 | name: "IPv4 unchanged", 13 | ip: "1.2.3.4", 14 | want: "1.2.3.4", 15 | }, 16 | { 17 | name: "bad IP unchanged", 18 | ip: "not an IP", 19 | want: "not an IP", 20 | }, 21 | { 22 | name: "bad IPv6 unchanged", 23 | ip: "not:an:IP", 24 | want: "not:an:IP", 25 | }, 26 | { 27 | name: "empty string unchanged", 28 | ip: "", 29 | want: "", 30 | }, 31 | { 32 | name: "IPv6 test 1", 33 | ip: "2001:DB8::21f:5bff:febf:ce22:8a2e", 34 | want: "2001:db8:0:21f::", 35 | }, 36 | { 37 | name: "IPv6 test 2", 38 | ip: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", 39 | want: "2001:db8:85a3::", 40 | }, 41 | { 42 | name: "IPv6 test 3", 43 | ip: "fe80::1ff:fe23:4567:890a", 44 | want: "fe80::", 45 | }, 46 | { 47 | name: "IPv6 test 4", 48 | ip: "f:f:f:f:f:f:f:f", 49 | want: "f:f:f:f::", 50 | }, 51 | } 52 | for _, tt := range tests { 53 | t.Run(tt.name, func(t *testing.T) { 54 | if got := canonicalizeIP(tt.ip); got != tt.want { 55 | t.Errorf("canonicalizeIP() = %v, want %v", got, tt.want) 56 | } 57 | }) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /.github/workflows/benchmark.yml: -------------------------------------------------------------------------------- 1 | name: Bechmark 2 | 3 | on: 4 | pull_request_target: 5 | 6 | permissions: 7 | contents: read 8 | pull-requests: write 9 | 10 | jobs: 11 | benchmark: 12 | name: Benchmark 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Set up Go 16 | uses: actions/setup-go@v5 17 | with: 18 | go-version: ^1.17 19 | 20 | - name: Git clone (master) 21 | uses: actions/checkout@v4 22 | with: 23 | ref: master 24 | 25 | - name: Run benchmark (master) 26 | run: go test -bench=. -count=10 -benchmem | tee /tmp/master.txt 27 | 28 | - name: Git clone (PR) 29 | uses: actions/checkout@v4 30 | 31 | - name: Run benchmark (PR) 32 | run: go test -bench=. -count=10 -benchmem | tee /tmp/pr.txt 33 | 34 | - name: Install benchstat 35 | run: go install golang.org/x/perf/cmd/benchstat@latest 36 | 37 | - name: Run benchstat 38 | run: cd /tmp && benchstat master.txt pr.txt | tee /tmp/result.txt 39 | 40 | - name: Comment on PR with benchmark results 41 | uses: actions/github-script@v6 42 | with: 43 | script: | 44 | const fs = require('fs'); 45 | const results = fs.readFileSync('/tmp/result.txt', 'utf8'); 46 | const issue_number = context.payload.pull_request.number; 47 | const { owner, repo } = context.repo; 48 | 49 | await github.rest.issues.createComment({ 50 | owner, 51 | repo, 52 | issue_number, 53 | body: `### Benchmark Results\n\n\`\`\`\n${results}\n\`\`\`` 54 | }); 55 | -------------------------------------------------------------------------------- /_example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "log" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/go-chi/chi/v5" 11 | "github.com/go-chi/chi/v5/middleware" 12 | "github.com/go-chi/httprate" 13 | ) 14 | 15 | func main() { 16 | r := chi.NewRouter() 17 | r.Use(middleware.Logger) 18 | 19 | // Rate-limit all routes at 1000 req/min by IP address. 20 | r.Use(httprate.LimitByIP(1000, time.Minute)) 21 | 22 | r.Route("/admin", func(r chi.Router) { 23 | r.Use(func(next http.Handler) http.Handler { 24 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 25 | // Note: This is a mock middleware to set a userID on the request context 26 | next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "userID", "123"))) 27 | }) 28 | }) 29 | 30 | // Rate-limit admin routes at 10 req/s by userID. 31 | r.Use(httprate.Limit( 32 | 10, time.Second, 33 | httprate.WithKeyFuncs(func(r *http.Request) (string, error) { 34 | token, _ := r.Context().Value("userID").(string) 35 | return token, nil 36 | }), 37 | )) 38 | 39 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 40 | w.Write([]byte("admin at 10 req/s\n")) 41 | }) 42 | }) 43 | 44 | // Rate-limiter for login endpoint. 45 | loginRateLimiter := httprate.NewRateLimiter(5, time.Minute) 46 | 47 | r.Post("/login", func(w http.ResponseWriter, r *http.Request) { 48 | var payload struct { 49 | Username string `json:"username"` 50 | Password string `json:"password"` 51 | } 52 | err := json.NewDecoder(r.Body).Decode(&payload) 53 | if err != nil || payload.Username == "" || payload.Password == "" { 54 | w.WriteHeader(400) 55 | return 56 | } 57 | 58 | // Rate-limit login at 5 req/min. 59 | if loginRateLimiter.RespondOnLimit(w, r, payload.Username) { 60 | return 61 | } 62 | 63 | w.Write([]byte("login at 5 req/min\n")) 64 | }) 65 | 66 | log.Printf("Serving at localhost:3333") 67 | log.Println() 68 | log.Printf("Try running:") 69 | log.Printf(`curl -v http://localhost:3333?[0-1000]`) 70 | log.Printf(`curl -v http://localhost:3333/admin?[1-12]`) 71 | log.Printf(`curl -v http://localhost:3333/login\?[1-8] --data '{"username":"alice","password":"***"}'`) 72 | 73 | http.ListenAndServe(":3333", r) 74 | } 75 | -------------------------------------------------------------------------------- /local_counter.go: -------------------------------------------------------------------------------- 1 | package httprate 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/zeebo/xxh3" 8 | ) 9 | 10 | // NewLocalLimitCounter creates an instance of localCounter, 11 | // which is an in-memory implementation of http.LimitCounter. 12 | // 13 | // All methods are guaranteed to always return nil error. 14 | func NewLocalLimitCounter(windowLength time.Duration) *localCounter { 15 | return &localCounter{ 16 | windowLength: windowLength, 17 | latestWindow: time.Now().UTC().Truncate(windowLength), 18 | latestCounters: make(map[uint64]int), 19 | previousCounters: make(map[uint64]int), 20 | } 21 | } 22 | 23 | var _ LimitCounter = (*localCounter)(nil) 24 | 25 | type localCounter struct { 26 | windowLength time.Duration 27 | latestWindow time.Time 28 | latestCounters map[uint64]int 29 | previousCounters map[uint64]int 30 | mu sync.RWMutex 31 | } 32 | 33 | func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount int) error { 34 | c.mu.Lock() 35 | defer c.mu.Unlock() 36 | 37 | c.evict(currentWindow) 38 | 39 | hkey := limitCounterKey(key) 40 | 41 | count, _ := c.latestCounters[hkey] 42 | c.latestCounters[hkey] = count + amount 43 | 44 | return nil 45 | } 46 | 47 | func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time) (int, int, error) { 48 | c.mu.RLock() 49 | defer c.mu.RUnlock() 50 | 51 | if c.latestWindow == currentWindow { 52 | curr, _ := c.latestCounters[limitCounterKey(key)] 53 | prev, _ := c.previousCounters[limitCounterKey(key)] 54 | return curr, prev, nil 55 | } 56 | 57 | if c.latestWindow == previousWindow { 58 | prev, _ := c.latestCounters[limitCounterKey(key)] 59 | return 0, prev, nil 60 | } 61 | 62 | return 0, 0, nil 63 | } 64 | 65 | func (c *localCounter) Config(requestLimit int, windowLength time.Duration) { 66 | c.windowLength = windowLength 67 | c.latestWindow = time.Now().UTC().Truncate(windowLength) 68 | } 69 | 70 | func (c *localCounter) Increment(key string, currentWindow time.Time) error { 71 | return c.IncrementBy(key, currentWindow, 1) 72 | } 73 | 74 | func (c *localCounter) evict(currentWindow time.Time) { 75 | if c.latestWindow == currentWindow { 76 | return 77 | } 78 | 79 | previousWindow := currentWindow.Add(-c.windowLength) 80 | if c.latestWindow == previousWindow { 81 | c.latestWindow = currentWindow 82 | // Shift the windows without map re-allocation. 83 | clear(c.previousCounters) 84 | c.latestCounters, c.previousCounters = c.previousCounters, c.latestCounters 85 | return 86 | } 87 | 88 | c.latestWindow = currentWindow 89 | 90 | clear(c.previousCounters) 91 | clear(c.latestCounters) 92 | } 93 | 94 | func limitCounterKey(key string) uint64 { 95 | h := xxh3.New() 96 | h.WriteString(key) 97 | return h.Sum64() 98 | } 99 | -------------------------------------------------------------------------------- /local_counter_test.go: -------------------------------------------------------------------------------- 1 | package httprate_test 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/go-chi/httprate" 11 | "golang.org/x/sync/errgroup" 12 | ) 13 | 14 | func TestLocalCounter(t *testing.T) { 15 | limitCounter := httprate.NewLocalLimitCounter(time.Minute) 16 | 17 | currentWindow := time.Now().UTC().Truncate(time.Minute) 18 | previousWindow := currentWindow.Add(-time.Minute) 19 | 20 | type test struct { 21 | name string // In each test do the following: 22 | advanceTime time.Duration // 1. advance time 23 | incrBy int // 2. increase counter 24 | prev int // 3. check previous window counter 25 | curr int // and current window counter 26 | } 27 | 28 | tests := []test{ 29 | { 30 | name: "t=0m: init", 31 | prev: 0, 32 | curr: 0, 33 | }, 34 | { 35 | name: "t=0m: increment 1", 36 | incrBy: 1, 37 | prev: 0, 38 | curr: 1, 39 | }, 40 | { 41 | name: "t=0m: increment by 99", 42 | incrBy: 99, 43 | prev: 0, 44 | curr: 100, 45 | }, 46 | { 47 | name: "t=1m: move clock by 1m", 48 | advanceTime: time.Minute, 49 | prev: 100, 50 | curr: 0, 51 | }, 52 | { 53 | name: "t=1m: increment by 20", 54 | incrBy: 20, 55 | prev: 100, 56 | curr: 20, 57 | }, 58 | { 59 | name: "t=1m: increment by 20", 60 | incrBy: 20, 61 | prev: 100, 62 | curr: 40, 63 | }, 64 | { 65 | name: "t=2m: move clock by 1m", 66 | advanceTime: time.Minute, 67 | prev: 40, 68 | curr: 0, 69 | }, 70 | { 71 | name: "t=2m: incr++", 72 | incrBy: 1, 73 | prev: 40, 74 | curr: 1, 75 | }, 76 | { 77 | name: "t=2m: incr+=9", 78 | incrBy: 9, 79 | prev: 40, 80 | curr: 10, 81 | }, 82 | { 83 | name: "t=2m: incr+=20", 84 | incrBy: 20, 85 | prev: 40, 86 | curr: 30, 87 | }, 88 | { 89 | name: "t=4m: move clock by 2m", 90 | advanceTime: 2 * time.Minute, 91 | prev: 0, 92 | curr: 0, 93 | }, 94 | } 95 | 96 | concurrentRequests := 1000 97 | 98 | for _, tt := range tests { 99 | if tt.advanceTime > 0 { 100 | currentWindow = currentWindow.Add(tt.advanceTime) 101 | previousWindow = previousWindow.Add(tt.advanceTime) 102 | } 103 | 104 | if tt.incrBy > 0 { 105 | var g errgroup.Group 106 | for i := 0; i < concurrentRequests; i++ { 107 | i := i 108 | g.Go(func() error { 109 | key := fmt.Sprintf("key:%v", i) 110 | return limitCounter.IncrementBy(key, currentWindow, tt.incrBy) 111 | }) 112 | } 113 | if err := g.Wait(); err != nil { 114 | t.Errorf("%s: %v", tt.name, err) 115 | } 116 | } 117 | 118 | var g errgroup.Group 119 | for i := 0; i < concurrentRequests; i++ { 120 | i := i 121 | g.Go(func() error { 122 | key := fmt.Sprintf("key:%v", i) 123 | curr, prev, err := limitCounter.Get(key, currentWindow, previousWindow) 124 | if err != nil { 125 | return fmt.Errorf("%q: %w", key, err) 126 | } 127 | if curr != tt.curr { 128 | return fmt.Errorf("%q: unexpected curr = %v, expected %v", key, curr, tt.curr) 129 | } 130 | if prev != tt.prev { 131 | return fmt.Errorf("%q: unexpected prev = %v, expected %v", key, prev, tt.prev) 132 | } 133 | return nil 134 | }) 135 | } 136 | if err := g.Wait(); err != nil { 137 | t.Errorf("%s: %v", tt.name, err) 138 | } 139 | } 140 | } 141 | 142 | func BenchmarkLocalCounter(b *testing.B) { 143 | limitCounter := httprate.NewLocalLimitCounter(time.Minute) 144 | 145 | currentWindow := time.Now().UTC().Truncate(time.Minute) 146 | previousWindow := currentWindow.Add(-time.Minute) 147 | 148 | b.ResetTimer() 149 | 150 | for i := 0; i < b.N; i++ { 151 | for i := range []int{0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0, 0, 1, 0} { 152 | // Simulate time. 153 | currentWindow.Add(time.Duration(i) * time.Minute) 154 | previousWindow.Add(time.Duration(i) * time.Minute) 155 | 156 | wg := sync.WaitGroup{} 157 | wg.Add(1000) 158 | for i := 0; i < 1000; i++ { 159 | // Simulate concurrent requests with different rate-limit keys. 160 | go func(i int) { 161 | defer wg.Done() 162 | 163 | _, _, _ = limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) 164 | _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, rand.Intn(100)) 165 | }(i) 166 | } 167 | wg.Wait() 168 | } 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /httprate.go: -------------------------------------------------------------------------------- 1 | package httprate 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | "strings" 7 | "time" 8 | ) 9 | 10 | func Limit(requestLimit int, windowLength time.Duration, options ...Option) func(next http.Handler) http.Handler { 11 | return NewRateLimiter(requestLimit, windowLength, options...).Handler 12 | } 13 | 14 | type KeyFunc func(r *http.Request) (string, error) 15 | type Option func(rl *RateLimiter) 16 | 17 | // Set custom response headers. If empty, the header is omitted. 18 | type ResponseHeaders struct { 19 | Limit string // Default: X-RateLimit-Limit 20 | Remaining string // Default: X-RateLimit-Remaining 21 | Increment string // Default: X-RateLimit-Increment 22 | Reset string // Default: X-RateLimit-Reset 23 | RetryAfter string // Default: Retry-After 24 | } 25 | 26 | func LimitAll(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler { 27 | return Limit(requestLimit, windowLength) 28 | } 29 | 30 | func LimitByIP(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler { 31 | return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByIP)) 32 | } 33 | 34 | func LimitByRealIP(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler { 35 | return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByRealIP)) 36 | } 37 | 38 | func Key(key string) func(r *http.Request) (string, error) { 39 | return func(r *http.Request) (string, error) { 40 | return key, nil 41 | } 42 | } 43 | 44 | func KeyByIP(r *http.Request) (string, error) { 45 | ip, _, err := net.SplitHostPort(r.RemoteAddr) 46 | if err != nil { 47 | ip = r.RemoteAddr 48 | } 49 | return canonicalizeIP(ip), nil 50 | } 51 | 52 | func KeyByRealIP(r *http.Request) (string, error) { 53 | var ip string 54 | 55 | if tcip := r.Header.Get("True-Client-IP"); tcip != "" { 56 | ip = tcip 57 | } else if xrip := r.Header.Get("X-Real-IP"); xrip != "" { 58 | ip = xrip 59 | } else if xff := r.Header.Get("X-Forwarded-For"); xff != "" { 60 | i := strings.Index(xff, ", ") 61 | if i == -1 { 62 | i = len(xff) 63 | } 64 | ip = xff[:i] 65 | } else { 66 | var err error 67 | ip, _, err = net.SplitHostPort(r.RemoteAddr) 68 | if err != nil { 69 | ip = r.RemoteAddr 70 | } 71 | } 72 | 73 | return canonicalizeIP(ip), nil 74 | } 75 | 76 | func KeyByEndpoint(r *http.Request) (string, error) { 77 | return r.URL.Path, nil 78 | } 79 | 80 | func WithKeyFuncs(keyFuncs ...KeyFunc) Option { 81 | return func(rl *RateLimiter) { 82 | if len(keyFuncs) > 0 { 83 | rl.keyFn = composedKeyFunc(keyFuncs...) 84 | } 85 | } 86 | } 87 | 88 | func WithKeyByIP() Option { 89 | return WithKeyFuncs(KeyByIP) 90 | } 91 | 92 | func WithKeyByRealIP() Option { 93 | return WithKeyFuncs(KeyByRealIP) 94 | } 95 | 96 | func WithLimitHandler(h http.HandlerFunc) Option { 97 | return func(rl *RateLimiter) { 98 | rl.onRateLimited = h 99 | } 100 | } 101 | 102 | func WithErrorHandler(h func(http.ResponseWriter, *http.Request, error)) Option { 103 | return func(rl *RateLimiter) { 104 | rl.onError = h 105 | } 106 | } 107 | 108 | func WithLimitCounter(c LimitCounter) Option { 109 | return func(rl *RateLimiter) { 110 | rl.limitCounter = c 111 | } 112 | } 113 | 114 | func WithResponseHeaders(headers ResponseHeaders) Option { 115 | return func(rl *RateLimiter) { 116 | rl.headers = headers 117 | } 118 | } 119 | 120 | func WithNoop() Option { 121 | return func(rl *RateLimiter) {} 122 | } 123 | 124 | func composedKeyFunc(keyFuncs ...KeyFunc) KeyFunc { 125 | return func(r *http.Request) (string, error) { 126 | var key strings.Builder 127 | for i := 0; i < len(keyFuncs); i++ { 128 | k, err := keyFuncs[i](r) 129 | if err != nil { 130 | return "", err 131 | } 132 | key.WriteString(k) 133 | key.WriteRune(':') 134 | } 135 | return key.String(), nil 136 | } 137 | } 138 | 139 | // canonicalizeIP returns a form of ip suitable for comparison to other IPs. 140 | // For IPv4 addresses, this is simply the whole string. 141 | // For IPv6 addresses, this is the /64 prefix. 142 | func canonicalizeIP(ip string) string { 143 | isIPv6 := false 144 | // This is how net.ParseIP decides if an address is IPv6 145 | // https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/ip.go;l=704 146 | for i := 0; !isIPv6 && i < len(ip); i++ { 147 | switch ip[i] { 148 | case '.': 149 | // IPv4 150 | return ip 151 | case ':': 152 | // IPv6 153 | isIPv6 = true 154 | break 155 | } 156 | } 157 | if !isIPv6 { 158 | // Not an IP address at all 159 | return ip 160 | } 161 | 162 | ipv6 := net.ParseIP(ip) 163 | if ipv6 == nil { 164 | return ip 165 | } 166 | 167 | return ipv6.Mask(net.CIDRMask(64, 128)).String() 168 | } 169 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # httprate - HTTP Rate Limiter 2 | 3 | ![CI workflow](https://github.com/go-chi/httprate/actions/workflows/ci.yml/badge.svg) 4 | ![Benchmark workflow](https://github.com/go-chi/httprate/actions/workflows/benchmark.yml/badge.svg) 5 | [![GoDoc Widget]][GoDoc] 6 | 7 | [GoDoc]: https://pkg.go.dev/github.com/go-chi/httprate 8 | [GoDoc Widget]: https://godoc.org/github.com/go-chi/httprate?status.svg 9 | 10 | `net/http` request rate limiter based on the Sliding Window Counter pattern inspired by 11 | CloudFlare https://blog.cloudflare.com/counting-things-a-lot-of-different-things. 12 | 13 | The sliding window counter pattern is accurate, smooths traffic and offers a simple counter 14 | design to share a rate-limit among a cluster of servers. For example, if you'd like 15 | to use redis to coordinate a rate-limit across a group of microservices you just need 16 | to implement the `httprate.LimitCounter` interface to support an atomic increment and get. 17 | 18 | ## Backends 19 | 20 | - [x] Local in-memory backend (default) 21 | - [x] Redis backend: https://github.com/go-chi/httprate-redis 22 | 23 | ## Example 24 | 25 | ```go 26 | package main 27 | 28 | import ( 29 | "net/http" 30 | "time" 31 | 32 | "github.com/go-chi/chi/v5" 33 | "github.com/go-chi/chi/v5/middleware" 34 | "github.com/go-chi/httprate" 35 | ) 36 | 37 | func main() { 38 | r := chi.NewRouter() 39 | r.Use(middleware.Logger) 40 | 41 | // Enable httprate request limiter of 100 requests per minute. 42 | // 43 | // In the code example below, rate-limiting is bound to the request IP address 44 | // via the LimitByIP middleware handler. 45 | // 46 | // To have a single rate-limiter for all requests, use httprate.LimitAll(..). 47 | // 48 | // Please see _example/main.go for other more, or read the library code. 49 | r.Use(httprate.LimitByIP(100, time.Minute)) 50 | 51 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 52 | w.Write([]byte(".")) 53 | }) 54 | 55 | http.ListenAndServe(":3333", r) 56 | } 57 | ``` 58 | 59 | ## Common use cases 60 | 61 | ### Rate limit by IP and URL path (aka endpoint) 62 | ```go 63 | r.Use(httprate.Limit( 64 | 10, // requests 65 | 10*time.Second, // per duration 66 | httprate.WithKeyFuncs(httprate.KeyByIP, httprate.KeyByEndpoint), 67 | )) 68 | ``` 69 | 70 | ### Rate limit by arbitrary keys 71 | ```go 72 | r.Use(httprate.Limit( 73 | 100, 74 | time.Minute, 75 | // an oversimplified example of rate limiting by a custom header 76 | httprate.WithKeyFuncs(func(r *http.Request) (string, error) { 77 | return r.Header.Get("X-Access-Token"), nil 78 | }), 79 | )) 80 | ``` 81 | 82 | ### Rate limit by request payload 83 | ```go 84 | // Rate-limiter for login endpoint. 85 | loginRateLimiter := httprate.NewRateLimiter(5, time.Minute) 86 | 87 | r.Post("/login", func(w http.ResponseWriter, r *http.Request) { 88 | var payload struct { 89 | Username string `json:"username"` 90 | Password string `json:"password"` 91 | } 92 | err := json.NewDecoder(r.Body).Decode(&payload) 93 | if err != nil || payload.Username == "" || payload.Password == "" { 94 | w.WriteHeader(400) 95 | return 96 | } 97 | 98 | // Rate-limit login at 5 req/min. 99 | if loginRateLimiter.RespondOnLimit(w, r, payload.Username) { 100 | return 101 | } 102 | 103 | w.Write([]byte("login at 5 req/min\n")) 104 | }) 105 | ``` 106 | 107 | ### Send specific response for rate-limited requests 108 | 109 | The default response is `HTTP 429` with `Too Many Requests` body. You can override it with: 110 | 111 | ```go 112 | r.Use(httprate.Limit( 113 | 10, 114 | time.Minute, 115 | httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { 116 | http.Error(w, `{"error": "Rate-limited. Please, slow down."}`, http.StatusTooManyRequests) 117 | }), 118 | )) 119 | ``` 120 | 121 | ### Send specific response on errors 122 | 123 | An error can be returned by: 124 | - A custom key function provided by `httprate.WithKeyFunc(customKeyFn)` 125 | - A custom backend provided by `httprateredis.WithRedisLimitCounter(customBackend)` 126 | - The default local in-memory counter is guaranteed not return any errors 127 | - Backends that fall-back to the local in-memory counter (e.g. [httprate-redis](https://github.com/go-chi/httprate-redis)) can choose not to return any errors either 128 | 129 | ```go 130 | r.Use(httprate.Limit( 131 | 10, 132 | time.Minute, 133 | httprate.WithErrorHandler(func(w http.ResponseWriter, r *http.Request, err error) { 134 | http.Error(w, fmt.Sprintf(`{"error": %q}`, err), http.StatusPreconditionRequired) 135 | }), 136 | httprate.WithLimitCounter(customBackend), 137 | )) 138 | ``` 139 | 140 | ### Send custom response headers 141 | 142 | ```go 143 | r.Use(httprate.Limit( 144 | 1000, 145 | time.Minute, 146 | httprate.WithResponseHeaders(httprate.ResponseHeaders{ 147 | Limit: "X-RateLimit-Limit", 148 | Remaining: "X-RateLimit-Remaining", 149 | Reset: "X-RateLimit-Reset", 150 | RetryAfter: "Retry-After", 151 | Increment: "", // omit 152 | }), 153 | )) 154 | ``` 155 | 156 | ### Omit response headers 157 | 158 | ```go 159 | r.Use(httprate.Limit( 160 | 1000, 161 | time.Minute, 162 | httprate.WithResponseHeaders(httprate.ResponseHeaders{}), 163 | )) 164 | ``` 165 | 166 | ## LICENSE 167 | 168 | MIT 169 | -------------------------------------------------------------------------------- /limiter.go: -------------------------------------------------------------------------------- 1 | package httprate 2 | 3 | import ( 4 | "math" 5 | "net/http" 6 | "strconv" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | type LimitCounter interface { 12 | Config(requestLimit int, windowLength time.Duration) 13 | Increment(key string, currentWindow time.Time) error 14 | IncrementBy(key string, currentWindow time.Time, amount int) error 15 | Get(key string, currentWindow, previousWindow time.Time) (int, int, error) 16 | } 17 | 18 | func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *RateLimiter { 19 | rl := &RateLimiter{ 20 | requestLimit: requestLimit, 21 | windowLength: windowLength, 22 | headers: ResponseHeaders{ 23 | Limit: "X-RateLimit-Limit", 24 | Remaining: "X-RateLimit-Remaining", 25 | Increment: "X-RateLimit-Increment", 26 | Reset: "X-RateLimit-Reset", 27 | RetryAfter: "Retry-After", 28 | }, 29 | } 30 | 31 | for _, opt := range options { 32 | opt(rl) 33 | } 34 | 35 | if rl.keyFn == nil { 36 | rl.keyFn = Key("*") 37 | } 38 | 39 | if rl.limitCounter == nil { 40 | rl.limitCounter = NewLocalLimitCounter(windowLength) 41 | } else { 42 | rl.limitCounter.Config(requestLimit, windowLength) 43 | } 44 | 45 | if rl.onRateLimited == nil { 46 | rl.onRateLimited = onRateLimited 47 | } 48 | 49 | if rl.onError == nil { 50 | rl.onError = onError 51 | } 52 | 53 | return rl 54 | } 55 | 56 | type RateLimiter struct { 57 | requestLimit int 58 | windowLength time.Duration 59 | keyFn KeyFunc 60 | limitCounter LimitCounter 61 | onRateLimited http.HandlerFunc 62 | onError func(http.ResponseWriter, *http.Request, error) 63 | headers ResponseHeaders 64 | mu sync.Mutex 65 | } 66 | 67 | // OnLimit checks the rate limit for the given key and updates the response headers accordingly. 68 | // If the limit is reached, it returns true, indicating that the request should be halted. Otherwise, 69 | // it increments the request count and returns false. This method does not send an HTTP response, 70 | // so the caller must handle the response themselves or use the RespondOnLimit() method instead. 71 | func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool { 72 | currentWindow := time.Now().UTC().Truncate(l.windowLength) 73 | ctx := r.Context() 74 | 75 | limit := l.requestLimit 76 | if val := getRequestLimit(ctx); val > 0 { 77 | limit = val 78 | } 79 | setHeader(w, l.headers.Limit, strconv.Itoa(limit)) 80 | setHeader(w, l.headers.Reset, strconv.FormatInt(currentWindow.Add(l.windowLength).Unix(), 10)) 81 | 82 | l.mu.Lock() 83 | _, rateFloat, err := l.calculateRate(key, limit) 84 | if err != nil { 85 | l.mu.Unlock() 86 | l.onError(w, r, err) 87 | return true 88 | } 89 | rate := int(math.Round(rateFloat)) 90 | 91 | increment := getIncrement(r.Context()) 92 | if increment > 1 { 93 | setHeader(w, l.headers.Increment, strconv.Itoa(increment)) 94 | } 95 | 96 | if rate+increment > limit { 97 | setHeader(w, l.headers.Remaining, strconv.Itoa(limit-rate)) 98 | 99 | l.mu.Unlock() 100 | setHeader(w, l.headers.RetryAfter, strconv.Itoa(int(l.windowLength.Seconds()))) // RFC 6585 101 | return true 102 | } 103 | 104 | err = l.limitCounter.IncrementBy(key, currentWindow, increment) 105 | if err != nil { 106 | l.mu.Unlock() 107 | l.onError(w, r, err) 108 | return true 109 | } 110 | l.mu.Unlock() 111 | 112 | setHeader(w, l.headers.Remaining, strconv.Itoa(limit-rate-increment)) 113 | return false 114 | } 115 | 116 | // RespondOnLimit checks the rate limit for the given key and updates the response headers accordingly. 117 | // If the limit is reached, it automatically sends an HTTP response and returns true, signaling the 118 | // caller to halt further request processing. If the limit is not reached, it increments the request 119 | // count and returns false, allowing the request to proceed. 120 | func (l *RateLimiter) RespondOnLimit(w http.ResponseWriter, r *http.Request, key string) bool { 121 | onLimit := l.OnLimit(w, r, key) 122 | if onLimit { 123 | l.onRateLimited(w, r) 124 | } 125 | return onLimit 126 | } 127 | 128 | func (l *RateLimiter) Counter() LimitCounter { 129 | return l.limitCounter 130 | } 131 | 132 | func (l *RateLimiter) Status(key string) (bool, float64, error) { 133 | return l.calculateRate(key, l.requestLimit) 134 | } 135 | 136 | func (l *RateLimiter) Handler(next http.Handler) http.Handler { 137 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 138 | key, err := l.keyFn(r) 139 | if err != nil { 140 | l.onError(w, r, err) 141 | return 142 | } 143 | 144 | if l.RespondOnLimit(w, r, key) { 145 | return 146 | } 147 | 148 | next.ServeHTTP(w, r) 149 | }) 150 | } 151 | 152 | func (l *RateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) { 153 | now := time.Now().UTC() 154 | currentWindow := now.Truncate(l.windowLength) 155 | previousWindow := currentWindow.Add(-l.windowLength) 156 | 157 | currCount, prevCount, err := l.limitCounter.Get(key, currentWindow, previousWindow) 158 | if err != nil { 159 | return false, 0, err 160 | } 161 | 162 | diff := now.Sub(currentWindow) 163 | rate := float64(prevCount)*(float64(l.windowLength)-float64(diff))/float64(l.windowLength) + float64(currCount) 164 | if rate > float64(requestLimit) { 165 | return false, rate, nil 166 | } 167 | 168 | return true, rate, nil 169 | } 170 | 171 | func setHeader(w http.ResponseWriter, key string, value string) { 172 | if key != "" { 173 | w.Header().Set(key, value) 174 | } 175 | } 176 | 177 | func onRateLimited(w http.ResponseWriter, r *http.Request) { 178 | http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) 179 | } 180 | 181 | func onError(w http.ResponseWriter, r *http.Request, err error) { 182 | http.Error(w, err.Error(), http.StatusPreconditionRequired) 183 | } 184 | -------------------------------------------------------------------------------- /limiter_test.go: -------------------------------------------------------------------------------- 1 | package httprate_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "io" 8 | "net/http" 9 | "net/http/httptest" 10 | "strconv" 11 | "strings" 12 | "testing" 13 | "time" 14 | 15 | "github.com/go-chi/httprate" 16 | ) 17 | 18 | func TestLimit(t *testing.T) { 19 | type test struct { 20 | name string 21 | requestsLimit int 22 | windowLength time.Duration 23 | respCodes []int 24 | } 25 | tests := []test{ 26 | { 27 | name: "no-block", 28 | requestsLimit: 3, 29 | windowLength: 4 * time.Second, 30 | respCodes: []int{200, 200, 200}, 31 | }, 32 | { 33 | name: "block", 34 | requestsLimit: 3, 35 | windowLength: 2 * time.Second, 36 | respCodes: []int{200, 200, 200, 429}, 37 | }, 38 | } 39 | for _, tt := range tests { 40 | t.Run(tt.name, func(t *testing.T) { 41 | h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 42 | router := httprate.LimitAll(tt.requestsLimit, tt.windowLength)(h) 43 | 44 | for i, code := range tt.respCodes { 45 | req := httptest.NewRequest("GET", "/", nil) 46 | recorder := httptest.NewRecorder() 47 | router.ServeHTTP(recorder, req) 48 | if respCode := recorder.Result().StatusCode; respCode != code { 49 | t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respCode, code) 50 | } 51 | } 52 | }) 53 | } 54 | } 55 | 56 | func TestWithIncrement(t *testing.T) { 57 | type test struct { 58 | name string 59 | increment int 60 | requestsLimit int 61 | respCodes []int 62 | } 63 | tests := []test{ 64 | { 65 | name: "no limit", 66 | increment: 0, 67 | requestsLimit: 3, 68 | respCodes: []int{200, 200, 200, 200}, 69 | }, 70 | { 71 | name: "increment 1", 72 | increment: 1, 73 | requestsLimit: 3, 74 | respCodes: []int{200, 200, 200, 429}, 75 | }, 76 | { 77 | name: "increment 2", 78 | increment: 2, 79 | requestsLimit: 3, 80 | respCodes: []int{200, 429, 429, 429}, 81 | }, 82 | { 83 | name: "increment 3", 84 | increment: 3, 85 | requestsLimit: 3, 86 | respCodes: []int{200, 429, 429, 429}, 87 | }, 88 | { 89 | name: "always block", 90 | increment: 4, 91 | requestsLimit: 3, 92 | respCodes: []int{429, 429, 429, 429}, 93 | }, 94 | } 95 | for _, tt := range tests { 96 | t.Run(tt.name, func(t *testing.T) { 97 | h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 98 | router := httprate.LimitAll(tt.requestsLimit, time.Minute)(h) 99 | 100 | for i, code := range tt.respCodes { 101 | req := httptest.NewRequest("GET", "/", nil) 102 | req = req.WithContext(httprate.WithIncrement(req.Context(), tt.increment)) 103 | recorder := httptest.NewRecorder() 104 | router.ServeHTTP(recorder, req) 105 | if respCode := recorder.Result().StatusCode; respCode != code { 106 | t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respCode, code) 107 | } 108 | } 109 | }) 110 | } 111 | } 112 | 113 | func TestResponseHeaders(t *testing.T) { 114 | type test struct { 115 | name string 116 | requestsLimit int 117 | increments []int 118 | respCodes []int 119 | respLimitHeader []string 120 | respRemainingHeader []string 121 | } 122 | tests := []test{ 123 | { 124 | name: "const increments", 125 | requestsLimit: 5, 126 | increments: []int{1, 1, 1, 1, 1, 1}, 127 | respCodes: []int{200, 200, 200, 200, 200, 429}, 128 | respLimitHeader: []string{"5", "5", "5", "5", "5", "5"}, 129 | respRemainingHeader: []string{"4", "3", "2", "1", "0", "0"}, 130 | }, 131 | { 132 | name: "varying increments", 133 | requestsLimit: 5, 134 | increments: []int{2, 2, 1, 2, 10, 1}, 135 | respCodes: []int{200, 200, 200, 429, 429, 429}, 136 | respLimitHeader: []string{"5", "5", "5", "5", "5", "5"}, 137 | respRemainingHeader: []string{"3", "1", "0", "0", "0", "0"}, 138 | }, 139 | { 140 | name: "no limit", 141 | requestsLimit: 5, 142 | increments: []int{0, 0, 0, 0, 0, 0}, 143 | respCodes: []int{200, 200, 200, 200, 200, 200}, 144 | respLimitHeader: []string{"5", "5", "5", "5", "5", "5"}, 145 | respRemainingHeader: []string{"5", "5", "5", "5", "5", "5"}, 146 | }, 147 | { 148 | name: "always block", 149 | requestsLimit: 5, 150 | increments: []int{10, 10, 10, 10, 10, 10}, 151 | respCodes: []int{429, 429, 429, 429, 429, 429}, 152 | respLimitHeader: []string{"5", "5", "5", "5", "5", "5"}, 153 | respRemainingHeader: []string{"5", "5", "5", "5", "5", "5"}, 154 | }, 155 | } 156 | for _, tt := range tests { 157 | t.Run(tt.name, func(t *testing.T) { 158 | count := len(tt.increments) 159 | if count != len(tt.respCodes) || count != len(tt.respLimitHeader) || count != len(tt.respRemainingHeader) { 160 | t.Fatalf("invalid test case: increments(%v), respCodes(%v), respLimitHeader(%v) and respRemainingHeaders(%v) must have same size", len(tt.increments), len(tt.respCodes), len(tt.respLimitHeader), len(tt.respRemainingHeader)) 161 | } 162 | 163 | h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 164 | router := httprate.LimitAll(tt.requestsLimit, time.Minute)(h) 165 | 166 | for i := 0; i < count; i++ { 167 | req := httptest.NewRequest("GET", "/", nil) 168 | req = req.WithContext(httprate.WithIncrement(req.Context(), tt.increments[i])) 169 | recorder := httptest.NewRecorder() 170 | router.ServeHTTP(recorder, req) 171 | 172 | if respCode := recorder.Result().StatusCode; respCode != tt.respCodes[i] { 173 | t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respCode, tt.respCodes[i]) 174 | } 175 | 176 | headers := recorder.Result().Header 177 | if limit := headers.Get("X-RateLimit-Limit"); limit != tt.respLimitHeader[i] { 178 | t.Errorf("X-RateLimit-Limit(%v) = %v, want %v", i, limit, tt.respLimitHeader[i]) 179 | } 180 | if remaining := headers.Get("X-RateLimit-Remaining"); remaining != tt.respRemainingHeader[i] { 181 | t.Errorf("X-RateLimit-Remaining(%v) = %v, want %v", i, remaining, tt.respRemainingHeader[i]) 182 | } 183 | 184 | reset := headers.Get("X-RateLimit-Reset") 185 | if resetUnixTime, err := strconv.ParseInt(reset, 10, 64); err != nil || resetUnixTime <= time.Now().Unix() { 186 | t.Errorf("X-RateLimit-Reset(%v) = %v, want unix timestamp in the future", i, reset) 187 | } 188 | } 189 | }) 190 | } 191 | } 192 | 193 | func TestCustomResponseHeaders(t *testing.T) { 194 | type test struct { 195 | name string 196 | headers httprate.ResponseHeaders 197 | } 198 | tests := []test{ 199 | { 200 | name: "no headers", 201 | headers: httprate.ResponseHeaders{ 202 | Limit: "", 203 | Remaining: "", 204 | Reset: "", 205 | RetryAfter: "", 206 | Increment: "", 207 | }, 208 | }, 209 | { 210 | name: "custom headers", 211 | headers: httprate.ResponseHeaders{ 212 | Limit: "RateLimit-Limit", 213 | Remaining: "RateLimit-Remaining", 214 | Reset: "RateLimit-Reset", 215 | RetryAfter: "RateLimit-Retry", 216 | Increment: "", 217 | }, 218 | }, 219 | } 220 | for _, tt := range tests { 221 | t.Run(tt.name, func(t *testing.T) { 222 | h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 223 | router := httprate.Limit( 224 | 1, 225 | time.Minute, 226 | httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { 227 | http.Error(w, "Wow Slow Down Kiddo", 429) 228 | }), 229 | httprate.WithResponseHeaders(tt.headers), 230 | )(h) 231 | 232 | req := httptest.NewRequest("GET", "/", nil) 233 | 234 | // Force Retry-After and X-RateLimit-Increment headers. 235 | req = req.WithContext(httprate.WithIncrement(req.Context(), 2)) 236 | 237 | recorder := httptest.NewRecorder() 238 | router.ServeHTTP(recorder, req) 239 | 240 | headers := recorder.Result().Header 241 | 242 | for _, header := range []string{ 243 | "X-RateLimit-Limit", 244 | "X-RateLimit-Remaining", 245 | "X-RateLimit-Increment", 246 | "X-RateLimit-Reset", 247 | "Retry-After", 248 | "", // ensure we don't set header with an empty key 249 | } { 250 | if len(headers.Values(header)) != 0 { 251 | t.Errorf("%q header not expected", header) 252 | } 253 | } 254 | 255 | for _, header := range []string{ 256 | tt.headers.Limit, 257 | tt.headers.Remaining, 258 | tt.headers.Increment, 259 | tt.headers.Reset, 260 | tt.headers.RetryAfter, 261 | } { 262 | if header == "" { 263 | continue 264 | } 265 | if h := headers.Get(header); h == "" { 266 | t.Errorf("%q header expected", header) 267 | } 268 | } 269 | }) 270 | } 271 | } 272 | 273 | func TestLimitHandler(t *testing.T) { 274 | type test struct { 275 | name string 276 | requestsLimit int 277 | windowLength time.Duration 278 | responses []struct { 279 | Body string 280 | StatusCode int 281 | } 282 | } 283 | tests := []test{ 284 | { 285 | name: "no-block", 286 | requestsLimit: 3, 287 | windowLength: 4 * time.Second, 288 | responses: []struct { 289 | Body string 290 | StatusCode int 291 | }{ 292 | {Body: "", StatusCode: 200}, 293 | {Body: "", StatusCode: 200}, 294 | {Body: "", StatusCode: 200}, 295 | }, 296 | }, 297 | { 298 | name: "block", 299 | requestsLimit: 3, 300 | windowLength: 2 * time.Second, 301 | responses: []struct { 302 | Body string 303 | StatusCode int 304 | }{ 305 | {Body: "", StatusCode: 200}, 306 | {Body: "", StatusCode: 200}, 307 | {Body: "", StatusCode: 200}, 308 | {Body: "Wow Slow Down Kiddo", StatusCode: 429}, 309 | }, 310 | }, 311 | } 312 | for i, tt := range tests { 313 | t.Run(tt.name, func(t *testing.T) { 314 | h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 315 | router := httprate.Limit( 316 | tt.requestsLimit, 317 | tt.windowLength, 318 | httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { 319 | http.Error(w, "Wow Slow Down Kiddo", 429) 320 | }), 321 | )(h) 322 | 323 | for _, expected := range tt.responses { 324 | req := httptest.NewRequest("GET", "/", nil) 325 | recorder := httptest.NewRecorder() 326 | router.ServeHTTP(recorder, req) 327 | result := recorder.Result() 328 | if respStatus := result.StatusCode; respStatus != expected.StatusCode { 329 | t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respStatus, expected.StatusCode) 330 | } 331 | buf := new(bytes.Buffer) 332 | buf.ReadFrom(result.Body) 333 | respBody := strings.TrimSuffix(buf.String(), "\n") 334 | 335 | if respBody != expected.Body { 336 | t.Errorf("resp.Body(%v) = %v, want %v", i, respBody, expected.Body) 337 | } 338 | } 339 | }) 340 | } 341 | } 342 | 343 | func TestLimitIP(t *testing.T) { 344 | type test struct { 345 | name string 346 | requestsLimit int 347 | windowLength time.Duration 348 | reqIp []string 349 | respCodes []int 350 | } 351 | tests := []test{ 352 | { 353 | name: "no-block", 354 | requestsLimit: 3, 355 | windowLength: 2 * time.Second, 356 | reqIp: []string{"1.1.1.1:100", "2.2.2.2:200"}, 357 | respCodes: []int{200, 200}, 358 | }, 359 | { 360 | name: "block-ip", 361 | requestsLimit: 1, 362 | windowLength: 2 * time.Second, 363 | reqIp: []string{"1.1.1.1:100", "1.1.1.1:100", "2.2.2.2:200"}, 364 | respCodes: []int{200, 429, 200}, 365 | }, 366 | { 367 | name: "block-ipv6", 368 | requestsLimit: 1, 369 | windowLength: 2 * time.Second, 370 | reqIp: []string{"2001:DB8::21f:5bff:febf:ce22:1111", "2001:DB8::21f:5bff:febf:ce22:2222", "2002:DB8::21f:5bff:febf:ce22:1111"}, 371 | respCodes: []int{200, 429, 200}, 372 | }, 373 | } 374 | for _, tt := range tests { 375 | t.Run(tt.name, func(t *testing.T) { 376 | h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 377 | router := httprate.LimitByIP(tt.requestsLimit, tt.windowLength)(h) 378 | 379 | for i, code := range tt.respCodes { 380 | req := httptest.NewRequest("GET", "/", nil) 381 | req.RemoteAddr = tt.reqIp[i] 382 | recorder := httptest.NewRecorder() 383 | router.ServeHTTP(recorder, req) 384 | if respCode := recorder.Result().StatusCode; respCode != code { 385 | t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respCode, code) 386 | } 387 | } 388 | }) 389 | } 390 | } 391 | 392 | func TestOverrideRequestLimit(t *testing.T) { 393 | h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 394 | router := httprate.Limit( 395 | 3, 396 | time.Minute, 397 | httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { 398 | http.Error(w, "Wow Slow Down Kiddo", 429) 399 | }), 400 | )(h) 401 | 402 | responses := []struct { 403 | StatusCode int 404 | Body string 405 | RequestLimit int // Default: 3 406 | }{ 407 | {StatusCode: 200, Body: ""}, 408 | {StatusCode: 429, Body: "Wow Slow Down Kiddo", RequestLimit: 1}, 409 | {StatusCode: 200, Body: ""}, 410 | {StatusCode: 200, Body: ""}, 411 | {StatusCode: 429, Body: "Wow Slow Down Kiddo"}, 412 | 413 | {StatusCode: 200, Body: "", RequestLimit: 5}, 414 | {StatusCode: 200, Body: "", RequestLimit: 5}, 415 | {StatusCode: 429, Body: "Wow Slow Down Kiddo", RequestLimit: 5}, 416 | } 417 | for i, response := range responses { 418 | ctx := context.Background() 419 | if response.RequestLimit > 0 { 420 | ctx = httprate.WithRequestLimit(ctx, response.RequestLimit) 421 | } 422 | req, err := http.NewRequestWithContext(ctx, "GET", "/", nil) 423 | if err != nil { 424 | t.Errorf("failed = %v", err) 425 | } 426 | 427 | recorder := httptest.NewRecorder() 428 | router.ServeHTTP(recorder, req) 429 | result := recorder.Result() 430 | if respStatus := result.StatusCode; respStatus != response.StatusCode { 431 | t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respStatus, response.StatusCode) 432 | } 433 | body, _ := io.ReadAll(result.Body) 434 | respBody := strings.TrimSuffix(string(body), "\n") 435 | 436 | if respBody != response.Body { 437 | t.Errorf("resp.Body(%v) = %q, want %q", i, respBody, response.Body) 438 | } 439 | } 440 | } 441 | 442 | func TestRateLimitPayload(t *testing.T) { 443 | loginRateLimiter := httprate.NewRateLimiter(5, time.Minute) 444 | 445 | h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 446 | var payload struct { 447 | Username string `json:"username"` 448 | Password string `json:"password"` 449 | } 450 | err := json.NewDecoder(r.Body).Decode(&payload) 451 | if err != nil || payload.Username == "" || payload.Password == "" { 452 | w.WriteHeader(400) 453 | return 454 | } 455 | 456 | // Rate-limit login at 5 req/min. 457 | if loginRateLimiter.RespondOnLimit(w, r, payload.Username) { 458 | return 459 | } 460 | 461 | w.Write([]byte("login at 5 req/min\n")) 462 | }) 463 | 464 | responses := []struct { 465 | StatusCode int 466 | Body string 467 | }{ 468 | {StatusCode: 200, Body: "login at 5 req/min"}, 469 | {StatusCode: 200, Body: "login at 5 req/min"}, 470 | {StatusCode: 200, Body: "login at 5 req/min"}, 471 | {StatusCode: 200, Body: "login at 5 req/min"}, 472 | {StatusCode: 200, Body: "login at 5 req/min"}, 473 | {StatusCode: 429, Body: "Too Many Requests"}, 474 | {StatusCode: 429, Body: "Too Many Requests"}, 475 | {StatusCode: 429, Body: "Too Many Requests"}, 476 | } 477 | for i, response := range responses { 478 | req, err := http.NewRequest("GET", "/", strings.NewReader(`{"username":"alice","password":"***"}`)) 479 | if err != nil { 480 | t.Errorf("failed = %v", err) 481 | } 482 | 483 | recorder := httptest.NewRecorder() 484 | h.ServeHTTP(recorder, req) 485 | result := recorder.Result() 486 | if respStatus := result.StatusCode; respStatus != response.StatusCode { 487 | t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respStatus, response.StatusCode) 488 | } 489 | body, _ := io.ReadAll(result.Body) 490 | respBody := strings.TrimSuffix(string(body), "\n") 491 | 492 | if string(respBody) != response.Body { 493 | t.Errorf("resp.Body(%v) = %q, want %q", i, respBody, response.Body) 494 | } 495 | } 496 | } 497 | --------------------------------------------------------------------------------