├── singleflight.go ├── go.mod ├── _example ├── go.mod ├── simple │ └── main.go ├── status-ttl │ └── main.go ├── go.sum └── with-key │ └── main.go ├── Makefile ├── .github └── workflows │ └── ci.yml ├── LICENSE ├── go.sum ├── README.md ├── options.go ├── stampede.go ├── stampede_test.go ├── http.go └── http_test.go /singleflight.go: -------------------------------------------------------------------------------- 1 | package stampede 2 | 3 | import ( 4 | "log/slog" 5 | "net/http" 6 | ) 7 | 8 | func Singleflight(logger *slog.Logger, varyRequestHeaders []string) func(next http.Handler) http.Handler { 9 | handler := Handler(logger, nil, 0, WithSkipCache(true), WithHTTPCacheKeyRequestHeaders(varyRequestHeaders)) 10 | 11 | return func(next http.Handler) http.Handler { 12 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 13 | handler(next).ServeHTTP(w, r) 14 | }) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-chi/stampede 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.2 6 | 7 | require ( 8 | github.com/go-chi/cors v1.2.1 9 | github.com/goware/cachestore2 v0.12.2 10 | github.com/goware/singleflight v0.3.0 11 | github.com/stretchr/testify v1.10.0 12 | github.com/zeebo/xxh3 v1.0.2 13 | ) 14 | 15 | require ( 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/klauspost/cpuid/v2 v2.2.10 // indirect 18 | github.com/pmezard/go-difflib v1.0.0 // indirect 19 | golang.org/x/sys v0.32.0 // indirect 20 | gopkg.in/yaml.v3 v3.0.1 // indirect 21 | ) 22 | -------------------------------------------------------------------------------- /_example/go.mod: -------------------------------------------------------------------------------- 1 | module example 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.2 6 | 7 | replace github.com/go-chi/stampede => ../ 8 | 9 | require ( 10 | github.com/go-chi/chi/v5 v5.0.11 11 | github.com/go-chi/stampede v0.5.1 12 | github.com/goware/cachestore-mem v0.2.1 13 | github.com/goware/cachestore2 v0.12.2 14 | ) 15 | 16 | require ( 17 | github.com/elastic/go-freelru v0.16.0 // indirect 18 | github.com/goware/singleflight v0.3.0 // indirect 19 | github.com/klauspost/cpuid/v2 v2.2.10 // indirect 20 | github.com/zeebo/xxh3 v1.0.2 // indirect 21 | golang.org/x/sys v0.32.0 // indirect 22 | ) 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL = bash -o pipefail 2 | TEST_FLAGS ?= -v -race 3 | 4 | all: 5 | @echo "make " 6 | @echo "" 7 | @echo "commands:" 8 | @echo "" 9 | @echo " + Development:" 10 | @echo " - build" 11 | @echo " - test" 12 | @echo " - todo" 13 | @echo " - clean" 14 | @echo "" 15 | @echo "" 16 | 17 | 18 | ## 19 | ## Development 20 | ## 21 | build: 22 | go build ./... 23 | 24 | clean: 25 | go clean -cache -testcache 26 | 27 | test: test-clean 28 | GOGC=off go test $(TEST_FLAGS) -run=$(TEST) ./... 29 | 30 | test-clean: 31 | GOGC=off go clean -testcache 32 | 33 | bench: 34 | @go test -timeout=25m -bench=. 35 | 36 | todo: 37 | @git grep TODO -- './*' ':!./vendor/' ':!./Makefile' || : 38 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | name: Test 3 | jobs: 4 | test: 5 | env: 6 | GOPATH: ${{ github.workspace }} 7 | GO111MODULE: on 8 | 9 | defaults: 10 | run: 11 | working-directory: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} 12 | 13 | strategy: 14 | matrix: 15 | go-version: [1.23.x,1.24.x] 16 | os: [ubuntu-latest, macos-latest] 17 | 18 | runs-on: ${{ matrix.os }} 19 | 20 | steps: 21 | - name: Checkout code 22 | uses: actions/checkout@v4 23 | with: 24 | path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} 25 | 26 | - name: Install Go 27 | uses: actions/setup-go@v4 28 | with: 29 | go-version: ${{ matrix.go-version }} 30 | 31 | - name: Build 32 | run: | 33 | go build ./... 34 | 35 | - name: Test 36 | run: | 37 | go test -v -race -run= ./... 38 | -------------------------------------------------------------------------------- /_example/simple/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log/slog" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/go-chi/chi/v5" 9 | "github.com/go-chi/chi/v5/middleware" 10 | "github.com/go-chi/stampede" 11 | memcache "github.com/goware/cachestore-mem" 12 | ) 13 | 14 | func main() { 15 | r := chi.NewRouter() 16 | r.Use(middleware.Logger) 17 | r.Use(middleware.Recoverer) 18 | 19 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 20 | w.Write([]byte("index")) 21 | }) 22 | 23 | cache, err := memcache.NewBackend(1000) 24 | if err != nil { 25 | panic(err) 26 | } 27 | 28 | cacheMiddleware := stampede.Handler( 29 | slog.Default(), cache, 5*time.Second, 30 | stampede.WithHTTPCacheKeyRequestHeaders([]string{"AuthorizatioN"}), 31 | ) 32 | 33 | r.With(cacheMiddleware).Get("/cached", func(w http.ResponseWriter, r *http.Request) { 34 | // processing.. 35 | time.Sleep(1 * time.Second) 36 | 37 | w.WriteHeader(200) 38 | w.Write([]byte("...hi")) 39 | }) 40 | 41 | http.ListenAndServe(":3333", r) 42 | } 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015-Present https://github.com/go-chi authors 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /_example/status-ttl/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log/slog" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/go-chi/chi/v5" 9 | "github.com/go-chi/chi/v5/middleware" 10 | "github.com/go-chi/stampede" 11 | memcache "github.com/goware/cachestore-mem" 12 | ) 13 | 14 | func main() { 15 | r := chi.NewRouter() 16 | r.Use(middleware.Logger) 17 | r.Use(middleware.Recoverer) 18 | 19 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 20 | w.Write([]byte("index")) 21 | }) 22 | 23 | cache, err := memcache.NewBackend(1000) 24 | if err != nil { 25 | panic(err) 26 | } 27 | 28 | cacheMiddleware := stampede.Handler( 29 | slog.Default(), cache, 5*time.Second, 30 | stampede.WithHTTPStatusTTL(func(status int) time.Duration { 31 | if status == 200 { 32 | return 10 * time.Second 33 | } else if status == 404 { 34 | return 1 * time.Second 35 | } else if status >= 500 { 36 | return 0 // no cache 37 | } else { 38 | return 0 // no cache 39 | } 40 | }), 41 | ) 42 | 43 | r.With(cacheMiddleware).Get("/cached", func(w http.ResponseWriter, r *http.Request) { 44 | // processing.. 45 | time.Sleep(1 * time.Second) 46 | 47 | if r.URL.Query().Get("error") == "true" { 48 | w.WriteHeader(500) 49 | w.Write([]byte("error")) 50 | return 51 | } 52 | 53 | if r.URL.Query().Get("notfound") == "true" { 54 | w.WriteHeader(404) 55 | w.Write([]byte("notfound")) 56 | return 57 | } 58 | 59 | w.WriteHeader(200) 60 | w.Write([]byte("...hi")) 61 | }) 62 | 63 | http.ListenAndServe(":3333", r) 64 | } 65 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= 4 | github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= 5 | github.com/goware/cachestore2 v0.12.2 h1:04YGXkMwbH1xe82siCO7iaPhetntRABN5fWhBKEzduY= 6 | github.com/goware/cachestore2 v0.12.2/go.mod h1:PR+lXK8UXa/wjKB7mpIj6HtRhC7vbcRXx4b5F1Av/ik= 7 | github.com/goware/singleflight v0.3.0 h1:b+OM844fuHzanOlE84WeI+G8YMksUY636v0bdcAfnHE= 8 | github.com/goware/singleflight v0.3.0/go.mod h1:vcmu9KY0BS9WbA3Pn+WOdUQlwT1CPZJm1Fgaz2l88Dc= 9 | github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= 10 | github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= 11 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 12 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 13 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 14 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 15 | github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= 16 | github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= 17 | github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= 18 | github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= 19 | golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= 20 | golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 21 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 22 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 23 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 24 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 25 | -------------------------------------------------------------------------------- /_example/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/elastic/go-freelru v0.16.0 h1:gG2HJ1WXN2tNl5/p40JS/l59HjvjRhjyAa+oFTRArYs= 4 | github.com/elastic/go-freelru v0.16.0/go.mod h1:bSdWT4M0lW79K8QbX6XY2heQYSCqD7THoYf82pT/H3I= 5 | github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA= 6 | github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= 7 | github.com/goware/cachestore-mem v0.2.1 h1:8ZIFtzpoFlwnPUKuGeazhuV2qzR4Bk7UslEGyXRZp9E= 8 | github.com/goware/cachestore-mem v0.2.1/go.mod h1:0WU95kEa8kmuYSsqOC/fXg/cGVqj5rsTzjUpQgaJHmw= 9 | github.com/goware/cachestore2 v0.12.2 h1:04YGXkMwbH1xe82siCO7iaPhetntRABN5fWhBKEzduY= 10 | github.com/goware/cachestore2 v0.12.2/go.mod h1:PR+lXK8UXa/wjKB7mpIj6HtRhC7vbcRXx4b5F1Av/ik= 11 | github.com/goware/singleflight v0.3.0 h1:b+OM844fuHzanOlE84WeI+G8YMksUY636v0bdcAfnHE= 12 | github.com/goware/singleflight v0.3.0/go.mod h1:vcmu9KY0BS9WbA3Pn+WOdUQlwT1CPZJm1Fgaz2l88Dc= 13 | github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= 14 | github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 16 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 17 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 18 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 19 | github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= 20 | github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= 21 | github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= 22 | github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= 23 | golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= 24 | golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 25 | golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= 26 | golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 27 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 28 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stampede 2 | 3 | ![](https://github.com/go-chi/stampede/workflows/build/badge.svg?branch=master) 4 | 5 | Prevents cache stampede https://en.wikipedia.org/wiki/Cache_stampede by only running a 6 | single data fetch operation per expired / missing key regardless of number of requests to that key. 7 | 8 | 9 | ## Example: HTTP Middleware 10 | 11 | ```go 12 | import ( 13 | "log/slog" 14 | "net/http" 15 | "time" 16 | 17 | "github.com/go-chi/chi/v5" 18 | "github.com/go-chi/chi/v5/middleware" 19 | "github.com/go-chi/stampede" 20 | memcache "github.com/goware/cachestore-mem" 21 | ) 22 | 23 | func main() { 24 | r := chi.NewRouter() 25 | r.Use(middleware.Logger) 26 | r.Use(middleware.Recoverer) 27 | 28 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 29 | w.Write([]byte("index")) 30 | }) 31 | 32 | cache, err := memcache.NewBackend(1000) 33 | if err != nil { 34 | panic(err) 35 | } 36 | 37 | cacheMiddleware := stampede.Handler( 38 | slog.Default(), cache, 5*time.Second, 39 | stampede.WithHTTPCacheKeyRequestHeaders([]string{"AuthorizatioN"}), 40 | ) 41 | 42 | r.With(cacheMiddleware).Get("/cached", func(w http.ResponseWriter, r *http.Request) { 43 | // processing.. 44 | time.Sleep(1 * time.Second) 45 | 46 | w.WriteHeader(200) 47 | w.Write([]byte("...hi")) 48 | }) 49 | 50 | http.ListenAndServe(":3333", r) 51 | } 52 | ``` 53 | 54 | 55 | ## Notes 56 | 57 | * Requests passed through the stampede handler will be batched into a single request 58 | when there are parallel requests for the same endpoint/resource. This is also known 59 | as request coalescing. 60 | * Parallel requests for the same endpoint / resource, will be just a single handler call 61 | and the remaining requests will receive the response of the first request handler. 62 | * The response payload for the endpoint / resource will then be cached for up to `ttl` 63 | time duration for subequence requests, which offers further caching. You may also 64 | use a `ttl` value of 0 if you want the response to be as fresh as possible, and still 65 | prevent a stampede scenario on your handler. 66 | * *Security note:* response headers will be the same for all requests, so make sure 67 | to not include anything sensitive or user specific. In the case you require user-specific 68 | stampede handlers, make sure you pass a custom `keyFunc` to the `stampede.Handler` and 69 | split the cache by an account's id. NOTE: we do avoid caching response headers 70 | for CORS, set-cookie and x-ratelimit. 71 | 72 | See [example](_example/with_key.go) for a variety of examples. 73 | 74 | 75 | ## LICENSE 76 | 77 | MIT 78 | -------------------------------------------------------------------------------- /_example/with-key/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log/slog" 5 | "net/http" 6 | "strings" 7 | "time" 8 | 9 | "github.com/go-chi/chi/v5" 10 | "github.com/go-chi/chi/v5/middleware" 11 | "github.com/go-chi/stampede" 12 | memcache "github.com/goware/cachestore-mem" 13 | ) 14 | 15 | // Example 1: Make two parallel requests: 16 | // First request in first client: 17 | // GET http://localhost:3333/me 18 | // Authorization: Bar 19 | // 20 | // Second request in second client: 21 | // GET http://localhost:3333/me 22 | // Authorization: Bar 23 | // 24 | // -> Result of both queries in one time: 25 | // HTTP/1.1 200 OK 26 | // Content-Length: 14 27 | // Content-Type: text/plain; charset=utf-8 28 | // 29 | // Bearer BarTone 30 | // 31 | // Response code: 200 (OK); Time: 1ms; Content length: 14 bytes 32 | // 33 | // --------------------------------------------------------------- 34 | // 35 | // Example 2: Make two parallel requests: 36 | // First request in first client: 37 | // GET http://localhost:3333/me 38 | // Authorization: Bar 39 | // 40 | // Second request in second client: 41 | // GET http://localhost:3333/me 42 | // Authorization: Foo 43 | // 44 | // -> Result of first: 45 | // HTTP/1.1 200 OK 46 | // Content-Length: 14 47 | // Content-Type: text/plain; charset=utf-8 48 | // 49 | // Bearer Bar 50 | // 51 | // Response code: 200 (OK); Time: 1ms; Content length: 14 bytes 52 | // 53 | // -> Result of second: 54 | // HTTP/1.1 200 OK 55 | // Content-Length: 14 56 | // Content-Type: text/plain; charset=utf-8 57 | // 58 | // Bearer Foo 59 | // 60 | // Response code: 200 (OK); Time: 1ms; Content length: 14 bytes 61 | 62 | func main() { 63 | r := chi.NewRouter() 64 | r.Use(middleware.Logger) 65 | r.Use(middleware.Recoverer) 66 | 67 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 68 | w.Write([]byte("index")) 69 | }) 70 | 71 | cache, err := memcache.NewBackend(1000) 72 | if err != nil { 73 | panic(err) 74 | } 75 | 76 | // Include anything user specific, e.g. Authorization Token 77 | customCacheKeyFunc := func(r *http.Request) (uint64, error) { 78 | token := r.Header.Get("Authorization") 79 | return stampede.StringToHash(r.Method, strings.ToLower(strings.ToLower(token))), nil 80 | } 81 | 82 | cacheMiddleware := stampede.HandlerWithKey( 83 | slog.Default(), cache, 5*time.Second, customCacheKeyFunc, 84 | stampede.WithHTTPCacheKeyRequestBody(false), 85 | ) 86 | 87 | r.With(cacheMiddleware).Get("/me", func(w http.ResponseWriter, r *http.Request) { 88 | // processing.. 89 | time.Sleep(3 * time.Second) 90 | 91 | w.WriteHeader(200) 92 | w.Write([]byte(r.Header.Get("Authorization"))) 93 | }) 94 | 95 | http.ListenAndServe(":3333", r) 96 | } 97 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package stampede 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type Options struct { 8 | // TTL is the time-to-live for the cache. NOTE: if this is not set, 9 | // then we use the package-level default of 1 minute. You can override 10 | // by passing `WithTTL(time.Second * 10)` to stampede.Do(), or passing 11 | // `WithSkipCache(true)` to skip the cache entirely. 12 | // 13 | // Default: 1 minute 14 | TTL time.Duration 15 | 16 | // SkipCache is a flag that determines whether the cache should be skipped. 17 | // If true, the cache will not be used, but the request will still use 18 | // singleflight request coalescing. 19 | // 20 | // Default: false 21 | SkipCache bool 22 | 23 | // HTTPCacheKeyRequestBody is a flag that determines whether the request body 24 | // should be used to generate the cache key. This is useful for varying the cache 25 | // key based on request headers. 26 | // 27 | // Default: true 28 | HTTPCacheKeyRequestBody bool 29 | 30 | // HTTPCacheKeyRequestHeaders is a list of headers that will be used to generate 31 | // the cache key. This ensures we use the request body contents so we can properly 32 | // cache different requests that have the same URL but different query params or 33 | // body content. 34 | // 35 | // Default: [] 36 | HTTPCacheKeyRequestHeaders []string 37 | 38 | // HTTPStatusTTL is a function that returns the time-to-live for a given HTTP 39 | // status code. This allows you to customize the TTL for different HTTP status codes. 40 | // 41 | // Default: nil 42 | HTTPStatusTTL func(status int) time.Duration 43 | } 44 | 45 | // WithTTL sets the TTL for the cache. 46 | // 47 | // Default: 1 minute 48 | func WithTTL(ttl time.Duration) Option { 49 | return func(o *Options) { 50 | o.TTL = ttl 51 | } 52 | } 53 | 54 | // WithSkipCache sets the SkipCache flag. If true, the cache will not be used, 55 | // but the request will still use singleflight request coalescing. 56 | // 57 | // Default: false 58 | func WithSkipCache(skip bool) Option { 59 | return func(o *Options) { 60 | o.SkipCache = skip 61 | } 62 | } 63 | 64 | // WithHTTPStatusTTL sets the HTTPStatusTTL function. This allows you to 65 | // customize the TTL for different HTTP status codes. 66 | // 67 | // Default: nil 68 | func WithHTTPStatusTTL(fn func(status int) time.Duration) Option { 69 | return func(o *Options) { 70 | o.HTTPStatusTTL = fn 71 | } 72 | } 73 | 74 | // WithHTTPCacheKeyRequestBody sets the HTTPCacheKeyRequestBody flag. This 75 | // ensures we use the request body contents so we can properly cache different 76 | // requests that have the same URL but different query params or body content. 77 | // 78 | // Default: true 79 | func WithHTTPCacheKeyRequestBody(b bool) Option { 80 | return func(o *Options) { 81 | o.HTTPCacheKeyRequestBody = b 82 | } 83 | } 84 | 85 | // WithHTTPCacheKeyRequestHeaders sets the HTTPCacheKeyRequestHeaders list. 86 | // This is useful for varying the cachekey based on request headers. 87 | // 88 | // Default: [] 89 | func WithHTTPCacheKeyRequestHeaders(headers []string) Option { 90 | return func(o *Options) { 91 | o.HTTPCacheKeyRequestHeaders = headers 92 | } 93 | } 94 | 95 | type Option func(*Options) 96 | 97 | // getOptions returns a new Options with the given ttl and options, 98 | // and also applies default values for any options that are not set. 99 | func getOptions(ttl time.Duration, options ...Option) *Options { 100 | if ttl == 0 { 101 | ttl = DefaultCacheTTL 102 | } 103 | opts := &Options{ 104 | TTL: ttl, 105 | SkipCache: false, 106 | HTTPStatusTTL: nil, 107 | HTTPCacheKeyRequestHeaders: nil, 108 | HTTPCacheKeyRequestBody: true, 109 | } 110 | for _, o := range options { 111 | o(opts) 112 | } 113 | return opts 114 | } 115 | -------------------------------------------------------------------------------- /stampede.go: -------------------------------------------------------------------------------- 1 | package stampede 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log/slog" 7 | "sync" 8 | "time" 9 | 10 | cachestore "github.com/goware/cachestore2" 11 | "github.com/goware/singleflight" 12 | "github.com/zeebo/xxh3" 13 | ) 14 | 15 | const ( 16 | // DefaultCacheTTL is the default TTL for cache entries. However, 17 | // you can pass WithTTL(d) to set your own ttl, or pass 18 | // WithSkipCache() to disable caching 19 | DefaultCacheTTL = 1 * time.Minute 20 | ) 21 | 22 | func NewStampede[V any](logger *slog.Logger, cache cachestore.Store[V], options ...Option) *stampede[V] { 23 | opts := &Options{} 24 | for _, o := range options { 25 | o(opts) 26 | } 27 | 28 | return &stampede[V]{ 29 | logger: logger, 30 | cache: cache, 31 | callGroup: singleflight.Group[string, doResult[V]]{}, 32 | options: opts, 33 | } 34 | } 35 | 36 | type stampede[V any] struct { 37 | logger *slog.Logger 38 | cache cachestore.Store[V] 39 | callGroup singleflight.Group[string, doResult[V]] 40 | options *Options 41 | mu sync.RWMutex 42 | } 43 | 44 | type doResult[V any] struct { 45 | Value V 46 | TTL *time.Duration 47 | } 48 | 49 | func (s *stampede[V]) Do(ctx context.Context, key string, fn func() (V, *time.Duration, error), options ...Option) (V, error) { 50 | var opts *Options 51 | if len(options) > 0 { 52 | opts = getOptions(0, options...) 53 | } else { 54 | opts = s.options 55 | } 56 | 57 | key = fmt.Sprintf("stampede:%s", key) 58 | 59 | if opts.SkipCache || s.cache == nil { 60 | // Singleflight mode only 61 | result, err, _ := s.callGroup.Do(key, func() (doResult[V], error) { 62 | v, ttl, err := fn() 63 | if err != nil { 64 | return doResult[V]{Value: v, TTL: ttl}, err 65 | } 66 | return doResult[V]{Value: v, TTL: ttl}, nil 67 | }) 68 | return result.Value, err 69 | 70 | } else { 71 | // Caching + Singleflight combo mode 72 | s.mu.RLock() 73 | v, ok, err := s.cache.Get(ctx, key) 74 | if err != nil { 75 | s.mu.RUnlock() 76 | return v, err 77 | } 78 | s.mu.RUnlock() 79 | if ok { 80 | // cache hit 81 | return v, nil 82 | } 83 | 84 | result, err, _ := s.callGroup.Do(key, func() (doResult[V], error) { 85 | v, ttl, err := fn() 86 | if err != nil { 87 | return doResult[V]{Value: v, TTL: ttl}, err 88 | } 89 | return doResult[V]{Value: v, TTL: ttl}, nil 90 | }) 91 | 92 | if err != nil { 93 | return result.Value, err 94 | } 95 | 96 | var ttl time.Duration 97 | if result.TTL != nil { 98 | ttl = *result.TTL 99 | } else { 100 | ttl = opts.TTL 101 | } 102 | 103 | // if ttl is 0, don't cache the result 104 | if ttl == 0 { 105 | return result.Value, nil 106 | } 107 | 108 | // cache the result 109 | s.mu.Lock() 110 | err = s.cache.SetEx(ctx, key, result.Value, ttl) 111 | if err != nil { 112 | s.mu.Unlock() 113 | // We log the error here and return the result.Value 114 | s.logger.Error("stampede: fail to set cache value", "err", err) 115 | return result.Value, nil 116 | } 117 | s.mu.Unlock() 118 | return result.Value, nil 119 | } 120 | } 121 | 122 | func (s *stampede[V]) SetOptions(options *Options) { 123 | s.mu.Lock() 124 | defer s.mu.Unlock() 125 | s.options = options 126 | } 127 | 128 | func BytesToHash(b ...[]byte) uint64 { 129 | d := xxh3.New() 130 | if len(b) == 0 { 131 | return 0 132 | } 133 | if len(b) == 1 { 134 | d.Write(b[0]) 135 | } else { 136 | for _, v := range b { 137 | d.Write(v) 138 | } 139 | } 140 | return d.Sum64() 141 | } 142 | 143 | func StringToHash(s ...string) uint64 { 144 | d := xxh3.New() 145 | if len(s) == 0 { 146 | return 0 147 | } 148 | if len(s) == 1 { 149 | d.WriteString(s[0]) 150 | } else { 151 | for _, v := range s { 152 | d.WriteString(v) 153 | } 154 | } 155 | return d.Sum64() 156 | } 157 | -------------------------------------------------------------------------------- /stampede_test.go: -------------------------------------------------------------------------------- 1 | package stampede_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log/slog" 7 | "runtime" 8 | "strings" 9 | "sync" 10 | "sync/atomic" 11 | "testing" 12 | "time" 13 | 14 | "github.com/go-chi/stampede" 15 | cachestore "github.com/goware/cachestore2" 16 | "github.com/stretchr/testify/assert" 17 | "github.com/stretchr/testify/require" 18 | ) 19 | 20 | func TestSingleflightDo(t *testing.T) { 21 | s := stampede.NewStampede[int](slog.Default(), nil) 22 | 23 | var numCalls atomic.Int64 24 | 25 | var wg sync.WaitGroup 26 | wg.Add(1) 27 | go func() { 28 | defer wg.Done() 29 | 30 | v, err := s.Do(context.Background(), "t1", func() (int, *time.Duration, error) { 31 | numCalls.Add(1) 32 | time.Sleep(1 * time.Second) 33 | return 1, nil, nil 34 | }, stampede.WithTTL(1*time.Second)) 35 | assert.NoError(t, err) 36 | assert.Equal(t, 1, v) 37 | }() 38 | 39 | // slight delay, to ensure first call is in flight 40 | time.Sleep(100 * time.Millisecond) 41 | 42 | for i := 0; i < 10; i++ { 43 | wg.Add(1) 44 | go func() { 45 | defer wg.Done() 46 | 47 | v, err := s.Do(context.Background(), "t1", func() (int, *time.Duration, error) { 48 | numCalls.Add(1) 49 | return i, nil, nil 50 | }) 51 | assert.NoError(t, err) 52 | assert.Equal(t, 1, v) 53 | }() 54 | } 55 | 56 | wg.Wait() 57 | 58 | require.Equal(t, int64(1), numCalls.Load()) 59 | } 60 | 61 | func TestCachedDo(t *testing.T) { 62 | var count uint64 63 | stampede := stampede.NewStampede(slog.Default(), newMockCacheBackend(), stampede.WithTTL(5*time.Second)) 64 | 65 | // repeat test multiple times 66 | for x := 0; x < 5; x++ { 67 | // time.Sleep(1 * time.Second) 68 | 69 | var wg sync.WaitGroup 70 | n := 10 71 | ctx := context.Background() 72 | 73 | for i := 0; i < n; i++ { 74 | t.Logf("numGoroutines now %d", runtime.NumGoroutine()) 75 | 76 | wg.Add(1) 77 | go func() { 78 | defer wg.Done() 79 | 80 | val, err := stampede.Do(ctx, "t1", func() (any, *time.Duration, error) { 81 | t.Log("cache.Get(t1, ...)") 82 | 83 | // some extensive op.. 84 | time.Sleep(2 * time.Second) 85 | atomic.AddUint64(&count, 1) 86 | 87 | return "result1", nil, nil 88 | }) 89 | 90 | assert.NoError(t, err) 91 | assert.Equal(t, "result1", val) 92 | }() 93 | } 94 | 95 | wg.Wait() 96 | 97 | // ensure single call 98 | assert.Equal(t, uint64(1), count) 99 | 100 | // confirm same before/after num of goroutines 101 | t.Logf("numGoroutines now %d", runtime.NumGoroutine()) 102 | } 103 | } 104 | 105 | func newMockCacheBackend() cachestore.Backend { 106 | return &mockCacheBackend[any]{ 107 | cache: make(map[string]any), 108 | expiry: make(map[string]int64), 109 | } 110 | } 111 | 112 | type mockCacheBackend[V any] struct { 113 | cache map[string]V 114 | expiry map[string]int64 115 | } 116 | 117 | var _ cachestore.Backend = &mockCacheBackend[any]{} 118 | 119 | func (m *mockCacheBackend[V]) Name() string { 120 | return "mockCacheBackend" 121 | } 122 | 123 | func (m *mockCacheBackend[V]) Options() cachestore.StoreOptions { 124 | return cachestore.StoreOptions{} 125 | } 126 | 127 | func (m *mockCacheBackend[V]) Exists(ctx context.Context, key string) (bool, error) { 128 | _, ok := m.cache[key] 129 | return ok, nil 130 | } 131 | 132 | func (m *mockCacheBackend[V]) Set(ctx context.Context, key string, value V) error { 133 | m.cache[key] = value 134 | return nil 135 | } 136 | 137 | func (m *mockCacheBackend[V]) SetEx(ctx context.Context, key string, value V, ttl time.Duration) error { 138 | m.cache[key] = value 139 | m.expiry[key] = time.Now().Unix() + int64(ttl.Seconds()) 140 | return nil 141 | } 142 | 143 | func (m *mockCacheBackend[V]) BatchSet(ctx context.Context, keys []string, values []V) error { 144 | for i, key := range keys { 145 | m.cache[key] = values[i] 146 | } 147 | return nil 148 | } 149 | 150 | func (m *mockCacheBackend[V]) BatchSetEx(ctx context.Context, keys []string, values []V, ttl time.Duration) error { 151 | for i, key := range keys { 152 | m.cache[key] = values[i] 153 | m.expiry[key] = time.Now().Unix() + int64(ttl.Seconds()) 154 | } 155 | return nil 156 | } 157 | 158 | func (m *mockCacheBackend[V]) Get(ctx context.Context, key string) (V, bool, error) { 159 | v, ok := m.cache[key] 160 | if ok { 161 | expiry, ok := m.expiry[key] 162 | if ok && expiry < time.Now().Unix() { 163 | delete(m.cache, key) 164 | delete(m.expiry, key) 165 | var v V 166 | return v, false, nil 167 | } 168 | } 169 | return v, ok, nil 170 | } 171 | 172 | func (m *mockCacheBackend[V]) BatchGet(ctx context.Context, keys []string) ([]V, []bool, error) { 173 | values := make([]V, len(keys)) 174 | exists := make([]bool, len(keys)) 175 | var err error 176 | for i, key := range keys { 177 | values[i], exists[i], err = m.Get(ctx, key) 178 | if err != nil { 179 | return nil, nil, err 180 | } 181 | if exists[i] { 182 | expiry, ok := m.expiry[key] 183 | if ok && expiry < time.Now().Unix() { 184 | exists[i] = false 185 | var v V 186 | values[i] = v 187 | } 188 | } 189 | } 190 | return values, exists, nil 191 | } 192 | 193 | func (m *mockCacheBackend[V]) Delete(ctx context.Context, key string) error { 194 | delete(m.cache, key) 195 | delete(m.expiry, key) 196 | return nil 197 | } 198 | 199 | func (m *mockCacheBackend[V]) DeletePrefix(ctx context.Context, keyPrefix string) error { 200 | for key := range m.cache { 201 | if strings.HasPrefix(key, keyPrefix) { 202 | delete(m.cache, key) 203 | delete(m.expiry, key) 204 | } 205 | } 206 | return nil 207 | } 208 | 209 | func (m *mockCacheBackend[V]) ClearAll(ctx context.Context) error { 210 | m.cache = make(map[string]V) 211 | m.expiry = make(map[string]int64) 212 | return nil 213 | } 214 | 215 | func (m *mockCacheBackend[V]) GetOrSetWithLock(ctx context.Context, key string, getter func(context.Context, string) (V, error)) (V, error) { 216 | var v V 217 | return v, fmt.Errorf("not implemented") 218 | } 219 | 220 | func (m *mockCacheBackend[V]) GetOrSetWithLockEx(ctx context.Context, key string, getter func(context.Context, string) (V, error), ttl time.Duration) (V, error) { 221 | var v V 222 | return v, fmt.Errorf("not implemented") 223 | } 224 | -------------------------------------------------------------------------------- /http.go: -------------------------------------------------------------------------------- 1 | package stampede 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "log/slog" 9 | "net/http" 10 | "strings" 11 | "time" 12 | 13 | cachestore "github.com/goware/cachestore2" 14 | ) 15 | 16 | func Handler(logger *slog.Logger, cacheBackend cachestore.Backend, ttl time.Duration, options ...Option) func(next http.Handler) http.Handler { 17 | return HandlerWithKey(logger, cacheBackend, ttl, nil, options...) 18 | } 19 | 20 | func HandlerWithKey(logger *slog.Logger, cacheBackend cachestore.Backend, ttl time.Duration, cacheKeyFunc CacheKeyFunc, options ...Option) func(next http.Handler) http.Handler { 21 | opts := getOptions(ttl, options...) 22 | 23 | // Combine various cache key functions into a single cache key value. 24 | cacheKeyWithRequestHeaders := cacheKeyWithRequestHeaders(opts.HTTPCacheKeyRequestHeaders) 25 | 26 | comboCacheKeyFunc := func(r *http.Request) (uint64, error) { 27 | var cacheKey1, cacheKey2, cacheKey3, cacheKey4 uint64 28 | var err error 29 | cacheKey1, err = cacheKeyWithRequestURL(r) 30 | if err != nil { 31 | return 0, err 32 | } 33 | if opts.HTTPCacheKeyRequestBody { 34 | cacheKey2, err = cacheKeyWithRequestBody(r) 35 | if err != nil { 36 | return 0, err 37 | } 38 | } 39 | if len(opts.HTTPCacheKeyRequestHeaders) > 0 { 40 | cacheKey3, err = cacheKeyWithRequestHeaders(r) 41 | if err != nil { 42 | return 0, err 43 | } 44 | } 45 | if cacheKeyFunc != nil { 46 | cacheKey4, err = cacheKeyFunc(r) 47 | if err != nil { 48 | return 0, err 49 | } 50 | } 51 | return cacheKey1 + cacheKey2 + cacheKey3 + cacheKey4, nil 52 | } 53 | 54 | var cache cachestore.Store[responseValue] 55 | if cacheBackend != nil { 56 | cache = cachestore.OpenStore[responseValue](cacheBackend) 57 | } 58 | h := stampedeHandler(logger, cache, comboCacheKeyFunc, opts) 59 | 60 | return func(next http.Handler) http.Handler { 61 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 62 | h(next).ServeHTTP(w, r) 63 | }) 64 | } 65 | } 66 | 67 | func cacheKeyWithRequestURL(r *http.Request) (uint64, error) { 68 | return StringToHash(strings.ToLower(r.URL.Path)), nil 69 | } 70 | 71 | func cacheKeyWithRequestBody(r *http.Request) (uint64, error) { 72 | // Skip request body caching for non-POST, PUT, PATCH requests. 73 | // If you'd like to cache these, you can use the `HandlerWithKey` 74 | // function which accepts a custom cache key function. 75 | if r.Method != "POST" && r.Method != "PUT" && r.Method != "PATCH" { 76 | return 0, nil 77 | } 78 | 79 | // Read the request payload, and then setup buffer for future reader 80 | var err error 81 | var buf []byte 82 | if r.Body != nil { 83 | buf, err = io.ReadAll(r.Body) 84 | if err != nil { 85 | return 0, err 86 | } 87 | r.Body = io.NopCloser(bytes.NewBuffer(buf)) 88 | } 89 | 90 | // Prepare cache key based on the request data payload. 91 | return BytesToHash(buf), nil 92 | } 93 | 94 | func cacheKeyWithRequestHeaders(headers []string) func(r *http.Request) (uint64, error) { 95 | return func(r *http.Request) (uint64, error) { 96 | if len(headers) == 0 { 97 | return 0, nil 98 | } 99 | var keys []string 100 | for _, header := range headers { 101 | v := r.Header.Get(header) 102 | if v == "" { 103 | continue 104 | } 105 | keys = append(keys, fmt.Sprintf("%s:%s", strings.ToLower(header), v)) 106 | } 107 | return StringToHash(keys...), nil 108 | } 109 | } 110 | 111 | type CacheKeyFunc func(r *http.Request) (uint64, error) 112 | 113 | func stampedeHandler(logger *slog.Logger, cache cachestore.Store[responseValue], cacheKeyFunc CacheKeyFunc, options *Options) func(next http.Handler) http.Handler { 114 | stampede := NewStampede(logger, cache) 115 | stampede.SetOptions(options) 116 | 117 | return func(next http.Handler) http.Handler { 118 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 119 | cacheKey, err := cacheKeyFunc(r) 120 | if err != nil { 121 | logger.Warn("stampede: fail to compute cache key", "err", err) 122 | next.ServeHTTP(w, r) 123 | return 124 | } 125 | 126 | firstRequest := false 127 | 128 | cachedVal, err := stampede.Do(context.Background(), fmt.Sprintf("http:%d", cacheKey), func() (responseValue, *time.Duration, error) { 129 | firstRequest = true 130 | buf := bytes.NewBuffer(nil) 131 | ww := &responseWriter{ResponseWriter: w, tee: buf} 132 | 133 | next.ServeHTTP(ww, r) 134 | 135 | val := responseValue{ 136 | Headers: ww.Header(), 137 | Status: ww.Status(), 138 | Body: buf.Bytes(), 139 | 140 | // the handler may not write header and body in some logic, 141 | // while writing only the body, an attempt is made to write the default header (http.StatusOK) 142 | Skip: !ww.IsValid(), 143 | } 144 | 145 | var ttl *time.Duration 146 | if options.HTTPStatusTTL != nil { 147 | t := options.HTTPStatusTTL(ww.Status()) 148 | ttl = &t 149 | } 150 | 151 | return val, ttl, nil 152 | }) 153 | 154 | if firstRequest { 155 | return 156 | } 157 | 158 | // handle response for subsequent requests 159 | if err != nil { 160 | logger.Error("stampede: fail to get value, serving standard request handler", "err", err) 161 | next.ServeHTTP(w, r) 162 | return 163 | } 164 | 165 | // if the handler did not write a header, then serve the next handler 166 | // a standard request handler 167 | if cachedVal.Skip { 168 | next.ServeHTTP(w, r) 169 | return 170 | } 171 | 172 | // copy headers from the first request to the response writer 173 | respHeader := w.Header() 174 | for k, v := range cachedVal.Headers { 175 | // Prevent certain headers to override the current 176 | // value of that header. This is important when you don't want a 177 | // header to affect all subsequent requests (for instance, when 178 | // working with several CORS domains, you don't want the first domain 179 | // to be recorded an to be printed in all responses). 180 | // Other examples include x-ratelimit or set-cookie. We therefore skip 181 | // returning any header with "x-ratelimit" prefix, "access-control-" prefix, or "set-cookie". 182 | // 183 | // TODO: we can move these options to the `Options` struct, with the below as defaults. 184 | headerKey := strings.ToLower(k) 185 | if strings.HasPrefix(headerKey, "x-ratelimit") || strings.HasPrefix(headerKey, "access-control-") || headerKey == "set-cookie" { 186 | continue 187 | } 188 | respHeader[k] = v 189 | } 190 | respHeader.Set("x-cache", "hit") 191 | 192 | w.WriteHeader(cachedVal.Status) 193 | w.Write(cachedVal.Body) 194 | }) 195 | } 196 | } 197 | 198 | type responseValue struct { 199 | Headers http.Header `json:"headers"` 200 | Status int `json:"status"` 201 | Body []byte `json:"body"` 202 | Skip bool `json:"skip"` 203 | } 204 | 205 | type responseWriter struct { 206 | http.ResponseWriter 207 | wroteHeader bool 208 | code int 209 | bytes int 210 | tee io.Writer 211 | } 212 | 213 | func (b *responseWriter) WriteHeader(code int) { 214 | if !b.wroteHeader { 215 | b.code = code 216 | b.wroteHeader = true 217 | b.ResponseWriter.WriteHeader(code) 218 | } 219 | } 220 | 221 | func (b *responseWriter) IsValid() bool { 222 | return b.wroteHeader && (b.code >= 100 && b.code < 999) 223 | } 224 | 225 | func (b *responseWriter) Write(buf []byte) (int, error) { 226 | b.maybeWriteHeader() 227 | n, err := b.ResponseWriter.Write(buf) 228 | if b.tee != nil { 229 | _, err2 := b.tee.Write(buf[:n]) 230 | if err == nil { 231 | err = err2 232 | } 233 | } 234 | b.bytes += n 235 | return n, err 236 | } 237 | 238 | func (b *responseWriter) maybeWriteHeader() { 239 | if !b.wroteHeader { 240 | b.WriteHeader(http.StatusOK) 241 | } 242 | } 243 | 244 | func (b *responseWriter) Status() int { 245 | return b.code 246 | } 247 | 248 | func (b *responseWriter) BytesWritten() int { 249 | return b.bytes 250 | } 251 | -------------------------------------------------------------------------------- /http_test.go: -------------------------------------------------------------------------------- 1 | package stampede_test 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "log/slog" 7 | "net/http" 8 | "net/http/httptest" 9 | "sync" 10 | "sync/atomic" 11 | "testing" 12 | "time" 13 | 14 | "github.com/go-chi/cors" 15 | "github.com/go-chi/stampede" 16 | "github.com/stretchr/testify/assert" 17 | "github.com/stretchr/testify/require" 18 | ) 19 | 20 | func TestSingleflightHTTPHandler(t *testing.T) { 21 | // Create a counter to track how many times handlers are called 22 | var callCount int 23 | var mu sync.Mutex 24 | 25 | // Create a test mux 26 | mux := http.NewServeMux() 27 | 28 | // Create the slow handler 29 | endpoint := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 30 | mu.Lock() 31 | callCount++ 32 | mu.Unlock() 33 | 34 | // Simulate processing time 35 | time.Sleep(100 * time.Millisecond) 36 | 37 | w.WriteHeader(http.StatusOK) 38 | w.Write([]byte("slow response")) 39 | }) 40 | 41 | // Apply Handler2 to the slow handler only 42 | wrappedSlowHandler := stampede.Singleflight(slog.Default(), nil)(endpoint) 43 | 44 | // Register the handlers with the mux 45 | mux.Handle("/slow", wrappedSlowHandler) 46 | // mux.Handle("/fast", fastHandler) // Fast handler is not wrapped 47 | 48 | // Create a test server with the mux 49 | server := httptest.NewServer(mux) 50 | defer server.Close() 51 | 52 | // Test concurrent requests to the slow endpoint 53 | var wg sync.WaitGroup 54 | concurrentRequests := 20 55 | 56 | // Reset call count before concurrent test 57 | mu.Lock() 58 | callCount = 0 59 | mu.Unlock() 60 | 61 | for i := 0; i < concurrentRequests; i++ { 62 | wg.Add(1) 63 | go func() { 64 | defer wg.Done() 65 | resp, err := http.Get(server.URL + "/slow") 66 | require.NoError(t, err) 67 | defer resp.Body.Close() 68 | 69 | _, err = io.ReadAll(resp.Body) 70 | require.NoError(t, err) 71 | }() 72 | } 73 | 74 | wg.Wait() 75 | 76 | // Verify call count - should be 1 if singleflight is working 77 | require.Equal(t, 1, callCount) 78 | } 79 | 80 | func TestHTTPCachingHandler(t *testing.T) { 81 | // Create a counter to track how many times handlers are called 82 | var callCount int 83 | var mu sync.Mutex 84 | 85 | // Create a test mux 86 | mux := http.NewServeMux() 87 | 88 | // Create the slow handler 89 | slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 90 | mu.Lock() 91 | callCount++ 92 | mu.Unlock() 93 | 94 | // Simulate processing time 95 | time.Sleep(100 * time.Millisecond) 96 | 97 | w.WriteHeader(http.StatusOK) 98 | w.Write([]byte("slow response")) 99 | }) 100 | 101 | // Apply Handler2 to the slow handler only 102 | // cache, _ := memcache.NewBackend(1000, cachestore.WithDefaultKeyExpiry(10*time.Second)) 103 | cache := newMockCacheBackend() 104 | 105 | wrappedSlowHandler := stampede.Handler(slog.Default(), cache, 5*time.Second, 106 | stampede.WithHTTPStatusTTL(func(status int) time.Duration { 107 | switch { 108 | case status >= 200 && status < 300: 109 | return 1 * time.Second 110 | case status >= 400 && status < 500: 111 | return 10 * time.Second 112 | case status == http.StatusNotFound: 113 | return 30 * time.Second // Special case for 404 114 | default: 115 | return 0 116 | } 117 | }), 118 | )(slowHandler) 119 | 120 | // Register the handlers with the mux 121 | mux.Handle("/slow", wrappedSlowHandler) 122 | // mux.Handle("/fast", fastHandler) // Fast handler is not wrapped 123 | 124 | // Create a test server with the mux 125 | server := httptest.NewServer(mux) 126 | defer server.Close() 127 | 128 | // Test concurrent requests to the slow endpoint 129 | var wg sync.WaitGroup 130 | concurrentRequests := 20 131 | 132 | // Reset call count before concurrent test 133 | mu.Lock() 134 | callCount = 0 135 | mu.Unlock() 136 | 137 | for i := 0; i < concurrentRequests; i++ { 138 | wg.Add(1) 139 | go func() { 140 | defer wg.Done() 141 | resp, err := http.Get(server.URL + "/slow") 142 | require.NoError(t, err) 143 | defer resp.Body.Close() 144 | 145 | body, err := io.ReadAll(resp.Body) 146 | require.NoError(t, err) 147 | require.Equal(t, "slow response", string(body)) 148 | }() 149 | } 150 | 151 | wg.Wait() 152 | 153 | // Verify call count - should be 1 if singleflight is working 154 | require.Equal(t, 1, callCount) 155 | } 156 | 157 | func TestHTTPCachingHandler2(t *testing.T) { 158 | var numRequests = 30 159 | 160 | var hits uint32 161 | var expectedStatus int = 201 162 | var expectedBody = []byte("hi") 163 | 164 | app := func(w http.ResponseWriter, r *http.Request) { 165 | // log.Println("app handler..") 166 | 167 | atomic.AddUint32(&hits, 1) 168 | 169 | hitsNow := atomic.LoadUint32(&hits) 170 | if hitsNow > 1 { 171 | // panic("uh oh") 172 | } 173 | 174 | // time.Sleep(100 * time.Millisecond) // slow handler 175 | w.Header().Set("X-Httpjoin", "test") 176 | w.WriteHeader(expectedStatus) 177 | w.Write(expectedBody) 178 | } 179 | 180 | var count uint32 181 | counter := func(next http.Handler) http.Handler { 182 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 183 | atomic.AddUint32(&count, 1) 184 | next.ServeHTTP(w, r) 185 | atomic.AddUint32(&count, ^uint32(0)) 186 | // log.Println("COUNT:", atomic.LoadUint32(&count)) 187 | }) 188 | } 189 | 190 | recoverer := func(next http.Handler) http.Handler { 191 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 192 | defer func() { 193 | if r := recover(); r != nil { 194 | log.Println("recovered panicing request:", r) 195 | } 196 | }() 197 | next.ServeHTTP(w, r) 198 | }) 199 | } 200 | 201 | cache := newMockCacheBackend() 202 | h := stampede.Handler(slog.Default(), cache, 1*time.Second) 203 | 204 | ts := httptest.NewServer(counter(recoverer(h(http.HandlerFunc(app))))) 205 | defer ts.Close() 206 | 207 | var wg sync.WaitGroup 208 | 209 | for i := 0; i < numRequests; i++ { 210 | wg.Add(1) 211 | go func() { 212 | defer wg.Done() 213 | resp, err := http.Get(ts.URL) 214 | if err != nil { 215 | panic(err) 216 | } 217 | 218 | body, err := io.ReadAll(resp.Body) 219 | if err != nil { 220 | panic(err) 221 | } 222 | defer resp.Body.Close() 223 | 224 | // log.Println("got resp:", resp, "len:", len(body), "body:", string(body)) 225 | 226 | if string(body) != string(expectedBody) { 227 | t.Error("expecting response body:", string(expectedBody)) 228 | } 229 | 230 | if resp.StatusCode != expectedStatus { 231 | t.Error("expecting response status:", expectedStatus) 232 | } 233 | 234 | assert.Equal(t, "test", resp.Header.Get("X-Httpjoin"), "expecting x-httpjoin test header") 235 | }() 236 | } 237 | 238 | wg.Wait() 239 | 240 | totalHits := atomic.LoadUint32(&hits) 241 | // if totalHits > 1 { 242 | // t.Error("handler was hit more than once. hits:", totalHits) 243 | // } 244 | log.Println("total hits:", totalHits) 245 | 246 | finalCount := atomic.LoadUint32(&count) 247 | if finalCount > 0 { 248 | t.Error("queue count was expected to be empty, but count:", finalCount) 249 | } 250 | log.Println("final count:", finalCount) 251 | } 252 | 253 | func TestBypassCORSHeaders(t *testing.T) { 254 | var expectedStatus int = 200 255 | var expectedBody = []byte("hi") 256 | 257 | var count uint64 258 | 259 | domains := []string{ 260 | "google.com", 261 | "sequence.build", 262 | "horizon.io", 263 | "github.com", 264 | "ethereum.org", 265 | } 266 | 267 | app := func(w http.ResponseWriter, r *http.Request) { 268 | w.Header().Set("X-Another-Header", "wakka") 269 | w.WriteHeader(expectedStatus) 270 | w.Write(expectedBody) 271 | 272 | atomic.AddUint64(&count, 1) 273 | } 274 | 275 | cache := newMockCacheBackend() 276 | h := stampede.Handler(slog.Default(), cache, 5*time.Second) 277 | c := cors.New(cors.Options{ 278 | AllowedOrigins: domains, 279 | AllowedMethods: []string{"GET"}, 280 | AllowedHeaders: []string{"*"}, 281 | }).Handler 282 | 283 | ts := httptest.NewServer(c(h(http.HandlerFunc(app)))) 284 | defer ts.Close() 285 | 286 | var mu sync.Mutex 287 | 288 | for i := 0; i < 10; i++ { 289 | var wg sync.WaitGroup 290 | var domainsHit = map[string]bool{} 291 | 292 | for _, domain := range domains { 293 | wg.Add(1) 294 | go func(domain string) { 295 | defer wg.Done() 296 | 297 | req, err := http.NewRequest("GET", ts.URL, nil) 298 | assert.NoError(t, err) 299 | req.Header.Set("Origin", domain) 300 | 301 | resp, err := http.DefaultClient.Do(req) 302 | if err != nil { 303 | panic(err) 304 | } 305 | 306 | body, err := io.ReadAll(resp.Body) 307 | if err != nil { 308 | panic(err) 309 | } 310 | defer resp.Body.Close() 311 | 312 | if string(body) != string(expectedBody) { 313 | t.Error("expecting response body:", string(expectedBody)) 314 | } 315 | 316 | if resp.StatusCode != expectedStatus { 317 | t.Error("expecting response status:", expectedStatus) 318 | } 319 | 320 | mu.Lock() 321 | domainsHit[resp.Header.Get("Access-Control-Allow-Origin")] = true 322 | mu.Unlock() 323 | 324 | assert.Equal(t, "wakka", resp.Header.Get("X-Another-Header")) 325 | }(domain) 326 | } 327 | 328 | wg.Wait() 329 | 330 | // expect all domains to be returned and recorded in domainsHit 331 | for _, domain := range domains { 332 | assert.True(t, domainsHit[domain]) 333 | } 334 | 335 | // expect to have only one actual hit 336 | assert.Equal(t, uint64(1), count) 337 | } 338 | } 339 | 340 | func TestEmptyHandlerFunc(t *testing.T) { 341 | mux := http.NewServeMux() 342 | cache := newMockCacheBackend() 343 | middleware := stampede.Handler(slog.Default(), cache, 1*time.Hour) 344 | mux.Handle("/", middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 345 | t.Log(r.Method, r.URL) 346 | }))) 347 | 348 | ts := httptest.NewServer(mux) 349 | defer ts.Close() 350 | 351 | { 352 | req, err := http.NewRequest(http.MethodGet, ts.URL, nil) 353 | if err != nil { 354 | t.Fatal(err) 355 | } 356 | resp, err := http.DefaultClient.Do(req) 357 | if err != nil { 358 | t.Fatal(err) 359 | } 360 | defer resp.Body.Close() 361 | t.Log(resp.StatusCode) 362 | } 363 | { 364 | req, err := http.NewRequest(http.MethodGet, ts.URL, nil) 365 | if err != nil { 366 | t.Fatal(err) 367 | } 368 | resp, err := http.DefaultClient.Do(req) 369 | if err != nil { 370 | t.Fatal(err) 371 | } 372 | defer resp.Body.Close() 373 | t.Log(resp.StatusCode) 374 | } 375 | } 376 | --------------------------------------------------------------------------------