├── .github ├── CODEOWNERS ├── FUNDING.yml └── workflows │ └── ci.yml ├── .golangci.yml ├── LICENSE ├── README.md ├── _example ├── go.mod ├── go.sum └── main.go ├── go.mod ├── go.sum ├── middleware ├── cache.go ├── cache │ ├── cache.go │ ├── cache_test.go │ └── options.go ├── cache_test.go ├── circuit_breaker.go ├── circuit_breaker_test.go ├── concurrent.go ├── concurrent_test.go ├── header.go ├── header_test.go ├── logger │ ├── logger.go │ └── logger_test.go ├── middleware.go ├── mocks │ ├── cache.go │ ├── circuit_breaker.go │ ├── logger.go │ ├── repeater.go │ └── roundtripper.go ├── repeater.go ├── repeater_test.go ├── retry.go └── retry_test.go ├── requester.go └── requester_test.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # These owners will be the default owners for everything in the repo. 2 | # Unless a later match takes precedence, @umputun will be requested for 3 | # review when someone opens a pull request. 4 | 5 | * @umputun 6 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [umputun] 2 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | tags: 7 | pull_request: 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: set up go 1.19 15 | uses: actions/setup-go@v3 16 | with: 17 | go-version: "1.19" 18 | id: go 19 | 20 | - name: checkout 21 | uses: actions/checkout@v3 22 | 23 | - name: build and test 24 | run: | 25 | go get -v 26 | go test -timeout=60s -covermode=count -coverprofile=$GITHUB_WORKSPACE/profile.cov_tmp ./... 27 | cat $GITHUB_WORKSPACE/profile.cov_tmp | grep -v "_mock.go" | grep -v "mocks" > $GITHUB_WORKSPACE/profile.cov 28 | go build -race 29 | env: 30 | GO111MODULE: "on" 31 | TZ: "America/Chicago" 32 | 33 | - name: golangci-lint 34 | uses: golangci/golangci-lint-action@v3 35 | with: 36 | version: latest 37 | 38 | - name: install goveralls 39 | run: GO111MODULE=off go get -u -v github.com/mattn/goveralls 40 | 41 | - name: submit coverage 42 | run: $(go env GOPATH)/bin/goveralls -service="github" -coverprofile=$GITHUB_WORKSPACE/profile.cov 43 | env: 44 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters-settings: 2 | govet: 3 | shadow: true 4 | golint: 5 | min-confidence: 0.6 6 | gocyclo: 7 | min-complexity: 15 8 | maligned: 9 | suggest-new: true 10 | dupl: 11 | threshold: 100 12 | goconst: 13 | min-len: 2 14 | min-occurrences: 2 15 | misspell: 16 | locale: US 17 | lll: 18 | line-length: 140 19 | gocritic: 20 | enabled-tags: 21 | - performance 22 | - style 23 | - experimental 24 | disabled-checks: 25 | - wrapperFunc 26 | - hugeParam 27 | - rangeValCopy 28 | 29 | linters: 30 | disable-all: true 31 | enable: 32 | - revive 33 | - govet 34 | - unconvert 35 | - gosec 36 | - unparam 37 | - unused 38 | - typecheck 39 | - ineffassign 40 | - stylecheck 41 | - gochecknoinits 42 | - gocritic 43 | - nakedret 44 | - gosimple 45 | - prealloc 46 | - gofmt 47 | 48 | fast: false 49 | 50 | 51 | run: 52 | concurrency: 4 53 | 54 | issues: 55 | exclude-dirs: 56 | - vendor 57 | exclude-rules: 58 | - text: "should have a package comment, unless it's in another file for this package" 59 | linters: 60 | - golint 61 | - text: "exitAfterDefer:" 62 | linters: 63 | - gocritic 64 | - text: "whyNoLint: include an explanation for nolint directive" 65 | linters: 66 | - gocritic 67 | - text: "go.mongodb.org/mongo-driver/bson/primitive.E" 68 | linters: 69 | - govet 70 | - text: "weak cryptographic primitive" 71 | linters: 72 | - gosec 73 | - text: "integer overflow conversion" 74 | linters: 75 | - gosec 76 | - text: "should have a package comment" 77 | linters: 78 | - revive 79 | - text: "at least one file in a package should have a package comment" 80 | linters: 81 | - stylecheck 82 | - text: "commentedOutCode: may want to remove commented-out code" 83 | linters: 84 | - gocritic 85 | - text: "unnamedResult: consider giving a name to these results" 86 | linters: 87 | - gocritic 88 | - text: "var-naming: don't use an underscore in package name" 89 | linters: 90 | - revive 91 | - text: "should not use underscores in package names" 92 | linters: 93 | - stylecheck 94 | - text: "struct literal uses unkeyed fields" 95 | linters: 96 | - govet 97 | - linters: 98 | - unparam 99 | - unused 100 | - revive 101 | path: _test\.go$ 102 | text: "unused-parameter" 103 | exclude-use-default: false 104 | 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Umputun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Requester 2 | 3 | [![Build Status](https://github.com/go-pkgz/requester/workflows/build/badge.svg)](https://github.com/go-pkgz/requester/actions) [![Coverage Status](https://coveralls.io/repos/github/go-pkgz/requester/badge.svg?branch=main)](https://coveralls.io/github/go-pkgz/requester?branch=main) [![Go Reference](https://pkg.go.dev/badge/github.com/go-pkgz/requester.svg)](https://pkg.go.dev/github.com/go-pkgz/requester) 4 | 5 | 6 | The package provides a very thin wrapper (no external dependencies) for `http.Client`, allowing the use of layers (middlewares) at the `http.RoundTripper` level. The goal is to maintain the way users leverage the stdlib HTTP client while adding a few useful extras on top of the standard `http.Client`. 7 | 8 | _Please note: this is not a replacement for `http.Client`, but rather a companion library._ 9 | 10 | ```go 11 | rq := requester.New( // make the requester 12 | http.Client{Timeout: 5*time.Second}, // set http client 13 | requester.MaxConcurrent(8), // maximum number of concurrent requests 14 | requester.JSON, // set json headers 15 | requester.Header("X-AUTH", "123456789"),// set some auth header 16 | requester.Logger(requester.StdLogger), // enable logging to stdout 17 | ) 18 | 19 | req := http.NewRequest("GET", "http://example.com/api", nil) // create the usual http.Request 20 | req.Header.Set("foo", "bar") // do the usual things with request, for example set some custome headers 21 | resp, err := rq.Do(req) // instead of client.Do call requester.Do 22 | ``` 23 | 24 | 25 | ## Install and update 26 | 27 | `go get -u github.com/go-pkgz/requester` 28 | 29 | ## Overview 30 | 31 | *Built-in middlewares:* 32 | 33 | - `Header` - appends user-defined headers to all requests. 34 | - `MaxConcurrent` - sets maximum concurrency 35 | - `Retry` - sets retry on errors and status codes 36 | - `JSON` - sets headers `"Content-Type": "application/json"` and `"Accept": "application/json"` 37 | - `BasicAuth(user, passwd string)` - adds HTTP Basic Authentication 38 | 39 | *Interfaces for external middlewares:* 40 | 41 | - `Repeater` - sets repeater to retry failed requests. Doesn't provide repeater implementation but wraps it. Compatible with any repeater (for example [go-pkgz/repeater](https://github.com/go-pkgz/repeater)) implementing a single method interface `Do(ctx context.Context, fun func() error, errors ...error) (err error)` interface. 42 | - `Cache` - sets any `LoadingCache` implementation to be used for request/response caching. Doesn't provide cache, but wraps it. Compatible with any cache (for example a family of caches from [go-pkgz/lcw](https://github.com/go-pkgz/lcw)) implementing a single-method interface `Get(key string, fn func() (interface{}, error)) (val interface{}, err error)` 43 | - `Logger` - sets logger, compatible with any implementation of a single-method interface `Logf(format string, args ...interface{})`, for example [go-pkgz/lgr](https://github.com/go-pkgz/lgr) 44 | - `CircuitBreaker` - sets circuit breaker, interface compatible with [sony/gobreaker](https://github.com/sony/gobreaker) 45 | 46 | Users can add any custom middleware. All it needs is a handler `RoundTripperHandler func(http.RoundTripper) http.RoundTripper`. 47 | Convenient functional adapter `middleware.RoundTripperFunc` provided. 48 | 49 | See examples of the usage in [_example](https://github.com/go-pkgz/requester/tree/master/_example) 50 | 51 | ### Header middleware 52 | 53 | `Header` middleware adds user-defined headers to all requests. It expects a map of headers to be added. For example: 54 | 55 | ```go 56 | rq := requester.New(http.Client{}, middleware.Header("X-Auth", "123456789")) 57 | ``` 58 | ### MaxConcurrent middleware 59 | 60 | `MaxConcurrent` middleware can be used to limit the concurrency of a given requester and limit overall concurrency for multiple requesters. For the first case, `MaxConcurrent(N)` should be created in the requester chain of middlewares. For example, `rq := requester.New(http.Client{Timeout: 3 * time.Second}, middleware.MaxConcurrent(8))`. To make it global, `MaxConcurrent` should be created once, outside the chain, and passed into each requester. For example: 61 | 62 | ```go 63 | mc := middleware.MaxConcurrent(16) 64 | rq1 := requester.New(http.Client{Timeout: 3 * time.Second}, mc) 65 | rq2 := requester.New(http.Client{Timeout: 1 * time.Second}, middleware.JSON, mc) 66 | ``` 67 | ### Retry middleware 68 | 69 | Retry middleware provides a flexible retry mechanism with different backoff strategies. By default, it retries on network errors and 5xx responses. 70 | 71 | ```go 72 | // retry 3 times with exponential backoff, starting from 100ms 73 | rq := requester.New(http.Client{}, middleware.Retry(3, 100*time.Millisecond)) 74 | 75 | // retry with custom options 76 | rq := requester.New(http.Client{}, middleware.Retry(3, 100*time.Millisecond, 77 | middleware.RetryWithBackoff(middleware.BackoffLinear), // use linear backoff 78 | middleware.RetryMaxDelay(5*time.Second), // cap maximum delay 79 | middleware.RetryWithJitter(0.1), // add 10% randomization 80 | middleware.RetryOnCodes(503, 502), // retry only on specific codes 81 | // or middleware.RetryExcludeCodes(404, 401), // alternatively, retry on all except these codes 82 | )) 83 | ``` 84 | 85 | Default configuration: 86 | - 3 attempts 87 | - Initial delay: 100ms 88 | - Max delay: 30s 89 | - Exponential backoff 90 | - 10% jitter 91 | - Retries on 5xx status codes 92 | 93 | Retry Options: 94 | - `RetryWithBackoff(t BackoffType)` - set backoff strategy (Constant, Linear, or Exponential) 95 | - `RetryMaxDelay(d time.Duration)` - cap the maximum delay between retries 96 | - `RetryWithJitter(f float64)` - add randomization to delays (0-1.0 factor) 97 | - `RetryOnCodes(codes ...int)` - retry only on specific status codes 98 | - `RetryExcludeCodes(codes ...int)` - retry on all codes except specified 99 | 100 | Note: `RetryOnCodes` and `RetryExcludeCodes` are mutually exclusive and can't be used together. 101 | 102 | ### Cache middleware 103 | 104 | Cache middleware provides an **in-memory caching layer** for HTTP responses. It improves performance by avoiding repeated network calls for the same request. 105 | 106 | #### **Basic Usage** 107 | 108 | ```go 109 | rq := requester.New(http.Client{}, middleware.Cache()) 110 | ``` 111 | 112 | By default: 113 | 114 | - Only GET requests are cached 115 | - TTL (Time-To-Live) is 5 minutes 116 | - Maximum cache size is 1000 entries 117 | - Caches only HTTP 200 responses 118 | 119 | 120 | #### **Cache Configuration Options** 121 | 122 | ```go 123 | rq := requester.New(http.Client{}, middleware.Cache( 124 | middleware.CacheTTL(10*time.Minute), // change TTL to 10 minutes 125 | middleware.CacheSize(500), // limit cache to 500 entries 126 | middleware.CacheMethods(http.MethodGet, http.MethodPost), // allow caching for GET and POST 127 | middleware.CacheStatuses(200, 201, 204), // cache only responses with these status codes 128 | middleware.CacheWithBody, // include request body in cache key 129 | middleware.CacheWithHeaders("Authorization", "X-Custom-Header"), // include selected headers in cache key 130 | )) 131 | ``` 132 | 133 | #### Cache Key Composition 134 | 135 | By default, the cache key is generated using: 136 | 137 | - HTTP **method** 138 | - Full **URL** 139 | - (Optional) **Headers** (if `CacheWithHeaders` is enabled) 140 | - (Optional) **Body** (if `CacheWithBody` is enabled) 141 | 142 | For example, enabling `CacheWithHeaders("Authorization")` will cache the same URL differently **for each unique Authorization token**. 143 | 144 | #### Cache Eviction Strategy 145 | 146 | - **Entries expire** when the TTL is reached. 147 | - **If the cache reaches its maximum size**, the **oldest entry is evicted** (FIFO order). 148 | 149 | 150 | #### Cache Limitations 151 | 152 | - **Only caches complete HTTP responses.** Streaming responses are **not** supported. 153 | - **Does not cache responses with status codes other than 200** (unless explicitly allowed). 154 | - **Uses in-memory storage**, meaning the cache **resets on application restart**. 155 | 156 | 157 | ### JSON middleware 158 | 159 | `JSON` middleware sets headers `"Content-Type": "application/json"` and `"Accept": "application/json"`. 160 | 161 | ```go 162 | rq := requester.New(http.Client{}, middleware.JSON) 163 | ``` 164 | 165 | ### BasicAuth middleware 166 | 167 | `BasicAuth` middleware adds HTTP Basic Authentication to all requests. It expects a username and password. For example: 168 | 169 | ```go 170 | rq := requester.New(http.Client{}, middleware.BasicAuth("user", "passwd")) 171 | ``` 172 | 173 | ---- 174 | 175 | ### Logging middleware interface 176 | 177 | Logger should implement `Logger` interface with a single method `Logf(format string, args ...interface{})`. 178 | For convenience, func type `LoggerFunc` is provided as an adapter to allow the use of ordinary functions as `Logger`. 179 | 180 | Two basic implementations included: 181 | 182 | - `NoOpLogger` do-nothing logger (default) 183 | - `StdLogger` wrapper for stdlib logger. 184 | 185 | logging options: 186 | 187 | - `Prefix(prefix string)` sets prefix for each logged line 188 | - `WithBody` - allows request's body logging 189 | - `WithHeaders` - allows request's headers logging 190 | 191 | Note: If logging is allowed, it will log the URL, method, and may log headers and the request body. It may affect application security. For example, if a request passes some sensitive information as part of the body or header. In this case, consider turning logging off or providing your own logger to suppress all that you need to hide. 192 | 193 | 194 | If the request is limited, it will wait till the limit is released. 195 | 196 | ### Cache middleware interface 197 | 198 | Cache expects the `LoadingCache` interface to implement a single method: `Get(key string, fn func() (interface{}, error)) (val interface{}, err error)`. [LCW](https://github.com/go-pkgz/lcw/) can be used directly, and in order to adopt other caches, see the provided `LoadingCacheFunc`. 199 | 200 | #### Caching Key and Allowed Requests 201 | 202 | By default, only `GET` calls are cached. This can be changed with the `Methods(methods ...string)` option. The default key is composed of the full URL. 203 | 204 | Several options define what part of the request will be used for the key: 205 | 206 | - `KeyWithHeaders` - adds all headers to a key 207 | - `KeyWithHeadersIncluded(headers ...string)` - adds only requested headers 208 | - `KeyWithHeadersExcluded(headers ...string)` - adds all headers excluded 209 | - `KeyWithBody` - adds the request's body, limited to the first 16k of the body 210 | - `KeyFunc` - any custom logic provided by the caller 211 | 212 | example: `cache.New(lruCache, cache.Methods("GET", "POST"), cache.KeyFunc() {func(r *http.Request) string {return r.Host})` 213 | 214 | #### cache and streaming response 215 | 216 | `Cache` is **not compatible** with HTTP streaming mode. Practically, this is rare and exotic, but allowing `Cache` will effectively transform the streaming response into a "get it all" typical response. This is due to the fact that the cache has to read the response body fully to save it, so technically streaming will be working, but the client will receive all the data at once. 217 | 218 | ### Repeater middleware interface 219 | 220 | `Repeater` expects a single method interface `Do(fn func() error, failOnCodes ...error) (err error)`. [repeater](github.com/go-pkgz/repeater) can be used directly. 221 | 222 | By default, the repeater will retry on any error and any status code >= 400. However, the user can pass `failOnCodes` to explicitly define which status codes should be treated as errors and retry only on those. For example: `Repeater(repeaterSvc, 500, 400)` repeats requests on 500 and 400 statuses only. 223 | 224 | In a special case where the user wants to retry only on the underlying transport errors (network, timeouts, etc.) and not on any status codes `Repeater(repeaterSvc, 0)` can be used. 225 | 226 | ### User-Defined Middlewares 227 | 228 | Users can add any additional handlers (middleware) to the chain. Each middleware provides `middleware.RoundTripperHandler` and 229 | can alter the request or implement any other custom functionality. 230 | 231 | Example of a handler resetting a particular header: 232 | 233 | ```go 234 | maskHeader := func(http.RoundTripper) http.RoundTripper { 235 | fn := func(req *http.Request) (*http.Response, error) { 236 | req.Header.Del("deleteme") 237 | return next(req) 238 | } 239 | return middleware.RoundTripperFunc(fn) 240 | } 241 | 242 | rq := requester.New(http.Client{}, maskHeader) 243 | ``` 244 | 245 | ## Adding middleware to requester 246 | There are 3 ways to add middleware(s): 247 | 248 | - Pass it to the `New` constructor, i.e. `requester.New(http.Client{}, middleware.MaxConcurrent(8), middleware.Header("foo", "bar"))` 249 | - Add after construction with the `Use` method 250 | - Create a new, inherited requester by using `With`: 251 | 252 | ```go 253 | rq := requester.New(http.Client{}, middleware.Header("foo", "bar")) // make requester enforcing header foo:bar 254 | resp, err := rq.Do(some_http_req) // send a request 255 | 256 | rqLimited := rq.With(middleware.MaxConcurrent(8)) // make requester from rq (foo:bar enforced) and add 8 max concurrency 257 | resp, err := rqLimited.Do(some_http_req) 258 | ``` 259 | 260 | ## Getting http.Client with all middlewares 261 | 262 | For convenience, `requester.Client()` returns `*http.Client` with all middlewares injected. From this point, the user can call `Do` on this client, and it will invoke the request with all the middlewares. 263 | 264 | ## Helpers and adapters 265 | 266 | - `CircuitBreakerFunc func(req func() (interface{}, error)) (interface{}, error)` - adapter to allow the use of an ordinary functions as CircuitBreakerSvc. 267 | - `logger.Func func(format string, args ...interface{})` - functional adapter for `logger.Service`. 268 | - `cache.ServiceFunc func(key string, fn func() (interface{}, error)) (interface{}, error)` - functional adapter for `cache.Service`. 269 | - `RoundTripperFunc func(*http.Request) (*http.Response, error)` - functional adapter for RoundTripperHandler 270 | -------------------------------------------------------------------------------- /_example/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-pkgz/requester/_example 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/go-pkgz/lcw v0.8.1 7 | github.com/go-pkgz/repeater v1.1.3 8 | github.com/go-pkgz/requester v1.0.0 9 | ) 10 | 11 | require ( 12 | github.com/go-redis/redis/v7 v7.4.0 // indirect 13 | github.com/google/uuid v1.1.2 // indirect 14 | github.com/hashicorp/errwrap v1.0.0 // indirect 15 | github.com/hashicorp/go-multierror v1.1.0 // indirect 16 | github.com/hashicorp/golang-lru v0.5.4 // indirect 17 | github.com/pkg/errors v0.9.1 // indirect 18 | ) 19 | 20 | replace github.com/go-pkgz/requester => ../ 21 | -------------------------------------------------------------------------------- /_example/go.sum: -------------------------------------------------------------------------------- 1 | github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 h1:45bxf7AZMwWcqkLzDAQugVEwedisr5nRJ1r+7LYnv0U= 2 | github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= 3 | github.com/alicebob/miniredis/v2 v2.11.4 h1:GsuyeunTx7EllZBU3/6Ji3dhMQZDpC9rLf1luJ+6M5M= 4 | github.com/alicebob/miniredis/v2 v2.11.4/go.mod h1:VL3UDEfAH59bSa7MuHMuFToxkqyHh69s/WUbYlOAuyg= 5 | github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= 6 | github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= 7 | github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= 8 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 9 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 10 | github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= 11 | github.com/go-pkgz/lcw v0.8.1 h1:Bpt2yYTE1J8hIhz8tjdm1WPOgH13eo5iTNsXyop7cMQ= 12 | github.com/go-pkgz/lcw v0.8.1/go.mod h1:Xw0/ZfApATgbjVPYRZO4XHdWyxAjErDWDWJ7TLlw1Vc= 13 | github.com/go-pkgz/repeater v1.1.3 h1:q6+JQF14ESSy28Dd7F+wRelY4F+41HJ0LEy/szNnMiE= 14 | github.com/go-pkgz/repeater v1.1.3/go.mod h1:hVTavuO5x3Gxnu8zW7d6sQBfAneKV8X2FjU48kGfpKw= 15 | github.com/go-redis/redis/v7 v7.4.0 h1:7obg6wUoj05T0EpY0o8B59S9w5yeMWql7sw2kwNW1x4= 16 | github.com/go-redis/redis/v7 v7.4.0/go.mod h1:JDNMw23GTyLNC4GZu9njt15ctBQVn7xjRfnwdHj/Dcg= 17 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 18 | github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 19 | github.com/gomodule/redigo v1.7.1-0.20190322064113-39e2c31b7ca3/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= 20 | github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= 21 | github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 22 | github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= 23 | github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 24 | github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= 25 | github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= 26 | github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= 27 | github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= 28 | github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= 29 | github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= 30 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 31 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 32 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 33 | github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= 34 | github.com/onsi/ginkgo v1.10.1 h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo= 35 | github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= 36 | github.com/onsi/gomega v1.7.0 h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME= 37 | github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= 38 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 39 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 40 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 41 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 42 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 43 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 44 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 45 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 46 | github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb h1:ZkM6LRnq40pR1Ox0hTHlnpkcOTuFIDQpZ1IN8rKKhX0= 47 | github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb/go.mod h1:gqRgreBUhTSL0GeU64rtZ3Uq3wtjOa/TB2YfrtkCbVQ= 48 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 49 | golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 50 | golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= 51 | golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 52 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 53 | golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 54 | golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 55 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 56 | golang.org/x/sys v0.0.0-20191010194322-b09406accb47 h1:/XfQ9z7ib8eEJX2hdgFTZJ/ntt0swNk5oYBziWeTCvY= 57 | golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 58 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 59 | golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= 60 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 61 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 62 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 63 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 64 | gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= 65 | gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= 66 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= 67 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= 68 | gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 69 | gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I= 70 | gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 71 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 72 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 73 | -------------------------------------------------------------------------------- /_example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "math/rand" 6 | "net/http" 7 | "net/http/httptest" 8 | "strconv" 9 | "sync" 10 | "sync/atomic" 11 | "time" 12 | 13 | "github.com/go-pkgz/lcw" 14 | "github.com/go-pkgz/repeater" 15 | 16 | "github.com/go-pkgz/requester" 17 | "github.com/go-pkgz/requester/middleware" 18 | "github.com/go-pkgz/requester/middleware/cache" 19 | "github.com/go-pkgz/requester/middleware/logger" 20 | ) 21 | 22 | func main() { 23 | 24 | // start test server 25 | ts := startTestServer() 26 | defer ts.Close() 27 | 28 | // start another test server returning 500 status on 5 first calls, for repeater demo 29 | tsRep := startTestServerFailedFirst() 30 | defer tsRep.Close() 31 | 32 | requestWithHeaders(ts) 33 | requestWithLogging(ts) 34 | requestWithCache(ts) 35 | requestWithCustom(ts) 36 | requestWithLimitConcurrency(ts) 37 | requestWithRepeater(tsRep) 38 | } 39 | 40 | // requestWithHeaders shows how to use requester with middleware altering headers 41 | func requestWithHeaders(ts *httptest.Server) { 42 | log.Printf("requestWithHeaders --------------") 43 | rq := requester.New(http.Client{Timeout: 3 * time.Second}, middleware.JSON) // make requester with JSON headers 44 | // add auth header, user agent and basic auth 45 | rq.Use( 46 | middleware.Header("X-Auth", "very-secret-key"), 47 | middleware.Header("User-Agent", "test-requester"), 48 | middleware.BasicAuth("user", "password"), 49 | middleware.JSON, 50 | ) 51 | 52 | // create http.Request 53 | req, err := http.NewRequest("GET", ts.URL+"/blah", nil) 54 | if err != nil { 55 | panic(err) 56 | } 57 | 58 | resp, err := rq.Do(req) 59 | if err != nil { 60 | panic(err) 61 | } 62 | log.Printf("status: %s", resp.Status) 63 | 64 | // alternatively, get http.Client and use it directly 65 | client := rq.Client() 66 | resp, err = client.Get(ts.URL + "/blah") 67 | if err != nil { 68 | panic(err) 69 | } 70 | log.Printf("status: %s", resp.Status) 71 | } 72 | 73 | // requestWithLogging example of logging 74 | func requestWithLogging(ts *httptest.Server) { 75 | log.Printf("requestWithLogging --------------") 76 | 77 | rq := requester.New(http.Client{Timeout: 3 * time.Second}, middleware.JSON) // make requester with JSON headers 78 | // add auth header, user agent and JSON headers 79 | // logging added after setting X-Auth to eliminate leaking it to the logs 80 | rq.Use( 81 | middleware.Header("X-Auth", "very-secret-key"), 82 | logger.New(logger.Std, logger.Prefix("REST"), logger.WithHeaders).Middleware, // uses std logger 83 | middleware.Header("User-Agent", "test-requester"), 84 | middleware.JSON, 85 | ) 86 | 87 | // create http.Request 88 | req, err := http.NewRequest("GET", ts.URL+"/blah", nil) 89 | if err != nil { 90 | panic(err) 91 | } 92 | 93 | resp, err := rq.Do(req) 94 | if err != nil { 95 | panic(err) 96 | } 97 | log.Printf("status: %s", resp.Status) 98 | } 99 | 100 | // requestWithCache example of using request cache 101 | func requestWithCache(ts *httptest.Server) { 102 | log.Printf("requestWithCache --------------") 103 | 104 | cacheService, err := lcw.NewLruCache(lcw.MaxKeys(100)) // make LRU loading cache 105 | if err != nil { 106 | panic(err) 107 | } 108 | 109 | // create cache middleware, allowing GET and POST responses caching 110 | // by default, caching key made from request's URL 111 | cmw := cache.New(cacheService, cache.Methods("GET", "POST")) 112 | 113 | // make requester with caching middleware and logger 114 | rq := requester.New(http.Client{Timeout: 3 * time.Second}, 115 | cmw.Middleware, 116 | logger.New(logger.Std, logger.Prefix("REST CACHED"), logger.WithHeaders).Middleware, 117 | ) 118 | 119 | // create http.Request 120 | req, err := http.NewRequest("GET", ts.URL+"/blah", nil) 121 | if err != nil { 122 | panic(err) 123 | } 124 | 125 | resp, err := rq.Do(req) // the first call hits the remote endpoint and cache response 126 | if err != nil { 127 | panic(err) 128 | } 129 | log.Printf("status1: %s", resp.Status) 130 | 131 | // make another call for cached resource, will be fast as result cached 132 | req2, err := http.NewRequest("GET", ts.URL+"/blah", nil) 133 | if err != nil { 134 | panic(err) 135 | } 136 | resp, err = rq.Do(req2) 137 | if err != nil { 138 | panic(err) 139 | } 140 | log.Printf("status2: %s", resp.Status) 141 | } 142 | 143 | // requestWithCustom example of a custom, user provided middleware 144 | func requestWithCustom(ts *httptest.Server) { 145 | log.Printf("requestWithCustom --------------") 146 | 147 | // custom middleware, removes header foo 148 | clearHeaders := func(next http.RoundTripper) http.RoundTripper { 149 | fn := func(r *http.Request) (*http.Response, error) { 150 | r.Header.Del("foo") 151 | return next.RoundTrip(r) 152 | } 153 | return middleware.RoundTripperFunc(fn) 154 | } 155 | 156 | // make requester with clearHeaders 157 | rq := requester.New(http.Client{Timeout: 3 * time.Second}, 158 | logger.New(logger.Std, logger.Prefix("REST CUSTOM"), logger.WithHeaders).Middleware, 159 | clearHeaders, 160 | ) 161 | 162 | // create http.Request 163 | req, err := http.NewRequest("GET", ts.URL+"/blah", nil) 164 | if err != nil { 165 | panic(err) 166 | } 167 | req.Header.Set("foo", "bar") // add foo header 168 | 169 | resp, err := rq.With(clearHeaders).Do(req) // can be used inline 170 | if err != nil { 171 | panic(err) 172 | } 173 | log.Printf("status: %s", resp.Status) 174 | } 175 | 176 | var inFly int32 177 | 178 | // requestWithLimitConcurrency example of concurrency limiter 179 | func requestWithLimitConcurrency(ts *httptest.Server) { 180 | log.Printf("requestWithLimitConcurrency --------------") 181 | 182 | // make requester with logger and max concurrency 4 183 | rq := requester.New(http.Client{Timeout: 3 * time.Second}, 184 | logger.New(logger.Std, logger.Prefix("REST CUSTOM"), logger.WithHeaders).Middleware, 185 | middleware.MaxConcurrent(4), 186 | ) 187 | 188 | client := rq.Client() 189 | 190 | // a test checking if concurrent requests limited to 4 191 | var wg sync.WaitGroup 192 | wg.Add(32) 193 | for i := 0; i < 32; i++ { 194 | go func(i int) { 195 | defer wg.Done() 196 | client.Get(ts.URL + "/blah" + strconv.Itoa(i)) 197 | log.Printf("completed: %d, in fly:%d", i, atomic.LoadInt32(&inFly)) 198 | }(i) 199 | } 200 | wg.Wait() 201 | } 202 | 203 | // requestWithRepeater example of repeater usage 204 | func requestWithRepeater(ts *httptest.Server) { 205 | log.Printf("requestWithRepeater --------------") 206 | 207 | rpt := repeater.NewDefault(10, 500*time.Millisecond) // make a repeater with up to 10 calls, 500ms between calls 208 | rq := requester.New(http.Client{}, 209 | // repeat failed call up to 10 times with 500ms delay on networking error or given status codes 210 | middleware.Repeater(rpt, http.StatusInternalServerError, http.StatusBadGateway), 211 | logger.New(logger.Std, logger.Prefix("REST REPT"), logger.WithHeaders).Middleware, 212 | ) 213 | 214 | // create http.Request 215 | req, err := http.NewRequest("GET", ts.URL+"/blah", nil) 216 | if err != nil { 217 | panic(err) 218 | } 219 | 220 | resp, err := rq.Do(req) 221 | if err != nil { 222 | panic(err) 223 | } 224 | log.Printf("status: %s", resp.Status) 225 | } 226 | 227 | func startTestServer() *httptest.Server { 228 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 229 | c := atomic.AddInt32(&inFly, 1) 230 | log.Printf("request: %+v (%d)", r, c) 231 | time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) // simulate random network latency 232 | w.Header().Set("k1", "v1") 233 | w.Write([]byte("something")) 234 | atomic.AddInt32(&inFly, -1) 235 | })) 236 | } 237 | 238 | func startTestServerFailedFirst() *httptest.Server { 239 | var n int32 240 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 241 | c := atomic.AddInt32(&inFly, 1) 242 | log.Printf("request: %+v (%d)", r, c) 243 | time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) // simulate random network latency 244 | 245 | if atomic.AddInt32(&n, 1) < 5 { // fail 5 first requests 246 | w.WriteHeader(http.StatusInternalServerError) 247 | return 248 | } 249 | 250 | w.Header().Set("k1", "v1") 251 | w.Write([]byte("something")) 252 | atomic.AddInt32(&inFly, -1) 253 | })) 254 | } 255 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-pkgz/requester 2 | 3 | go 1.19 4 | 5 | require github.com/stretchr/testify v1.10.0 6 | 7 | require ( 8 | github.com/davecgh/go-spew v1.1.1 // indirect 9 | github.com/pmezard/go-difflib v1.0.0 // indirect 10 | gopkg.in/yaml.v3 v3.0.1 // indirect 11 | ) 12 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 6 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 7 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 8 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 9 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 10 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 11 | -------------------------------------------------------------------------------- /middleware/cache.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "crypto/sha256" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "sort" 10 | "strings" 11 | "sync" 12 | "time" 13 | ) 14 | 15 | // CacheEntry represents a cached response with metadata 16 | type CacheEntry struct { 17 | body []byte 18 | headers http.Header 19 | status int 20 | createdAt time.Time 21 | } 22 | 23 | // CacheMiddleware implements in-memory cache for HTTP responses with TTL-based eviction 24 | type CacheMiddleware struct { 25 | next http.RoundTripper 26 | ttl time.Duration 27 | maxKeys int 28 | includeBody bool 29 | headers []string 30 | allowedCodes []int 31 | allowedMethods []string 32 | 33 | cache map[string]CacheEntry 34 | keys []string // Maintains insertion order 35 | mu sync.Mutex 36 | } 37 | 38 | // RoundTrip implements http.RoundTripper 39 | func (c *CacheMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { 40 | // check if method is allowed 41 | methodAllowed := false 42 | for _, m := range c.allowedMethods { 43 | if req.Method == m { 44 | methodAllowed = true 45 | break 46 | } 47 | } 48 | if !methodAllowed { 49 | return c.next.RoundTrip(req) 50 | } 51 | 52 | key := c.makeKey(req) // generate cache key based on request 53 | 54 | c.mu.Lock() 55 | // remove expired entries 56 | for len(c.keys) > 0 { 57 | oldestKey := c.keys[0] 58 | if time.Since(c.cache[oldestKey].createdAt) < c.ttl { 59 | break // Stop once we find a non-expired entry 60 | } 61 | delete(c.cache, oldestKey) 62 | c.keys = c.keys[1:] 63 | } 64 | // check cache 65 | entry, found := c.cache[key] 66 | c.mu.Unlock() 67 | 68 | if found { 69 | // cache hit - reconstruct response 70 | return &http.Response{ 71 | Status: fmt.Sprintf("%d %s", entry.status, http.StatusText(entry.status)), 72 | StatusCode: entry.status, 73 | Header: entry.headers.Clone(), 74 | Body: io.NopCloser(bytes.NewReader(entry.body)), 75 | Request: req, 76 | Proto: "HTTP/1.1", 77 | ProtoMajor: 1, 78 | ProtoMinor: 1, 79 | ContentLength: int64(len(entry.body)), 80 | }, nil 81 | } 82 | 83 | // fetch fresh response 84 | resp, err := c.next.RoundTrip(req) 85 | if err != nil { 86 | return resp, err 87 | } 88 | 89 | // check if response code is allowed for caching 90 | if !c.shouldCache(resp.StatusCode) { 91 | return resp, nil 92 | } 93 | 94 | // read and store response body 95 | body, err := io.ReadAll(resp.Body) 96 | if err != nil { 97 | return resp, err 98 | } 99 | _ = resp.Body.Close() 100 | resp.Body = io.NopCloser(bytes.NewReader(body)) 101 | 102 | // store in cache 103 | c.mu.Lock() 104 | defer c.mu.Unlock() 105 | 106 | // evict oldest if maxKeys reached 107 | if len(c.cache) >= c.maxKeys { 108 | oldestKey := c.keys[0] 109 | delete(c.cache, oldestKey) 110 | c.keys = c.keys[1:] 111 | } 112 | 113 | // store new entry 114 | c.cache[key] = CacheEntry{body: body, headers: resp.Header.Clone(), status: resp.StatusCode, createdAt: time.Now()} 115 | c.keys = append(c.keys, key) // maintain order of keys for LRU eviction 116 | 117 | return resp, nil 118 | } 119 | 120 | // makeKey generates a cache key based on the request details 121 | func (c *CacheMiddleware) makeKey(req *http.Request) string { 122 | var sb strings.Builder 123 | sb.WriteString(req.Method) 124 | sb.WriteString(":") 125 | sb.WriteString(req.URL.String()) 126 | 127 | if c.includeBody && req.Body != nil { 128 | body, err := io.ReadAll(req.Body) 129 | if err == nil { 130 | sb.Write(body) 131 | req.Body = io.NopCloser(bytes.NewReader(body)) 132 | } 133 | } 134 | 135 | if len(c.headers) > 0 { 136 | var headers []string 137 | for _, h := range c.headers { 138 | if vals := req.Header.Values(h); len(vals) > 0 { 139 | headers = append(headers, fmt.Sprintf("%s:%s", h, strings.Join(vals, ","))) 140 | } 141 | } 142 | sort.Strings(headers) 143 | sb.WriteString(strings.Join(headers, "||")) 144 | } 145 | 146 | hash := sha256.Sum256([]byte(sb.String())) 147 | return fmt.Sprintf("%x", hash) 148 | } 149 | 150 | func (c *CacheMiddleware) shouldCache(code int) bool { 151 | for _, allowed := range c.allowedCodes { 152 | if code == allowed { 153 | return true 154 | } 155 | } 156 | return false 157 | } 158 | 159 | // Cache creates caching middleware with provided options 160 | func Cache(opts ...CacheOption) RoundTripperHandler { 161 | return func(next http.RoundTripper) http.RoundTripper { 162 | c := &CacheMiddleware{ 163 | next: next, 164 | ttl: 5 * time.Minute, 165 | maxKeys: 1000, 166 | allowedCodes: []int{200}, 167 | allowedMethods: []string{http.MethodGet}, 168 | cache: make(map[string]CacheEntry), 169 | keys: make([]string, 0, 1000), 170 | } 171 | 172 | for _, opt := range opts { 173 | opt(c) 174 | } 175 | 176 | return c 177 | } 178 | } 179 | 180 | // CacheOption represents cache middleware options 181 | type CacheOption func(c *CacheMiddleware) 182 | 183 | // CacheTTL sets cache TTL 184 | func CacheTTL(ttl time.Duration) CacheOption { 185 | return func(c *CacheMiddleware) { 186 | c.ttl = ttl 187 | } 188 | } 189 | 190 | // CacheSize sets maximum number of cached entries 191 | func CacheSize(size int) CacheOption { 192 | return func(c *CacheMiddleware) { 193 | c.maxKeys = size 194 | } 195 | } 196 | 197 | // CacheWithBody includes request body in cache key 198 | func CacheWithBody(c *CacheMiddleware) { 199 | c.includeBody = true 200 | } 201 | 202 | // CacheWithHeaders includes specified headers in cache key 203 | func CacheWithHeaders(headers ...string) CacheOption { 204 | return func(c *CacheMiddleware) { 205 | c.headers = headers 206 | } 207 | } 208 | 209 | // CacheStatuses sets which response status codes should be cached 210 | func CacheStatuses(codes ...int) CacheOption { 211 | return func(c *CacheMiddleware) { 212 | c.allowedCodes = codes 213 | } 214 | } 215 | 216 | // CacheMethods sets which HTTP methods should be cached 217 | func CacheMethods(methods ...string) CacheOption { 218 | return func(c *CacheMiddleware) { 219 | c.allowedMethods = methods 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /middleware/cache/cache.go: -------------------------------------------------------------------------------- 1 | // Package cache implements middleware for response caching. Request's component used as a key. 2 | package cache 3 | 4 | import ( 5 | "bufio" 6 | "bytes" 7 | "crypto/sha256" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "net/http/httputil" 12 | "sort" 13 | "strings" 14 | 15 | "github.com/go-pkgz/requester/middleware" 16 | ) 17 | 18 | // Middleware for caching responses. The key cache generated from request. 19 | type Middleware struct { 20 | Service 21 | 22 | allowedMethods []string 23 | keyFunc func(r *http.Request) string 24 | keyComponents struct { 25 | body bool 26 | headers struct { 27 | enabled bool 28 | include []string 29 | exclude []string 30 | } 31 | } 32 | 33 | dbg bool 34 | } 35 | 36 | const maxBodySize = 1024 * 16 37 | 38 | // Service defines loading cache interface to be used for caching, matching github.com/go-pkgz/lcw interface 39 | type Service interface { 40 | Get(key string, fn func() (interface{}, error)) (interface{}, error) 41 | } 42 | 43 | // ServiceFunc is an adapter to allow the use of an ordinary functions as the loading cache. 44 | type ServiceFunc func(key string, fn func() (interface{}, error)) (interface{}, error) 45 | 46 | // Get and/or fill the cached value for a given keyDbg 47 | func (c ServiceFunc) Get(key string, fn func() (interface{}, error)) (interface{}, error) { 48 | return c(key, fn) 49 | } 50 | 51 | // New makes cache middleware for given cache.Service and optional set of params 52 | // By default allowed methods limited to GET only and key for request's URL 53 | func New(svc Service, opts ...func(m *Middleware)) *Middleware { 54 | res := Middleware{Service: svc, allowedMethods: []string{"GET"}} 55 | for _, opt := range opts { 56 | opt(&res) 57 | } 58 | return &res 59 | } 60 | 61 | // Middleware is the middleware wrapper injecting external cache.Service (LoadingCache) into the call chain. 62 | // Key extracted from the request and options defines what part of request should be used for key and what method 63 | // are allowed for caching. 64 | func (m *Middleware) Middleware(next http.RoundTripper) http.RoundTripper { 65 | fn := func(req *http.Request) (resp *http.Response, err error) { 66 | 67 | if m.Service == nil || !m.methodCacheable(req) { 68 | return next.RoundTrip(req) 69 | } 70 | 71 | key, e := m.extractCacheKey(req) 72 | if e != nil { 73 | return nil, fmt.Errorf("cache key: %w", e) 74 | } 75 | 76 | cachedResp, e := m.Get(key, func() (interface{}, error) { 77 | resp, err = next.RoundTrip(req) 78 | if err != nil { 79 | return nil, err 80 | } 81 | if resp.Body == nil { 82 | return nil, nil 83 | } 84 | return httputil.DumpResponse(resp, true) 85 | }) 86 | 87 | if e != nil { 88 | return nil, fmt.Errorf("cache read for %s: %w", key, e) 89 | } 90 | 91 | body := cachedResp.([]byte) 92 | return http.ReadResponse(bufio.NewReader(bytes.NewReader(body)), req) 93 | } 94 | return middleware.RoundTripperFunc(fn) 95 | } 96 | 97 | func (m *Middleware) extractCacheKey(req *http.Request) (key string, err error) { 98 | bodyKey := func() (string, error) { 99 | if req.Body == nil { 100 | return "", nil 101 | } 102 | reqBody, e := io.ReadAll(io.LimitReader(req.Body, maxBodySize)) 103 | if e != nil { 104 | return "", e 105 | } 106 | _ = req.Body.Close() 107 | req.Body = io.NopCloser(bytes.NewReader(reqBody)) 108 | return string(reqBody), nil 109 | } 110 | 111 | bkey := "" 112 | if m.keyComponents.body && m.keyFunc == nil { 113 | bkey, err = bodyKey() 114 | } 115 | 116 | hkey := "" 117 | if m.keyComponents.headers.enabled && m.keyFunc == nil { 118 | var hh []string 119 | for k, h := range req.Header { 120 | if m.headerAllowed(k) { 121 | hh = append(hh, k+":"+strings.Join(h, "%%")) 122 | } 123 | } 124 | sort.Strings(hh) 125 | hkey = strings.Join(hh, "$$") 126 | } 127 | 128 | if m.keyFunc != nil { 129 | key = m.keyFunc(req) 130 | } else { 131 | key = fmt.Sprintf("%s##%s##%v##%s", req.URL.String(), req.Method, hkey, bkey) 132 | } 133 | if m.dbg { // dbg for testing only, keeps the key human-readable 134 | return key, nil 135 | } 136 | 137 | return fmt.Sprintf("%x", sha256.Sum256([]byte(key))), err 138 | } 139 | 140 | func (m *Middleware) headerAllowed(key string) bool { 141 | if !m.keyComponents.headers.enabled { 142 | return false 143 | } 144 | if len(m.keyComponents.headers.include) > 0 { 145 | for _, h := range m.keyComponents.headers.include { 146 | if strings.EqualFold(key, h) { 147 | return true 148 | } 149 | } 150 | return false 151 | } 152 | 153 | if len(m.keyComponents.headers.exclude) > 0 { 154 | for _, h := range m.keyComponents.headers.exclude { 155 | if strings.EqualFold(key, h) { 156 | return false 157 | } 158 | } 159 | return true 160 | } 161 | return true 162 | } 163 | 164 | func (m *Middleware) methodCacheable(req *http.Request) bool { 165 | for _, m := range m.allowedMethods { 166 | if strings.EqualFold(m, req.Method) { 167 | return true 168 | } 169 | } 170 | return false 171 | } 172 | -------------------------------------------------------------------------------- /middleware/cache/cache_test.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "strconv" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | 16 | "github.com/go-pkgz/requester/middleware/mocks" 17 | ) 18 | 19 | func Test_extractCacheKey(t *testing.T) { 20 | makeReq := func(method, url string, body io.Reader, headers http.Header) *http.Request { 21 | res, err := http.NewRequest(method, url, body) 22 | require.NoError(t, err) 23 | if headers != nil { 24 | res.Header = headers 25 | } 26 | return res 27 | } 28 | 29 | tbl := []struct { 30 | req *http.Request 31 | opts []func(m *Middleware) 32 | keyDbg string 33 | keyHash string 34 | }{ 35 | { 36 | req: makeReq("GET", "http://example.com/1/2?k1=v1&k2=v2", nil, nil), 37 | opts: []func(m *Middleware){}, 38 | keyDbg: "http://example.com/1/2?k1=v1&k2=v2##GET####", 39 | keyHash: "e847b72f947c83d096d71433f6d53202c148242d54150dc275e547f023ff3d5e", 40 | }, 41 | { 42 | req: makeReq("GET", "http://example.com/1/2?k1=v1&k2=v2", nil, 43 | http.Header{"keyDbg": []string{"val1", "val2"}, "k2": []string{"v22"}}), 44 | opts: []func(m *Middleware){KeyWithHeaders}, 45 | keyDbg: "http://example.com/1/2?k1=v1&k2=v2##GET##k2:v22$$keyDbg:val1%%val2##", 46 | keyHash: "7770dca95a1fe3a1dd5719dcc376e2dfa9f64a6c77729c8c98120db5d3ddf6ce", 47 | }, 48 | { 49 | req: makeReq("GET", "http://example.com/1/2?k1=v1&k2=v2", nil, 50 | http.Header{"keyDbg": []string{"val1", "val2"}, "k2": []string{"v22"}}), 51 | opts: []func(m *Middleware){KeyWithHeadersIncluded("k2")}, 52 | keyDbg: "http://example.com/1/2?k1=v1&k2=v2##GET##k2:v22##", 53 | keyHash: "96cdeaac00f84d5e80b9f8e57dceab324ee8d27e44f379c5150f315ba5a61dfb", 54 | }, 55 | { 56 | req: makeReq("GET", "http://example.com/1/2?k1=v1&k2=v2", nil, 57 | http.Header{"keyDbg": []string{"val1", "val2"}, "k2": []string{"v22"}}), 58 | opts: []func(m *Middleware){KeyWithHeadersExcluded("k2")}, 59 | keyDbg: "http://example.com/1/2?k1=v1&k2=v2##GET##keyDbg:val1%%val2##", 60 | keyHash: "7df35feb246b3cc39d15f2b86825dab6587044e017db5284613ce55b3d30dad5", 61 | }, 62 | { 63 | req: makeReq("GET", "http://example.com/1/2?k1=v1&k2=v2", nil, 64 | http.Header{"keyDbg": []string{"val1", "val2"}, "k2": []string{"v22"}}), 65 | opts: []func(m *Middleware){KeyWithHeadersExcluded("xyz", "abc")}, 66 | keyDbg: "http://example.com/1/2?k1=v1&k2=v2##GET##k2:v22$$keyDbg:val1%%val2##", 67 | keyHash: "7770dca95a1fe3a1dd5719dcc376e2dfa9f64a6c77729c8c98120db5d3ddf6ce", 68 | }, 69 | { 70 | req: makeReq("GET", "http://example.com/1/2?k1=v1&k2=v2", bytes.NewBufferString("something"), 71 | http.Header{"keyDbg": []string{"val1", "val2"}, "k2": []string{"v22"}}), 72 | opts: []func(m *Middleware){KeyWithHeadersExcluded("xyz", "abc")}, 73 | keyDbg: "http://example.com/1/2?k1=v1&k2=v2##GET##k2:v22$$keyDbg:val1%%val2##", 74 | keyHash: "7770dca95a1fe3a1dd5719dcc376e2dfa9f64a6c77729c8c98120db5d3ddf6ce", 75 | }, 76 | { 77 | req: makeReq("GET", "http://example.com/1/2?k1=v1&k2=v2", bytes.NewBufferString("something"), 78 | http.Header{"keyDbg": []string{"val1", "val2"}, "k2": []string{"v22"}}), 79 | opts: []func(m *Middleware){KeyWithHeadersExcluded("xyz", "abc"), KeyWithBody}, 80 | keyDbg: "http://example.com/1/2?k1=v1&k2=v2##GET##k2:v22$$keyDbg:val1%%val2##something", 81 | keyHash: "c77208b375a9df49e97920b5621c9ac8e733a13ab6c74abcef7bc4f052af8d38", 82 | }, 83 | { 84 | req: makeReq("GET", "http://example.com/1/2?k1=v1&k2=v2", nil, nil), 85 | opts: []func(m *Middleware){KeyFunc(func(r *http.Request) string { 86 | return r.Host 87 | })}, 88 | keyDbg: "example.com", 89 | keyHash: "a379a6f6eeafb9a55e378c118034e2751e682fab9f2d30ab13d2125586ce1947", 90 | }, 91 | { 92 | req: makeReq("GET", "http://example.com/1/2?k1=v1&k2=v2", nil, nil), 93 | opts: []func(m *Middleware){KeyFunc(func(r *http.Request) string { 94 | return r.URL.Path 95 | })}, 96 | keyDbg: "/1/2", 97 | keyHash: "c385023fa5c9b3d71679c9557649b476784a44c2f1f71b6d46a5a65694f688a0", 98 | }, 99 | } 100 | 101 | // nolint scopelint 102 | for i, tt := range tbl { 103 | t.Run(strconv.Itoa(i), func(t *testing.T) { 104 | c := New(nil, tt.opts...) 105 | c.dbg = true 106 | keyDbg, err := c.extractCacheKey(tt.req) 107 | require.NoError(t, err) 108 | assert.Equal(t, tt.keyDbg, keyDbg) 109 | 110 | c.dbg = false 111 | keyHash, err := c.extractCacheKey(tt.req) 112 | require.NoError(t, err) 113 | assert.Equal(t, tt.keyHash, keyHash) 114 | 115 | }) 116 | } 117 | 118 | } 119 | 120 | func TestMiddleware_Handle(t *testing.T) { 121 | cacheMock := mocks.CacheSvc{GetFunc: func(key string, fn func() (interface{}, error)) (interface{}, error) { 122 | return fn() 123 | }} 124 | c := New(&cacheMock) 125 | c.dbg = true 126 | 127 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 128 | w.Header().Set("k1", "v1") 129 | _, err := w.Write([]byte("something")) 130 | require.NoError(t, err) 131 | })) 132 | 133 | client := http.Client{Transport: c.Middleware(http.DefaultTransport)} 134 | req, err := http.NewRequest("GET", ts.URL+"?k=v", http.NoBody) 135 | require.NoError(t, err) 136 | 137 | resp, err := client.Do(req) 138 | require.NoError(t, err) 139 | assert.Equal(t, 200, resp.StatusCode) 140 | 141 | v, err := io.ReadAll(resp.Body) 142 | require.NoError(t, err) 143 | assert.Equal(t, "something", string(v)) 144 | assert.Equal(t, "v1", resp.Header.Get("k1")) 145 | assert.Equal(t, 1, len(cacheMock.GetCalls())) 146 | assert.Contains(t, cacheMock.GetCalls()[0].Key, "?k=v##GET####") 147 | 148 | req, err = http.NewRequest("GET", ts.URL+"?k=v", http.NoBody) 149 | require.NoError(t, err) 150 | 151 | resp, err = client.Do(req) 152 | require.NoError(t, err) 153 | assert.Equal(t, 200, resp.StatusCode) 154 | 155 | v, err = io.ReadAll(resp.Body) 156 | require.NoError(t, err) 157 | assert.Equal(t, "something", string(v)) 158 | assert.Equal(t, 2, len(cacheMock.GetCalls())) 159 | assert.Contains(t, cacheMock.GetCalls()[1].Key, "?k=v##GET####") 160 | } 161 | 162 | func TestMiddleware_HandleMethodDisabled(t *testing.T) { 163 | cacheMock := mocks.CacheSvc{GetFunc: func(key string, fn func() (interface{}, error)) (interface{}, error) { 164 | return fn() 165 | }} 166 | c := New(&cacheMock, Methods("PUT")) 167 | c.dbg = true 168 | 169 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 170 | w.Header().Set("k1", "v1") 171 | _, err := w.Write([]byte("something")) 172 | require.NoError(t, err) 173 | })) 174 | 175 | client := http.Client{Transport: c.Middleware(http.DefaultTransport)} 176 | req, err := http.NewRequest("GET", ts.URL+"?k=v", http.NoBody) 177 | require.NoError(t, err) 178 | 179 | resp, err := client.Do(req) 180 | require.NoError(t, err) 181 | assert.Equal(t, 200, resp.StatusCode) 182 | assert.Equal(t, 0, len(cacheMock.GetCalls())) 183 | 184 | req, err = http.NewRequest("PUT", ts.URL+"?k=v", http.NoBody) 185 | require.NoError(t, err) 186 | resp, err = client.Do(req) 187 | require.NoError(t, err) 188 | assert.Equal(t, 200, resp.StatusCode) 189 | assert.Equal(t, 1, len(cacheMock.GetCalls())) 190 | } 191 | 192 | func TestMiddleware_EdgeCases(t *testing.T) { 193 | 194 | t.Run("nil service", func(t *testing.T) { 195 | c := New(nil) 196 | req, err := http.NewRequest("GET", "http://example.com", http.NoBody) 197 | require.NoError(t, err) 198 | resp, err := c.Middleware(http.DefaultTransport).RoundTrip(req) 199 | require.NoError(t, err) 200 | assert.NotNil(t, resp) 201 | }) 202 | 203 | t.Run("large body", func(t *testing.T) { 204 | c := New(nil, KeyWithBody) 205 | originalBody := strings.Repeat("a", maxBodySize-1) 206 | body := bytes.NewBufferString(originalBody) 207 | req, err := http.NewRequest("POST", "http://example.com", body) 208 | require.NoError(t, err) 209 | key, err := c.extractCacheKey(req) 210 | require.NoError(t, err) 211 | assert.NotEmpty(t, key) 212 | 213 | // verify key was generated with truncated body but original body is still readable 214 | data, err := io.ReadAll(req.Body) 215 | require.NoError(t, err) 216 | assert.Equal(t, originalBody, string(data)) 217 | 218 | // get key again with same input 219 | body = bytes.NewBufferString(originalBody) 220 | req, err = http.NewRequest("POST", "http://example.com", body) 221 | require.NoError(t, err) 222 | key2, err := c.extractCacheKey(req) 223 | require.NoError(t, err) 224 | 225 | // verify keys match even with truncated bodies 226 | assert.Equal(t, key, key2, "keys should match for same content even if truncated") 227 | }) 228 | 229 | t.Run("body read error", func(t *testing.T) { 230 | c := New(nil, KeyWithBody) 231 | errReader := &errorReader{err: errors.New("read error")} 232 | req, err := http.NewRequest("POST", "http://example.com", errReader) 233 | require.NoError(t, err) 234 | _, err = c.extractCacheKey(req) 235 | require.Error(t, err) 236 | assert.Contains(t, err.Error(), "read error") 237 | }) 238 | 239 | t.Run("body and headers", func(t *testing.T) { 240 | c := New(nil, KeyWithBody, KeyWithHeaders) 241 | body := bytes.NewBufferString("test body") 242 | req, err := http.NewRequest("POST", "http://example.com", body) 243 | require.NoError(t, err) 244 | req.Header.Set("Test", "value") 245 | key1, err := c.extractCacheKey(req) 246 | require.NoError(t, err) 247 | 248 | // same request but different header 249 | body = bytes.NewBufferString("test body") 250 | req, err = http.NewRequest("POST", "http://example.com", body) 251 | require.NoError(t, err) 252 | req.Header.Set("Test", "different") 253 | key2, err := c.extractCacheKey(req) 254 | require.NoError(t, err) 255 | 256 | assert.NotEqual(t, key1, key2) 257 | }) 258 | } 259 | 260 | type errorReader struct { 261 | err error 262 | } 263 | 264 | func (e *errorReader) Read(p []byte) (n int, err error) { 265 | return 0, e.err 266 | } 267 | -------------------------------------------------------------------------------- /middleware/cache/options.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import "net/http" 4 | 5 | // Methods sets what HTTP methods allowed to be cached, default is "GET" only 6 | func Methods(methods ...string) func(m *Middleware) { 7 | return func(m *Middleware) { 8 | m.allowedMethods = append([]string{}, methods...) 9 | } 10 | } 11 | 12 | // KeyWithHeaders makes all headers to affect caching key 13 | func KeyWithHeaders(m *Middleware) { 14 | m.keyComponents.headers.enabled = true 15 | } 16 | 17 | // KeyWithHeadersIncluded allows some headers to affect caching key 18 | func KeyWithHeadersIncluded(headers ...string) func(m *Middleware) { 19 | return func(m *Middleware) { 20 | m.keyComponents.headers.enabled = true 21 | m.keyComponents.headers.include = append(m.keyComponents.headers.include, headers...) 22 | } 23 | } 24 | 25 | // KeyWithHeadersExcluded make all headers, except passed in to affect caching key 26 | func KeyWithHeadersExcluded(headers ...string) func(m *Middleware) { 27 | return func(m *Middleware) { 28 | m.keyComponents.headers.enabled = true 29 | m.keyComponents.headers.exclude = append(m.keyComponents.headers.exclude, headers...) 30 | } 31 | } 32 | 33 | // KeyWithBody makes whole body to be a part of the caching key 34 | func KeyWithBody(m *Middleware) { 35 | m.keyComponents.body = true 36 | } 37 | 38 | // KeyFunc defines custom caching key function 39 | func KeyFunc(fn func(r *http.Request) string) func(m *Middleware) { 40 | return func(m *Middleware) { 41 | m.keyFunc = fn 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /middleware/cache_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "strings" 9 | "sync/atomic" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | 16 | "github.com/go-pkgz/requester/middleware/mocks" 17 | ) 18 | 19 | func TestCache_BasicCaching(t *testing.T) { 20 | t.Run("caches GET request", func(t *testing.T) { 21 | var requestCount int32 22 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 23 | atomic.AddInt32(&requestCount, 1) 24 | return &http.Response{ 25 | StatusCode: 200, 26 | Header: http.Header{"X-Test": []string{"value"}}, 27 | Body: io.NopCloser(strings.NewReader("response body")), 28 | }, nil 29 | }} 30 | 31 | h := Cache()(rmock) 32 | req, err := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) 33 | require.NoError(t, err) 34 | 35 | // first request - cache miss 36 | resp1, err := h.RoundTrip(req) 37 | require.NoError(t, err) 38 | body1, err := io.ReadAll(resp1.Body) 39 | require.NoError(t, err) 40 | _ = resp1.Body.Close() 41 | 42 | // second request - should be cached 43 | resp2, err := h.RoundTrip(req) 44 | require.NoError(t, err) 45 | body2, err := io.ReadAll(resp2.Body) 46 | require.NoError(t, err) 47 | _ = resp2.Body.Close() 48 | 49 | assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount)) 50 | assert.Equal(t, 200, resp1.StatusCode) 51 | assert.Equal(t, 200, resp2.StatusCode) 52 | assert.Equal(t, "response body", string(body1)) 53 | assert.Equal(t, "response body", string(body2)) 54 | assert.Equal(t, "value", resp1.Header.Get("X-Test")) 55 | assert.Equal(t, "value", resp2.Header.Get("X-Test")) 56 | }) 57 | 58 | t.Run("does not cache POST by default", func(t *testing.T) { 59 | var requestCount int32 60 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 61 | atomic.AddInt32(&requestCount, 1) 62 | return &http.Response{ 63 | StatusCode: 200, 64 | Body: io.NopCloser(strings.NewReader("response body")), 65 | }, nil 66 | }} 67 | 68 | h := Cache()(rmock) 69 | req, err := http.NewRequest(http.MethodPost, "http://example.com/", http.NoBody) 70 | require.NoError(t, err) 71 | 72 | // make two requests 73 | _, err = h.RoundTrip(req) 74 | require.NoError(t, err) 75 | _, err = h.RoundTrip(req) 76 | require.NoError(t, err) 77 | 78 | assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) 79 | }) 80 | 81 | t.Run("does not cache non-200 by default", func(t *testing.T) { 82 | var requestCount int32 83 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 84 | atomic.AddInt32(&requestCount, 1) 85 | return &http.Response{ 86 | StatusCode: 404, 87 | Body: io.NopCloser(strings.NewReader("not found")), 88 | }, nil 89 | }} 90 | 91 | h := Cache()(rmock) 92 | req, err := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) 93 | require.NoError(t, err) 94 | 95 | // make two requests 96 | _, err = h.RoundTrip(req) 97 | require.NoError(t, err) 98 | _, err = h.RoundTrip(req) 99 | require.NoError(t, err) 100 | 101 | assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) 102 | }) 103 | } 104 | 105 | func TestCache_Options(t *testing.T) { 106 | 107 | t.Run("respects TTL", func(t *testing.T) { 108 | var requestCount int32 109 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 110 | atomic.AddInt32(&requestCount, 1) 111 | return &http.Response{ 112 | StatusCode: 200, 113 | Body: io.NopCloser(strings.NewReader("response body")), 114 | }, nil 115 | }} 116 | 117 | h := Cache(CacheTTL(50 * time.Millisecond))(rmock) 118 | req, err := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) 119 | require.NoError(t, err) 120 | 121 | // first request 122 | _, err = h.RoundTrip(req) 123 | require.NoError(t, err) 124 | 125 | // second request - should be cached 126 | _, err = h.RoundTrip(req) 127 | require.NoError(t, err) 128 | assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount)) 129 | 130 | // wait for TTL to expire 131 | time.Sleep(100 * time.Millisecond) 132 | 133 | // third request - should hit the backend 134 | _, err = h.RoundTrip(req) 135 | require.NoError(t, err) 136 | assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) 137 | }) 138 | 139 | t.Run("respects cache size", func(t *testing.T) { 140 | var requestCount int32 141 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 142 | return &http.Response{ 143 | StatusCode: 200, 144 | Header: http.Header{"X-Test": []string{"value"}}, 145 | Body: io.NopCloser(strings.NewReader(fmt.Sprintf("response %d", atomic.AddInt32(&requestCount, 1)))), 146 | }, nil 147 | }} 148 | 149 | h := Cache(CacheSize(2))(rmock) 150 | 151 | // first request - should be cached 152 | req1, _ := http.NewRequest(http.MethodGet, "http://example.com/1", http.NoBody) 153 | _, _ = h.RoundTrip(req1) // first call: should hit the backend 154 | _, _ = h.RoundTrip(req1) // second call: should be served from cache 155 | 156 | // second request - should be cached 157 | req2, _ := http.NewRequest(http.MethodGet, "http://example.com/2", http.NoBody) 158 | _, _ = h.RoundTrip(req2) // First call: should hit the backend 159 | _, _ = h.RoundTrip(req2) // Second call: should be served from cache 160 | 161 | // third request - triggers eviction of first request 162 | req3, _ := http.NewRequest(http.MethodGet, "http://example.com/3", http.NoBody) 163 | _, _ = h.RoundTrip(req3) 164 | 165 | // first request should be evicted, making a new backend call 166 | _, _ = h.RoundTrip(req1) 167 | 168 | assert.Equal(t, int32(4), atomic.LoadInt32(&requestCount), "First request should be evicted and re-fetched") 169 | }) 170 | 171 | t.Run("respects allowed methods", func(t *testing.T) { 172 | var requestCount int32 173 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 174 | atomic.AddInt32(&requestCount, 1) 175 | return &http.Response{ 176 | StatusCode: 200, 177 | Header: http.Header{"X-Test": []string{"value"}}, 178 | Body: io.NopCloser(strings.NewReader("response body")), 179 | }, nil 180 | }} 181 | 182 | h := Cache(CacheMethods(http.MethodGet, http.MethodPost))(rmock) 183 | 184 | // GET request should be cached 185 | req1, _ := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) 186 | _, err := h.RoundTrip(req1) 187 | require.NoError(t, err) 188 | _, err = h.RoundTrip(req1) 189 | require.NoError(t, err) 190 | assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount), "GET requests should be cached") 191 | 192 | // POST request should be cached 193 | req2, _ := http.NewRequest(http.MethodPost, "http://example.com/", http.NoBody) 194 | _, err = h.RoundTrip(req2) 195 | require.NoError(t, err) 196 | _, err = h.RoundTrip(req2) 197 | require.NoError(t, err) 198 | assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount), "POST requests should use different cache key") 199 | 200 | // PUT request should not be cached 201 | req3, _ := http.NewRequest(http.MethodPut, "http://example.com/", http.NoBody) 202 | _, err = h.RoundTrip(req3) 203 | require.NoError(t, err) 204 | _, err = h.RoundTrip(req3) 205 | require.NoError(t, err) 206 | assert.Equal(t, int32(4), atomic.LoadInt32(&requestCount), "PUT requests should not be cached") 207 | }) 208 | } 209 | 210 | func TestCache_Keys(t *testing.T) { 211 | t.Run("different URLs get different cache entries", func(t *testing.T) { 212 | var requestCount int32 213 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 214 | atomic.AddInt32(&requestCount, 1) 215 | return &http.Response{ 216 | StatusCode: 200, 217 | Body: io.NopCloser(strings.NewReader(r.URL.Path)), 218 | }, nil 219 | }} 220 | 221 | h := Cache()(rmock) 222 | 223 | req1, err := http.NewRequest(http.MethodGet, "http://example.com/1", http.NoBody) 224 | require.NoError(t, err) 225 | resp1, err := h.RoundTrip(req1) 226 | require.NoError(t, err) 227 | body1, err := io.ReadAll(resp1.Body) 228 | require.NoError(t, err) 229 | err = resp1.Body.Close() 230 | require.NoError(t, err) 231 | 232 | req2, err := http.NewRequest(http.MethodGet, "http://example.com/2", http.NoBody) 233 | require.NoError(t, err) 234 | resp2, err := h.RoundTrip(req2) 235 | require.NoError(t, err) 236 | body2, err := io.ReadAll(resp2.Body) 237 | require.NoError(t, err) 238 | err = resp2.Body.Close() 239 | require.NoError(t, err) 240 | 241 | assert.Equal(t, "/1", string(body1)) 242 | assert.Equal(t, "/2", string(body2)) 243 | assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) 244 | }) 245 | 246 | t.Run("includes headers in cache key when configured", func(t *testing.T) { 247 | var requestCount int32 248 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 249 | atomic.AddInt32(&requestCount, 1) 250 | return &http.Response{ 251 | StatusCode: 200, 252 | Body: io.NopCloser(strings.NewReader(r.Header.Get("X-Test"))), 253 | }, nil 254 | }} 255 | 256 | h := Cache(CacheWithHeaders("X-Test"))(rmock) 257 | req1, err := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) 258 | require.NoError(t, err) 259 | req1.Header.Set("X-Test", "value1") 260 | resp1, err := h.RoundTrip(req1) 261 | require.NoError(t, err) 262 | body1, err := io.ReadAll(resp1.Body) 263 | require.NoError(t, err) 264 | err = resp1.Body.Close() 265 | require.NoError(t, err) 266 | 267 | req2, err := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) 268 | require.NoError(t, err) 269 | req2.Header.Set("X-Test", "value2") 270 | resp2, err := h.RoundTrip(req2) 271 | require.NoError(t, err) 272 | body2, err := io.ReadAll(resp2.Body) 273 | require.NoError(t, err) 274 | err = resp2.Body.Close() 275 | require.NoError(t, err) 276 | 277 | assert.Equal(t, "value1", string(body1)) 278 | assert.Equal(t, "value2", string(body2)) 279 | assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) 280 | }) 281 | 282 | t.Run("includes body in cache key when configured", func(t *testing.T) { 283 | var requestCount int32 284 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 285 | atomic.AddInt32(&requestCount, 1) 286 | body, err := io.ReadAll(r.Body) 287 | require.NoError(t, err) 288 | return &http.Response{ 289 | StatusCode: 200, 290 | Body: io.NopCloser(bytes.NewReader(body)), 291 | }, nil 292 | }} 293 | 294 | h := Cache(CacheWithBody)(rmock) 295 | req1, err := http.NewRequest(http.MethodGet, "http://example.com/", strings.NewReader("body1")) 296 | require.NoError(t, err) 297 | resp1, err := h.RoundTrip(req1) 298 | require.NoError(t, err) 299 | body1, err := io.ReadAll(resp1.Body) 300 | require.NoError(t, err) 301 | err = resp1.Body.Close() 302 | require.NoError(t, err) 303 | 304 | req2, err := http.NewRequest(http.MethodGet, "http://example.com/", strings.NewReader("body2")) 305 | require.NoError(t, err) 306 | resp2, err := h.RoundTrip(req2) 307 | require.NoError(t, err) 308 | body2, err := io.ReadAll(resp2.Body) 309 | require.NoError(t, err) 310 | err = resp2.Body.Close() 311 | require.NoError(t, err) 312 | 313 | assert.Equal(t, "body1", string(body1)) 314 | assert.Equal(t, "body2", string(body2)) 315 | assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) 316 | }) 317 | } 318 | 319 | func TestCache_EdgeCases(t *testing.T) { 320 | 321 | t.Run("expired cache entry should be ignored", func(t *testing.T) { 322 | var requestCount int32 323 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 324 | atomic.AddInt32(&requestCount, 1) 325 | return &http.Response{ 326 | StatusCode: 200, 327 | Body: io.NopCloser(strings.NewReader("fresh response")), 328 | }, nil 329 | }} 330 | 331 | h := Cache(CacheTTL(50 * time.Millisecond))(rmock) 332 | 333 | req, err := http.NewRequest(http.MethodGet, "http://example.com/expired", http.NoBody) 334 | require.NoError(t, err, "failed to create request") 335 | 336 | _, err = h.RoundTrip(req) 337 | require.NoError(t, err, "first request should not fail") 338 | assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount), "first request should hit backend") 339 | 340 | time.Sleep(100 * time.Millisecond) 341 | 342 | _, err = h.RoundTrip(req) 343 | require.NoError(t, err, "second request should not fail") 344 | assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount), "expired cache entry should not be used") 345 | }) 346 | 347 | t.Run("cache size 1 should evict immediately", func(t *testing.T) { 348 | var requestCount int32 349 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 350 | atomic.AddInt32(&requestCount, 1) 351 | return &http.Response{ 352 | StatusCode: 200, 353 | Body: io.NopCloser(strings.NewReader("cached response")), 354 | }, nil 355 | }} 356 | 357 | h := Cache(CacheSize(1))(rmock) 358 | 359 | req1, err := http.NewRequest(http.MethodGet, "http://example.com/1", http.NoBody) 360 | require.NoError(t, err, "failed to create request 1") 361 | req2, err := http.NewRequest(http.MethodGet, "http://example.com/2", http.NoBody) 362 | require.NoError(t, err, "failed to create request 2") 363 | req3, err := http.NewRequest(http.MethodGet, "http://example.com/3", http.NoBody) 364 | require.NoError(t, err, "failed to create request 3") 365 | 366 | _, err = h.RoundTrip(req1) 367 | require.NoError(t, err) 368 | 369 | _, err = h.RoundTrip(req2) 370 | require.NoError(t, err) 371 | 372 | _, err = h.RoundTrip(req3) 373 | require.NoError(t, err) 374 | 375 | _, err = h.RoundTrip(req1) 376 | require.NoError(t, err) 377 | 378 | assert.Equal(t, int32(4), atomic.LoadInt32(&requestCount), "each request should evict the previous one") 379 | }) 380 | 381 | t.Run("only specified status codes should be cached", func(t *testing.T) { 382 | var requestCount int32 383 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 384 | atomic.AddInt32(&requestCount, 1) 385 | return &http.Response{ 386 | StatusCode: 202, // not in allowed list 387 | Body: io.NopCloser(strings.NewReader("not cached")), 388 | }, nil 389 | }} 390 | 391 | h := Cache(CacheStatuses(200, 201, 204))(rmock) 392 | 393 | req, err := http.NewRequest(http.MethodGet, "http://example.com/status", http.NoBody) 394 | require.NoError(t, err, "failed to create request") 395 | 396 | _, err = h.RoundTrip(req) 397 | require.NoError(t, err) 398 | assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount), "first request should hit backend") 399 | 400 | _, err = h.RoundTrip(req) 401 | require.NoError(t, err) 402 | assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount), "non-allowed status codes should not be cached") 403 | }) 404 | 405 | t.Run("headers should be included in cache key when configured", func(t *testing.T) { 406 | var requestCount int32 407 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 408 | atomic.AddInt32(&requestCount, 1) 409 | return &http.Response{ 410 | StatusCode: 200, 411 | Body: io.NopCloser(strings.NewReader(r.Header.Get("Authorization"))), 412 | }, nil 413 | }} 414 | 415 | h := Cache(CacheWithHeaders("Authorization"))(rmock) 416 | 417 | req1, err := http.NewRequest(http.MethodGet, "http://example.com/auth", http.NoBody) 418 | require.NoError(t, err, "failed to create request 1") 419 | req1.Header.Set("Authorization", "Bearer token1") 420 | 421 | resp1, err := h.RoundTrip(req1) 422 | require.NoError(t, err) 423 | body1, err := io.ReadAll(resp1.Body) 424 | require.NoError(t, err) 425 | err = resp1.Body.Close() 426 | require.NoError(t, err) 427 | 428 | resp2, err := h.RoundTrip(req1) // second call should hit cache 429 | require.NoError(t, err) 430 | body2, err := io.ReadAll(resp2.Body) 431 | require.NoError(t, err) 432 | err = resp2.Body.Close() 433 | require.NoError(t, err) 434 | 435 | req2, err := http.NewRequest(http.MethodGet, "http://example.com/auth", http.NoBody) 436 | require.NoError(t, err, "failed to create request 2") 437 | req2.Header.Set("Authorization", "Bearer token2") 438 | 439 | resp3, err := h.RoundTrip(req2) 440 | require.NoError(t, err) 441 | body3, err := io.ReadAll(resp3.Body) 442 | require.NoError(t, err) 443 | err = resp3.Body.Close() 444 | require.NoError(t, err) 445 | 446 | assert.Equal(t, "Bearer token1", string(body1), "first request should be cached separately") 447 | assert.Equal(t, "Bearer token1", string(body2), "second request should be served from cache") 448 | assert.Equal(t, "Bearer token2", string(body3), "third request should be a new cache entry") 449 | assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount), "each authorization header should generate a new cache entry") 450 | }) 451 | } 452 | -------------------------------------------------------------------------------- /middleware/circuit_breaker.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | // CircuitBreaker middleware injects external CircuitBreakerSvc into the call chain 9 | func CircuitBreaker(svc CircuitBreakerSvc) RoundTripperHandler { 10 | 11 | return func(next http.RoundTripper) http.RoundTripper { 12 | fn := func(req *http.Request) (*http.Response, error) { 13 | 14 | if svc == nil { 15 | return next.RoundTrip(req) 16 | } 17 | 18 | resp, e := svc.Execute(func() (interface{}, error) { 19 | return next.RoundTrip(req) 20 | }) 21 | if e != nil { 22 | return nil, fmt.Errorf("circuit breaker: %w", e) 23 | } 24 | return resp.(*http.Response), nil 25 | } 26 | return RoundTripperFunc(fn) 27 | } 28 | } 29 | 30 | // CircuitBreakerSvc is an interface wrapping any function to send a request with circuit breaker. 31 | // can be used with github.com/sony/gobreaker or any similar implementations 32 | type CircuitBreakerSvc interface { 33 | Execute(req func() (interface{}, error)) (interface{}, error) 34 | } 35 | 36 | // CircuitBreakerFunc is an adapter to allow the use of ordinary functions as CircuitBreakerSvc. 37 | type CircuitBreakerFunc func(req func() (interface{}, error)) (interface{}, error) 38 | 39 | // Execute CircuitBreakerFunc 40 | func (c CircuitBreakerFunc) Execute(req func() (interface{}, error)) (interface{}, error) { 41 | return c(req) 42 | } 43 | -------------------------------------------------------------------------------- /middleware/circuit_breaker_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/go-pkgz/requester/middleware/mocks" 11 | ) 12 | 13 | func TestCircuitBreaker(t *testing.T) { 14 | 15 | cbMock := &mocks.CircuitBreakerSvcMock{ 16 | ExecuteFunc: func(req func() (interface{}, error)) (interface{}, error) { 17 | return req() 18 | }, 19 | } 20 | 21 | rmock := &mocks.RoundTripper{ 22 | RoundTripFunc: func(r *http.Request) (*http.Response, error) { 23 | resp := &http.Response{StatusCode: 201} 24 | return resp, nil 25 | }, 26 | } 27 | 28 | req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 29 | require.NoError(t, err) 30 | 31 | h := CircuitBreaker(cbMock) 32 | 33 | resp, err := h(rmock).RoundTrip(req) 34 | require.NoError(t, err) 35 | assert.Equal(t, 201, resp.StatusCode) 36 | 37 | assert.Equal(t, 1, rmock.Calls()) 38 | assert.Equal(t, 1, len(cbMock.ExecuteCalls())) 39 | } 40 | -------------------------------------------------------------------------------- /middleware/concurrent.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | // MaxConcurrent middleware limits the total concurrency for a given requester 8 | func MaxConcurrent(maxLimit int) func(http.RoundTripper) http.RoundTripper { 9 | sema := make(chan struct{}, maxLimit) 10 | return func(next http.RoundTripper) http.RoundTripper { 11 | fn := func(req *http.Request) (*http.Response, error) { 12 | sema <- struct{}{} 13 | defer func() { 14 | <-sema 15 | }() 16 | return next.RoundTrip(req) 17 | } 18 | return RoundTripperFunc(fn) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /middleware/concurrent_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "math/rand" 7 | "net/http" 8 | "sync" 9 | "sync/atomic" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | 16 | "github.com/go-pkgz/requester/middleware/mocks" 17 | ) 18 | 19 | func TestMaxConcurrentHandler(t *testing.T) { 20 | var concurrentCount int32 21 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 22 | c := atomic.AddInt32(&concurrentCount, 1) 23 | t.Logf("concurrent: %d", c) 24 | assert.True(t, c <= 8, c) 25 | defer func() { 26 | atomic.AddInt32(&concurrentCount, -1) 27 | }() 28 | resp := &http.Response{StatusCode: 201} 29 | time.Sleep(time.Millisecond * time.Duration(rand.Intn(100))) // nolint 30 | return resp, nil 31 | }} 32 | 33 | req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 34 | require.NoError(t, err) 35 | 36 | h := MaxConcurrent(8) 37 | 38 | var wg sync.WaitGroup 39 | for i := 0; i < 100; i++ { 40 | wg.Add(1) 41 | go func() { 42 | defer wg.Done() 43 | resp, err := h(rmock).RoundTrip(req) 44 | require.NoError(t, err) 45 | assert.Equal(t, 201, resp.StatusCode) 46 | }() 47 | } 48 | wg.Wait() 49 | 50 | assert.Equal(t, 100, rmock.Calls()) 51 | } 52 | 53 | func TestMaxConcurrent_Advanced(t *testing.T) { 54 | 55 | t.Run("context cancellation", func(t *testing.T) { 56 | var active int32 57 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 58 | atomic.AddInt32(&active, 1) 59 | defer atomic.AddInt32(&active, -1) 60 | 61 | select { 62 | case <-r.Context().Done(): 63 | return nil, r.Context().Err() 64 | case <-time.After(100 * time.Millisecond): 65 | return &http.Response{StatusCode: 200}, nil 66 | } 67 | }} 68 | 69 | h := MaxConcurrent(2) 70 | wrapped := h(rmock) 71 | 72 | ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) 73 | defer cancel() 74 | 75 | var wg sync.WaitGroup 76 | for i := 0; i < 5; i++ { 77 | wg.Add(1) 78 | go func() { 79 | defer wg.Done() 80 | req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com/blah", http.NoBody) 81 | _, err := wrapped.RoundTrip(req) 82 | assert.Error(t, err) 83 | assert.True(t, errors.Is(err, context.DeadlineExceeded)) 84 | }() 85 | } 86 | wg.Wait() 87 | 88 | assert.LessOrEqual(t, atomic.LoadInt32(&active), int32(2)) 89 | }) 90 | 91 | t.Run("stress test", func(t *testing.T) { 92 | if testing.Short() { 93 | t.Skip("skipping stress test in short mode") 94 | } 95 | 96 | var ( 97 | active int32 98 | maxActive int32 99 | errs int32 100 | ) 101 | 102 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 103 | defer atomic.AddInt32(&active, -1) 104 | for { 105 | current := atomic.LoadInt32(&active) 106 | stored := atomic.LoadInt32(&maxActive) 107 | if current > stored { 108 | atomic.CompareAndSwapInt32(&maxActive, stored, current) 109 | } 110 | if current <= stored { 111 | break 112 | } 113 | } 114 | 115 | // random delay and random errors 116 | time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond) //nolint:gosec // no need for secure random here 117 | 118 | if rand.Float32() < 0.1 { //nolint:gosec // no need for secure random here 119 | // 10% error rate 120 | atomic.AddInt32(&errs, 1) 121 | return nil, errors.New("random error") 122 | } 123 | 124 | return &http.Response{StatusCode: 200}, nil 125 | }} 126 | 127 | h := MaxConcurrent(5) 128 | wrapped := h(rmock) 129 | 130 | var wg sync.WaitGroup 131 | for i := 0; i < 100; i++ { 132 | wg.Add(1) 133 | go func() { 134 | defer wg.Done() 135 | req, _ := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 136 | _, _ = wrapped.RoundTrip(req) 137 | }() 138 | } 139 | wg.Wait() 140 | 141 | assert.LessOrEqual(t, atomic.LoadInt32(&maxActive), int32(5), "should never exceed max concurrent") 142 | t.Logf("errors encountered: %d", atomic.LoadInt32(&errs)) 143 | }) 144 | } 145 | -------------------------------------------------------------------------------- /middleware/header.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | // Header middleware adds a header to request 8 | func Header(key, value string) func(http.RoundTripper) http.RoundTripper { 9 | return func(next http.RoundTripper) http.RoundTripper { 10 | fn := func(req *http.Request) (*http.Response, error) { 11 | req.Header.Set(key, value) 12 | return next.RoundTrip(req) 13 | } 14 | return RoundTripperFunc(fn) 15 | } 16 | } 17 | 18 | // JSON sets Content-Type and Accept headers to json 19 | func JSON(next http.RoundTripper) http.RoundTripper { 20 | fn := func(req *http.Request) (*http.Response, error) { 21 | req.Header.Set("Content-Type", "application/json") 22 | req.Header.Set("Accept", "application/json") 23 | return next.RoundTrip(req) 24 | } 25 | return RoundTripperFunc(fn) 26 | } 27 | 28 | // BasicAuth middleware adds basic auth to request 29 | func BasicAuth(user, passwd string) func(http.RoundTripper) http.RoundTripper { 30 | return func(next http.RoundTripper) http.RoundTripper { 31 | fn := func(req *http.Request) (*http.Response, error) { 32 | req.SetBasicAuth(user, passwd) 33 | return next.RoundTrip(req) 34 | } 35 | return RoundTripperFunc(fn) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /middleware/header_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/go-pkgz/requester/middleware/mocks" 11 | ) 12 | 13 | func TestHeader(t *testing.T) { 14 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 15 | assert.Equal(t, "v1", r.Header.Get("k1")) 16 | resp := &http.Response{StatusCode: 201} 17 | return resp, nil 18 | }} 19 | 20 | req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 21 | require.NoError(t, err) 22 | 23 | h := Header("k1", "v1") 24 | resp, err := h(rmock).RoundTrip(req) 25 | require.NoError(t, err) 26 | assert.Equal(t, 201, resp.StatusCode) 27 | 28 | assert.Equal(t, 1, rmock.Calls()) 29 | } 30 | 31 | func TestJSON(t *testing.T) { 32 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 33 | assert.Equal(t, "application/json", r.Header.Get("Content-Type")) 34 | assert.Equal(t, "application/json", r.Header.Get("Accept")) 35 | resp := &http.Response{StatusCode: 201} 36 | return resp, nil 37 | }} 38 | 39 | req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 40 | require.NoError(t, err) 41 | 42 | resp, err := JSON(rmock).RoundTrip(req) 43 | require.NoError(t, err) 44 | assert.Equal(t, 201, resp.StatusCode) 45 | assert.Equal(t, 1, rmock.Calls()) 46 | } 47 | 48 | func TestBasicAuth(t *testing.T) { 49 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 50 | assert.Equal(t, "Basic dXNlcjpwYXNzd2Q=", r.Header.Get("Authorization")) 51 | resp := &http.Response{StatusCode: 201} 52 | return resp, nil 53 | }} 54 | 55 | req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 56 | require.NoError(t, err) 57 | 58 | resp, err := BasicAuth("user", "passwd")(rmock).RoundTrip(req) 59 | require.NoError(t, err) 60 | assert.Equal(t, 201, resp.StatusCode) 61 | assert.Equal(t, 1, rmock.Calls()) 62 | } 63 | func TestHeader_EdgeCases(t *testing.T) { 64 | t.Run("case insensitive", func(t *testing.T) { 65 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 66 | assert.Equal(t, "v1", r.Header.Get("key")) 67 | assert.Equal(t, "v1", r.Header.Get("Key")) 68 | assert.Equal(t, "v1", r.Header.Get("KEY")) 69 | return &http.Response{StatusCode: 200}, nil 70 | }} 71 | 72 | req, err := http.NewRequest("GET", "http://example.com", http.NoBody) 73 | require.NoError(t, err) 74 | 75 | h := Header("KEY", "v1") 76 | _, err = h(rmock).RoundTrip(req) 77 | require.NoError(t, err) 78 | }) 79 | 80 | t.Run("header overwrite", func(t *testing.T) { 81 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 82 | // header middleware overwrites existing values 83 | assert.Equal(t, []string{"v2"}, r.Header.Values("key")) 84 | return &http.Response{StatusCode: 200}, nil 85 | }} 86 | 87 | req, err := http.NewRequest("GET", "http://example.com", http.NoBody) 88 | require.NoError(t, err) 89 | req.Header.Add("key", "v1") 90 | 91 | h := Header("key", "v2") 92 | _, err = h(rmock).RoundTrip(req) 93 | require.NoError(t, err) 94 | }) 95 | 96 | t.Run("json headers set", func(t *testing.T) { 97 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 98 | // JSON middleware sets both Content-Type and Accept 99 | assert.Equal(t, "application/json", r.Header.Get("Content-Type")) 100 | assert.Equal(t, "application/json", r.Header.Get("Accept")) 101 | return &http.Response{StatusCode: 200}, nil 102 | }} 103 | 104 | req, err := http.NewRequest("GET", "http://example.com", http.NoBody) 105 | require.NoError(t, err) 106 | req.Header.Set("Content-Type", "application/xml") 107 | 108 | resp, err := JSON(rmock).RoundTrip(req) 109 | require.NoError(t, err) 110 | assert.Equal(t, 200, resp.StatusCode) 111 | }) 112 | 113 | t.Run("basic auth headers", func(t *testing.T) { 114 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 115 | user, pass, ok := r.BasicAuth() 116 | assert.True(t, ok) 117 | assert.Equal(t, "user", user) 118 | assert.Equal(t, "pass123$$!@", pass) 119 | return &http.Response{StatusCode: 200}, nil 120 | }} 121 | 122 | req, err := http.NewRequest("GET", "http://example.com", http.NoBody) 123 | require.NoError(t, err) 124 | 125 | h := BasicAuth("user", "pass123$$!@") 126 | resp, err := h(rmock).RoundTrip(req) 127 | require.NoError(t, err) 128 | assert.Equal(t, 200, resp.StatusCode) 129 | }) 130 | 131 | t.Run("header middleware order", func(t *testing.T) { 132 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 133 | assert.Equal(t, "v1", r.Header.Get("key1")) 134 | assert.Equal(t, "v2", r.Header.Get("key2")) 135 | // JSON middleware runs last in the chain, so it overwrites Content-Type 136 | assert.Equal(t, "application/json", r.Header.Get("Content-Type")) 137 | user, pass, ok := r.BasicAuth() 138 | assert.True(t, ok) 139 | assert.Equal(t, "user", user) 140 | assert.Equal(t, "pass", pass) 141 | return &http.Response{StatusCode: 200}, nil 142 | }} 143 | 144 | req, err := http.NewRequest("GET", "http://example.com", http.NoBody) 145 | require.NoError(t, err) 146 | 147 | h1 := Header("key1", "v1") 148 | h2 := Header("key2", "v2") 149 | h3 := BasicAuth("user", "pass") 150 | h4 := JSON 151 | 152 | resp, err := h1(h2(h3(h4(rmock)))).RoundTrip(req) 153 | require.NoError(t, err) 154 | assert.Equal(t, 200, resp.StatusCode) 155 | }) 156 | 157 | t.Run("empty headers", func(t *testing.T) { 158 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 159 | assert.Equal(t, "", r.Header.Get("empty-key")) 160 | assert.Equal(t, "application/json", r.Header.Get("Content-Type")) 161 | return &http.Response{StatusCode: 200}, nil 162 | }} 163 | 164 | req, err := http.NewRequest("GET", "http://example.com", http.NoBody) 165 | require.NoError(t, err) 166 | 167 | h := Header("empty-key", "") 168 | _, err = h(JSON(rmock)).RoundTrip(req) 169 | require.NoError(t, err) 170 | }) 171 | } 172 | -------------------------------------------------------------------------------- /middleware/logger/logger.go: -------------------------------------------------------------------------------- 1 | // Package logger implements middleware for request logging. 2 | package logger 3 | 4 | import ( 5 | "bytes" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "log" 10 | "net/http" 11 | "strings" 12 | "time" 13 | 14 | "github.com/go-pkgz/requester/middleware" 15 | ) 16 | 17 | // Middleware for logging requests 18 | type Middleware struct { 19 | Service 20 | prefix string 21 | body bool 22 | headers bool 23 | } 24 | 25 | // New creates logging middleware with optional parameters turning on logging elements 26 | func New(svc Service, opts ...func(m *Middleware)) *Middleware { 27 | res := Middleware{Service: svc} 28 | for _, opt := range opts { 29 | opt(&res) 30 | } 31 | return &res 32 | } 33 | 34 | // Middleware request logging 35 | func (m Middleware) Middleware(next http.RoundTripper) http.RoundTripper { 36 | fn := func(req *http.Request) (resp *http.Response, err error) { 37 | if m.Service == nil { 38 | return next.RoundTrip(req) 39 | } 40 | 41 | st := time.Now() 42 | logParts := []string{} 43 | if m.prefix != "" { 44 | logParts = append(logParts, m.prefix) 45 | } 46 | logParts = append(logParts, req.Method, req.URL.String()+",") 47 | 48 | headerLog := []byte{} // nolint 49 | if m.headers { 50 | if headerLog, err = json.Marshal(req.Header); err != nil { 51 | headerLog = []byte(fmt.Sprintf("headers: %v", req.Header)) 52 | } 53 | logParts = append(logParts, string(headerLog)+",") 54 | } 55 | 56 | bodyLog := "" 57 | if m.body && req.Body != nil { 58 | body, e := io.ReadAll(req.Body) 59 | if e == nil { 60 | _ = req.Body.Close() 61 | req.Body = io.NopCloser(bytes.NewReader(body)) 62 | bodyLog = " body: " + string(body) 63 | if len(bodyLog) > 1024 { 64 | bodyLog = bodyLog[:1024] + "..." 65 | } 66 | bodyLog = strings.Replace(bodyLog, "\n", " ", -1) 67 | } 68 | } 69 | if bodyLog != "" { 70 | logParts = append(logParts, bodyLog+",") 71 | } 72 | resp, err = next.RoundTrip(req) 73 | logParts = append(logParts, fmt.Sprintf("time: %v", time.Since(st))) 74 | m.Logf(strings.Join(logParts, " ")) 75 | return resp, err 76 | } 77 | return middleware.RoundTripperFunc(fn) 78 | } 79 | 80 | // Prefix sets logging prefix for each line 81 | func Prefix(prefix string) func(m *Middleware) { 82 | return func(m *Middleware) { 83 | m.prefix = prefix 84 | } 85 | } 86 | 87 | // WithBody enables body logging 88 | func WithBody(m *Middleware) { 89 | m.body = true 90 | } 91 | 92 | // WithHeaders enables headers logging 93 | func WithHeaders(m *Middleware) { 94 | m.headers = true 95 | } 96 | 97 | // Service defined logger interface used everywhere in the package 98 | type Service interface { 99 | Logf(format string, args ...interface{}) 100 | } 101 | 102 | // Func type is an adapter to allow the use of ordinary functions as Service. 103 | type Func func(format string, args ...interface{}) 104 | 105 | // Logf calls f(id) 106 | func (f Func) Logf(format string, args ...interface{}) { f(format, args...) } 107 | 108 | // Std logger sends to std default logger directly 109 | var Std = Func(func(format string, args ...interface{}) { log.Printf(format, args...) }) 110 | -------------------------------------------------------------------------------- /middleware/logger/logger_test.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | 14 | "github.com/go-pkgz/requester/middleware/mocks" 15 | ) 16 | 17 | func TestMiddleware_Handle(t *testing.T) { 18 | outBuf := bytes.NewBuffer(nil) 19 | loggerMock := &mocks.LoggerSvc{ 20 | LogfFunc: func(format string, args ...interface{}) { 21 | _, _ = fmt.Fprintf(outBuf, format, args...) 22 | }, 23 | } 24 | l := New(loggerMock) 25 | 26 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 27 | w.Header().Set("k1", "v1") 28 | _, err := w.Write([]byte("something")) 29 | require.NoError(t, err) 30 | })) 31 | 32 | client := http.Client{Transport: l.Middleware(http.DefaultTransport)} 33 | req, err := http.NewRequest("GET", ts.URL+"?k=v", http.NoBody) 34 | require.NoError(t, err) 35 | 36 | resp, err := client.Do(req) 37 | require.NoError(t, err) 38 | assert.Equal(t, 200, resp.StatusCode) 39 | 40 | t.Log(outBuf.String()) 41 | assert.True(t, strings.HasPrefix(outBuf.String(), "GET http://127.0.0.1:")) 42 | assert.Contains(t, outBuf.String(), "time:") 43 | 44 | assert.Equal(t, 1, len(loggerMock.LogfCalls())) 45 | } 46 | 47 | func TestMiddleware_HandleWithOptions(t *testing.T) { 48 | outBuf := bytes.NewBuffer(nil) 49 | loggerMock := &mocks.LoggerSvc{ 50 | LogfFunc: func(format string, args ...interface{}) { 51 | _, _ = fmt.Fprintf(outBuf, format, args...) 52 | }, 53 | } 54 | l := New(loggerMock, WithBody, WithHeaders, Prefix("HIT")) 55 | 56 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 57 | w.Header().Set("k1", "v1") 58 | _, err := w.Write([]byte("something")) 59 | require.NoError(t, err) 60 | })) 61 | 62 | client := http.Client{Transport: l.Middleware(http.DefaultTransport)} 63 | req, err := http.NewRequest("POST", ts.URL+"?k=v", bytes.NewBufferString("blah1 blah2\nblah3")) 64 | require.NoError(t, err) 65 | req.Header.Set("inkey", "inval") 66 | 67 | resp, err := client.Do(req) 68 | require.NoError(t, err) 69 | assert.Equal(t, 200, resp.StatusCode) 70 | 71 | t.Log(outBuf.String()) 72 | assert.True(t, strings.HasPrefix(outBuf.String(), "HIT POST http://127.0.0.1:")) 73 | assert.Contains(t, outBuf.String(), "time:") 74 | assert.Contains(t, outBuf.String(), `{"Inkey":["inval"]}`) 75 | assert.Contains(t, outBuf.String(), `body: blah1 blah2 blah3`) 76 | 77 | assert.Equal(t, 1, len(loggerMock.LogfCalls())) 78 | } 79 | func TestLogger_EdgeCases(t *testing.T) { 80 | t.Run("non-standard headers", func(t *testing.T) { 81 | outBuf := bytes.NewBuffer(nil) 82 | loggerMock := &mocks.LoggerSvc{ 83 | LogfFunc: func(format string, args ...interface{}) { 84 | _, _ = fmt.Fprintf(outBuf, format, args...) 85 | }, 86 | } 87 | l := New(loggerMock, WithHeaders) 88 | 89 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 90 | w.Header().Set("k1", "v1") 91 | _, err := w.Write([]byte("ok")) 92 | require.NoError(t, err) 93 | })) 94 | defer ts.Close() 95 | 96 | req, err := http.NewRequest("GET", ts.URL+"?k=v", http.NoBody) 97 | require.NoError(t, err) 98 | 99 | // use complex unicode char sequence that might affect json marshaling 100 | req.Header.Set("Test-Header", "привет世界") 101 | req.Header.Set("Multiple-Values", "val1") 102 | req.Header.Add("Multiple-Values", "val2") 103 | 104 | client := http.Client{Transport: l.Middleware(http.DefaultTransport)} 105 | resp, err := client.Do(req) 106 | require.NoError(t, err) 107 | assert.Equal(t, 200, resp.StatusCode) 108 | 109 | logOutput := outBuf.String() 110 | assert.Contains(t, logOutput, "Test-Header") 111 | assert.Contains(t, logOutput, "привет世界") 112 | assert.Contains(t, logOutput, "Multiple-Values") 113 | assert.Contains(t, logOutput, "val1") 114 | assert.Contains(t, logOutput, "val2") 115 | }) 116 | 117 | t.Run("prefix handling", func(t *testing.T) { 118 | outBuf := bytes.NewBuffer(nil) 119 | loggerMock := &mocks.LoggerSvc{ 120 | LogfFunc: func(format string, args ...interface{}) { 121 | _, _ = fmt.Fprintf(outBuf, format, args...) 122 | }, 123 | } 124 | l := New(loggerMock, Prefix("TEST-PREFIX")) 125 | 126 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 127 | _, err := w.Write([]byte("ok")) 128 | require.NoError(t, err) 129 | })) 130 | defer ts.Close() 131 | 132 | client := http.Client{Transport: l.Middleware(http.DefaultTransport)} 133 | resp, err := client.Get(ts.URL) 134 | require.NoError(t, err) 135 | assert.Equal(t, 200, resp.StatusCode) 136 | 137 | assert.True(t, strings.HasPrefix(outBuf.String(), "TEST-PREFIX")) 138 | }) 139 | 140 | t.Run("large body truncation", func(t *testing.T) { 141 | outBuf := bytes.NewBuffer(nil) 142 | loggerMock := &mocks.LoggerSvc{ 143 | LogfFunc: func(format string, args ...interface{}) { 144 | _, _ = fmt.Fprintf(outBuf, format, args...) 145 | }, 146 | } 147 | l := New(loggerMock, WithBody) 148 | 149 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 150 | _, err := w.Write([]byte("ok")) 151 | require.NoError(t, err) 152 | })) 153 | defer ts.Close() 154 | 155 | largeBody := strings.Repeat("x", 2000) 156 | req, err := http.NewRequest("POST", ts.URL, strings.NewReader(largeBody)) 157 | require.NoError(t, err) 158 | 159 | client := http.Client{Transport: l.Middleware(http.DefaultTransport)} 160 | resp, err := client.Do(req) 161 | require.NoError(t, err) 162 | assert.Equal(t, 200, resp.StatusCode) 163 | 164 | output := outBuf.String() 165 | assert.Contains(t, output, "...") 166 | assert.True(t, len(output) < len(largeBody)) 167 | }) 168 | 169 | t.Run("multiline body", func(t *testing.T) { 170 | outBuf := bytes.NewBuffer(nil) 171 | loggerMock := &mocks.LoggerSvc{ 172 | LogfFunc: func(format string, args ...interface{}) { 173 | _, _ = fmt.Fprintf(outBuf, format, args...) 174 | }, 175 | } 176 | l := New(loggerMock, WithBody) 177 | 178 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 179 | _, err := w.Write([]byte("ok")) 180 | require.NoError(t, err) 181 | })) 182 | defer ts.Close() 183 | 184 | bodyContent := "line1\nline2\nline3" 185 | req, err := http.NewRequest("POST", ts.URL, strings.NewReader(bodyContent)) 186 | require.NoError(t, err) 187 | 188 | client := http.Client{Transport: l.Middleware(http.DefaultTransport)} 189 | resp, err := client.Do(req) 190 | require.NoError(t, err) 191 | assert.Equal(t, 200, resp.StatusCode) 192 | 193 | output := outBuf.String() 194 | assert.NotContains(t, output, "\n") 195 | assert.Contains(t, output, "line1 line2 line3") 196 | }) 197 | 198 | t.Run("nil logger", func(t *testing.T) { 199 | l := New(nil, WithBody) 200 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 201 | _, err := w.Write([]byte("ok")) 202 | require.NoError(t, err) 203 | })) 204 | defer ts.Close() 205 | 206 | req, err := http.NewRequest("POST", ts.URL, strings.NewReader("test")) 207 | require.NoError(t, err) 208 | 209 | client := http.Client{Transport: l.Middleware(http.DefaultTransport)} 210 | resp, err := client.Do(req) 211 | require.NoError(t, err) 212 | assert.Equal(t, 200, resp.StatusCode) 213 | }) 214 | } 215 | -------------------------------------------------------------------------------- /middleware/middleware.go: -------------------------------------------------------------------------------- 1 | // Package middleware provides middlewares for htt.Client as RoundTripperHandler 2 | package middleware 3 | 4 | import ( 5 | "net/http" 6 | ) 7 | 8 | //go:generate moq -out mocks/repeater.go -pkg mocks -skip-ensure -with-resets -fmt goimports . RepeaterSvc 9 | //go:generate moq -out mocks/circuit_breaker.go -pkg mocks -skip-ensure -with-resets -fmt goimports . CircuitBreakerSvc 10 | //go:generate moq -out mocks/logger.go -pkg mocks -skip-ensure -with-resets -fmt goimports logger Service:LoggerSvc 11 | //go:generate moq -out mocks/cache.go -pkg mocks -skip-ensure -with-resets -fmt goimports cache Service:CacheSvc 12 | 13 | // RoundTripperHandler is a type for middleware handler 14 | type RoundTripperHandler func(http.RoundTripper) http.RoundTripper 15 | 16 | // RoundTripperFunc is a functional adapter for RoundTripperHandler 17 | type RoundTripperFunc func(*http.Request) (*http.Response, error) 18 | 19 | // RoundTrip adopts function to the type 20 | func (rt RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return rt(r) } 21 | -------------------------------------------------------------------------------- /middleware/mocks/cache.go: -------------------------------------------------------------------------------- 1 | // Code generated by moq; DO NOT EDIT. 2 | // github.com/matryer/moq 3 | 4 | package mocks 5 | 6 | import ( 7 | "sync" 8 | ) 9 | 10 | // CacheSvc is a mock implementation of cache.Service. 11 | // 12 | // func TestSomethingThatUsesService(t *testing.T) { 13 | // 14 | // // make and configure a mocked cache.Service 15 | // mockedService := &CacheSvc{ 16 | // GetFunc: func(key string, fn func() (interface{}, error)) (interface{}, error) { 17 | // panic("mock out the Get method") 18 | // }, 19 | // } 20 | // 21 | // // use mockedService in code that requires cache.Service 22 | // // and then make assertions. 23 | // 24 | // } 25 | type CacheSvc struct { 26 | // GetFunc mocks the Get method. 27 | GetFunc func(key string, fn func() (interface{}, error)) (interface{}, error) 28 | 29 | // calls tracks calls to the methods. 30 | calls struct { 31 | // Get holds details about calls to the Get method. 32 | Get []struct { 33 | // Key is the key argument value. 34 | Key string 35 | // Fn is the fn argument value. 36 | Fn func() (interface{}, error) 37 | } 38 | } 39 | lockGet sync.RWMutex 40 | } 41 | 42 | // Get calls GetFunc. 43 | func (mock *CacheSvc) Get(key string, fn func() (interface{}, error)) (interface{}, error) { 44 | if mock.GetFunc == nil { 45 | panic("CacheSvc.GetFunc: method is nil but Service.Get was just called") 46 | } 47 | callInfo := struct { 48 | Key string 49 | Fn func() (interface{}, error) 50 | }{ 51 | Key: key, 52 | Fn: fn, 53 | } 54 | mock.lockGet.Lock() 55 | mock.calls.Get = append(mock.calls.Get, callInfo) 56 | mock.lockGet.Unlock() 57 | return mock.GetFunc(key, fn) 58 | } 59 | 60 | // GetCalls gets all the calls that were made to Get. 61 | // Check the length with: 62 | // 63 | // len(mockedService.GetCalls()) 64 | func (mock *CacheSvc) GetCalls() []struct { 65 | Key string 66 | Fn func() (interface{}, error) 67 | } { 68 | var calls []struct { 69 | Key string 70 | Fn func() (interface{}, error) 71 | } 72 | mock.lockGet.RLock() 73 | calls = mock.calls.Get 74 | mock.lockGet.RUnlock() 75 | return calls 76 | } 77 | 78 | // ResetGetCalls reset all the calls that were made to Get. 79 | func (mock *CacheSvc) ResetGetCalls() { 80 | mock.lockGet.Lock() 81 | mock.calls.Get = nil 82 | mock.lockGet.Unlock() 83 | } 84 | 85 | // ResetCalls reset all the calls that were made to all mocked methods. 86 | func (mock *CacheSvc) ResetCalls() { 87 | mock.lockGet.Lock() 88 | mock.calls.Get = nil 89 | mock.lockGet.Unlock() 90 | } 91 | -------------------------------------------------------------------------------- /middleware/mocks/circuit_breaker.go: -------------------------------------------------------------------------------- 1 | // Code generated by moq; DO NOT EDIT. 2 | // github.com/matryer/moq 3 | 4 | package mocks 5 | 6 | import ( 7 | "sync" 8 | ) 9 | 10 | // CircuitBreakerSvcMock is a mock implementation of middleware.CircuitBreakerSvc. 11 | // 12 | // func TestSomethingThatUsesCircuitBreakerSvc(t *testing.T) { 13 | // 14 | // // make and configure a mocked middleware.CircuitBreakerSvc 15 | // mockedCircuitBreakerSvc := &CircuitBreakerSvcMock{ 16 | // ExecuteFunc: func(req func() (interface{}, error)) (interface{}, error) { 17 | // panic("mock out the Execute method") 18 | // }, 19 | // } 20 | // 21 | // // use mockedCircuitBreakerSvc in code that requires middleware.CircuitBreakerSvc 22 | // // and then make assertions. 23 | // 24 | // } 25 | type CircuitBreakerSvcMock struct { 26 | // ExecuteFunc mocks the Execute method. 27 | ExecuteFunc func(req func() (interface{}, error)) (interface{}, error) 28 | 29 | // calls tracks calls to the methods. 30 | calls struct { 31 | // Execute holds details about calls to the Execute method. 32 | Execute []struct { 33 | // Req is the req argument value. 34 | Req func() (interface{}, error) 35 | } 36 | } 37 | lockExecute sync.RWMutex 38 | } 39 | 40 | // Execute calls ExecuteFunc. 41 | func (mock *CircuitBreakerSvcMock) Execute(req func() (interface{}, error)) (interface{}, error) { 42 | if mock.ExecuteFunc == nil { 43 | panic("CircuitBreakerSvcMock.ExecuteFunc: method is nil but CircuitBreakerSvc.Execute was just called") 44 | } 45 | callInfo := struct { 46 | Req func() (interface{}, error) 47 | }{ 48 | Req: req, 49 | } 50 | mock.lockExecute.Lock() 51 | mock.calls.Execute = append(mock.calls.Execute, callInfo) 52 | mock.lockExecute.Unlock() 53 | return mock.ExecuteFunc(req) 54 | } 55 | 56 | // ExecuteCalls gets all the calls that were made to Execute. 57 | // Check the length with: 58 | // 59 | // len(mockedCircuitBreakerSvc.ExecuteCalls()) 60 | func (mock *CircuitBreakerSvcMock) ExecuteCalls() []struct { 61 | Req func() (interface{}, error) 62 | } { 63 | var calls []struct { 64 | Req func() (interface{}, error) 65 | } 66 | mock.lockExecute.RLock() 67 | calls = mock.calls.Execute 68 | mock.lockExecute.RUnlock() 69 | return calls 70 | } 71 | 72 | // ResetExecuteCalls reset all the calls that were made to Execute. 73 | func (mock *CircuitBreakerSvcMock) ResetExecuteCalls() { 74 | mock.lockExecute.Lock() 75 | mock.calls.Execute = nil 76 | mock.lockExecute.Unlock() 77 | } 78 | 79 | // ResetCalls reset all the calls that were made to all mocked methods. 80 | func (mock *CircuitBreakerSvcMock) ResetCalls() { 81 | mock.lockExecute.Lock() 82 | mock.calls.Execute = nil 83 | mock.lockExecute.Unlock() 84 | } 85 | -------------------------------------------------------------------------------- /middleware/mocks/logger.go: -------------------------------------------------------------------------------- 1 | // Code generated by moq; DO NOT EDIT. 2 | // github.com/matryer/moq 3 | 4 | package mocks 5 | 6 | import ( 7 | "sync" 8 | ) 9 | 10 | // LoggerSvc is a mock implementation of logger.Service. 11 | // 12 | // func TestSomethingThatUsesService(t *testing.T) { 13 | // 14 | // // make and configure a mocked logger.Service 15 | // mockedService := &LoggerSvc{ 16 | // LogfFunc: func(format string, args ...interface{}) { 17 | // panic("mock out the Logf method") 18 | // }, 19 | // } 20 | // 21 | // // use mockedService in code that requires logger.Service 22 | // // and then make assertions. 23 | // 24 | // } 25 | type LoggerSvc struct { 26 | // LogfFunc mocks the Logf method. 27 | LogfFunc func(format string, args ...interface{}) 28 | 29 | // calls tracks calls to the methods. 30 | calls struct { 31 | // Logf holds details about calls to the Logf method. 32 | Logf []struct { 33 | // Format is the format argument value. 34 | Format string 35 | // Args is the args argument value. 36 | Args []interface{} 37 | } 38 | } 39 | lockLogf sync.RWMutex 40 | } 41 | 42 | // Logf calls LogfFunc. 43 | func (mock *LoggerSvc) Logf(format string, args ...interface{}) { 44 | if mock.LogfFunc == nil { 45 | panic("LoggerSvc.LogfFunc: method is nil but Service.Logf was just called") 46 | } 47 | callInfo := struct { 48 | Format string 49 | Args []interface{} 50 | }{ 51 | Format: format, 52 | Args: args, 53 | } 54 | mock.lockLogf.Lock() 55 | mock.calls.Logf = append(mock.calls.Logf, callInfo) 56 | mock.lockLogf.Unlock() 57 | mock.LogfFunc(format, args...) 58 | } 59 | 60 | // LogfCalls gets all the calls that were made to Logf. 61 | // Check the length with: 62 | // 63 | // len(mockedService.LogfCalls()) 64 | func (mock *LoggerSvc) LogfCalls() []struct { 65 | Format string 66 | Args []interface{} 67 | } { 68 | var calls []struct { 69 | Format string 70 | Args []interface{} 71 | } 72 | mock.lockLogf.RLock() 73 | calls = mock.calls.Logf 74 | mock.lockLogf.RUnlock() 75 | return calls 76 | } 77 | 78 | // ResetLogfCalls reset all the calls that were made to Logf. 79 | func (mock *LoggerSvc) ResetLogfCalls() { 80 | mock.lockLogf.Lock() 81 | mock.calls.Logf = nil 82 | mock.lockLogf.Unlock() 83 | } 84 | 85 | // ResetCalls reset all the calls that were made to all mocked methods. 86 | func (mock *LoggerSvc) ResetCalls() { 87 | mock.lockLogf.Lock() 88 | mock.calls.Logf = nil 89 | mock.lockLogf.Unlock() 90 | } 91 | -------------------------------------------------------------------------------- /middleware/mocks/repeater.go: -------------------------------------------------------------------------------- 1 | // Code generated by moq; DO NOT EDIT. 2 | // github.com/matryer/moq 3 | 4 | package mocks 5 | 6 | import ( 7 | "context" 8 | "sync" 9 | ) 10 | 11 | // RepeaterSvcMock is a mock implementation of middleware.RepeaterSvc. 12 | // 13 | // func TestSomethingThatUsesRepeaterSvc(t *testing.T) { 14 | // 15 | // // make and configure a mocked middleware.RepeaterSvc 16 | // mockedRepeaterSvc := &RepeaterSvcMock{ 17 | // DoFunc: func(ctx context.Context, fun func() error, errs ...error) error { 18 | // panic("mock out the Do method") 19 | // }, 20 | // } 21 | // 22 | // // use mockedRepeaterSvc in code that requires middleware.RepeaterSvc 23 | // // and then make assertions. 24 | // 25 | // } 26 | type RepeaterSvcMock struct { 27 | // DoFunc mocks the Do method. 28 | DoFunc func(ctx context.Context, fun func() error, errs ...error) error 29 | 30 | // calls tracks calls to the methods. 31 | calls struct { 32 | // Do holds details about calls to the Do method. 33 | Do []struct { 34 | // Ctx is the ctx argument value. 35 | Ctx context.Context 36 | // Fun is the fun argument value. 37 | Fun func() error 38 | // Errs is the errs argument value. 39 | Errs []error 40 | } 41 | } 42 | lockDo sync.RWMutex 43 | } 44 | 45 | // Do calls DoFunc. 46 | func (mock *RepeaterSvcMock) Do(ctx context.Context, fun func() error, errs ...error) error { 47 | if mock.DoFunc == nil { 48 | panic("RepeaterSvcMock.DoFunc: method is nil but RepeaterSvc.Do was just called") 49 | } 50 | callInfo := struct { 51 | Ctx context.Context 52 | Fun func() error 53 | Errs []error 54 | }{ 55 | Ctx: ctx, 56 | Fun: fun, 57 | Errs: errs, 58 | } 59 | mock.lockDo.Lock() 60 | mock.calls.Do = append(mock.calls.Do, callInfo) 61 | mock.lockDo.Unlock() 62 | return mock.DoFunc(ctx, fun, errs...) 63 | } 64 | 65 | // DoCalls gets all the calls that were made to Do. 66 | // Check the length with: 67 | // 68 | // len(mockedRepeaterSvc.DoCalls()) 69 | func (mock *RepeaterSvcMock) DoCalls() []struct { 70 | Ctx context.Context 71 | Fun func() error 72 | Errs []error 73 | } { 74 | var calls []struct { 75 | Ctx context.Context 76 | Fun func() error 77 | Errs []error 78 | } 79 | mock.lockDo.RLock() 80 | calls = mock.calls.Do 81 | mock.lockDo.RUnlock() 82 | return calls 83 | } 84 | 85 | // ResetDoCalls reset all the calls that were made to Do. 86 | func (mock *RepeaterSvcMock) ResetDoCalls() { 87 | mock.lockDo.Lock() 88 | mock.calls.Do = nil 89 | mock.lockDo.Unlock() 90 | } 91 | 92 | // ResetCalls reset all the calls that were made to all mocked methods. 93 | func (mock *RepeaterSvcMock) ResetCalls() { 94 | mock.lockDo.Lock() 95 | mock.calls.Do = nil 96 | mock.lockDo.Unlock() 97 | } 98 | -------------------------------------------------------------------------------- /middleware/mocks/roundtripper.go: -------------------------------------------------------------------------------- 1 | package mocks 2 | 3 | import ( 4 | "net/http" 5 | "sync/atomic" 6 | ) 7 | 8 | // RoundTripper mock to test other middlewares 9 | type RoundTripper struct { 10 | RoundTripFunc func(*http.Request) (*http.Response, error) 11 | calls int32 12 | } 13 | 14 | // RoundTrip adds to calls count and hit user-provided RoundTripFunc 15 | func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 16 | atomic.AddInt32(&r.calls, 1) 17 | return r.RoundTripFunc(req) 18 | } 19 | 20 | // Calls returns how many time RoundTrip func called 21 | func (r *RoundTripper) Calls() int { 22 | return int(atomic.LoadInt32(&r.calls)) 23 | } 24 | 25 | // ResetCalls resets calls counter 26 | func (r *RoundTripper) ResetCalls() { 27 | atomic.StoreInt32(&r.calls, 0) 28 | } 29 | -------------------------------------------------------------------------------- /middleware/repeater.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | ) 9 | 10 | // RepeaterSvc defines repeater interface 11 | type RepeaterSvc interface { 12 | Do(ctx context.Context, fun func() error, errs ...error) (err error) 13 | } 14 | 15 | // Repeater sets middleware with provided RepeaterSvc to retry failed requests 16 | func Repeater(repeater RepeaterSvc, failOnCodes ...int) RoundTripperHandler { 17 | 18 | return func(next http.RoundTripper) http.RoundTripper { 19 | 20 | fn := func(req *http.Request) (*http.Response, error) { 21 | if repeater == nil { 22 | return next.RoundTrip(req) 23 | } 24 | 25 | var resp *http.Response 26 | var err error 27 | e := repeater.Do(req.Context(), func() error { 28 | resp, err = next.RoundTrip(req) 29 | if err != nil { 30 | return err 31 | } 32 | // no explicit codes provided, fail on any 4xx or 5xx 33 | if len(failOnCodes) == 0 && resp.StatusCode >= 400 { 34 | return errors.New(resp.Status) 35 | } 36 | // fail on provided codes only 37 | for _, fc := range failOnCodes { 38 | if resp.StatusCode == fc { 39 | return errors.New(resp.Status) 40 | } 41 | } 42 | return nil 43 | }) 44 | if e != nil { 45 | return nil, fmt.Errorf("repeater: %w", e) 46 | } 47 | return resp, nil 48 | } 49 | return RoundTripperFunc(fn) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /middleware/repeater_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "strings" 11 | "sync/atomic" 12 | "testing" 13 | "time" 14 | 15 | "github.com/stretchr/testify/assert" 16 | "github.com/stretchr/testify/require" 17 | 18 | "github.com/go-pkgz/requester/middleware/mocks" 19 | ) 20 | 21 | func TestRepeater_Passed(t *testing.T) { 22 | var count int32 23 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 24 | if atomic.AddInt32(&count, 1) >= 3 { 25 | resp := &http.Response{StatusCode: 201} 26 | return resp, nil 27 | } 28 | resp := &http.Response{StatusCode: 400, Status: "400 Bad Request"} 29 | return resp, errors.New("blah") 30 | }} 31 | 32 | repeater := &mocks.RepeaterSvcMock{DoFunc: func(ctx context.Context, fun func() error, errs ...error) (err error) { 33 | for i := 0; i < 5; i++ { 34 | if err = fun(); err == nil { 35 | return nil 36 | } 37 | } 38 | return err 39 | }} 40 | 41 | h := Repeater(repeater) 42 | 43 | req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 44 | require.NoError(t, err) 45 | 46 | resp, err := h(rmock).RoundTrip(req) 47 | require.NoError(t, err) 48 | assert.Equal(t, 201, resp.StatusCode) 49 | 50 | assert.Equal(t, 3, rmock.Calls()) 51 | } 52 | 53 | func TestRepeater_Failed(t *testing.T) { 54 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 55 | resp := &http.Response{StatusCode: 400} 56 | return resp, errors.New("http error") 57 | }} 58 | 59 | repeater := &mocks.RepeaterSvcMock{DoFunc: func(ctx context.Context, fun func() error, errs ...error) (err error) { 60 | for i := 0; i < 5; i++ { 61 | if err = fun(); err == nil { 62 | return nil 63 | } 64 | } 65 | return err 66 | }} 67 | 68 | h := Repeater(repeater) 69 | 70 | req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 71 | require.NoError(t, err) 72 | 73 | _, err = h(rmock).RoundTrip(req) 74 | require.EqualError(t, err, "repeater: http error") 75 | 76 | assert.Equal(t, 5, rmock.Calls()) 77 | } 78 | 79 | func TestRepeater_FailedStatus(t *testing.T) { 80 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 81 | resp := &http.Response{StatusCode: 400, Status: "400 Bad Request"} 82 | return resp, nil 83 | }} 84 | 85 | repeater := &mocks.RepeaterSvcMock{DoFunc: func(ctx context.Context, fun func() error, errs ...error) (err error) { 86 | for i := 0; i < 5; i++ { 87 | if err = fun(); err == nil { 88 | return nil 89 | } 90 | } 91 | return err 92 | }} 93 | t.Run("no codes", func(t *testing.T) { 94 | rmock.ResetCalls() 95 | h := Repeater(repeater) 96 | req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 97 | require.NoError(t, err) 98 | 99 | _, err = h(rmock).RoundTrip(req) 100 | require.EqualError(t, err, "repeater: 400 Bad Request") 101 | assert.Equal(t, 5, rmock.Calls()) 102 | }) 103 | 104 | t.Run("with codes", func(t *testing.T) { 105 | rmock.ResetCalls() 106 | h := Repeater(repeater, 300, 400, 401) 107 | req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 108 | require.NoError(t, err) 109 | 110 | _, err = h(rmock).RoundTrip(req) 111 | require.EqualError(t, err, "repeater: 400 Bad Request") 112 | assert.Equal(t, 5, rmock.Calls()) 113 | }) 114 | 115 | t.Run("no codes, no match", func(t *testing.T) { 116 | rmock.ResetCalls() 117 | h := Repeater(repeater, 300, 401) 118 | 119 | req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 120 | require.NoError(t, err) 121 | 122 | resp, err := h(rmock).RoundTrip(req) 123 | require.NoError(t, err) 124 | assert.Equal(t, 400, resp.StatusCode) 125 | assert.Equal(t, 1, rmock.Calls()) 126 | }) 127 | } 128 | 129 | func TestRepeater_EdgeCases(t *testing.T) { 130 | 131 | t.Run("context cancellation", func(t *testing.T) { 132 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 133 | time.Sleep(50 * time.Millisecond) 134 | return &http.Response{StatusCode: 500}, nil 135 | }} 136 | 137 | repeater := &mocks.RepeaterSvcMock{DoFunc: func(ctx context.Context, fun func() error, errs ...error) error { 138 | for i := 0; i < 5; i++ { 139 | select { 140 | case <-ctx.Done(): 141 | return ctx.Err() 142 | default: 143 | if err := fun(); err == nil { 144 | return nil 145 | } 146 | time.Sleep(10 * time.Millisecond) 147 | } 148 | } 149 | return errors.New("max retries exceeded") 150 | }} 151 | 152 | h := Repeater(repeater) 153 | 154 | ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) 155 | defer cancel() 156 | 157 | req, err := http.NewRequestWithContext(ctx, "GET", "http://example.com/blah", http.NoBody) 158 | require.NoError(t, err) 159 | 160 | _, err = h(rmock).RoundTrip(req) 161 | require.Error(t, err) 162 | assert.True(t, errors.Is(err, context.DeadlineExceeded) || 163 | strings.Contains(err.Error(), "context deadline exceeded")) 164 | }) 165 | 166 | t.Run("retry with request body", func(t *testing.T) { 167 | var bodies []string 168 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 169 | if r.Body == nil { 170 | bodies = append(bodies, "") 171 | return &http.Response{StatusCode: 500}, nil 172 | } 173 | body, err := io.ReadAll(r.Body) 174 | require.NoError(t, err) 175 | bodies = append(bodies, string(body)) 176 | // recreate body for next read 177 | r.Body = io.NopCloser(bytes.NewReader(body)) 178 | return &http.Response{StatusCode: 500}, nil 179 | }} 180 | 181 | repeater := &mocks.RepeaterSvcMock{DoFunc: func(ctx context.Context, fun func() error, errs ...error) error { 182 | for i := 0; i < 3; i++ { 183 | if err := fun(); err == nil { 184 | return nil 185 | } 186 | } 187 | return errors.New("max retries") 188 | }} 189 | 190 | h := Repeater(repeater) 191 | 192 | bodyContent := "test body" 193 | req, err := http.NewRequest("POST", "http://example.com/blah", 194 | bytes.NewBufferString(bodyContent)) 195 | require.NoError(t, err) 196 | 197 | _, err = h(rmock).RoundTrip(req) 198 | require.Error(t, err) 199 | assert.Equal(t, 3, len(bodies)) 200 | for _, body := range bodies { 201 | assert.Equal(t, bodyContent, body) 202 | } 203 | }) 204 | 205 | t.Run("status code ranges", func(t *testing.T) { 206 | cases := []struct { 207 | code int 208 | failOnCodes []int 209 | retryCount int 210 | expectError bool 211 | description string 212 | }{ 213 | { 214 | code: 200, failOnCodes: []int{}, 215 | retryCount: 1, expectError: false, 216 | description: "success with no explicit codes", 217 | }, 218 | { 219 | code: 404, failOnCodes: []int{}, 220 | retryCount: 5, expectError: true, 221 | description: "4xx with default fail codes", 222 | }, 223 | { 224 | code: 503, failOnCodes: []int{}, 225 | retryCount: 5, expectError: true, 226 | description: "5xx with default fail codes", 227 | }, 228 | { 229 | code: 404, failOnCodes: []int{503}, 230 | retryCount: 1, expectError: false, 231 | description: "4xx not in explicit codes", 232 | }, 233 | { 234 | code: 503, failOnCodes: []int{503}, 235 | retryCount: 5, expectError: true, 236 | description: "5xx in explicit codes", 237 | }, 238 | } 239 | 240 | for _, tc := range cases { 241 | t.Run(tc.description, func(t *testing.T) { 242 | retryCount := 0 243 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 244 | retryCount++ 245 | return &http.Response{ 246 | StatusCode: tc.code, 247 | Status: fmt.Sprintf("%d Status", tc.code), 248 | }, nil 249 | }} 250 | 251 | repeater := &mocks.RepeaterSvcMock{DoFunc: func(ctx context.Context, fun func() error, errs ...error) error { 252 | var lastErr error 253 | for i := 0; i < 5; i++ { 254 | var err error 255 | if err = fun(); err == nil { 256 | return nil 257 | } 258 | lastErr = err 259 | } 260 | return fmt.Errorf("repeater: %w", lastErr) 261 | }} 262 | 263 | h := Repeater(repeater, tc.failOnCodes...) 264 | req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) 265 | require.NoError(t, err) 266 | 267 | resp, err := h(rmock).RoundTrip(req) 268 | 269 | if tc.expectError { 270 | require.Error(t, err) 271 | assert.Contains(t, err.Error(), fmt.Sprint(tc.code)) 272 | assert.Equal(t, tc.retryCount, retryCount, "unexpected retry count") 273 | } else { 274 | require.NoError(t, err) 275 | require.NotNil(t, resp) 276 | assert.Equal(t, tc.code, resp.StatusCode) 277 | assert.Equal(t, tc.retryCount, retryCount, "unexpected retry count") 278 | } 279 | }) 280 | } 281 | }) 282 | } 283 | -------------------------------------------------------------------------------- /middleware/retry.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | "net/http" 8 | "time" 9 | ) 10 | 11 | // BackoffType represents backoff strategy 12 | type BackoffType int 13 | 14 | const ( 15 | // BackoffConstant is a backoff strategy with constant delay 16 | BackoffConstant BackoffType = iota 17 | // BackoffLinear is a backoff strategy with linear delay 18 | BackoffLinear 19 | // BackoffExponential is a backoff strategy with exponential delay 20 | BackoffExponential 21 | ) 22 | 23 | // RetryMiddleware implements a retry mechanism for http requests with configurable backoff strategies. 24 | // It supports constant, linear and exponential backoff with optional jitter for better load distribution. 25 | // By default retries on network errors and 5xx responses. Can be configured to retry on specific status codes 26 | // or to exclude specific codes from retry. 27 | // 28 | // Default configuration: 29 | // - 3 attempts 30 | // - Initial delay: 100ms 31 | // - Max delay: 30s 32 | // - Exponential backoff 33 | // - 10% jitter 34 | // - Retries on 5xx status codes 35 | type RetryMiddleware struct { 36 | next http.RoundTripper 37 | attempts int 38 | initialDelay time.Duration 39 | maxDelay time.Duration 40 | backoff BackoffType 41 | jitterFactor float64 42 | retryCodes []int 43 | excludeCodes []int 44 | } 45 | 46 | // Retry creates retry middleware with provided options 47 | func Retry(attempts int, initialDelay time.Duration, opts ...RetryOption) RoundTripperHandler { 48 | return func(next http.RoundTripper) http.RoundTripper { 49 | r := &RetryMiddleware{ 50 | next: next, 51 | attempts: attempts, 52 | initialDelay: initialDelay, 53 | maxDelay: 30 * time.Second, 54 | backoff: BackoffExponential, 55 | jitterFactor: 0.1, 56 | } 57 | 58 | for _, opt := range opts { 59 | opt(r) 60 | } 61 | 62 | if len(r.retryCodes) > 0 && len(r.excludeCodes) > 0 { 63 | panic("retry: cannot use both RetryOnCodes and RetryExcludeCodes") 64 | } 65 | 66 | return r 67 | } 68 | } 69 | 70 | // RoundTrip implements http.RoundTripper 71 | func (r *RetryMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { 72 | var lastResponse *http.Response 73 | var lastError error 74 | 75 | for attempt := 0; attempt < r.attempts; attempt++ { 76 | if req.Context().Err() != nil { 77 | return nil, req.Context().Err() 78 | } 79 | 80 | if attempt > 0 { 81 | delay := r.calcDelay(attempt) 82 | select { 83 | case <-req.Context().Done(): 84 | return nil, req.Context().Err() 85 | case <-time.After(delay): 86 | } 87 | } 88 | 89 | resp, err := r.next.RoundTrip(req) 90 | if err != nil { 91 | lastError = err 92 | lastResponse = resp 93 | continue 94 | } 95 | 96 | if !r.shouldRetry(resp) { 97 | return resp, nil 98 | } 99 | 100 | lastResponse = resp 101 | } 102 | 103 | if lastError != nil { 104 | return lastResponse, fmt.Errorf("retry: transport error after %d attempts: %w", r.attempts, lastError) 105 | } 106 | return lastResponse, nil 107 | } 108 | 109 | func (r *RetryMiddleware) calcDelay(attempt int) time.Duration { 110 | if attempt == 0 { 111 | return 0 112 | } 113 | 114 | var delay time.Duration 115 | switch r.backoff { 116 | case BackoffConstant: 117 | delay = r.initialDelay 118 | case BackoffLinear: 119 | delay = r.initialDelay * time.Duration(attempt) 120 | case BackoffExponential: 121 | delay = r.initialDelay * time.Duration(math.Pow(2, float64(attempt-1))) 122 | } 123 | 124 | if delay > r.maxDelay { 125 | delay = r.maxDelay 126 | } 127 | 128 | if r.jitterFactor > 0 { 129 | jitter := float64(delay) * r.jitterFactor 130 | delay = time.Duration(float64(delay) + rand.Float64()*jitter - jitter/2) //nolint:gosec // week randomness is acceptable 131 | } 132 | 133 | return delay 134 | } 135 | 136 | func (r *RetryMiddleware) shouldRetry(resp *http.Response) bool { 137 | if len(r.retryCodes) > 0 { 138 | for _, code := range r.retryCodes { 139 | if resp.StatusCode == code { 140 | return true 141 | } 142 | } 143 | return false 144 | } 145 | 146 | if len(r.excludeCodes) > 0 { 147 | for _, code := range r.excludeCodes { 148 | if resp.StatusCode == code { 149 | return false 150 | } 151 | } 152 | return true 153 | } 154 | 155 | return resp.StatusCode >= 500 156 | } 157 | 158 | // RetryOption represents option type for retry middleware 159 | type RetryOption func(r *RetryMiddleware) 160 | 161 | // RetryMaxDelay sets maximum delay between retries 162 | func RetryMaxDelay(d time.Duration) RetryOption { 163 | return func(r *RetryMiddleware) { 164 | r.maxDelay = d 165 | } 166 | } 167 | 168 | // RetryWithBackoff sets backoff strategy 169 | func RetryWithBackoff(t BackoffType) RetryOption { 170 | return func(r *RetryMiddleware) { 171 | r.backoff = t 172 | } 173 | } 174 | 175 | // RetryWithJitter sets how much randomness to add to delay (0-1.0) 176 | func RetryWithJitter(f float64) RetryOption { 177 | return func(r *RetryMiddleware) { 178 | r.jitterFactor = f 179 | } 180 | } 181 | 182 | // RetryOnCodes sets status codes that should trigger a retry 183 | func RetryOnCodes(codes ...int) RetryOption { 184 | return func(r *RetryMiddleware) { 185 | r.retryCodes = codes 186 | } 187 | } 188 | 189 | // RetryExcludeCodes sets status codes that should not trigger a retry 190 | func RetryExcludeCodes(codes ...int) RetryOption { 191 | return func(r *RetryMiddleware) { 192 | r.excludeCodes = codes 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /middleware/retry_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "sync/atomic" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | 14 | "github.com/go-pkgz/requester/middleware/mocks" 15 | ) 16 | 17 | func TestRetry_BasicBehavior(t *testing.T) { 18 | t.Run("retries on network error", func(t *testing.T) { 19 | var attemptCount int32 20 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 21 | count := atomic.AddInt32(&attemptCount, 1) 22 | if count < 3 { 23 | return nil, errors.New("network error") 24 | } 25 | return &http.Response{StatusCode: 200}, nil 26 | }} 27 | 28 | h := Retry(3, time.Millisecond)(rmock) 29 | req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) 30 | require.NoError(t, err) 31 | 32 | resp, err := h.RoundTrip(req) 33 | require.NoError(t, err) 34 | assert.Equal(t, 200, resp.StatusCode) 35 | assert.Equal(t, int32(3), atomic.LoadInt32(&attemptCount)) 36 | }) 37 | 38 | t.Run("retries on 5xx status by default", func(t *testing.T) { 39 | var attemptCount int32 40 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 41 | count := atomic.AddInt32(&attemptCount, 1) 42 | if count < 3 { 43 | return &http.Response{StatusCode: 503}, nil 44 | } 45 | return &http.Response{StatusCode: 200}, nil 46 | }} 47 | 48 | h := Retry(3, time.Millisecond)(rmock) 49 | req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) 50 | require.NoError(t, err) 51 | 52 | resp, err := h.RoundTrip(req) 53 | require.NoError(t, err) 54 | assert.Equal(t, 200, resp.StatusCode) 55 | assert.Equal(t, int32(3), atomic.LoadInt32(&attemptCount)) 56 | }) 57 | 58 | t.Run("no retry on 4xx by default", func(t *testing.T) { 59 | var attemptCount int32 60 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 61 | atomic.AddInt32(&attemptCount, 1) 62 | return &http.Response{StatusCode: 404}, nil 63 | }} 64 | 65 | h := Retry(3, time.Millisecond)(rmock) 66 | req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) 67 | require.NoError(t, err) 68 | 69 | resp, err := h.RoundTrip(req) 70 | require.NoError(t, err) 71 | assert.Equal(t, 404, resp.StatusCode) 72 | assert.Equal(t, int32(1), atomic.LoadInt32(&attemptCount)) 73 | }) 74 | 75 | t.Run("fails with error after max attempts", func(t *testing.T) { 76 | var attemptCount int32 77 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 78 | atomic.AddInt32(&attemptCount, 1) 79 | return nil, errors.New("persistent network error") 80 | }} 81 | 82 | h := Retry(3, time.Millisecond)(rmock) 83 | req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) 84 | require.NoError(t, err) 85 | 86 | resp, err := h.RoundTrip(req) 87 | assert.Nil(t, resp) 88 | require.Error(t, err) 89 | assert.Contains(t, err.Error(), "retry: transport error after 3 attempts") 90 | assert.Equal(t, int32(3), atomic.LoadInt32(&attemptCount)) 91 | }) 92 | 93 | t.Run("respects request context cancellation", func(t *testing.T) { 94 | var attemptCount int32 95 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 96 | atomic.AddInt32(&attemptCount, 1) 97 | return nil, errors.New("network failure") 98 | }} 99 | 100 | ctx, cancel := context.WithCancel(context.Background()) 101 | req, err := http.NewRequestWithContext(ctx, "GET", "http://example.com/", http.NoBody) 102 | require.NoError(t, err) 103 | 104 | h := Retry(5, 50*time.Millisecond)(rmock) 105 | 106 | // Cancel request after first attempt 107 | time.AfterFunc(20*time.Millisecond, cancel) 108 | 109 | _, err = h.RoundTrip(req) 110 | require.Error(t, err) 111 | assert.Contains(t, err.Error(), "context canceled") 112 | assert.Equal(t, int32(1), atomic.LoadInt32(&attemptCount), "should stop retrying after context cancellation") 113 | }) 114 | } 115 | 116 | func TestRetry_BackoffStrategies(t *testing.T) { 117 | tests := []struct { 118 | name string 119 | backoff BackoffType 120 | minDuration time.Duration 121 | maxDuration time.Duration 122 | }{ 123 | { 124 | name: "constant backoff", 125 | backoff: BackoffConstant, 126 | minDuration: 3 * time.Millisecond, // 1ms * 3 127 | maxDuration: 5 * time.Millisecond, // some buffer for execution time 128 | }, 129 | { 130 | name: "linear backoff", 131 | backoff: BackoffLinear, 132 | minDuration: 6 * time.Millisecond, // 1ms + 2ms + 3ms 133 | maxDuration: 8 * time.Millisecond, 134 | }, 135 | { 136 | name: "exponential backoff", 137 | backoff: BackoffExponential, 138 | minDuration: 7 * time.Millisecond, // 1ms + 2ms + 4ms 139 | maxDuration: 9 * time.Millisecond, 140 | }, 141 | } 142 | 143 | for _, tt := range tests { 144 | t.Run(tt.name, func(t *testing.T) { 145 | var attemptCount int32 146 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 147 | count := atomic.AddInt32(&attemptCount, 1) 148 | if count < 4 { 149 | return &http.Response{StatusCode: 503}, nil 150 | } 151 | return &http.Response{StatusCode: 200}, nil 152 | }} 153 | 154 | start := time.Now() 155 | h := Retry(4, time.Millisecond, RetryWithBackoff(tt.backoff))(rmock) 156 | req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) 157 | require.NoError(t, err) 158 | 159 | resp, err := h.RoundTrip(req) 160 | duration := time.Since(start) 161 | 162 | require.NoError(t, err) 163 | assert.Equal(t, 200, resp.StatusCode) 164 | assert.Equal(t, int32(4), atomic.LoadInt32(&attemptCount)) 165 | assert.GreaterOrEqual(t, duration, tt.minDuration) 166 | assert.LessOrEqual(t, duration, tt.maxDuration) 167 | }) 168 | } 169 | 170 | t.Run("max delay limits backoff", func(t *testing.T) { 171 | var attemptCount int32 172 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 173 | atomic.AddInt32(&attemptCount, 1) 174 | return &http.Response{StatusCode: 503}, nil 175 | }} 176 | 177 | start := time.Now() 178 | h := Retry(3, 10*time.Millisecond, 179 | RetryMaxDelay(15*time.Millisecond), 180 | RetryWithBackoff(BackoffExponential), 181 | )(rmock) 182 | req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) 183 | require.NoError(t, err) 184 | 185 | _, _ = h.RoundTrip(req) 186 | duration := time.Since(start) 187 | 188 | // With exponential backoff and 10ms initial delay, without max delay 189 | // it would be 10ms + 20ms + 40ms = 70ms, but with max delay of 15ms 190 | // it should be 10ms + 15ms + 15ms = 40ms 191 | assert.Less(t, duration, 50*time.Millisecond) 192 | }) 193 | 194 | t.Run("jitter factor affects delay", func(t *testing.T) { 195 | var callTimes []time.Time 196 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 197 | callTimes = append(callTimes, time.Now()) 198 | return &http.Response{StatusCode: 503}, nil 199 | }} 200 | 201 | h := Retry(3, 10*time.Millisecond, 202 | RetryWithJitter(0.5), 203 | RetryWithBackoff(BackoffConstant), 204 | )(rmock) 205 | req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) 206 | require.NoError(t, err) 207 | 208 | _, _ = h.RoundTrip(req) 209 | 210 | require.Greater(t, len(callTimes), 2) 211 | delay1 := callTimes[1].Sub(callTimes[0]) 212 | delay2 := callTimes[2].Sub(callTimes[1]) 213 | // With 0.5 jitter and 10ms delay, delays should be different 214 | assert.NotEqual(t, delay1, delay2) 215 | // But both should be in range 5ms-15ms (10ms ±50%) 216 | assert.Greater(t, delay1, 5*time.Millisecond) 217 | assert.Less(t, delay1, 15*time.Millisecond) 218 | assert.Greater(t, delay2, 5*time.Millisecond) 219 | assert.Less(t, delay2, 15*time.Millisecond) 220 | }) 221 | 222 | t.Run("verifies retry actually introduces delay", func(t *testing.T) { 223 | var attemptCount int32 224 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 225 | count := atomic.AddInt32(&attemptCount, 1) 226 | if count < 4 { 227 | return &http.Response{StatusCode: 503}, nil 228 | } 229 | return &http.Response{StatusCode: 200}, nil 230 | }} 231 | 232 | start := time.Now() 233 | h := Retry(4, 5*time.Millisecond, RetryWithBackoff(BackoffExponential))(rmock) 234 | req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) 235 | require.NoError(t, err) 236 | 237 | resp, err := h.RoundTrip(req) 238 | duration := time.Since(start) 239 | 240 | require.NoError(t, err) 241 | assert.Equal(t, 200, resp.StatusCode) 242 | assert.Equal(t, int32(4), atomic.LoadInt32(&attemptCount)) 243 | 244 | // expected delay: 5ms + 10ms + 20ms = 35ms (exponential backoff) 245 | expectedMin := 30 * time.Millisecond 246 | expectedMax := 40 * time.Millisecond 247 | 248 | assert.Greater(t, duration, expectedMin, "retries should introduce actual delay") 249 | assert.LessOrEqual(t, duration, expectedMax, "delay should not exceed expected range") 250 | }) 251 | } 252 | 253 | func TestRetry_RetryConditions(t *testing.T) { 254 | t.Run("retry specific codes", func(t *testing.T) { 255 | var attemptCount int32 256 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 257 | count := atomic.AddInt32(&attemptCount, 1) 258 | if count < 3 { 259 | return &http.Response{StatusCode: 418}, nil // teapot error 260 | } 261 | return &http.Response{StatusCode: 200}, nil 262 | }} 263 | 264 | h := Retry(3, time.Millisecond, RetryOnCodes(418))(rmock) 265 | req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) 266 | require.NoError(t, err) 267 | 268 | resp, err := h.RoundTrip(req) 269 | require.NoError(t, err) 270 | assert.Equal(t, 200, resp.StatusCode) 271 | assert.Equal(t, int32(3), atomic.LoadInt32(&attemptCount)) 272 | }) 273 | 274 | t.Run("exclude codes from retry", func(t *testing.T) { 275 | var attemptCount int32 276 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 277 | count := atomic.AddInt32(&attemptCount, 1) 278 | if count < 3 { 279 | return &http.Response{StatusCode: 404}, nil 280 | } 281 | return &http.Response{StatusCode: 200}, nil 282 | }} 283 | 284 | h := Retry(3, time.Millisecond, RetryExcludeCodes(503, 404))(rmock) 285 | req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) 286 | require.NoError(t, err) 287 | 288 | resp, err := h.RoundTrip(req) 289 | require.NoError(t, err) 290 | assert.Equal(t, 404, resp.StatusCode) 291 | assert.Equal(t, int32(1), atomic.LoadInt32(&attemptCount)) 292 | }) 293 | 294 | t.Run("cannot use both include and exclude codes", func(t *testing.T) { 295 | assert.Panics(t, func() { 296 | rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 297 | return &http.Response{StatusCode: 200}, nil 298 | }} 299 | _ = Retry(3, time.Millisecond, 300 | RetryOnCodes(503), 301 | RetryExcludeCodes(404), 302 | )(rmock) 303 | }) 304 | }) 305 | } 306 | -------------------------------------------------------------------------------- /requester.go: -------------------------------------------------------------------------------- 1 | // Package requester wraps http.Client with a chain of middleware.RoundTripperHandler. 2 | // Each RoundTripperHandler implements a part of functionality expanding http.Request oar altering 3 | // the flow in some way. Some middlewares set headers, some add logging and caching, some limit concurrency. 4 | // User can provide custom middlewares. 5 | package requester 6 | 7 | import ( 8 | "net/http" 9 | 10 | "github.com/go-pkgz/requester/middleware" 11 | ) 12 | 13 | // Requester provides a wrapper for the standard http.Do request. 14 | type Requester struct { 15 | client http.Client 16 | middlewares []middleware.RoundTripperHandler 17 | } 18 | 19 | // New creates requester with defaults 20 | func New(client http.Client, middlewares ...middleware.RoundTripperHandler) *Requester { 21 | return &Requester{ 22 | client: client, 23 | middlewares: middlewares, 24 | } 25 | } 26 | 27 | // Use adds middleware(s) to the requester chain 28 | func (r *Requester) Use(middlewares ...middleware.RoundTripperHandler) { 29 | r.middlewares = append(r.middlewares, middlewares...) 30 | } 31 | 32 | // With makes a new Requested with inherited middlewares and add passed middleware(s) to the chain 33 | func (r *Requester) With(middlewares ...middleware.RoundTripperHandler) *Requester { 34 | res := &Requester{ 35 | client: r.client, 36 | middlewares: append(r.middlewares, middlewares...), 37 | } 38 | return res 39 | } 40 | 41 | // Client returns http.Client with all middlewares injected 42 | func (r *Requester) Client() *http.Client { 43 | cl := r.client 44 | if cl.Transport == nil { 45 | cl.Transport = http.DefaultTransport 46 | } 47 | for _, handler := range r.middlewares { 48 | cl.Transport = handler(cl.Transport) 49 | } 50 | return &cl 51 | } 52 | 53 | // Do runs http request with optional middleware handlers wrapping the request 54 | func (r *Requester) Do(req *http.Request) (*http.Response, error) { 55 | return r.Client().Do(req) 56 | } 57 | -------------------------------------------------------------------------------- /requester_test.go: -------------------------------------------------------------------------------- 1 | package requester 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "io" 8 | "log" 9 | "math/rand" 10 | "net/http" 11 | "net/http/httptest" 12 | "net/url" 13 | "strings" 14 | "sync/atomic" 15 | "testing" 16 | "time" 17 | 18 | "github.com/stretchr/testify/assert" 19 | "github.com/stretchr/testify/require" 20 | 21 | "github.com/go-pkgz/requester/middleware" 22 | "github.com/go-pkgz/requester/middleware/logger" 23 | "github.com/go-pkgz/requester/middleware/mocks" 24 | ) 25 | 26 | func TestRequester_DoSimpleMiddleware(t *testing.T) { 27 | 28 | mw := func(next http.RoundTripper) http.RoundTripper { 29 | fn := func(req *http.Request) (*http.Response, error) { 30 | req.Header.Set("test", "blah") 31 | return next.RoundTrip(req) 32 | } 33 | return middleware.RoundTripperFunc(fn) 34 | } 35 | 36 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 37 | time.Sleep(100 * time.Millisecond) 38 | assert.Equal(t, "blah", r.Header.Get("test")) 39 | _, err := w.Write([]byte("something")) 40 | require.NoError(t, err) 41 | })) 42 | defer ts.Close() 43 | 44 | rq := New(http.Client{Timeout: 1 * time.Second}, mw) 45 | 46 | req, err := http.NewRequest("GET", ts.URL, http.NoBody) 47 | require.NoError(t, err) 48 | 49 | resp, err := rq.Do(req) 50 | require.NoError(t, err) 51 | assert.Equal(t, 200, resp.StatusCode) 52 | body, err := io.ReadAll(resp.Body) 53 | assert.NoError(t, err) 54 | assert.Equal(t, "something", string(body)) 55 | } 56 | 57 | func TestRequester_DoMiddlewareChain(t *testing.T) { 58 | mw1 := func(next http.RoundTripper) http.RoundTripper { 59 | fn := func(r *http.Request) (*http.Response, error) { 60 | r.Header.Set("test", "blah") 61 | return next.RoundTrip(r) 62 | } 63 | return middleware.RoundTripperFunc(fn) 64 | } 65 | mw2 := func(next http.RoundTripper) http.RoundTripper { 66 | fn := func(r *http.Request) (*http.Response, error) { 67 | r.Header.Set("test2", "blah2") 68 | return next.RoundTrip(r) 69 | } 70 | return middleware.RoundTripperFunc(fn) 71 | } 72 | 73 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 74 | time.Sleep(100 * time.Millisecond) 75 | assert.Equal(t, "blah", r.Header.Get("test")) 76 | assert.Equal(t, "blah2", r.Header.Get("test2")) 77 | _, err := w.Write([]byte("something")) 78 | require.NoError(t, err) 79 | })) 80 | defer ts.Close() 81 | 82 | rq := New(http.Client{Timeout: 1 * time.Second}) 83 | rq.Use(mw1) 84 | rq.Use(mw2) 85 | 86 | req, err := http.NewRequest("GET", ts.URL, http.NoBody) 87 | require.NoError(t, err) 88 | 89 | resp, err := rq.Do(req) 90 | require.NoError(t, err) 91 | assert.Equal(t, 200, resp.StatusCode) 92 | body, err := io.ReadAll(resp.Body) 93 | assert.NoError(t, err) 94 | assert.Equal(t, "something", string(body)) 95 | } 96 | 97 | func TestRequester_With(t *testing.T) { 98 | 99 | mw1 := func(next http.RoundTripper) http.RoundTripper { 100 | fn := func(r *http.Request) (*http.Response, error) { 101 | r.Header.Set("test", "blah") 102 | return next.RoundTrip(r) 103 | } 104 | return middleware.RoundTripperFunc(fn) 105 | } 106 | mw2 := func(next http.RoundTripper) http.RoundTripper { 107 | fn := func(r *http.Request) (*http.Response, error) { 108 | r.Header.Set("test2", "blah2") 109 | return next.RoundTrip(r) 110 | } 111 | return middleware.RoundTripperFunc(fn) 112 | } 113 | 114 | var count int32 115 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 116 | time.Sleep(100 * time.Millisecond) 117 | assert.Equal(t, "blah", r.Header.Get("test")) 118 | if atomic.LoadInt32(&count) == 0 { 119 | assert.Equal(t, "", r.Header.Get("test2")) 120 | } 121 | if atomic.LoadInt32(&count) == 1 { 122 | assert.Equal(t, "blah2", r.Header.Get("test2")) 123 | } 124 | _, err := w.Write([]byte("something")) 125 | require.NoError(t, err) 126 | atomic.AddInt32(&count, 1) 127 | })) 128 | defer ts.Close() 129 | 130 | rq := New(http.Client{Timeout: 1 * time.Second}, mw1) 131 | req, err := http.NewRequest("GET", ts.URL, http.NoBody) 132 | require.NoError(t, err) 133 | resp, err := rq.Do(req) 134 | require.NoError(t, err) 135 | assert.Equal(t, 200, resp.StatusCode) 136 | 137 | rq2 := rq.With(mw2) 138 | req, err = http.NewRequest("GET", ts.URL, http.NoBody) 139 | require.NoError(t, err) 140 | resp, err = rq2.Do(req) 141 | require.NoError(t, err) 142 | assert.Equal(t, 200, resp.StatusCode) 143 | } 144 | 145 | func TestRequester_Client(t *testing.T) { 146 | mw := func(next http.RoundTripper) http.RoundTripper { 147 | fn := func(req *http.Request) (*http.Response, error) { 148 | req.Header.Set("test", "blah") 149 | return next.RoundTrip(req) 150 | } 151 | return middleware.RoundTripperFunc(fn) 152 | } 153 | 154 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 155 | time.Sleep(100 * time.Millisecond) 156 | assert.Equal(t, "blah", r.Header.Get("test")) 157 | _, err := w.Write([]byte("something")) 158 | require.NoError(t, err) 159 | })) 160 | defer ts.Close() 161 | 162 | rq := New(http.Client{Timeout: 1 * time.Second}, mw) 163 | resp, err := rq.Client().Get(ts.URL) 164 | require.NoError(t, err) 165 | assert.Equal(t, 200, resp.StatusCode) 166 | body, err := io.ReadAll(resp.Body) 167 | assert.NoError(t, err) 168 | assert.Equal(t, "something", string(body)) 169 | } 170 | 171 | func TestRequester_CustomMiddleware(t *testing.T) { 172 | 173 | maskHeader := func(next http.RoundTripper) http.RoundTripper { 174 | fn := func(req *http.Request) (*http.Response, error) { 175 | req.Header.Del("deleteme") 176 | return next.RoundTrip(req) 177 | } 178 | return middleware.RoundTripperFunc(fn) 179 | } 180 | 181 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 182 | assert.Equal(t, "application/json", r.Header.Get("Content-Type")) 183 | assert.Equal(t, "blah2", r.Header.Get("do-not-deleteme")) 184 | assert.Equal(t, "", r.Header.Get("deleteme")) 185 | body, err := io.ReadAll(r.Body) 186 | assert.NoError(t, err) 187 | assert.Contains(t, string(body), "request body") 188 | _, err = w.Write([]byte("something")) 189 | require.NoError(t, err) 190 | time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond) // nolint 191 | })) 192 | defer ts.Close() 193 | 194 | rqMasked := New(http.Client{}, logger.New(logger.Std, logger.WithHeaders).Middleware, middleware.JSON, maskHeader) 195 | req, err := http.NewRequest("POST", ts.URL, bytes.NewBufferString("request body")) 196 | require.NoError(t, err) 197 | req.Header.Set("deleteme", "blah1") 198 | req.Header.Set("do-not-deleteme", "blah2") 199 | resp, err := rqMasked.Do(req) 200 | require.NoError(t, err) 201 | assert.Equal(t, 200, resp.StatusCode) 202 | } 203 | 204 | func TestRequester_DoNotReplaceTransport(t *testing.T) { 205 | remoteTS := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { 206 | t.Fatal("remote should not be reached due to redirect") 207 | })) 208 | defer remoteTS.Close() 209 | 210 | // indicates that the request was caught by the test server, 211 | // to which we are redirecting the request 212 | caughtReq := int32(0) 213 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 214 | atomic.AddInt32(&caughtReq, 1) 215 | assert.Equal(t, "value", r.Header.Get("blah")) 216 | _, err := w.Write([]byte("something")) 217 | require.NoError(t, err) 218 | })) 219 | defer ts.Close() 220 | tsURL, err := url.Parse(ts.URL) 221 | require.NoError(t, err) 222 | 223 | redirectingRoundTripper := middleware.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { 224 | r.URL = tsURL 225 | return http.DefaultTransport.RoundTrip(r) 226 | }) 227 | 228 | rq := New(http.Client{Transport: redirectingRoundTripper}, middleware.Header("blah", "value")) 229 | 230 | req, err := http.NewRequest("GET", remoteTS.URL, http.NoBody) 231 | require.NoError(t, err) 232 | resp, err := rq.Do(req) 233 | require.NoError(t, err) 234 | assert.Equal(t, 200, resp.StatusCode) 235 | assert.Greater(t, atomic.LoadInt32(&caughtReq), int32(0)) 236 | 237 | req, err = http.NewRequest("GET", remoteTS.URL, http.NoBody) 238 | require.NoError(t, err) 239 | resp, err = rq.Client().Do(req) 240 | require.NoError(t, err) 241 | assert.Equal(t, 200, resp.StatusCode) 242 | assert.Greater(t, atomic.LoadInt32(&caughtReq), int32(1)) 243 | } 244 | 245 | func TestRequester_TransportHandling(t *testing.T) { 246 | const baseURL = "http://example.com" 247 | 248 | t.Run("custom transport preserved", func(t *testing.T) { 249 | customTransport := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 250 | return &http.Response{StatusCode: 200}, nil 251 | }} 252 | 253 | client := http.Client{Transport: customTransport} 254 | rq := New(client) 255 | 256 | req, err := http.NewRequest(http.MethodGet, baseURL, http.NoBody) 257 | require.NoError(t, err) 258 | resp, err := rq.Do(req) 259 | require.NoError(t, err) 260 | assert.Equal(t, 200, resp.StatusCode) 261 | assert.Equal(t, 1, customTransport.Calls()) 262 | }) 263 | 264 | t.Run("transport reused between calls", func(t *testing.T) { 265 | customTransport := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 266 | return &http.Response{StatusCode: 200}, nil 267 | }} 268 | 269 | client := http.Client{Transport: customTransport} 270 | rq := New(client) 271 | 272 | for i := 0; i < 3; i++ { 273 | req, err := http.NewRequest(http.MethodGet, baseURL, http.NoBody) 274 | require.NoError(t, err) 275 | resp, err := rq.Do(req) 276 | require.NoError(t, err) 277 | assert.Equal(t, 200, resp.StatusCode) 278 | } 279 | assert.Equal(t, 3, customTransport.Calls()) 280 | }) 281 | 282 | t.Run("nil transport uses default", func(t *testing.T) { 283 | client := http.Client{Transport: nil} 284 | rq := New(client) 285 | _, err := http.NewRequest(http.MethodGet, baseURL, http.NoBody) 286 | require.NoError(t, err) 287 | cl := rq.Client() 288 | assert.Equal(t, http.DefaultTransport, cl.Transport) 289 | }) 290 | } 291 | 292 | func TestRequester_MiddlewareHandling(t *testing.T) { 293 | const baseURL = "http://example.com" 294 | 295 | t.Run("chaining with With", func(t *testing.T) { 296 | var calls []string 297 | mw1 := func(next http.RoundTripper) http.RoundTripper { 298 | return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 299 | calls = append(calls, "mw1") 300 | return next.RoundTrip(req) 301 | }) 302 | } 303 | mw2 := func(next http.RoundTripper) http.RoundTripper { 304 | return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 305 | calls = append(calls, "mw2") 306 | return next.RoundTrip(req) 307 | }) 308 | } 309 | 310 | transport := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 311 | return &http.Response{StatusCode: 200}, nil 312 | }} 313 | 314 | base := New(http.Client{Transport: transport}, mw1) 315 | r2 := base.With(mw2) 316 | 317 | req, err := http.NewRequest(http.MethodGet, baseURL, http.NoBody) 318 | require.NoError(t, err) 319 | _, err = r2.Do(req) 320 | require.NoError(t, err) 321 | assert.Equal(t, []string{"mw2", "mw1"}, calls) 322 | }) 323 | 324 | t.Run("nil middleware allowed", func(t *testing.T) { 325 | transport := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 326 | return &http.Response{StatusCode: 200}, nil 327 | }} 328 | 329 | rq := New(http.Client{Transport: transport}) 330 | rq.Use() // empty Use call 331 | rq = rq.With() // empty With call 332 | 333 | req, err := http.NewRequest(http.MethodGet, baseURL, http.NoBody) 334 | require.NoError(t, err) 335 | resp, err := rq.Do(req) 336 | require.NoError(t, err) 337 | assert.Equal(t, 200, resp.StatusCode) 338 | assert.Equal(t, 1, transport.Calls()) 339 | }) 340 | 341 | t.Run("chains kept separate", func(t *testing.T) { 342 | var rq1Calls, rq2Calls []string 343 | mw1 := func(next http.RoundTripper) http.RoundTripper { 344 | return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 345 | rq1Calls = append(rq1Calls, "mw1") 346 | return next.RoundTrip(req) 347 | }) 348 | } 349 | mw2 := func(next http.RoundTripper) http.RoundTripper { 350 | return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 351 | rq2Calls = append(rq2Calls, "mw2") 352 | return next.RoundTrip(req) 353 | }) 354 | } 355 | 356 | transport := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 357 | return &http.Response{StatusCode: 200}, nil 358 | }} 359 | 360 | rq1 := New(http.Client{Transport: transport}) 361 | rq2 := New(http.Client{Transport: transport}) 362 | rq1.Use(mw1) 363 | rq2.Use(mw2) 364 | 365 | req, _ := http.NewRequest(http.MethodGet, baseURL, http.NoBody) 366 | _, _ = rq1.Do(req) 367 | _, _ = rq2.Do(req) 368 | 369 | assert.Equal(t, []string{"mw1"}, rq1Calls) 370 | assert.Equal(t, []string{"mw2"}, rq2Calls) 371 | }) 372 | } 373 | 374 | func TestRequester_ErrorHandling(t *testing.T) { 375 | const baseURL = "http://example.com" 376 | 377 | t.Run("error from middleware", func(t *testing.T) { 378 | expectedErr := errors.New("custom error") 379 | errorMW := func(next http.RoundTripper) http.RoundTripper { 380 | return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 381 | return nil, expectedErr 382 | }) 383 | } 384 | 385 | rq := New(http.Client{}) 386 | rq.Use(errorMW) 387 | 388 | req, err := http.NewRequest(http.MethodGet, baseURL, http.NoBody) 389 | require.NoError(t, err) 390 | _, err = rq.Do(req) 391 | assert.ErrorIs(t, err, expectedErr) 392 | }) 393 | 394 | t.Run("error propagation chain", func(t *testing.T) { 395 | var calls []string 396 | mw1 := func(next http.RoundTripper) http.RoundTripper { 397 | return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { 398 | calls = append(calls, "mw1-before") 399 | resp, err := next.RoundTrip(req) 400 | if err != nil { 401 | calls = append(calls, "mw1-error") 402 | } 403 | return resp, err 404 | }) 405 | } 406 | 407 | expectedErr := errors.New("transport error") 408 | transport := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 409 | return nil, expectedErr 410 | }} 411 | 412 | rq := New(http.Client{Transport: transport}, mw1) 413 | req, _ := http.NewRequest(http.MethodGet, baseURL, http.NoBody) 414 | _, err := rq.Do(req) 415 | require.Error(t, err) 416 | assert.Equal(t, []string{"mw1-before", "mw1-error"}, calls) 417 | assert.ErrorIs(t, err, expectedErr) 418 | }) 419 | } 420 | 421 | func TestRequester_Timeouts(t *testing.T) { 422 | const baseURL = "http://example.com" 423 | 424 | t.Run("client timeout", func(t *testing.T) { 425 | transport := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 426 | select { 427 | case <-r.Context().Done(): 428 | return nil, r.Context().Err() 429 | case <-time.After(100 * time.Millisecond): 430 | return &http.Response{StatusCode: 200}, nil 431 | } 432 | }} 433 | 434 | client := http.Client{ 435 | Transport: transport, 436 | Timeout: 50 * time.Millisecond, 437 | } 438 | rq := New(client) 439 | 440 | req, err := http.NewRequest(http.MethodGet, baseURL, http.NoBody) 441 | require.NoError(t, err) 442 | _, err = rq.Do(req) 443 | require.Error(t, err) 444 | assert.True(t, strings.Contains(err.Error(), "context deadline exceeded") || 445 | strings.Contains(err.Error(), "Client.Timeout")) 446 | }) 447 | 448 | t.Run("request timeout", func(t *testing.T) { 449 | transport := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { 450 | select { 451 | case <-r.Context().Done(): 452 | return nil, r.Context().Err() 453 | case <-time.After(100 * time.Millisecond): 454 | return &http.Response{StatusCode: 200}, nil 455 | } 456 | }} 457 | 458 | rq := New(http.Client{Transport: transport}) 459 | ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) 460 | defer cancel() 461 | 462 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, http.NoBody) 463 | require.NoError(t, err) 464 | _, err = rq.Do(req) 465 | require.Error(t, err) 466 | assert.ErrorIs(t, err, context.DeadlineExceeded) 467 | }) 468 | } 469 | 470 | func ExampleNew() { 471 | // make requester, set JSON headers middleware 472 | rq := New(http.Client{Timeout: 3 * time.Second}, middleware.JSON) 473 | 474 | // add auth header, user agent and a custom X-Auth header middlewared 475 | rq.Use( 476 | middleware.Header("X-Auth", "very-secret-key"), 477 | middleware.Header("User-Agent", "test-requester"), 478 | middleware.BasicAuth("user", "password"), 479 | ) 480 | } 481 | 482 | func ExampleRequester_Do() { 483 | rq := New(http.Client{Timeout: 3 * time.Second}) // make new requester 484 | 485 | // add logger, auth header, user agent and JSON headers 486 | rq.Use( 487 | middleware.Header("X-Auth", "very-secret-key"), 488 | logger.New(logger.Std, logger.Prefix("REST"), logger.WithHeaders).Middleware, // uses std logger 489 | middleware.Header("User-Agent", "test-requester"), 490 | middleware.JSON, 491 | ) 492 | 493 | // create http.Request 494 | req, err := http.NewRequest("GET", "http://example.com", http.NoBody) 495 | if err != nil { 496 | panic(err) 497 | } 498 | 499 | // Send request and get reposnse 500 | resp, err := rq.Do(req) 501 | if err != nil { 502 | panic(err) 503 | } 504 | log.Printf("status: %s", resp.Status) 505 | } 506 | 507 | func ExampleRequester_Client() { 508 | // make new requester with some middlewares 509 | rq := New(http.Client{Timeout: 3 * time.Second}, 510 | middleware.JSON, 511 | middleware.Header("User-Agent", "test-requester"), 512 | middleware.BasicAuth("user", "password"), 513 | middleware.MaxConcurrent(4), 514 | ) 515 | 516 | client := rq.Client() // get http.Client 517 | resp, err := client.Get("http://example.com") 518 | if err != nil { 519 | panic(err) 520 | } 521 | log.Printf("status: %s", resp.Status) 522 | } 523 | 524 | func ExampleRequester_With() { 525 | rq1 := New(http.Client{Timeout: 3 * time.Second}, middleware.JSON) // make a requester with JSON middleware 526 | 527 | // make another requester inherited from rq1 with extra middlewares 528 | rq2 := rq1.With(middleware.BasicAuth("user", "password"), middleware.MaxConcurrent(4)) 529 | 530 | // create http.Request 531 | req, err := http.NewRequest("GET", "http://example.com", http.NoBody) 532 | if err != nil { 533 | panic(err) 534 | } 535 | 536 | // send request with rq1 (JSON headers only) 537 | resp, err := rq1.Do(req) 538 | if err != nil { 539 | panic(err) 540 | } 541 | log.Printf("status1: %s", resp.Status) 542 | 543 | // send request with rq2 (JSON headers, basic auth and limiteted concurrecny) 544 | resp, err = rq2.Do(req) 545 | if err != nil { 546 | panic(err) 547 | } 548 | log.Printf("status2: %s", resp.Status) 549 | } 550 | --------------------------------------------------------------------------------