├── .github └── workflows │ └── go.yml ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── cancelreader.go ├── default.go ├── default_test.go ├── docs ├── default.md └── options.md ├── errors.go ├── errors_test.go ├── go.mod ├── options.go ├── transport.go └── transport_test.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v4 21 | with: 22 | go-version: '1.20' 23 | 24 | - name: Test 25 | run: go test -v ./... 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Test binary, built with `go test -c` 2 | *.test 3 | 4 | # Output of the go coverage tool, specifically when used with LiteIDE 5 | *.out 6 | 7 | # Go workspace file 8 | go.work 9 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## v1.0.1 2 | 3 | Change from using a global random source to instead use the top level functions provided by `math/rand`. This only affects code paths using `DefaultDelayFn` or `CustomizedDelayFn`. Note that this package does not seed the top level random source, so it is left up to this package's consumer if that is desired. 4 | 5 | ## v1.0.0 6 | 7 | Initial tagged version. 8 | 9 | Renamed helper functions for setting context keys to be shorter by removing `OnContext` (for example, `SetShouldRetryFnOnContext` -> `SetShouldRetryFn`). -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Feel free to open an issue if you have questions, comments, concerns, or bugs. I'm happy to review PRs, but please start the conversation in issues before requesting a review. This isn't my full-time job, so I may not respond immediately, but I have gotten value from this package and want to make sure it is the best it can be to help others. 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Justin Ricks 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # retryhttp [![Build Status](https://github.com/justinrixx/retryhttp/actions/workflows/go.yml/badge.svg?branch=main)](https://github.com/justinrixx/retryhttp/actions) [![Go Reference](https://pkg.go.dev/badge/github.com/justinrixx/retryhttp.svg)](https://pkg.go.dev/github.com/justinrixx/retryhttp) 2 | 3 | `retryhttp` allows you to add HTTP retries to your service or application with no refactoring at all, just a few lines of configuration where your client is instantiated. This package's goals are: 4 | 5 | - Make adding retries easy, with no refactor required (as stated above) 6 | - [Provide a good starting point for retry behavior](./docs/default.md) 7 | - [Make customizing retry behavior easy](./docs/options.md) 8 | - [Allow for one-off behavior changes without needing multiple HTTP clients](./docs/options.md#example) 9 | - 100% standard library, with no external dependencies (have a peek at `go.mod`) 10 | 11 | ## How it works 12 | 13 | `retryhttp` exports a `Transport` struct which implements [the standard library's `http.RoundTripper` interface](https://pkg.go.dev/net/http#RoundTripper). By performing retries at the `http.RoundTripper` level, retries can be introduced to a service or script with just a few lines of configuration and no changes required to any of the code actually making HTTP requests. Regardless of which HTTP client you're using, it's very likely to have a configurable `http.RoundTripper`, which means this package can be integrated. 14 | 15 | `http.RoundTripper`s are also highly composable. By default, this package uses `http.DefaultTransport` as its underlying `RoundTripper`, but you may choose to wrap a customized one that sets `MaxIdleConns`, or even [something like this](https://pkg.go.dev/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp#Transport) that captures metric spans to instrument your calls. 16 | 17 | ## Get it 18 | 19 | ```bash 20 | $ go get github.com/justinrixx/retryhttp 21 | ``` 22 | 23 | ## Example 24 | 25 | ```go 26 | // BEFORE 27 | client := http.Client{ 28 | // HTTP client options 29 | } 30 | 31 | // AFTER 32 | client := http.Client{ 33 | Transport: retryhttp.New( 34 | // optional retry configurations 35 | retryhttp.WithShouldRetryFn(func(attempt retryhttp.Attempt) bool { 36 | return attempt.Res != nil && attempt.Res.StatusCode == http.StatusServiceUnavailable 37 | }), 38 | retryhttp.WithDelayFn(func(attempt retryhttp.Attempt) time.Duration { 39 | return expBackoff(attempt.Count) 40 | }), 41 | retryhttp.WithMaxRetries(2), 42 | ), 43 | // other HTTP client options 44 | } 45 | ``` 46 | 47 | This package was inspired by https://github.com/PuerkitoBio/rehttp/ but it aims to take away a couple footguns, provide widely-applicable defaults, and make one-off overriding options easy using context keys. 48 | 49 | ## License 50 | 51 | [MIT](https://github.com/justinrixx/retryhttp/blob/main/LICENSE) 52 | -------------------------------------------------------------------------------- /cancelreader.go: -------------------------------------------------------------------------------- 1 | package retryhttp 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http" 7 | ) 8 | 9 | // If a request's context is canceled before the response's body is read, the byte stream 10 | // will be reclaimed by the runtime. This results in a race against the runtime to read 11 | // the body and often ends in an error. Instead of canceling the context before returning 12 | // a response out, the cancel call is delayed until Close is called on the response body. 13 | // The response body is replaced with this struct to facilitate this. 14 | // This solution is based on https://github.com/go-kit/kit/issues/773. 15 | type cancelReader struct { 16 | io.ReadCloser 17 | 18 | cancel context.CancelFunc 19 | } 20 | 21 | func (cr cancelReader) Close() error { 22 | cr.cancel() 23 | return cr.ReadCloser.Close() 24 | } 25 | 26 | func injectCancelReader(res *http.Response, cancel context.CancelFunc) *http.Response { 27 | if res == nil { 28 | return nil 29 | } 30 | 31 | res.Body = cancelReader{ 32 | ReadCloser: res.Body, 33 | cancel: cancel, 34 | } 35 | return res 36 | } 37 | -------------------------------------------------------------------------------- /default.go: -------------------------------------------------------------------------------- 1 | package retryhttp 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | "net/http" 7 | "strconv" 8 | "time" 9 | ) 10 | 11 | // CustomizedShouldRetryFnOptions are used to tweak the behavior of CustomizedShouldRetryFn. 12 | type CustomizedShouldRetryFnOptions struct { 13 | IdempotentMethods []string 14 | RetryableStatusCodes []int 15 | } 16 | 17 | // CustomizedDelayFnOptions are used to tweak the behavior of [CustomizedDelayFn]. 18 | // Base and Cap are used in calculating exponential backoff: min(base * (2 ** i), cap) 19 | // JitterMagnitude determines the maximum portion of delay specified by Retry-After to 20 | // add or subtract as jitter. 21 | // [DefaultDelayFn] uses base=250ms, cap=10s, jitter magnitude=0.333 22 | type CustomizedDelayFnOptions struct { 23 | Base time.Duration 24 | Cap time.Duration 25 | JitterMagnitude float64 26 | } 27 | 28 | // DefaultShouldRetryFn is a sane default starting point for a should retry policy. 29 | // Not all HTTP requests should be retried. If a request succeeded or failed in a way 30 | // that is not likely to change on retry (is deterministic), a retry is wasteful. 31 | // Idempotency should also be taken into account when retrying: retrying a non-idempotent 32 | // request can result in creating duplicate resources for example. 33 | // DefaultShouldRetryFn's behavior is that: 34 | // - DNS errors never reached the target server, and are therefore safe to retry. 35 | // - If a timeout error occurred and the request is guessed to be idempotent, it is retried. 36 | // - If a 429 status is returned or the Retry-After response header is included it is retried. 37 | // - If the status code is retryable and the request is guessed to be idempotent it is retried. 38 | // 39 | // Default retryablestatus codes are [http.StatusBadGateway] and [http.StatusServiceUnavailable]. 40 | // Idempotency is guessed based on the inclusion of the Idempotency-Key or X-Idempotency-Key 41 | // header, or an idempotent method (as defined in RFC 9110). 42 | var DefaultShouldRetryFn = CustomizedShouldRetryFn(CustomizedShouldRetryFnOptions{ 43 | // https://www.rfc-editor.org/rfc/rfc9110.html#name-idempotent-methods 44 | IdempotentMethods: []string{ 45 | http.MethodGet, 46 | http.MethodHead, 47 | http.MethodOptions, 48 | http.MethodTrace, 49 | http.MethodPut, 50 | http.MethodDelete, 51 | }, 52 | RetryableStatusCodes: []int{http.StatusBadGateway, http.StatusServiceUnavailable}, 53 | }) 54 | 55 | // CustomizedShouldRetryFn has the same logic as [DefaultShouldRetryFn] but it allows for 56 | // specifying which status codes should be assumed retryable and which methods should be 57 | // guessed idempotent. This is useful if the default behavior is desired, with small tweaks. 58 | func CustomizedShouldRetryFn(options CustomizedShouldRetryFnOptions) func(attempt Attempt) bool { 59 | idempotentMethods := map[string]bool{} 60 | retryableStatusCodes := map[int]bool{} 61 | 62 | for _, method := range options.IdempotentMethods { 63 | idempotentMethods[method] = true 64 | } 65 | for _, status := range options.RetryableStatusCodes { 66 | retryableStatusCodes[status] = true 67 | } 68 | 69 | return func(attempt Attempt) bool { 70 | idempotent := guessIdempotent(attempt.Req, idempotentMethods) 71 | 72 | if attempt.Err != nil { 73 | // dns errors are safe to retry 74 | if IsDNSErr(attempt.Err) { 75 | return true 76 | } 77 | 78 | return idempotent && IsTimeoutErr(attempt.Err) 79 | } 80 | 81 | // caller signalling they expect a retry 82 | if attempt.Res.StatusCode == http.StatusTooManyRequests || attempt.Res.Header.Get("Retry-After") != "" { 83 | return true 84 | } 85 | 86 | return idempotent && retryableStatusCodes[attempt.Res.StatusCode] 87 | } 88 | } 89 | 90 | // DefaultDelayFn is a sane default starting point for a delay policy. It respects 91 | // the [Retry-After] response header if present. This header is used by the destination 92 | // service to communicate when the next attempt is appropriate. It can be either 93 | // an integer (specifying the number of seconds to wait) or a timestamp from which 94 | // a duration is calculated. Once a base duration is determined, plus or minus up to 95 | // 1/3 of that value is added as jitter. 96 | // If the Retry-After header is not present, the "[full jitter]" exponential backoff 97 | // algorithm is used with base=250ms and cap=10s. 98 | // Note that the top level functions of math/rand are used to produce random values. 99 | // If determinism is desired, or if determinism is acceptable and when running on 100 | // a version prior to go 1.20, package consumers may wish to call [rand.Seed]. 101 | // 102 | // [Retry-After]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After 103 | // [full jitter]: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ 104 | var DefaultDelayFn = CustomizedDelayFn(CustomizedDelayFnOptions{ 105 | Base: time.Millisecond * 250, 106 | Cap: time.Second * 10, 107 | JitterMagnitude: 0.333, 108 | }) 109 | 110 | // CustomizedDelayFn has the same logic as [DefaultDelayFn] but it allows for specifying 111 | // the exponential backoff's base and maximum, as well as the fraction to calculate 112 | // jitter with. 113 | func CustomizedDelayFn(options CustomizedDelayFnOptions) func(attempt Attempt) time.Duration { 114 | return func(attempt Attempt) time.Duration { 115 | // check for a retry-after header 116 | if attempt.Res != nil && attempt.Res.Header.Get("Retry-After") != "" { 117 | retryAfterStr := attempt.Res.Header.Get("Retry-After") 118 | 119 | // try parsing as an integer 120 | // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#delay-seconds 121 | i, err := strconv.Atoi(retryAfterStr) 122 | if err == nil { 123 | return addJitter(time.Duration(i)*time.Second, options.JitterMagnitude) 124 | } 125 | 126 | // try parsing as date 127 | // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#http-date 128 | t, err := time.Parse(http.TimeFormat, retryAfterStr) 129 | if err == nil { 130 | return addJitter(time.Until(t), options.JitterMagnitude) 131 | } 132 | } 133 | 134 | // fall back to exponential backoff 135 | return expBackoff(attempt.Count, options.Base, options.Cap) 136 | } 137 | } 138 | 139 | // based on "full jitter": https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ 140 | func expBackoff(attempt int, base time.Duration, cap time.Duration) time.Duration { 141 | exp := math.Pow(2, float64(attempt-1)) 142 | v := float64(base) * exp 143 | return time.Duration( 144 | rand.Int63n(int64(math.Min(float64(cap), v))), 145 | ) 146 | } 147 | 148 | // default jitter is plus or minus 1/3 of the duration 149 | func addJitter(d time.Duration, magnitude float64) time.Duration { 150 | f := float64(d) 151 | mj := f * magnitude 152 | 153 | // randomness determines jitter magnitude 154 | j := rand.Float64() * mj 155 | 156 | // randomness determines if jitter is added or subtracted 157 | coin := rand.Float64() 158 | if coin < 0.5 { 159 | return time.Duration(f + j) 160 | } 161 | 162 | return time.Duration(f - j) 163 | } 164 | 165 | func guessIdempotent(req *http.Request, idempotentMethods map[string]bool) bool { 166 | if req.Header.Get("Idempotency-Key") != "" || req.Header.Get("X-Idempotency-Key") != "" { 167 | return true 168 | } 169 | 170 | return idempotentMethods[req.Method] 171 | } 172 | -------------------------------------------------------------------------------- /default_test.go: -------------------------------------------------------------------------------- 1 | package retryhttp_test 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | "testing" 7 | "time" 8 | 9 | "github.com/justinrixx/retryhttp" 10 | ) 11 | 12 | func TestDefaultShouldRetryFn(t *testing.T) { 13 | tests := []struct { 14 | name string 15 | attempt retryhttp.Attempt 16 | want bool 17 | }{ 18 | { 19 | name: "should retry on dns error for a GET request", 20 | attempt: retryhttp.Attempt{ 21 | Count: 1, 22 | Req: &http.Request{ 23 | Method: http.MethodGet, 24 | }, 25 | Err: &net.DNSError{ 26 | IsNotFound: true, 27 | }, 28 | }, 29 | want: true, 30 | }, 31 | { 32 | name: "should retry on dns error for a POST request", 33 | attempt: retryhttp.Attempt{ 34 | Count: 1, 35 | Req: &http.Request{ 36 | Method: http.MethodPost, 37 | }, 38 | Err: &net.DNSError{ 39 | IsNotFound: true, 40 | }, 41 | }, 42 | want: true, 43 | }, 44 | { 45 | name: "should retry idempotent requests that timed out", 46 | attempt: retryhttp.Attempt{ 47 | Count: 1, 48 | Req: &http.Request{ 49 | Method: http.MethodGet, 50 | }, 51 | Err: &net.OpError{ 52 | Err: timeoutErr{}, 53 | }, 54 | }, 55 | want: true, 56 | }, 57 | { 58 | name: "should not retry non-idempotent requests that timed out", 59 | attempt: retryhttp.Attempt{ 60 | Count: 1, 61 | Req: &http.Request{ 62 | Method: http.MethodPost, 63 | }, 64 | Err: &net.OpError{ 65 | Err: timeoutErr{}, 66 | }, 67 | }, 68 | want: false, 69 | }, 70 | { 71 | name: "should recognize requests with idempotency key headers as idempotent", 72 | attempt: retryhttp.Attempt{ 73 | Count: 1, 74 | Req: &http.Request{ 75 | Method: http.MethodPost, 76 | Header: http.Header{"Idempotency-Key": []string{"foobar"}}, 77 | }, 78 | Err: &net.OpError{ 79 | Err: timeoutErr{}, 80 | }, 81 | }, 82 | want: true, 83 | }, 84 | { 85 | name: "should retry on 429 status", 86 | attempt: retryhttp.Attempt{ 87 | Count: 1, 88 | Req: &http.Request{ 89 | Method: http.MethodGet, 90 | }, 91 | Res: &http.Response{ 92 | StatusCode: http.StatusTooManyRequests, 93 | }, 94 | }, 95 | want: true, 96 | }, 97 | { 98 | name: "should retry on 429 status even for non-idempotent methods", 99 | attempt: retryhttp.Attempt{ 100 | Count: 1, 101 | Req: &http.Request{ 102 | Method: http.MethodPost, 103 | }, 104 | Res: &http.Response{ 105 | StatusCode: http.StatusTooManyRequests, 106 | }, 107 | }, 108 | want: true, 109 | }, 110 | { 111 | name: "should retry if retry-after header is present", 112 | attempt: retryhttp.Attempt{ 113 | Count: 1, 114 | Req: &http.Request{ 115 | Method: http.MethodGet, 116 | }, 117 | Res: &http.Response{ 118 | StatusCode: http.StatusInternalServerError, 119 | Header: http.Header{"Retry-After": []string{"3"}}, 120 | }, 121 | }, 122 | want: true, 123 | }, 124 | { 125 | name: "should retry if retry-after header is present even for non-idempotent methods", 126 | attempt: retryhttp.Attempt{ 127 | Count: 1, 128 | Req: &http.Request{ 129 | Method: http.MethodPost, 130 | }, 131 | Res: &http.Response{ 132 | StatusCode: http.StatusInternalServerError, 133 | Header: http.Header{"Retry-After": []string{"3"}}, 134 | }, 135 | }, 136 | want: true, 137 | }, 138 | { 139 | name: "should not retry if status is not retryable even if guessed idempotent", 140 | attempt: retryhttp.Attempt{ 141 | Count: 1, 142 | Req: &http.Request{ 143 | Method: http.MethodGet, 144 | }, 145 | Res: &http.Response{ 146 | StatusCode: http.StatusInternalServerError, 147 | }, 148 | }, 149 | want: false, 150 | }, 151 | { 152 | name: "should not retry if request is guessed non-idempotent, even if status code is retryable", 153 | attempt: retryhttp.Attempt{ 154 | Count: 1, 155 | Req: &http.Request{ 156 | Method: http.MethodPost, 157 | }, 158 | Res: &http.Response{ 159 | StatusCode: http.StatusServiceUnavailable, 160 | }, 161 | }, 162 | want: false, 163 | }, 164 | { 165 | name: "should not retry if request is guessed non-idempotent, or status code is not retryable", 166 | attempt: retryhttp.Attempt{ 167 | Count: 1, 168 | Req: &http.Request{ 169 | Method: http.MethodPost, 170 | }, 171 | Res: &http.Response{ 172 | StatusCode: http.StatusNotFound, 173 | }, 174 | }, 175 | want: false, 176 | }, 177 | } 178 | for _, tt := range tests { 179 | t.Run(tt.name, func(t *testing.T) { 180 | actual := retryhttp.DefaultShouldRetryFn(tt.attempt) 181 | if actual != tt.want { 182 | t.Errorf("actual != expected: got %t, want %t", actual, tt.want) 183 | } 184 | }) 185 | } 186 | } 187 | 188 | func TestDefaultDelayFn(t *testing.T) { 189 | tests := []struct { 190 | name string 191 | retryAfter string 192 | attempt int 193 | wantLow time.Duration 194 | wantHigh time.Duration 195 | }{ 196 | { 197 | name: "should respect retry-after when provided as 1s", 198 | retryAfter: "1", 199 | attempt: 1, 200 | wantLow: time.Millisecond * 666, 201 | wantHigh: time.Millisecond * 1333, 202 | }, 203 | { 204 | name: "should respect retry-after when provided as 2s", 205 | retryAfter: "2", 206 | attempt: 1, 207 | wantLow: time.Millisecond * 1333, 208 | wantHigh: time.Millisecond * 2666, 209 | }, 210 | { 211 | name: "should respect retry-after when provided as 10s", 212 | retryAfter: "10", 213 | attempt: 1, 214 | wantLow: time.Millisecond * 6666, 215 | wantHigh: time.Millisecond * 13333, 216 | }, 217 | { 218 | name: "should respect retry-after when provided as date 2s in the future", 219 | retryAfter: time.Now().UTC().Add(time.Second * 2).Format(http.TimeFormat), 220 | attempt: 1, 221 | wantLow: time.Millisecond * 666, 222 | wantHigh: time.Millisecond * 2666, 223 | }, 224 | { 225 | name: "should respect retry-after when provided as date 10s in the future", 226 | retryAfter: time.Now().UTC().Add(time.Second * 10).Format(http.TimeFormat), 227 | attempt: 1, 228 | wantLow: time.Millisecond * 5555, 229 | wantHigh: time.Millisecond * 13333, 230 | }, 231 | { 232 | name: "should respect retry-after when provided as date 2h in the past", 233 | retryAfter: time.Now().UTC().Add(time.Hour * -2).Format(http.TimeFormat), 234 | attempt: 1, 235 | wantLow: time.Minute * -160, 236 | wantHigh: time.Minute * -80, 237 | }, 238 | // retry-after with non-numeric / non-date value 239 | { 240 | name: "should fall back to exponential backoff when retry-after header is malformed", 241 | retryAfter: "not a date", 242 | attempt: 1, 243 | wantLow: 0, 244 | wantHigh: time.Millisecond * 250, 245 | }, 246 | // exp backoff with varying values 247 | { 248 | name: "should return result consistent with exponential backoff on attempt 1", 249 | attempt: 1, 250 | wantLow: 0, 251 | wantHigh: time.Millisecond * 250, 252 | }, 253 | { 254 | name: "should return result consistent with exponential backoff on attempt 2", 255 | attempt: 2, 256 | wantLow: 0, 257 | wantHigh: time.Millisecond * 500, 258 | }, 259 | { 260 | name: "should return result consistent with exponential backoff on attempt 3", 261 | attempt: 3, 262 | wantLow: 0, 263 | wantHigh: time.Second, 264 | }, 265 | { 266 | name: "should return result consistent with exponential backoff on attempt 4", 267 | attempt: 4, 268 | wantLow: 0, 269 | wantHigh: time.Second * 2, 270 | }, 271 | { 272 | name: "should return result consistent with exponential backoff on attempt 5", 273 | attempt: 5, 274 | wantLow: 0, 275 | wantHigh: time.Second * 4, 276 | }, 277 | { 278 | name: "should return result consistent with exponential backoff on attempt 6", 279 | attempt: 6, 280 | wantLow: 0, 281 | wantHigh: time.Second * 8, 282 | }, 283 | { 284 | name: "exponential backoff should be capped by attempt 7", 285 | attempt: 7, 286 | wantLow: 0, 287 | wantHigh: time.Second * 10, 288 | }, 289 | { 290 | name: "exponential backoff should be capped beyond attempt 7", 291 | attempt: 100, 292 | wantLow: 0, 293 | wantHigh: time.Second * 10, 294 | }, 295 | { 296 | name: "exponential backoff should be capped beyond attempt 7", 297 | attempt: 100, 298 | wantLow: 0, 299 | wantHigh: time.Second * 10, 300 | }, 301 | } 302 | for _, tt := range tests { 303 | t.Run(tt.name, func(t *testing.T) { 304 | res := http.Response{ 305 | Header: http.Header{}, 306 | } 307 | if tt.retryAfter != "" { 308 | res.Header.Set("Retry-After", tt.retryAfter) 309 | } 310 | actual := retryhttp.DefaultDelayFn(retryhttp.Attempt{ 311 | Count: tt.attempt, 312 | Res: &res, 313 | }) 314 | if actual < tt.wantLow { 315 | t.Errorf("actual less than expected range; expected between %s and %s, got %s", tt.wantLow, tt.wantHigh, actual) 316 | } 317 | if actual > tt.wantHigh { 318 | t.Errorf("actual greater than expected range; expected between %s and %s, got %s", tt.wantLow, tt.wantHigh, actual) 319 | } 320 | }) 321 | } 322 | } 323 | -------------------------------------------------------------------------------- /docs/default.md: -------------------------------------------------------------------------------- 1 | # Default behaviors 2 | 3 | The default behaviors of this package are mostly implemented in `DefaultShouldRetryFn` and `DefaultDelayFn`, which you may choose to read for yourself. A high level description of their logic follows. 4 | 5 | ## `DefaultShouldRetryFn` 6 | 7 | - If an error occured (non-nil `attempt.Err`, nil `attempt.Res`), and if that error is a DNS error, the request is retried. This is because it never reached the target server due to failing on the DNS lookup. 8 | - If an error occured and if that error is a common timeout error (see `IsTimeoutErr`), the request is retried only if it is guessed to be idempotent[^1]. 9 | - If no error occured an a non-nil response was returned, the request is retried if the response indicates the server expects a retry[^2]. 10 | - If the request is guessed idempotent[^1] and the status code is 502 or 503, the request is retried 11 | - Otherwise, the request is not retried 12 | 13 | The methods considered idempotent and the status codes considered retryable can be tweaked by using `CustomizedShouldRetryFn` instead. 14 | 15 | ## `DefaultDelayFn` 16 | 17 | - If the `Retry-After` header is provided, a wait duration is derived from its value. This field may be a non-negative integer representing seconds, or a timestamp. Once a duration is obtained, jitter of magnitude up to one third ($\frac{1}{3}$) is added or subtracted from that duration as jitter. 18 | - If no `Retry-After` header is provided, exponential backoff with jitter is used. The algorithm used [is described here](https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/) as "full jitter". The exponential base used is 250ms, and it is capped at 10s. 19 | 20 | The jitter magnitude, exponential base, and exponential backoff cap can be tweaked by using `CustomizedDelayFn` instead. 21 | 22 | [^1]: A request is guessed idempotent if it uses an [idempotent HTTP method](-editor.org/rfc/rfc9110.html#name-idempotent-methods) or includes the `X-Idempotency-Key` or `Idempotency-Key` header. 23 | [^2]: A status code of 429 indicates the server did not process the request and anticipates the caller to retry after some delay. Similarly, the `Retry-After` response header indicates the request should be retried after a delay. -------------------------------------------------------------------------------- /docs/options.md: -------------------------------------------------------------------------------- 1 | # Options 2 | 3 | `retryhttp.Transport` can be customized with several options. In general, each option that can be specified at creation time has an equivalent helper function for overriding the option using the request `Context`. An option set on the `Context` takes precedence over an option set on the `Transport`. 4 | 5 | | Option | Context Equivalent | Default Value | Description | 6 | | ------ | ------------------ | ------------- | ----------- | 7 | | `WithTransport` | none | `http.DefaultTransport` | The internal `http.RoundTripper` to use for requests. | 8 | | `WithShouldRetryFn` | `SetShouldRetryFn` | `DefaultShouldRetryFn` | The `ShouldRetryFn` that determines if a request should be retried. `DefaultShouldRetryFn` is a good starting point. If you're only looking to make minor tweaks, `CustomizedShouldRetryFn` may be appropriate. | 9 | | `WithDelayFn` | `SetDelayFn` | `DefaultDelayFn` | The `DelayFn` that determines how long to delay between retries. If `DefaultDelayFn` doesn't solve your use-case, `CustomizedDelayFn` may be appropriate. | 10 | | `WithMaxRetries` | `SetMaxRetries` | 3 | The maximum number of retries to make. Note that this is the number of _retries_ not _attempts_, so a `MaxRetries` of 3 means up to 4 total attempts: 1 initial attempt and 3 retries. Note also that if your `ShouldRetryFn` returns `false`, a retry will not be made even if `MaxRetries` has not been exhausted. | 11 | | `WithPreventRetryWithBody` | `SetPreventRetryWithBody` | `false` | Whether to prevent retrying requests that have a HTTP body. Any request that has any chance of needing a retry must buffer its body into memory so that it can be replayed in subsequent attempts. This may or may not be appropriate for certain use-cases, which is why this option is provided. | 12 | | `WithAttemptTimeout` | `SetAttemptTimeout` | No timeout | A per-attempt timeout to be used. This differs from an overall timeout in that the timeout is reset for each attempt. Without a per-attempt timeout, the overall timeout could be exhausted in a single attempt with no time left for subsequent retries. Providing `time.Duration(0)` here removes the timeout. | 13 | 14 | ## Example 15 | 16 | ```go 17 | client := http.Client{ 18 | Transport: retryhttp.New( 19 | retryhttp.WithShouldRetryFn(attempt retryhttp.Attempt) bool { 20 | // only retry HTTP 418 statuses 21 | if attempt.Res != nil && attempt.Res.StatusCode == http.StatusTeapot { 22 | return true 23 | } 24 | return false 25 | }, 26 | retryhttp.WithMaxRetries(2), 27 | retryhttp.WithAttemptTimeout(time.Second * 10), 28 | ), 29 | } 30 | 31 | ctx := context.TODO 32 | ctx = retryhttp.SetShouldRetryFn(ctx, func(attempt retryhttp.Attempt) bool { 33 | // retry any error 34 | if attempt.Err != nil { 35 | return true 36 | } 37 | return false 38 | }) 39 | ctx = retryhttp.SetMaxRetries(ctx, 1) // only 1 retry 40 | ctx = retryhttp.SetAttemptTimeout(ctx, 0) // remove attempt timeout 41 | 42 | req, err := http.NewRequest(http.MethodGet, "example.com", nil) 43 | ... 44 | 45 | // add augmented context to the request: retries will abide by the overrides 46 | // instead of the orginial configurations 47 | res, err := client.Do(req.WithContext(ctx)) 48 | ... 49 | ``` 50 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package retryhttp 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | ) 7 | 8 | // IsDNSErr is used to determine if an error from an attempt is due to DNS. Requests that 9 | // failed with a DNS error 10 | func IsDNSErr(err error) bool { 11 | var dnse *net.DNSError 12 | return errors.As(err, &dnse) 13 | } 14 | 15 | // IsTimeoutErr is used to determine if an error from an attempt is due to a common timeout. 16 | // This includes network timeouts or the context deadline being exceeded. 17 | func IsTimeoutErr(err error) bool { 18 | var netErr net.Error 19 | return errors.As(err, &netErr) && netErr.Timeout() 20 | } 21 | -------------------------------------------------------------------------------- /errors_test.go: -------------------------------------------------------------------------------- 1 | package retryhttp_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "testing" 8 | 9 | "github.com/justinrixx/retryhttp" 10 | ) 11 | 12 | type timeoutErr struct{} 13 | 14 | func (t timeoutErr) Error() string { return "timeout error" } 15 | func (t timeoutErr) Timeout() bool { return true } 16 | 17 | func TestIsDNSErr(t *testing.T) { 18 | tests := []struct { 19 | name string 20 | err error 21 | want bool 22 | }{ 23 | { 24 | name: "returns true for a dns timeout error", 25 | err: &net.DNSError{ 26 | IsTimeout: true, 27 | }, 28 | want: true, 29 | }, 30 | { 31 | name: "returns true for a dns not found error", 32 | err: &net.DNSError{ 33 | IsNotFound: true, 34 | }, 35 | want: true, 36 | }, 37 | { 38 | name: "returns true for generic dns error", 39 | err: &net.DNSError{}, 40 | want: true, 41 | }, 42 | { 43 | name: "returns false for non dns error", 44 | err: errors.New("fake error"), 45 | want: false, 46 | }, 47 | } 48 | for _, tt := range tests { 49 | t.Run(tt.name, func(t *testing.T) { 50 | if got := retryhttp.IsDNSErr(tt.err); got != tt.want { 51 | t.Errorf("IsDNSErr() = %v, want %v", got, tt.want) 52 | } 53 | }) 54 | } 55 | } 56 | 57 | func TestIsTimeoutErr(t *testing.T) { 58 | tests := []struct { 59 | name string 60 | err error 61 | want bool 62 | }{ 63 | { 64 | name: "returns true for an error that timed out", 65 | err: &net.OpError{ 66 | Err: timeoutErr{}, 67 | }, 68 | want: true, 69 | }, 70 | { 71 | name: "returns true for a context deadline exceeded error", 72 | err: context.DeadlineExceeded, 73 | want: true, 74 | }, 75 | { 76 | name: "returns false for non-timeout error", 77 | err: errors.New("fake error"), 78 | want: false, 79 | }, 80 | } 81 | for _, tt := range tests { 82 | t.Run(tt.name, func(t *testing.T) { 83 | if got := retryhttp.IsTimeoutErr(tt.err); got != tt.want { 84 | t.Errorf("IsTimeoutErr() = %v, want %v", got, tt.want) 85 | } 86 | }) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/justinrixx/retryhttp 2 | 3 | go 1.20 4 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package retryhttp 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | type ( 10 | maxRetriesContextKeyType string 11 | shouldRetryFnContextKeyType string 12 | delayFnContextKeyType string 13 | preventRetryWithBodyContextKeyType string 14 | attemptTimeoutContextKeyType string 15 | ) 16 | 17 | const ( 18 | maxRetriesContextKey = maxRetriesContextKeyType("maxRetries") 19 | shouldRetryFnContextKey = shouldRetryFnContextKeyType("shouldRetryFn") 20 | delayFnContextKey = delayFnContextKeyType("delayFn") 21 | preventRetryWithBodyContextKey = preventRetryWithBodyContextKeyType("preventRetryWithBody") 22 | attemptTimeoutContextKey = attemptTimeoutContextKeyType("attemptTimeout") 23 | ) 24 | 25 | // WithTransport configures a Transport with an internal roundtripper of its own. 26 | // This is often [http.DefaultTransport], but it could be anything else. 27 | func WithTransport(transport http.RoundTripper) func(*Transport) { 28 | return func(t *Transport) { 29 | t.rt = transport 30 | } 31 | } 32 | 33 | // WithMaxRetries configures the maximum number of retries a Transport is allowed to make. 34 | // If not set, defaults to [DefaultMaxRetries]. Note that this number does not include the 35 | // initial attempt, so if this is configured as 3, there could be up to 4 total attempts. 36 | func WithMaxRetries(maxRetries int) func(*Transport) { 37 | return func(t *Transport) { 38 | t.maxRetries = &maxRetries 39 | } 40 | } 41 | 42 | // WithShouldRetryFn configures the [ShouldRetryFn] callback to use. 43 | func WithShouldRetryFn(shouldRetryFn ShouldRetryFn) func(*Transport) { 44 | return func(t *Transport) { 45 | t.shouldRetryFn = shouldRetryFn 46 | } 47 | } 48 | 49 | // WithDelayFn configures the [DelayFn] callback to use. 50 | func WithDelayFn(delayFn DelayFn) func(*Transport) { 51 | return func(t *Transport) { 52 | t.delayFn = delayFn 53 | } 54 | } 55 | 56 | // WithPreventRetryWithBody configures whether to prevent retries on requests that 57 | // have bodies. This may be desirable because any request that has a chance of 58 | // requiring a retry must have its body buffered into memory by Transport in case 59 | // it needs to be replayed on subsequent attempts. It is up to package consumers 60 | // to determine if and when this behavior is appropriate. 61 | func WithPreventRetryWithBody(preventRetryWithBody bool) func(*Transport) { 62 | return func(t *Transport) { 63 | t.preventRetryWithBody = preventRetryWithBody 64 | } 65 | } 66 | 67 | // WithAttemptTimeout configures a per-attempt timeout to be used in requests. A 68 | // per-attempt timeout differs from an overall timeout in that it applies to and is 69 | // reset in each individual attempt rather than all attempts and delays combined. 70 | // If using an overall timeout along with a per-attempt timeout, the stricter of 71 | // the two takes precedence. 72 | func WithAttemptTimeout(attemptTimeout time.Duration) func(*Transport) { 73 | return func(t *Transport) { 74 | t.attemptTimeout = attemptTimeout 75 | } 76 | } 77 | 78 | // SetMaxRetries can be used to override the settings on a Transport. 79 | // Any request made with the returned context will have its MaxRetries setting 80 | // overridden with the provided value. 81 | func SetMaxRetries(ctx context.Context, maxRetries int) context.Context { 82 | return context.WithValue(ctx, maxRetriesContextKey, maxRetries) 83 | } 84 | 85 | // SetShouldRetryFn can be used to override the settings on a Transport. 86 | // Any request made with the returned context will have its [ShouldRetryFn] overridden with 87 | // the provided value. 88 | func SetShouldRetryFn(ctx context.Context, shouldRetryFn ShouldRetryFn) context.Context { 89 | return context.WithValue(ctx, shouldRetryFnContextKey, shouldRetryFn) 90 | } 91 | 92 | // SetDelayFn can be used to override the settings on a Transport. 93 | // Any request made with the returned context will have its [DelayFn] overridden with 94 | // the provided value. 95 | func SetDelayFn(ctx context.Context, delayFn DelayFn) context.Context { 96 | return context.WithValue(ctx, delayFnContextKey, delayFn) 97 | } 98 | 99 | // SetPreventRetryWithBody can be used to override the settings on a 100 | // Transport. Any request made with the returned context will have its 101 | // PreventRetryWithbody setting overridden with the provided value. 102 | func SetPreventRetryWithBody(ctx context.Context, preventRetryWithBody bool) context.Context { 103 | return context.WithValue(ctx, preventRetryWithBodyContextKey, preventRetryWithBody) 104 | } 105 | 106 | // SetAttemptTimeout can be used to override the settings on a// Transport. 107 | // Any request made with the returned context will have its AttemptTimeout setting 108 | // overridden with the provided value. 109 | func SetAttemptTimeout(ctx context.Context, attemptTimeout time.Duration) context.Context { 110 | return context.WithValue(ctx, attemptTimeoutContextKey, attemptTimeout) 111 | } 112 | 113 | func getMaxRetriesFromContext(ctx context.Context) (int, bool) { 114 | val, ok := ctx.Value(maxRetriesContextKey).(int) 115 | return val, ok 116 | } 117 | 118 | func getShouldRetryFnFromContext(ctx context.Context) (ShouldRetryFn, bool) { 119 | val, ok := ctx.Value(shouldRetryFnContextKey).(ShouldRetryFn) 120 | return val, ok 121 | } 122 | 123 | func getDelayFnFromContext(ctx context.Context) (DelayFn, bool) { 124 | val, ok := ctx.Value(delayFnContextKey).(DelayFn) 125 | return val, ok 126 | } 127 | 128 | func getPreventRetryWithBodyFromContext(ctx context.Context) (bool, bool) { 129 | val, ok := ctx.Value(preventRetryWithBodyContextKey).(bool) 130 | return val, ok 131 | } 132 | 133 | func getAttemptTimeoutFromContext(ctx context.Context) (time.Duration, bool) { 134 | val, ok := ctx.Value(attemptTimeoutContextKey).(time.Duration) 135 | return val, ok 136 | } 137 | -------------------------------------------------------------------------------- /transport.go: -------------------------------------------------------------------------------- 1 | package retryhttp 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | // DefaultMaxRetries is the default maximum retries setting. This can be configured using 15 | // [WithMaxRetries]. 16 | const DefaultMaxRetries = 3 17 | 18 | var ( 19 | // ErrBufferingBody is a sentinel that signals an error before the response was sent. Since 20 | // request body streams can only be consumed once, they must be buffered into memory before 21 | // the first attempt. If an error occurs during that buffering process, it is returned 22 | // in a new error wrapping this sentinel. A caller can identify this case using 23 | // errors.Is(err, ErrBufferingBody). 24 | ErrBufferingBody = errors.New("error buffering body before first attempt") 25 | 26 | // ErrSeekingBody is a sentinel that signals an error preparing for a new attempt by 27 | // rewinding the stream back to the beginning. If an error occurs during that seek, it is 28 | // returned in a new error wrapping this sentinel. A caller can identify this case using 29 | // errors.Is(err, ErrSeekingBody). 30 | ErrSeekingBody = errors.New("error seeking body buffer back to beginning after attempt") 31 | 32 | // ErrRetriesExhausted is a sentinel that signals all retry attempts have been exhausted. 33 | // This error will be joined with the last attempt's error to provide clearer context 34 | // about why the request failed. A caller can identify this case using 35 | // errors.Is(err, ErrRetriesExhausted). 36 | ErrRetriesExhausted = errors.New("max retries exhausted") 37 | ) 38 | 39 | type ( 40 | // Attempt is a collection of information used by [ShouldRetryFn] and [DelayFn] to determine 41 | // if a retry is appropriate and if so how long to delay. 42 | Attempt struct { 43 | // Count represents how many attempts have been made. This includes the initial attempt. 44 | Count int 45 | 46 | // Req is the HTTP request used to make the request. 47 | Req *http.Request 48 | 49 | // Res is the HTTP response returned by the attempt. This may be nil if a non-nil error 50 | // occurred. Note that since the response body is a stream, if you need to inspect it 51 | // you are responsible for buffering it into memory and resetting the stream to be 52 | // returned out of the HTTP round trip. 53 | Res *http.Response 54 | 55 | // Err is an optional error that may have occurred during the HTTP round trip. 56 | Err error 57 | } 58 | 59 | // ShouldRetryFn is a callback type consulted by [Transport] to determine if another attempt 60 | // should be made after the current one. 61 | ShouldRetryFn func(attempt Attempt) bool 62 | 63 | // DelayFn is a callback type consulted by [Transport] to determine how long to wait before 64 | // the next attempt. 65 | DelayFn func(attempt Attempt) time.Duration 66 | 67 | // Transport implements [http.RoundTripper] and can be configured with many options. See 68 | // the documentation for the [New] function. 69 | Transport struct { 70 | rt http.RoundTripper 71 | maxRetries *int // pointer to differentiate between 0 and unset 72 | shouldRetryFn ShouldRetryFn 73 | delayFn DelayFn 74 | preventRetryWithBody bool 75 | attemptTimeout time.Duration 76 | initOnce sync.Once 77 | } 78 | ) 79 | 80 | // New is used to construct a new [Transport], configured with any desired options. 81 | // These options include [WithTransport], [WithMaxRetries], [WithShouldRetryFn], 82 | // [WithDelayFn], and [WithPreventRetryWithBody]. Any number of options may be provided. 83 | // If the same option is provided multiple times, the latest one takes precedence. 84 | func New(options ...func(*Transport)) *Transport { 85 | tr := &Transport{} 86 | 87 | for _, option := range options { 88 | option(tr) 89 | } 90 | 91 | return tr 92 | } 93 | 94 | func (t *Transport) init() { 95 | if t.rt == nil { 96 | t.rt = http.DefaultTransport 97 | } 98 | if t.shouldRetryFn == nil { 99 | t.shouldRetryFn = DefaultShouldRetryFn 100 | } 101 | if t.delayFn == nil { 102 | t.delayFn = DefaultDelayFn 103 | } 104 | 105 | if t.maxRetries == nil { 106 | tmp := DefaultMaxRetries 107 | t.maxRetries = &tmp 108 | } 109 | } 110 | 111 | // RoundTrip performs the actual HTTP round trip for a request. It performs setup 112 | // and retries, but delegates the actual HTTP round trip to [Transport]'s internal 113 | // roundtripper. This is not intended to be called directly, but rather implements 114 | // the [http.RoundTripper] interface so that it can be passed to a [http.Client] as 115 | // its internal Transport. 116 | func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { 117 | t.initOnce.Do(t.init) 118 | 119 | var attemptCount int 120 | ctx := req.Context() 121 | 122 | maxRetries := *t.maxRetries 123 | ctxRetries, set := getMaxRetriesFromContext(ctx) 124 | if set { 125 | maxRetries = ctxRetries 126 | } 127 | 128 | shouldRetryFn := t.shouldRetryFn 129 | ctxShouldRetryFn, set := getShouldRetryFnFromContext(ctx) 130 | if set { 131 | shouldRetryFn = ctxShouldRetryFn 132 | } 133 | 134 | delayFn := t.delayFn 135 | ctxDelayFn, set := getDelayFnFromContext(ctx) 136 | if set { 137 | delayFn = ctxDelayFn 138 | } 139 | 140 | preventRetryWithBody := t.preventRetryWithBody 141 | ctxPreventRetry, set := getPreventRetryWithBodyFromContext(ctx) 142 | if set { 143 | preventRetryWithBody = ctxPreventRetry 144 | } 145 | 146 | preventRetry := req.Body != nil && req.Body != http.NoBody && preventRetryWithBody 147 | 148 | // if body is present, it must be buffered if there is any chance of a retry 149 | // since it can only be consumed once. 150 | var br *bytes.Reader 151 | if req.Body != nil && req.Body != http.NoBody && !preventRetry { 152 | var buf bytes.Buffer 153 | if _, err := io.Copy(&buf, req.Body); err != nil { 154 | req.Body.Close() 155 | return nil, fmt.Errorf("%w: %s", ErrBufferingBody, err) 156 | } 157 | req.Body.Close() 158 | 159 | br = bytes.NewReader(buf.Bytes()) 160 | req.Body = io.NopCloser(br) 161 | } 162 | 163 | attemptTimeout := t.attemptTimeout 164 | ctxAttemptTimeout, set := getAttemptTimeoutFromContext(ctx) 165 | if set { 166 | attemptTimeout = ctxAttemptTimeout 167 | } 168 | 169 | for { 170 | // set per-attempt timeout if needed 171 | var cancel context.CancelFunc = func() {} 172 | reqWithTimeout := req 173 | if attemptTimeout != 0 { 174 | var timeoutCtx context.Context 175 | timeoutCtx, cancel = context.WithTimeout(ctx, attemptTimeout) 176 | reqWithTimeout = req.WithContext(timeoutCtx) 177 | } 178 | 179 | // the actual round trip 180 | res, err := t.rt.RoundTrip(reqWithTimeout) 181 | attemptCount++ 182 | 183 | if preventRetry { 184 | return injectCancelReader(res, cancel), err 185 | } 186 | 187 | if attemptCount-1 >= maxRetries { 188 | if err != nil { 189 | err = errors.Join(ErrRetriesExhausted, err) 190 | } 191 | return injectCancelReader(res, cancel), err 192 | } 193 | 194 | attempt := Attempt{ 195 | Count: attemptCount, 196 | Req: req, 197 | Res: res, 198 | Err: err, 199 | } 200 | 201 | shouldRetry := shouldRetryFn(attempt) 202 | if !shouldRetry { 203 | return injectCancelReader(res, cancel), err 204 | } 205 | 206 | delay := delayFn(attempt) 207 | if br != nil { 208 | if _, serr := br.Seek(0, 0); serr != nil { 209 | return injectCancelReader(res, cancel), fmt.Errorf("%w: %s", ErrSeekingBody, err) 210 | } 211 | reqWithTimeout.Body = io.NopCloser(br) 212 | } 213 | 214 | if res != nil { 215 | _, _ = io.Copy(io.Discard, res.Body) 216 | res.Body.Close() 217 | } 218 | 219 | // going for another attempt, cancel the context of the attempt that was just made 220 | cancel() 221 | 222 | select { 223 | case <-time.After(delay): 224 | // do nothing, just loop again 225 | case <-req.Context().Done(): // happens if the parent context expires 226 | return nil, req.Context().Err() 227 | } 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /transport_test.go: -------------------------------------------------------------------------------- 1 | package retryhttp_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "io" 8 | "net/http" 9 | "net/http/httptest" 10 | "reflect" 11 | "sync" 12 | "testing" 13 | "time" 14 | 15 | "github.com/justinrixx/retryhttp" 16 | ) 17 | 18 | func TestTransport_RoundTrip(t *testing.T) { 19 | type fields struct { 20 | tr *retryhttp.Transport 21 | method string 22 | body io.Reader 23 | ctxFn func(context.Context) context.Context 24 | responseCodes func(int) int 25 | responseBodies func(int) []byte 26 | } 27 | tests := []struct { 28 | name string 29 | fields fields 30 | wantAttemptCount int 31 | wantStatus int 32 | wantErr bool 33 | expReqBody []byte 34 | expResBody []byte 35 | }{ 36 | { 37 | name: "should retry the appropriate number of times with default configurations", 38 | fields: fields{ 39 | tr: retryhttp.New( 40 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 41 | return 0 42 | }), 43 | ), 44 | method: http.MethodGet, 45 | responseCodes: func(_ int) int { 46 | return http.StatusTooManyRequests 47 | }, 48 | }, 49 | wantAttemptCount: 4, 50 | wantStatus: http.StatusTooManyRequests, 51 | }, 52 | { 53 | name: "should not retry on success", 54 | fields: fields{ 55 | tr: retryhttp.New( 56 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 57 | return 0 58 | }), 59 | ), 60 | method: http.MethodGet, 61 | responseCodes: func(_ int) int { 62 | return http.StatusOK 63 | }, 64 | }, 65 | wantAttemptCount: 1, 66 | wantStatus: http.StatusOK, 67 | }, 68 | { 69 | name: "should not retry beyond success", 70 | fields: fields{ 71 | tr: retryhttp.New( 72 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 73 | return 0 74 | }), 75 | ), 76 | method: http.MethodGet, 77 | responseCodes: func(i int) int { 78 | if i > 1 { 79 | return http.StatusOK 80 | } 81 | return http.StatusTooManyRequests 82 | }, 83 | }, 84 | wantAttemptCount: 3, 85 | wantStatus: http.StatusOK, 86 | }, 87 | { 88 | name: "should respect custom ShouldRetryFn", 89 | fields: fields{ 90 | tr: retryhttp.New( 91 | retryhttp.WithShouldRetryFn(func(attempt retryhttp.Attempt) bool { 92 | return attempt.Res != nil && attempt.Res.StatusCode == http.StatusTeapot 93 | }), 94 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 95 | return 0 96 | }), 97 | ), 98 | method: http.MethodGet, 99 | responseCodes: func(i int) int { 100 | return http.StatusTeapot 101 | }, 102 | }, 103 | wantAttemptCount: 4, 104 | wantStatus: http.StatusTeapot, 105 | }, 106 | { 107 | name: "should respect custom MaxRetries", 108 | fields: fields{ 109 | tr: retryhttp.New( 110 | retryhttp.WithMaxRetries(2), 111 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 112 | return 0 113 | }), 114 | ), 115 | method: http.MethodGet, 116 | responseCodes: func(_ int) int { 117 | return http.StatusTooManyRequests 118 | }, 119 | }, 120 | wantAttemptCount: 3, 121 | wantStatus: http.StatusTooManyRequests, 122 | }, 123 | { 124 | name: "should retry requests with bodies", 125 | fields: fields{ 126 | tr: retryhttp.New( 127 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 128 | return 0 129 | }), 130 | ), 131 | method: http.MethodPost, 132 | body: bytes.NewReader([]byte(`this is the request body`)), 133 | responseCodes: func(i int) int { 134 | if i < 2 { 135 | return http.StatusTooManyRequests 136 | } 137 | return http.StatusOK 138 | }, 139 | responseBodies: func(i int) []byte { 140 | if i < 2 { 141 | return nil 142 | } 143 | return []byte(`foo bar baz it's all ok`) 144 | }, 145 | }, 146 | wantAttemptCount: 3, 147 | wantStatus: http.StatusOK, 148 | expReqBody: []byte(`this is the request body`), 149 | expResBody: []byte(`foo bar baz it's all ok`), 150 | }, 151 | { 152 | name: "should prevent retry of requests with bodies when enabled", 153 | fields: fields{ 154 | tr: retryhttp.New( 155 | retryhttp.WithPreventRetryWithBody(true), 156 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 157 | return 0 158 | }), 159 | ), 160 | method: http.MethodPost, 161 | body: bytes.NewReader([]byte(`this is the request body`)), 162 | responseCodes: func(i int) int { 163 | return http.StatusTooManyRequests 164 | }, 165 | }, 166 | wantAttemptCount: 1, 167 | wantStatus: http.StatusTooManyRequests, 168 | expReqBody: []byte(`this is the request body`), 169 | }, 170 | { 171 | name: "should respect MaxRetries context key override", 172 | fields: fields{ 173 | tr: retryhttp.New( 174 | retryhttp.WithMaxRetries(0), // transport says 0 retries 175 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 176 | return 0 177 | }), 178 | ), 179 | method: http.MethodGet, 180 | ctxFn: func(ctx context.Context) context.Context { // context overrides retry count 181 | return retryhttp.SetMaxRetries(ctx, 3) 182 | }, 183 | responseCodes: func(_ int) int { 184 | return http.StatusTooManyRequests 185 | }, 186 | }, 187 | wantAttemptCount: 4, 188 | wantStatus: http.StatusTooManyRequests, 189 | }, 190 | { 191 | name: "should respect ShouldRetryFn context key override", 192 | fields: fields{ 193 | tr: retryhttp.New( 194 | retryhttp.WithShouldRetryFn(func(_ retryhttp.Attempt) bool { 195 | return true 196 | }), 197 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 198 | return 0 199 | }), 200 | ), 201 | method: http.MethodGet, 202 | ctxFn: func(ctx context.Context) context.Context { 203 | return retryhttp.SetShouldRetryFn(ctx, func(_ retryhttp.Attempt) bool { 204 | return false 205 | }) 206 | }, 207 | responseCodes: func(_ int) int { 208 | return http.StatusOK 209 | }, 210 | }, 211 | wantAttemptCount: 1, 212 | wantStatus: http.StatusOK, 213 | }, 214 | { 215 | name: "should respect prevent retry with body context key override", 216 | fields: fields{ 217 | tr: retryhttp.New( 218 | retryhttp.WithPreventRetryWithBody(true), 219 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 220 | return 0 221 | }), 222 | ), 223 | method: http.MethodPost, 224 | body: bytes.NewReader([]byte(`this is the request body`)), 225 | ctxFn: func(ctx context.Context) context.Context { 226 | return retryhttp.SetPreventRetryWithBody(ctx, false) 227 | }, 228 | responseCodes: func(i int) int { 229 | if i == 0 { 230 | return http.StatusTooManyRequests 231 | } 232 | return http.StatusOK 233 | }, 234 | }, 235 | wantAttemptCount: 2, 236 | wantStatus: http.StatusOK, 237 | expReqBody: []byte(`this is the request body`), 238 | }, 239 | { 240 | name: "should retry the appropriate number of times with default DelayFn", 241 | fields: fields{ 242 | // actually sleeps which makes tests longer 243 | tr: retryhttp.New( 244 | retryhttp.WithMaxRetries(1), 245 | ), 246 | method: http.MethodGet, 247 | responseCodes: func(_ int) int { 248 | return http.StatusTooManyRequests 249 | }, 250 | }, 251 | wantAttemptCount: 2, 252 | wantStatus: http.StatusTooManyRequests, 253 | }, 254 | { 255 | name: "should retry the appropriate number of times with delayFn overridden to 0", 256 | fields: fields{ 257 | tr: retryhttp.New(), 258 | method: http.MethodGet, 259 | ctxFn: func(ctx context.Context) context.Context { 260 | return retryhttp.SetDelayFn(ctx, func(_ retryhttp.Attempt) time.Duration { 261 | return 0 262 | }) 263 | }, 264 | responseCodes: func(_ int) int { 265 | return http.StatusTooManyRequests 266 | }, 267 | }, 268 | wantAttemptCount: 4, 269 | wantStatus: http.StatusTooManyRequests, 270 | }, 271 | } 272 | for _, tt := range tests { 273 | t.Run(tt.name, func(t *testing.T) { 274 | attemptCount := 0 275 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 276 | if tt.expReqBody != nil { 277 | body, err := io.ReadAll(r.Body) 278 | if err != nil { 279 | t.Errorf("error reading request body stream: %s", err) 280 | } 281 | r.Body.Close() 282 | 283 | if !reflect.DeepEqual(body, tt.expReqBody) { 284 | t.Errorf("request body does not match expected. got %s, want %s", string(body), string(tt.expReqBody)) 285 | } 286 | } 287 | 288 | w.WriteHeader(tt.fields.responseCodes(attemptCount)) 289 | if tt.fields.responseBodies != nil { 290 | b := tt.fields.responseBodies(attemptCount) 291 | if b != nil { 292 | w.Write(b) 293 | } 294 | } 295 | 296 | attemptCount++ 297 | })) 298 | defer ts.Close() 299 | 300 | client := http.Client{ 301 | Transport: tt.fields.tr, 302 | } 303 | 304 | req, err := http.NewRequest(tt.fields.method, ts.URL, tt.fields.body) 305 | if err != nil { 306 | t.Errorf("error creating request: %s", err) 307 | } 308 | 309 | ctx := context.Background() 310 | if tt.fields.ctxFn != nil { 311 | ctx = tt.fields.ctxFn(ctx) 312 | } 313 | 314 | res, err := client.Do(req.WithContext(ctx)) 315 | if (err != nil) != tt.wantErr { 316 | t.Errorf("Transport.RoundTrip() error = %v, wantErr %v", err, tt.wantErr) 317 | return 318 | } 319 | if attemptCount != tt.wantAttemptCount { 320 | t.Errorf("unexpected attempt count: got %d, want %d", attemptCount, tt.wantAttemptCount) 321 | } 322 | 323 | if tt.wantStatus > 0 { 324 | if res == nil { 325 | t.Errorf("unexpected status: want %d, got nil response", tt.wantStatus) 326 | } else if res.StatusCode != tt.wantStatus { 327 | t.Errorf("unexpected status: got %d, want %d", res.StatusCode, tt.wantStatus) 328 | } 329 | } 330 | if tt.expResBody != nil { 331 | body, err := io.ReadAll(res.Body) 332 | if err != nil { 333 | t.Errorf("unexpected error reading response body: %s", err) 334 | } 335 | res.Body.Close() 336 | if !reflect.DeepEqual(body, tt.expResBody) { 337 | t.Errorf("unexpected response body: got %s, want %s", string(body), string(tt.expResBody)) 338 | } 339 | } 340 | }) 341 | } 342 | } 343 | 344 | func TestRetriesExhaustedError(t *testing.T) { 345 | // Create a server that closes connections to trigger errors 346 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 347 | // Close connection immediately to trigger EOF 348 | hj, ok := w.(http.Hijacker) 349 | if !ok { 350 | t.Fatal("server doesn't support hijacking") 351 | } 352 | conn, _, err := hj.Hijack() 353 | if err != nil { 354 | t.Fatal(err) 355 | } 356 | conn.Close() 357 | })) 358 | defer ts.Close() 359 | 360 | t.Run("with wrap error enabled", func(t *testing.T) { 361 | tr := retryhttp.New( 362 | retryhttp.WithMaxRetries(2), 363 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 364 | return 0 365 | }), 366 | retryhttp.WithShouldRetryFn(func(attempt retryhttp.Attempt) bool { 367 | return attempt.Err != nil // Ensures EOF errors are retried 368 | }), 369 | ) 370 | 371 | client := http.Client{ 372 | Transport: tr, 373 | } 374 | 375 | req, err := http.NewRequest(http.MethodGet, ts.URL, nil) 376 | if err != nil { 377 | t.Fatalf("error creating request: %s", err) 378 | } 379 | 380 | _, err = client.Do(req) 381 | if err == nil { 382 | t.Fatal("expected error but got nil") 383 | } 384 | 385 | // Check that the error is wrapped with ErrRetriesExhausted 386 | if !errors.Is(err, retryhttp.ErrRetriesExhausted) { 387 | t.Errorf("expected error to be wrapped with ErrRetriesExhausted, got: %v", err) 388 | } 389 | }) 390 | } 391 | 392 | func TestPerAttemptTimeout(t *testing.T) { 393 | mu := sync.Mutex{} 394 | attemptCount := 0 395 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 396 | mu.Lock() 397 | attemptCount++ 398 | mu.Unlock() 399 | time.Sleep(time.Millisecond * 100) // sleep longer than the timeout 400 | w.WriteHeader(http.StatusTooManyRequests) 401 | })) 402 | defer ts.Close() 403 | 404 | tr := retryhttp.New( 405 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 406 | return 0 407 | }), 408 | retryhttp.WithAttemptTimeout(time.Millisecond*10), // really short timeout 409 | ) 410 | 411 | client := http.Client{ 412 | Transport: tr, 413 | } 414 | 415 | req, err := http.NewRequest(http.MethodGet, ts.URL, nil) 416 | if err != nil { 417 | t.Fatalf("error creating request: %s", err) 418 | } 419 | 420 | _, err = client.Do(req) 421 | mu.Lock() 422 | if attemptCount != 4 { 423 | t.Fatalf("attempt count does not match expected; got %d, want %d", attemptCount, 4) 424 | } 425 | if err == nil { 426 | t.Fatal("expected error from request but got nil") 427 | } 428 | 429 | attemptCount = 0 430 | mu.Unlock() 431 | 432 | // override per-attempt timeout with context 433 | ctx := retryhttp.SetAttemptTimeout(context.Background(), 0) 434 | res, err := client.Do(req.WithContext(ctx)) 435 | mu.Lock() 436 | if attemptCount != 4 { 437 | t.Fatalf("attempt count does not match expected; got %d, want %d", attemptCount, 4) 438 | } 439 | if err != nil { 440 | t.Fatalf("expected nil error but got %s", err) 441 | } 442 | if res.StatusCode != http.StatusTooManyRequests { 443 | t.Fatalf("unexpected status code; got %d, want %d", res.StatusCode, http.StatusTooManyRequests) 444 | } 445 | 446 | attemptCount = 0 447 | mu.Unlock() 448 | } 449 | 450 | // TODO test parent context expiring 451 | func TestParentContextDeadline(t *testing.T) { 452 | mu := sync.Mutex{} 453 | attemptCount := 0 454 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 455 | mu.Lock() 456 | attemptCount++ 457 | mu.Unlock() 458 | time.Sleep(time.Millisecond * 100) // sleep longer than the deadline 459 | w.WriteHeader(http.StatusTooManyRequests) 460 | })) 461 | defer ts.Close() 462 | 463 | tr := retryhttp.New( 464 | retryhttp.WithDelayFn(func(_ retryhttp.Attempt) time.Duration { 465 | return 0 466 | }), 467 | ) 468 | 469 | client := http.Client{ 470 | Transport: tr, 471 | } 472 | 473 | req, err := http.NewRequest(http.MethodGet, ts.URL, nil) 474 | if err != nil { 475 | t.Fatalf("error creating request: %s", err) 476 | } 477 | 478 | // really short deadline 479 | ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*10)) 480 | defer cancel() 481 | _, err = client.Do(req.WithContext(ctx)) 482 | mu.Lock() 483 | if attemptCount != 1 { 484 | t.Fatalf("attempt count does not match expected; got %d, want %d", attemptCount, 1) 485 | } 486 | if err == nil { 487 | t.Fatal("expected error from request but got nil") 488 | } 489 | } 490 | --------------------------------------------------------------------------------