├── .github ├── dependabot.yml └── workflows │ ├── actionlint.yml │ ├── pr-gofmt.yaml │ └── pr-unit-tests.yaml ├── .gitignore ├── .go-version ├── .golangci.yml ├── CHANGELOG.md ├── CODEOWNERS ├── LICENSE ├── Makefile ├── README.md ├── cert_error_go119.go ├── cert_error_go120.go ├── client.go ├── client_test.go ├── go.mod ├── go.sum ├── roundtripper.go └── roundtripper_test.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "daily" 8 | 9 | - package-ecosystem: "gomod" 10 | directory: "/" 11 | schedule: 12 | interval: "weekly" -------------------------------------------------------------------------------- /.github/workflows/actionlint.yml: -------------------------------------------------------------------------------- 1 | name: actionlint 2 | 3 | on: 4 | push: 5 | paths: 6 | - .github/** 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | actionlint: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 16 | - name: "Check GitHub workflow files" 17 | uses: docker://docker.mirror.hashicorp.services/rhysd/actionlint:latest 18 | with: 19 | args: -color -------------------------------------------------------------------------------- /.github/workflows/pr-gofmt.yaml: -------------------------------------------------------------------------------- 1 | name: Go format check 2 | on: 3 | pull_request: 4 | types: ['opened', 'synchronize'] 5 | 6 | jobs: 7 | run-tests: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 11 | 12 | - uses: actions/setup-go@cdcb36043654635271a94b9a6d1392de5bb323a7 # v5.0.1 13 | with: 14 | go-version-file: ./.go-version 15 | 16 | - name: Run go format 17 | run: |- 18 | files=$(gofmt -s -l .) 19 | if [ -n "$files" ]; then 20 | echo >&2 "The following file(s) are not gofmt compliant:" 21 | echo >&2 "$files" 22 | exit 1 23 | fi 24 | -------------------------------------------------------------------------------- /.github/workflows/pr-unit-tests.yaml: -------------------------------------------------------------------------------- 1 | name: Unit tests 2 | on: 3 | pull_request: 4 | types: ['opened', 'synchronize'] 5 | 6 | jobs: 7 | run-tests: 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | go-version: ['1.23', '1.22'] 12 | steps: 13 | - name: Checkout Code 14 | uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 15 | 16 | - name: Setup Go 17 | uses: actions/setup-go@cdcb36043654635271a94b9a6d1392de5bb323a7 # v5.0.1 18 | with: 19 | go-version: ${{matrix.go-version}} 20 | 21 | - name: Run golangci-lint 22 | uses: golangci/golangci-lint-action@08e2f20817b15149a52b5b3ebe7de50aff2ba8c5 23 | 24 | - name: Run unit tests and generate coverage report 25 | run: make test 26 | 27 | - name: Upload coverage report 28 | uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 29 | with: 30 | path: coverage.out 31 | name: Coverage-report-${{matrix.go-version}} 32 | 33 | - name: Display coverage test 34 | run: go tool cover -func=coverage.out 35 | 36 | - name: Build Go 37 | run: go build ./... 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.iml 3 | *.test 4 | .vscode/ -------------------------------------------------------------------------------- /.go-version: -------------------------------------------------------------------------------- 1 | 1.23 2 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) HashiCorp, Inc. 2 | # SPDX-License-Identifier: MPL-2.0 3 | 4 | linters: 5 | disable-all: true 6 | enable: 7 | - errcheck 8 | - staticcheck 9 | - gosimple 10 | - govet 11 | output_format: colored-line-number 12 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 0.7.7 (May 30, 2024) 2 | 3 | BUG FIXES: 4 | 5 | - client: avoid potentially leaking URL-embedded basic authentication credentials in logs (#158) 6 | 7 | ## 0.7.6 (May 9, 2024) 8 | 9 | ENHANCEMENTS: 10 | 11 | - client: support a `RetryPrepare` function for modifying the request before retrying (#216) 12 | - client: support HTTP-date values for `Retry-After` header value (#138) 13 | - client: avoid reading entire body when the body is a `*bytes.Reader` (#197) 14 | 15 | BUG FIXES: 16 | 17 | - client: fix a broken check for invalid server certificate in go 1.20+ (#210) 18 | 19 | ## 0.7.5 (Nov 8, 2023) 20 | 21 | BUG FIXES: 22 | 23 | - client: fixes an issue where the request body is not preserved on temporary redirects or re-established HTTP/2 connections (#207) 24 | 25 | ## 0.7.4 (Jun 6, 2023) 26 | 27 | BUG FIXES: 28 | 29 | - client: fixing an issue where the Content-Type header wouldn't be sent with an empty payload when using HTTP/2 (#194) 30 | 31 | ## 0.7.3 (May 15, 2023) 32 | 33 | Initial release 34 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Each line is a file pattern followed by one or more owners. 2 | # More on CODEOWNERS files: https://help.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners 3 | 4 | # Default owner 5 | * @hashicorp/team-ip-compliance @hashicorp/go-retryablehttp-maintainers 6 | 7 | # Add override rules below. Each line is a file/folder pattern followed by one or more owners. 8 | # Being an owner means those groups or individuals will be added as reviewers to PRs affecting 9 | # those areas of the code. 10 | # Examples: 11 | # /docs/ @docs-team 12 | # *.js @js-team 13 | # *.go @go-team 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 HashiCorp, Inc. 2 | 3 | Mozilla Public License, version 2.0 4 | 5 | 1. Definitions 6 | 7 | 1.1. "Contributor" 8 | 9 | means each individual or legal entity that creates, contributes to the 10 | creation of, or owns Covered Software. 11 | 12 | 1.2. "Contributor Version" 13 | 14 | means the combination of the Contributions of others (if any) used by a 15 | Contributor and that particular Contributor's Contribution. 16 | 17 | 1.3. "Contribution" 18 | 19 | means Covered Software of a particular Contributor. 20 | 21 | 1.4. "Covered Software" 22 | 23 | means Source Code Form to which the initial Contributor has attached the 24 | notice in Exhibit A, the Executable Form of such Source Code Form, and 25 | Modifications of such Source Code Form, in each case including portions 26 | thereof. 27 | 28 | 1.5. "Incompatible With Secondary Licenses" 29 | means 30 | 31 | a. that the initial Contributor has attached the notice described in 32 | Exhibit B to the Covered Software; or 33 | 34 | b. that the Covered Software was made available under the terms of 35 | version 1.1 or earlier of the License, but not also under the terms of 36 | a Secondary License. 37 | 38 | 1.6. "Executable Form" 39 | 40 | means any form of the work other than Source Code Form. 41 | 42 | 1.7. "Larger Work" 43 | 44 | means a work that combines Covered Software with other material, in a 45 | separate file or files, that is not Covered Software. 46 | 47 | 1.8. "License" 48 | 49 | means this document. 50 | 51 | 1.9. "Licensable" 52 | 53 | means having the right to grant, to the maximum extent possible, whether 54 | at the time of the initial grant or subsequently, any and all of the 55 | rights conveyed by this License. 56 | 57 | 1.10. "Modifications" 58 | 59 | means any of the following: 60 | 61 | a. any file in Source Code Form that results from an addition to, 62 | deletion from, or modification of the contents of Covered Software; or 63 | 64 | b. any new file in Source Code Form that contains any Covered Software. 65 | 66 | 1.11. "Patent Claims" of a Contributor 67 | 68 | means any patent claim(s), including without limitation, method, 69 | process, and apparatus claims, in any patent Licensable by such 70 | Contributor that would be infringed, but for the grant of the License, 71 | by the making, using, selling, offering for sale, having made, import, 72 | or transfer of either its Contributions or its Contributor Version. 73 | 74 | 1.12. "Secondary License" 75 | 76 | means either the GNU General Public License, Version 2.0, the GNU Lesser 77 | General Public License, Version 2.1, the GNU Affero General Public 78 | License, Version 3.0, or any later versions of those licenses. 79 | 80 | 1.13. "Source Code Form" 81 | 82 | means the form of the work preferred for making modifications. 83 | 84 | 1.14. "You" (or "Your") 85 | 86 | means an individual or a legal entity exercising rights under this 87 | License. For legal entities, "You" includes any entity that controls, is 88 | controlled by, or is under common control with You. For purposes of this 89 | definition, "control" means (a) the power, direct or indirect, to cause 90 | the direction or management of such entity, whether by contract or 91 | otherwise, or (b) ownership of more than fifty percent (50%) of the 92 | outstanding shares or beneficial ownership of such entity. 93 | 94 | 95 | 2. License Grants and Conditions 96 | 97 | 2.1. Grants 98 | 99 | Each Contributor hereby grants You a world-wide, royalty-free, 100 | non-exclusive license: 101 | 102 | a. under intellectual property rights (other than patent or trademark) 103 | Licensable by such Contributor to use, reproduce, make available, 104 | modify, display, perform, distribute, and otherwise exploit its 105 | Contributions, either on an unmodified basis, with Modifications, or 106 | as part of a Larger Work; and 107 | 108 | b. under Patent Claims of such Contributor to make, use, sell, offer for 109 | sale, have made, import, and otherwise transfer either its 110 | Contributions or its Contributor Version. 111 | 112 | 2.2. Effective Date 113 | 114 | The licenses granted in Section 2.1 with respect to any Contribution 115 | become effective for each Contribution on the date the Contributor first 116 | distributes such Contribution. 117 | 118 | 2.3. Limitations on Grant Scope 119 | 120 | The licenses granted in this Section 2 are the only rights granted under 121 | this License. No additional rights or licenses will be implied from the 122 | distribution or licensing of Covered Software under this License. 123 | Notwithstanding Section 2.1(b) above, no patent license is granted by a 124 | Contributor: 125 | 126 | a. for any code that a Contributor has removed from Covered Software; or 127 | 128 | b. for infringements caused by: (i) Your and any other third party's 129 | modifications of Covered Software, or (ii) the combination of its 130 | Contributions with other software (except as part of its Contributor 131 | Version); or 132 | 133 | c. under Patent Claims infringed by Covered Software in the absence of 134 | its Contributions. 135 | 136 | This License does not grant any rights in the trademarks, service marks, 137 | or logos of any Contributor (except as may be necessary to comply with 138 | the notice requirements in Section 3.4). 139 | 140 | 2.4. Subsequent Licenses 141 | 142 | No Contributor makes additional grants as a result of Your choice to 143 | distribute the Covered Software under a subsequent version of this 144 | License (see Section 10.2) or under the terms of a Secondary License (if 145 | permitted under the terms of Section 3.3). 146 | 147 | 2.5. Representation 148 | 149 | Each Contributor represents that the Contributor believes its 150 | Contributions are its original creation(s) or it has sufficient rights to 151 | grant the rights to its Contributions conveyed by this License. 152 | 153 | 2.6. Fair Use 154 | 155 | This License is not intended to limit any rights You have under 156 | applicable copyright doctrines of fair use, fair dealing, or other 157 | equivalents. 158 | 159 | 2.7. Conditions 160 | 161 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in 162 | Section 2.1. 163 | 164 | 165 | 3. Responsibilities 166 | 167 | 3.1. Distribution of Source Form 168 | 169 | All distribution of Covered Software in Source Code Form, including any 170 | Modifications that You create or to which You contribute, must be under 171 | the terms of this License. You must inform recipients that the Source 172 | Code Form of the Covered Software is governed by the terms of this 173 | License, and how they can obtain a copy of this License. You may not 174 | attempt to alter or restrict the recipients' rights in the Source Code 175 | Form. 176 | 177 | 3.2. Distribution of Executable Form 178 | 179 | If You distribute Covered Software in Executable Form then: 180 | 181 | a. such Covered Software must also be made available in Source Code Form, 182 | as described in Section 3.1, and You must inform recipients of the 183 | Executable Form how they can obtain a copy of such Source Code Form by 184 | reasonable means in a timely manner, at a charge no more than the cost 185 | of distribution to the recipient; and 186 | 187 | b. You may distribute such Executable Form under the terms of this 188 | License, or sublicense it under different terms, provided that the 189 | license for the Executable Form does not attempt to limit or alter the 190 | recipients' rights in the Source Code Form under this License. 191 | 192 | 3.3. Distribution of a Larger Work 193 | 194 | You may create and distribute a Larger Work under terms of Your choice, 195 | provided that You also comply with the requirements of this License for 196 | the Covered Software. If the Larger Work is a combination of Covered 197 | Software with a work governed by one or more Secondary Licenses, and the 198 | Covered Software is not Incompatible With Secondary Licenses, this 199 | License permits You to additionally distribute such Covered Software 200 | under the terms of such Secondary License(s), so that the recipient of 201 | the Larger Work may, at their option, further distribute the Covered 202 | Software under the terms of either this License or such Secondary 203 | License(s). 204 | 205 | 3.4. Notices 206 | 207 | You may not remove or alter the substance of any license notices 208 | (including copyright notices, patent notices, disclaimers of warranty, or 209 | limitations of liability) contained within the Source Code Form of the 210 | Covered Software, except that You may alter any license notices to the 211 | extent required to remedy known factual inaccuracies. 212 | 213 | 3.5. Application of Additional Terms 214 | 215 | You may choose to offer, and to charge a fee for, warranty, support, 216 | indemnity or liability obligations to one or more recipients of Covered 217 | Software. However, You may do so only on Your own behalf, and not on 218 | behalf of any Contributor. You must make it absolutely clear that any 219 | such warranty, support, indemnity, or liability obligation is offered by 220 | You alone, and You hereby agree to indemnify every Contributor for any 221 | liability incurred by such Contributor as a result of warranty, support, 222 | indemnity or liability terms You offer. You may include additional 223 | disclaimers of warranty and limitations of liability specific to any 224 | jurisdiction. 225 | 226 | 4. Inability to Comply Due to Statute or Regulation 227 | 228 | If it is impossible for You to comply with any of the terms of this License 229 | with respect to some or all of the Covered Software due to statute, 230 | judicial order, or regulation then You must: (a) comply with the terms of 231 | this License to the maximum extent possible; and (b) describe the 232 | limitations and the code they affect. Such description must be placed in a 233 | text file included with all distributions of the Covered Software under 234 | this License. Except to the extent prohibited by statute or regulation, 235 | such description must be sufficiently detailed for a recipient of ordinary 236 | skill to be able to understand it. 237 | 238 | 5. Termination 239 | 240 | 5.1. The rights granted under this License will terminate automatically if You 241 | fail to comply with any of its terms. However, if You become compliant, 242 | then the rights granted under this License from a particular Contributor 243 | are reinstated (a) provisionally, unless and until such Contributor 244 | explicitly and finally terminates Your grants, and (b) on an ongoing 245 | basis, if such Contributor fails to notify You of the non-compliance by 246 | some reasonable means prior to 60 days after You have come back into 247 | compliance. Moreover, Your grants from a particular Contributor are 248 | reinstated on an ongoing basis if such Contributor notifies You of the 249 | non-compliance by some reasonable means, this is the first time You have 250 | received notice of non-compliance with this License from such 251 | Contributor, and You become compliant prior to 30 days after Your receipt 252 | of the notice. 253 | 254 | 5.2. If You initiate litigation against any entity by asserting a patent 255 | infringement claim (excluding declaratory judgment actions, 256 | counter-claims, and cross-claims) alleging that a Contributor Version 257 | directly or indirectly infringes any patent, then the rights granted to 258 | You by any and all Contributors for the Covered Software under Section 259 | 2.1 of this License shall terminate. 260 | 261 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user 262 | license agreements (excluding distributors and resellers) which have been 263 | validly granted by You or Your distributors under this License prior to 264 | termination shall survive termination. 265 | 266 | 6. Disclaimer of Warranty 267 | 268 | Covered Software is provided under this License on an "as is" basis, 269 | without warranty of any kind, either expressed, implied, or statutory, 270 | including, without limitation, warranties that the Covered Software is free 271 | of defects, merchantable, fit for a particular purpose or non-infringing. 272 | The entire risk as to the quality and performance of the Covered Software 273 | is with You. Should any Covered Software prove defective in any respect, 274 | You (not any Contributor) assume the cost of any necessary servicing, 275 | repair, or correction. This disclaimer of warranty constitutes an essential 276 | part of this License. No use of any Covered Software is authorized under 277 | this License except under this disclaimer. 278 | 279 | 7. Limitation of Liability 280 | 281 | Under no circumstances and under no legal theory, whether tort (including 282 | negligence), contract, or otherwise, shall any Contributor, or anyone who 283 | distributes Covered Software as permitted above, be liable to You for any 284 | direct, indirect, special, incidental, or consequential damages of any 285 | character including, without limitation, damages for lost profits, loss of 286 | goodwill, work stoppage, computer failure or malfunction, or any and all 287 | other commercial damages or losses, even if such party shall have been 288 | informed of the possibility of such damages. This limitation of liability 289 | shall not apply to liability for death or personal injury resulting from 290 | such party's negligence to the extent applicable law prohibits such 291 | limitation. Some jurisdictions do not allow the exclusion or limitation of 292 | incidental or consequential damages, so this exclusion and limitation may 293 | not apply to You. 294 | 295 | 8. Litigation 296 | 297 | Any litigation relating to this License may be brought only in the courts 298 | of a jurisdiction where the defendant maintains its principal place of 299 | business and such litigation shall be governed by laws of that 300 | jurisdiction, without reference to its conflict-of-law provisions. Nothing 301 | in this Section shall prevent a party's ability to bring cross-claims or 302 | counter-claims. 303 | 304 | 9. Miscellaneous 305 | 306 | This License represents the complete agreement concerning the subject 307 | matter hereof. If any provision of this License is held to be 308 | unenforceable, such provision shall be reformed only to the extent 309 | necessary to make it enforceable. Any law or regulation which provides that 310 | the language of a contract shall be construed against the drafter shall not 311 | be used to construe this License against a Contributor. 312 | 313 | 314 | 10. Versions of the License 315 | 316 | 10.1. New Versions 317 | 318 | Mozilla Foundation is the license steward. Except as provided in Section 319 | 10.3, no one other than the license steward has the right to modify or 320 | publish new versions of this License. Each version will be given a 321 | distinguishing version number. 322 | 323 | 10.2. Effect of New Versions 324 | 325 | You may distribute the Covered Software under the terms of the version 326 | of the License under which You originally received the Covered Software, 327 | or under the terms of any subsequent version published by the license 328 | steward. 329 | 330 | 10.3. Modified Versions 331 | 332 | If you create software not governed by this License, and you want to 333 | create a new license for such software, you may create and use a 334 | modified version of this License if you rename the license and remove 335 | any references to the name of the license steward (except to note that 336 | such modified license differs from this License). 337 | 338 | 10.4. Distributing Source Code Form that is Incompatible With Secondary 339 | Licenses If You choose to distribute Source Code Form that is 340 | Incompatible With Secondary Licenses under the terms of this version of 341 | the License, the notice described in Exhibit B of this License must be 342 | attached. 343 | 344 | Exhibit A - Source Code Form License Notice 345 | 346 | This Source Code Form is subject to the 347 | terms of the Mozilla Public License, v. 348 | 2.0. If a copy of the MPL was not 349 | distributed with this file, You can 350 | obtain one at 351 | http://mozilla.org/MPL/2.0/. 352 | 353 | If it is not possible or desirable to put the notice in a particular file, 354 | then You may include the notice in a location (such as a LICENSE file in a 355 | relevant directory) where a recipient would be likely to look for such a 356 | notice. 357 | 358 | You may add additional accurate notices of copyright ownership. 359 | 360 | Exhibit B - "Incompatible With Secondary Licenses" Notice 361 | 362 | This Source Code Form is "Incompatible 363 | With Secondary Licenses", as defined by 364 | the Mozilla Public License, v. 2.0. 365 | 366 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | default: test 2 | 3 | test: 4 | go vet ./... 5 | go test -v -race ./... -coverprofile=coverage.out 6 | 7 | updatedeps: 8 | go get -f -t -u ./... 9 | go get -f -u ./... 10 | 11 | .PHONY: default test updatedeps 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | go-retryablehttp 2 | ================ 3 | 4 | [![Build Status](http://img.shields.io/travis/hashicorp/go-retryablehttp.svg?style=flat-square)][travis] 5 | [![Go Documentation](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)][godocs] 6 | 7 | [travis]: http://travis-ci.org/hashicorp/go-retryablehttp 8 | [godocs]: http://godoc.org/github.com/hashicorp/go-retryablehttp 9 | 10 | The `retryablehttp` package provides a familiar HTTP client interface with 11 | automatic retries and exponential backoff. It is a thin wrapper over the 12 | standard `net/http` client library and exposes nearly the same public API. This 13 | makes `retryablehttp` very easy to drop into existing programs. 14 | 15 | `retryablehttp` performs automatic retries under certain conditions. Mainly, if 16 | an error is returned by the client (connection errors, etc.), or if a 500-range 17 | response code is received (except 501), then a retry is invoked after a wait 18 | period. Otherwise, the response is returned and left to the caller to 19 | interpret. 20 | 21 | The main difference from `net/http` is that requests which take a request body 22 | (POST/PUT et. al) can have the body provided in a number of ways (some more or 23 | less efficient) that allow "rewinding" the request body if the initial request 24 | fails so that the full request can be attempted again. See the 25 | [godoc](http://godoc.org/github.com/hashicorp/go-retryablehttp) for more 26 | details. 27 | 28 | Version 0.6.0 and before are compatible with Go prior to 1.12. From 0.6.1 onward, Go 1.12+ is required. 29 | From 0.6.7 onward, Go 1.13+ is required. 30 | 31 | Example Use 32 | =========== 33 | 34 | Using this library should look almost identical to what you would do with 35 | `net/http`. The most simple example of a GET request is shown below: 36 | 37 | ```go 38 | resp, err := retryablehttp.Get("/foo") 39 | if err != nil { 40 | panic(err) 41 | } 42 | ``` 43 | 44 | The returned response object is an `*http.Response`, the same thing you would 45 | usually get from `net/http`. Had the request failed one or more times, the above 46 | call would block and retry with exponential backoff. 47 | 48 | ## Getting a stdlib `*http.Client` with retries 49 | 50 | It's possible to convert a `*retryablehttp.Client` directly to a `*http.Client`. 51 | This makes use of retryablehttp broadly applicable with minimal effort. Simply 52 | configure a `*retryablehttp.Client` as you wish, and then call `StandardClient()`: 53 | 54 | ```go 55 | retryClient := retryablehttp.NewClient() 56 | retryClient.RetryMax = 10 57 | 58 | standardClient := retryClient.StandardClient() // *http.Client 59 | ``` 60 | 61 | For more usage and examples see the 62 | [pkg.go.dev](https://pkg.go.dev/github.com/hashicorp/go-retryablehttp). 63 | -------------------------------------------------------------------------------- /cert_error_go119.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | //go:build !go1.20 5 | // +build !go1.20 6 | 7 | package retryablehttp 8 | 9 | import "crypto/x509" 10 | 11 | func isCertError(err error) bool { 12 | _, ok := err.(x509.UnknownAuthorityError) 13 | return ok 14 | } 15 | -------------------------------------------------------------------------------- /cert_error_go120.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | //go:build go1.20 5 | // +build go1.20 6 | 7 | package retryablehttp 8 | 9 | import "crypto/tls" 10 | 11 | func isCertError(err error) bool { 12 | _, ok := err.(*tls.CertificateVerificationError) 13 | return ok 14 | } 15 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | // Package retryablehttp provides a familiar HTTP client interface with 5 | // automatic retries and exponential backoff. It is a thin wrapper over the 6 | // standard net/http client library and exposes nearly the same public API. 7 | // This makes retryablehttp very easy to drop into existing programs. 8 | // 9 | // retryablehttp performs automatic retries under certain conditions. Mainly, if 10 | // an error is returned by the client (connection errors etc), or if a 500-range 11 | // response is received, then a retry is invoked. Otherwise, the response is 12 | // returned and left to the caller to interpret. 13 | // 14 | // Requests which take a request body should provide a non-nil function 15 | // parameter. The best choice is to provide either a function satisfying 16 | // ReaderFunc which provides multiple io.Readers in an efficient manner, a 17 | // *bytes.Buffer (the underlying raw byte slice will be used) or a raw byte 18 | // slice. As it is a reference type, and we will wrap it as needed by readers, 19 | // we can efficiently re-use the request body without needing to copy it. If an 20 | // io.Reader (such as a *bytes.Reader) is provided, the full body will be read 21 | // prior to the first request, and will be efficiently re-used for any retries. 22 | // ReadSeeker can be used, but some users have observed occasional data races 23 | // between the net/http library and the Seek functionality of some 24 | // implementations of ReadSeeker, so should be avoided if possible. 25 | package retryablehttp 26 | 27 | import ( 28 | "bytes" 29 | "context" 30 | "fmt" 31 | "io" 32 | "log" 33 | "math" 34 | "math/rand" 35 | "net/http" 36 | "net/url" 37 | "os" 38 | "regexp" 39 | "strconv" 40 | "strings" 41 | "sync" 42 | "time" 43 | 44 | cleanhttp "github.com/hashicorp/go-cleanhttp" 45 | ) 46 | 47 | var ( 48 | // Default retry configuration 49 | defaultRetryWaitMin = 1 * time.Second 50 | defaultRetryWaitMax = 30 * time.Second 51 | defaultRetryMax = 4 52 | 53 | // defaultLogger is the logger provided with defaultClient 54 | defaultLogger = log.New(os.Stderr, "", log.LstdFlags) 55 | 56 | // defaultClient is used for performing requests without explicitly making 57 | // a new client. It is purposely private to avoid modifications. 58 | defaultClient = NewClient() 59 | 60 | // We need to consume response bodies to maintain http connections, but 61 | // limit the size we consume to respReadLimit. 62 | respReadLimit = int64(4096) 63 | 64 | // timeNow sets the function that returns the current time. 65 | // This defaults to time.Now. Changes to this should only be done in tests. 66 | timeNow = time.Now 67 | 68 | // A regular expression to match the error returned by net/http when the 69 | // configured number of redirects is exhausted. This error isn't typed 70 | // specifically so we resort to matching on the error string. 71 | redirectsErrorRe = regexp.MustCompile(`stopped after \d+ redirects\z`) 72 | 73 | // A regular expression to match the error returned by net/http when the 74 | // scheme specified in the URL is invalid. This error isn't typed 75 | // specifically so we resort to matching on the error string. 76 | schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`) 77 | 78 | // A regular expression to match the error returned by net/http when a 79 | // request header or value is invalid. This error isn't typed 80 | // specifically so we resort to matching on the error string. 81 | invalidHeaderErrorRe = regexp.MustCompile(`invalid header`) 82 | 83 | // A regular expression to match the error returned by net/http when the 84 | // TLS certificate is not trusted. This error isn't typed 85 | // specifically so we resort to matching on the error string. 86 | notTrustedErrorRe = regexp.MustCompile(`certificate is not trusted`) 87 | ) 88 | 89 | // ReaderFunc is the type of function that can be given natively to NewRequest 90 | type ReaderFunc func() (io.Reader, error) 91 | 92 | // ResponseHandlerFunc is a type of function that takes in a Response, and does something with it. 93 | // The ResponseHandlerFunc is called when the HTTP client successfully receives a response and the 94 | // CheckRetry function indicates that a retry of the base request is not necessary. 95 | // If an error is returned from this function, the CheckRetry policy will be used to determine 96 | // whether to retry the whole request (including this handler). 97 | // 98 | // Make sure to check status codes! Even if the request was completed it may have a non-2xx status code. 99 | // 100 | // The response body is not automatically closed. It must be closed either by the ResponseHandlerFunc or 101 | // by the caller out-of-band. Failure to do so will result in a memory leak. 102 | type ResponseHandlerFunc func(*http.Response) error 103 | 104 | // LenReader is an interface implemented by many in-memory io.Reader's. Used 105 | // for automatically sending the right Content-Length header when possible. 106 | type LenReader interface { 107 | Len() int 108 | } 109 | 110 | // Request wraps the metadata needed to create HTTP requests. 111 | type Request struct { 112 | // body is a seekable reader over the request body payload. This is 113 | // used to rewind the request data in between retries. 114 | body ReaderFunc 115 | 116 | responseHandler ResponseHandlerFunc 117 | 118 | // Embed an HTTP request directly. This makes a *Request act exactly 119 | // like an *http.Request so that all meta methods are supported. 120 | *http.Request 121 | } 122 | 123 | // WithContext returns wrapped Request with a shallow copy of underlying *http.Request 124 | // with its context changed to ctx. The provided ctx must be non-nil. 125 | func (r *Request) WithContext(ctx context.Context) *Request { 126 | return &Request{ 127 | body: r.body, 128 | responseHandler: r.responseHandler, 129 | Request: r.Request.WithContext(ctx), 130 | } 131 | } 132 | 133 | // SetResponseHandler allows setting the response handler. 134 | func (r *Request) SetResponseHandler(fn ResponseHandlerFunc) { 135 | r.responseHandler = fn 136 | } 137 | 138 | // BodyBytes allows accessing the request body. It is an analogue to 139 | // http.Request's Body variable, but it returns a copy of the underlying data 140 | // rather than consuming it. 141 | // 142 | // This function is not thread-safe; do not call it at the same time as another 143 | // call, or at the same time this request is being used with Client.Do. 144 | func (r *Request) BodyBytes() ([]byte, error) { 145 | if r.body == nil { 146 | return nil, nil 147 | } 148 | body, err := r.body() 149 | if err != nil { 150 | return nil, err 151 | } 152 | buf := new(bytes.Buffer) 153 | _, err = buf.ReadFrom(body) 154 | if err != nil { 155 | return nil, err 156 | } 157 | return buf.Bytes(), nil 158 | } 159 | 160 | // SetBody allows setting the request body. 161 | // 162 | // It is useful if a new body needs to be set without constructing a new Request. 163 | func (r *Request) SetBody(rawBody interface{}) error { 164 | bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody) 165 | if err != nil { 166 | return err 167 | } 168 | r.body = bodyReader 169 | r.ContentLength = contentLength 170 | if bodyReader != nil { 171 | r.GetBody = func() (io.ReadCloser, error) { 172 | body, err := bodyReader() 173 | if err != nil { 174 | return nil, err 175 | } 176 | if rc, ok := body.(io.ReadCloser); ok { 177 | return rc, nil 178 | } 179 | return io.NopCloser(body), nil 180 | } 181 | } else { 182 | r.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil } 183 | } 184 | return nil 185 | } 186 | 187 | // WriteTo allows copying the request body into a writer. 188 | // 189 | // It writes data to w until there's no more data to write or 190 | // when an error occurs. The return int64 value is the number of bytes 191 | // written. Any error encountered during the write is also returned. 192 | // The signature matches io.WriterTo interface. 193 | func (r *Request) WriteTo(w io.Writer) (int64, error) { 194 | body, err := r.body() 195 | if err != nil { 196 | return 0, err 197 | } 198 | if c, ok := body.(io.Closer); ok { 199 | defer c.Close() 200 | } 201 | return io.Copy(w, body) 202 | } 203 | 204 | func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, error) { 205 | var bodyReader ReaderFunc 206 | var contentLength int64 207 | 208 | switch body := rawBody.(type) { 209 | // If they gave us a function already, great! Use it. 210 | case ReaderFunc: 211 | bodyReader = body 212 | tmp, err := body() 213 | if err != nil { 214 | return nil, 0, err 215 | } 216 | if lr, ok := tmp.(LenReader); ok { 217 | contentLength = int64(lr.Len()) 218 | } 219 | if c, ok := tmp.(io.Closer); ok { 220 | c.Close() 221 | } 222 | 223 | case func() (io.Reader, error): 224 | bodyReader = body 225 | tmp, err := body() 226 | if err != nil { 227 | return nil, 0, err 228 | } 229 | if lr, ok := tmp.(LenReader); ok { 230 | contentLength = int64(lr.Len()) 231 | } 232 | if c, ok := tmp.(io.Closer); ok { 233 | c.Close() 234 | } 235 | 236 | // If a regular byte slice, we can read it over and over via new 237 | // readers 238 | case []byte: 239 | buf := body 240 | bodyReader = func() (io.Reader, error) { 241 | return bytes.NewReader(buf), nil 242 | } 243 | contentLength = int64(len(buf)) 244 | 245 | // If a bytes.Buffer we can read the underlying byte slice over and 246 | // over 247 | case *bytes.Buffer: 248 | buf := body 249 | bodyReader = func() (io.Reader, error) { 250 | return bytes.NewReader(buf.Bytes()), nil 251 | } 252 | contentLength = int64(buf.Len()) 253 | 254 | // We prioritize *bytes.Reader here because we don't really want to 255 | // deal with it seeking so want it to match here instead of the 256 | // io.ReadSeeker case. 257 | case *bytes.Reader: 258 | snapshot := *body 259 | bodyReader = func() (io.Reader, error) { 260 | r := snapshot 261 | return &r, nil 262 | } 263 | contentLength = int64(body.Len()) 264 | 265 | // Compat case 266 | case io.ReadSeeker: 267 | raw := body 268 | bodyReader = func() (io.Reader, error) { 269 | _, err := raw.Seek(0, 0) 270 | return io.NopCloser(raw), err 271 | } 272 | if lr, ok := raw.(LenReader); ok { 273 | contentLength = int64(lr.Len()) 274 | } 275 | 276 | // Read all in so we can reset 277 | case io.Reader: 278 | buf, err := io.ReadAll(body) 279 | if err != nil { 280 | return nil, 0, err 281 | } 282 | if len(buf) == 0 { 283 | bodyReader = func() (io.Reader, error) { 284 | return http.NoBody, nil 285 | } 286 | contentLength = 0 287 | } else { 288 | bodyReader = func() (io.Reader, error) { 289 | return bytes.NewReader(buf), nil 290 | } 291 | contentLength = int64(len(buf)) 292 | } 293 | 294 | // No body provided, nothing to do 295 | case nil: 296 | 297 | // Unrecognized type 298 | default: 299 | return nil, 0, fmt.Errorf("cannot handle type %T", rawBody) 300 | } 301 | return bodyReader, contentLength, nil 302 | } 303 | 304 | // FromRequest wraps an http.Request in a retryablehttp.Request 305 | func FromRequest(r *http.Request) (*Request, error) { 306 | bodyReader, _, err := getBodyReaderAndContentLength(r.Body) 307 | if err != nil { 308 | return nil, err 309 | } 310 | // Could assert contentLength == r.ContentLength 311 | return &Request{body: bodyReader, Request: r}, nil 312 | } 313 | 314 | // NewRequest creates a new wrapped request. 315 | func NewRequest(method, url string, rawBody interface{}) (*Request, error) { 316 | return NewRequestWithContext(context.Background(), method, url, rawBody) 317 | } 318 | 319 | // NewRequestWithContext creates a new wrapped request with the provided context. 320 | // 321 | // The context controls the entire lifetime of a request and its response: 322 | // obtaining a connection, sending the request, and reading the response headers and body. 323 | func NewRequestWithContext(ctx context.Context, method, url string, rawBody interface{}) (*Request, error) { 324 | httpReq, err := http.NewRequestWithContext(ctx, method, url, nil) 325 | if err != nil { 326 | return nil, err 327 | } 328 | 329 | req := &Request{ 330 | Request: httpReq, 331 | } 332 | if err := req.SetBody(rawBody); err != nil { 333 | return nil, err 334 | } 335 | 336 | return req, nil 337 | } 338 | 339 | // Logger interface allows to use other loggers than 340 | // standard log.Logger. 341 | type Logger interface { 342 | Printf(string, ...interface{}) 343 | } 344 | 345 | // LeveledLogger is an interface that can be implemented by any logger or a 346 | // logger wrapper to provide leveled logging. The methods accept a message 347 | // string and a variadic number of key-value pairs. For log.Printf style 348 | // formatting where message string contains a format specifier, use Logger 349 | // interface. 350 | type LeveledLogger interface { 351 | Error(msg string, keysAndValues ...interface{}) 352 | Info(msg string, keysAndValues ...interface{}) 353 | Debug(msg string, keysAndValues ...interface{}) 354 | Warn(msg string, keysAndValues ...interface{}) 355 | } 356 | 357 | // hookLogger adapts an LeveledLogger to Logger for use by the existing hook functions 358 | // without changing the API. 359 | type hookLogger struct { 360 | LeveledLogger 361 | } 362 | 363 | func (h hookLogger) Printf(s string, args ...interface{}) { 364 | h.Info(fmt.Sprintf(s, args...)) 365 | } 366 | 367 | // RequestLogHook allows a function to run before each retry. The HTTP 368 | // request which will be made, and the retry number (0 for the initial 369 | // request) are available to users. The internal logger is exposed to 370 | // consumers. 371 | type RequestLogHook func(Logger, *http.Request, int) 372 | 373 | // ResponseLogHook is like RequestLogHook, but allows running a function 374 | // on each HTTP response. This function will be invoked at the end of 375 | // every HTTP request executed, regardless of whether a subsequent retry 376 | // needs to be performed or not. If the response body is read or closed 377 | // from this method, this will affect the response returned from Do(). 378 | type ResponseLogHook func(Logger, *http.Response) 379 | 380 | // CheckRetry specifies a policy for handling retries. It is called 381 | // following each request with the response and error values returned by 382 | // the http.Client. If CheckRetry returns false, the Client stops retrying 383 | // and returns the response to the caller. If CheckRetry returns an error, 384 | // that error value is returned in lieu of the error from the request. The 385 | // Client will close any response body when retrying, but if the retry is 386 | // aborted it is up to the CheckRetry callback to properly close any 387 | // response body before returning. 388 | type CheckRetry func(ctx context.Context, resp *http.Response, err error) (bool, error) 389 | 390 | // Backoff specifies a policy for how long to wait between retries. 391 | // It is called after a failing request to determine the amount of time 392 | // that should pass before trying again. 393 | type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration 394 | 395 | // ErrorHandler is called if retries are expired, containing the last status 396 | // from the http library. If not specified, default behavior for the library is 397 | // to close the body and return an error indicating how many tries were 398 | // attempted. If overriding this, be sure to close the body if needed. 399 | type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error) 400 | 401 | // PrepareRetry is called before retry operation. It can be used for example to re-sign the request 402 | type PrepareRetry func(req *http.Request) error 403 | 404 | // Client is used to make HTTP requests. It adds additional functionality 405 | // like automatic retries to tolerate minor outages. 406 | type Client struct { 407 | HTTPClient *http.Client // Internal HTTP client. 408 | Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger 409 | 410 | RetryWaitMin time.Duration // Minimum time to wait 411 | RetryWaitMax time.Duration // Maximum time to wait 412 | RetryMax int // Maximum number of retries 413 | 414 | // RequestLogHook allows a user-supplied function to be called 415 | // before each retry. 416 | RequestLogHook RequestLogHook 417 | 418 | // ResponseLogHook allows a user-supplied function to be called 419 | // with the response from each HTTP request executed. 420 | ResponseLogHook ResponseLogHook 421 | 422 | // CheckRetry specifies the policy for handling retries, and is called 423 | // after each request. The default policy is DefaultRetryPolicy. 424 | CheckRetry CheckRetry 425 | 426 | // Backoff specifies the policy for how long to wait between retries 427 | Backoff Backoff 428 | 429 | // ErrorHandler specifies the custom error handler to use, if any 430 | ErrorHandler ErrorHandler 431 | 432 | // PrepareRetry can prepare the request for retry operation, for example re-sign it 433 | PrepareRetry PrepareRetry 434 | 435 | loggerInit sync.Once 436 | clientInit sync.Once 437 | } 438 | 439 | // NewClient creates a new Client with default settings. 440 | func NewClient() *Client { 441 | return &Client{ 442 | HTTPClient: cleanhttp.DefaultPooledClient(), 443 | Logger: defaultLogger, 444 | RetryWaitMin: defaultRetryWaitMin, 445 | RetryWaitMax: defaultRetryWaitMax, 446 | RetryMax: defaultRetryMax, 447 | CheckRetry: DefaultRetryPolicy, 448 | Backoff: DefaultBackoff, 449 | } 450 | } 451 | 452 | func (c *Client) logger() interface{} { 453 | c.loggerInit.Do(func() { 454 | if c.Logger == nil { 455 | return 456 | } 457 | 458 | switch c.Logger.(type) { 459 | case Logger, LeveledLogger: 460 | // ok 461 | default: 462 | // This should happen in dev when they are setting Logger and work on code, not in prod. 463 | panic(fmt.Sprintf("invalid logger type passed, must be Logger or LeveledLogger, was %T", c.Logger)) 464 | } 465 | }) 466 | 467 | return c.Logger 468 | } 469 | 470 | // DefaultRetryPolicy provides a default callback for Client.CheckRetry, which 471 | // will retry on connection errors and server errors. 472 | func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { 473 | // do not retry on context.Canceled or context.DeadlineExceeded 474 | if ctx.Err() != nil { 475 | return false, ctx.Err() 476 | } 477 | 478 | // don't propagate other errors 479 | shouldRetry, _ := baseRetryPolicy(resp, err) 480 | return shouldRetry, nil 481 | } 482 | 483 | // ErrorPropagatedRetryPolicy is the same as DefaultRetryPolicy, except it 484 | // propagates errors back instead of returning nil. This allows you to inspect 485 | // why it decided to retry or not. 486 | func ErrorPropagatedRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { 487 | // do not retry on context.Canceled or context.DeadlineExceeded 488 | if ctx.Err() != nil { 489 | return false, ctx.Err() 490 | } 491 | 492 | return baseRetryPolicy(resp, err) 493 | } 494 | 495 | func baseRetryPolicy(resp *http.Response, err error) (bool, error) { 496 | if err != nil { 497 | if v, ok := err.(*url.Error); ok { 498 | // Don't retry if the error was due to too many redirects. 499 | if redirectsErrorRe.MatchString(v.Error()) { 500 | return false, v 501 | } 502 | 503 | // Don't retry if the error was due to an invalid protocol scheme. 504 | if schemeErrorRe.MatchString(v.Error()) { 505 | return false, v 506 | } 507 | 508 | // Don't retry if the error was due to an invalid header. 509 | if invalidHeaderErrorRe.MatchString(v.Error()) { 510 | return false, v 511 | } 512 | 513 | // Don't retry if the error was due to TLS cert verification failure. 514 | if notTrustedErrorRe.MatchString(v.Error()) { 515 | return false, v 516 | } 517 | if isCertError(v.Err) { 518 | return false, v 519 | } 520 | } 521 | 522 | // The error is likely recoverable so retry. 523 | return true, nil 524 | } 525 | 526 | // 429 Too Many Requests is recoverable. Sometimes the server puts 527 | // a Retry-After response header to indicate when the server is 528 | // available to start processing request from client. 529 | if resp.StatusCode == http.StatusTooManyRequests { 530 | return true, nil 531 | } 532 | 533 | // Check the response code. We retry on 500-range responses to allow 534 | // the server time to recover, as 500's are typically not permanent 535 | // errors and may relate to outages on the server side. This will catch 536 | // invalid response codes as well, like 0 and 999. 537 | if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != http.StatusNotImplemented) { 538 | return true, fmt.Errorf("unexpected HTTP status %s", resp.Status) 539 | } 540 | 541 | return false, nil 542 | } 543 | 544 | // DefaultBackoff provides a default callback for Client.Backoff which 545 | // will perform exponential backoff based on the attempt number and limited 546 | // by the provided minimum and maximum durations. 547 | // 548 | // It also tries to parse Retry-After response header when a http.StatusTooManyRequests 549 | // (HTTP Code 429) is found in the resp parameter. Hence it will return the number of 550 | // seconds the server states it may be ready to process more requests from this client. 551 | func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { 552 | if resp != nil { 553 | if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { 554 | if sleep, ok := parseRetryAfterHeader(resp.Header["Retry-After"]); ok { 555 | return sleep 556 | } 557 | } 558 | } 559 | 560 | mult := math.Pow(2, float64(attemptNum)) * float64(min) 561 | sleep := time.Duration(mult) 562 | if float64(sleep) != mult || sleep > max { 563 | sleep = max 564 | } 565 | return sleep 566 | } 567 | 568 | // parseRetryAfterHeader parses the Retry-After header and returns the 569 | // delay duration according to the spec: https://httpwg.org/specs/rfc7231.html#header.retry-after 570 | // The bool returned will be true if the header was successfully parsed. 571 | // Otherwise, the header was either not present, or was not parseable according to the spec. 572 | // 573 | // Retry-After headers come in two flavors: Seconds or HTTP-Date 574 | // 575 | // Examples: 576 | // * Retry-After: Fri, 31 Dec 1999 23:59:59 GMT 577 | // * Retry-After: 120 578 | func parseRetryAfterHeader(headers []string) (time.Duration, bool) { 579 | if len(headers) == 0 || headers[0] == "" { 580 | return 0, false 581 | } 582 | header := headers[0] 583 | // Retry-After: 120 584 | if sleep, err := strconv.ParseInt(header, 10, 64); err == nil { 585 | if sleep < 0 { // a negative sleep doesn't make sense 586 | return 0, false 587 | } 588 | return time.Second * time.Duration(sleep), true 589 | } 590 | 591 | // Retry-After: Fri, 31 Dec 1999 23:59:59 GMT 592 | retryTime, err := time.Parse(time.RFC1123, header) 593 | if err != nil { 594 | return 0, false 595 | } 596 | if until := retryTime.Sub(timeNow()); until > 0 { 597 | return until, true 598 | } 599 | // date is in the past 600 | return 0, true 601 | } 602 | 603 | // LinearJitterBackoff provides a callback for Client.Backoff which will 604 | // perform linear backoff based on the attempt number and with jitter to 605 | // prevent a thundering herd. 606 | // 607 | // min and max here are *not* absolute values. The number to be multiplied by 608 | // the attempt number will be chosen at random from between them, thus they are 609 | // bounding the jitter. 610 | // 611 | // For instance: 612 | // * To get strictly linear backoff of one second increasing each retry, set 613 | // both to one second (1s, 2s, 3s, 4s, ...) 614 | // * To get a small amount of jitter centered around one second increasing each 615 | // retry, set to around one second, such as a min of 800ms and max of 1200ms 616 | // (892ms, 2102ms, 2945ms, 4312ms, ...) 617 | // * To get extreme jitter, set to a very wide spread, such as a min of 100ms 618 | // and a max of 20s (15382ms, 292ms, 51321ms, 35234ms, ...) 619 | func LinearJitterBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { 620 | // attemptNum always starts at zero but we want to start at 1 for multiplication 621 | attemptNum++ 622 | 623 | if max <= min { 624 | // Unclear what to do here, or they are the same, so return min * 625 | // attemptNum 626 | return min * time.Duration(attemptNum) 627 | } 628 | 629 | // Seed rand; doing this every time is fine 630 | source := rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) 631 | 632 | // Pick a random number that lies somewhere between the min and max and 633 | // multiply by the attemptNum. attemptNum starts at zero so we always 634 | // increment here. We first get a random percentage, then apply that to the 635 | // difference between min and max, and add to min. 636 | jitter := source.Float64() * float64(max-min) 637 | jitterMin := int64(jitter) + int64(min) 638 | return time.Duration(jitterMin * int64(attemptNum)) 639 | } 640 | 641 | // PassthroughErrorHandler is an ErrorHandler that directly passes through the 642 | // values from the net/http library for the final request. The body is not 643 | // closed. 644 | func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Response, error) { 645 | return resp, err 646 | } 647 | 648 | // Do wraps calling an HTTP method with retries. 649 | func (c *Client) Do(req *Request) (*http.Response, error) { 650 | c.clientInit.Do(func() { 651 | if c.HTTPClient == nil { 652 | c.HTTPClient = cleanhttp.DefaultPooledClient() 653 | } 654 | }) 655 | 656 | logger := c.logger() 657 | 658 | if logger != nil { 659 | switch v := logger.(type) { 660 | case LeveledLogger: 661 | v.Debug("performing request", "method", req.Method, "url", redactURL(req.URL)) 662 | case Logger: 663 | v.Printf("[DEBUG] %s %s", req.Method, redactURL(req.URL)) 664 | } 665 | } 666 | 667 | var resp *http.Response 668 | var attempt int 669 | var shouldRetry bool 670 | var doErr, respErr, checkErr, prepareErr error 671 | 672 | for i := 0; ; i++ { 673 | doErr, respErr, prepareErr = nil, nil, nil 674 | attempt++ 675 | 676 | // Always rewind the request body when non-nil. 677 | if req.body != nil { 678 | body, err := req.body() 679 | if err != nil { 680 | c.HTTPClient.CloseIdleConnections() 681 | return resp, err 682 | } 683 | if c, ok := body.(io.ReadCloser); ok { 684 | req.Body = c 685 | } else { 686 | req.Body = io.NopCloser(body) 687 | } 688 | } 689 | 690 | if c.RequestLogHook != nil { 691 | switch v := logger.(type) { 692 | case LeveledLogger: 693 | c.RequestLogHook(hookLogger{v}, req.Request, i) 694 | case Logger: 695 | c.RequestLogHook(v, req.Request, i) 696 | default: 697 | c.RequestLogHook(nil, req.Request, i) 698 | } 699 | } 700 | 701 | // Attempt the request 702 | resp, doErr = c.HTTPClient.Do(req.Request) 703 | 704 | // Check if we should continue with retries. 705 | shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) 706 | if !shouldRetry && doErr == nil && req.responseHandler != nil { 707 | respErr = req.responseHandler(resp) 708 | shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, respErr) 709 | } 710 | 711 | err := doErr 712 | if respErr != nil { 713 | err = respErr 714 | } 715 | if err != nil { 716 | switch v := logger.(type) { 717 | case LeveledLogger: 718 | v.Error("request failed", "error", err, "method", req.Method, "url", redactURL(req.URL)) 719 | case Logger: 720 | v.Printf("[ERR] %s %s request failed: %v", req.Method, redactURL(req.URL), err) 721 | } 722 | } else { 723 | // Call this here to maintain the behavior of logging all requests, 724 | // even if CheckRetry signals to stop. 725 | if c.ResponseLogHook != nil { 726 | // Call the response logger function if provided. 727 | switch v := logger.(type) { 728 | case LeveledLogger: 729 | c.ResponseLogHook(hookLogger{v}, resp) 730 | case Logger: 731 | c.ResponseLogHook(v, resp) 732 | default: 733 | c.ResponseLogHook(nil, resp) 734 | } 735 | } 736 | } 737 | 738 | if !shouldRetry { 739 | break 740 | } 741 | 742 | // We do this before drainBody because there's no need for the I/O if 743 | // we're breaking out 744 | remain := c.RetryMax - i 745 | if remain <= 0 { 746 | break 747 | } 748 | 749 | // We're going to retry, consume any response to reuse the connection. 750 | if doErr == nil { 751 | c.drainBody(resp.Body) 752 | } 753 | 754 | wait := c.Backoff(c.RetryWaitMin, c.RetryWaitMax, i, resp) 755 | if logger != nil { 756 | desc := fmt.Sprintf("%s %s", req.Method, redactURL(req.URL)) 757 | if resp != nil { 758 | desc = fmt.Sprintf("%s (status: %d)", desc, resp.StatusCode) 759 | } 760 | switch v := logger.(type) { 761 | case LeveledLogger: 762 | v.Debug("retrying request", "request", desc, "timeout", wait, "remaining", remain) 763 | case Logger: 764 | v.Printf("[DEBUG] %s: retrying in %s (%d left)", desc, wait, remain) 765 | } 766 | } 767 | timer := time.NewTimer(wait) 768 | select { 769 | case <-req.Context().Done(): 770 | timer.Stop() 771 | c.HTTPClient.CloseIdleConnections() 772 | return nil, req.Context().Err() 773 | case <-timer.C: 774 | } 775 | 776 | // Make shallow copy of http Request so that we can modify its body 777 | // without racing against the closeBody call in persistConn.writeLoop. 778 | httpreq := *req.Request 779 | req.Request = &httpreq 780 | 781 | if c.PrepareRetry != nil { 782 | if err := c.PrepareRetry(req.Request); err != nil { 783 | prepareErr = err 784 | break 785 | } 786 | } 787 | } 788 | 789 | // this is the closest we have to success criteria 790 | if doErr == nil && respErr == nil && checkErr == nil && prepareErr == nil && !shouldRetry { 791 | return resp, nil 792 | } 793 | 794 | defer c.HTTPClient.CloseIdleConnections() 795 | 796 | var err error 797 | if prepareErr != nil { 798 | err = prepareErr 799 | } else if checkErr != nil { 800 | err = checkErr 801 | } else if respErr != nil { 802 | err = respErr 803 | } else { 804 | err = doErr 805 | } 806 | 807 | if c.ErrorHandler != nil { 808 | return c.ErrorHandler(resp, err, attempt) 809 | } 810 | 811 | // By default, we close the response body and return an error without 812 | // returning the response 813 | if resp != nil { 814 | c.drainBody(resp.Body) 815 | } 816 | 817 | // this means CheckRetry thought the request was a failure, but didn't 818 | // communicate why 819 | if err == nil { 820 | return nil, fmt.Errorf("%s %s giving up after %d attempt(s)", 821 | req.Method, redactURL(req.URL), attempt) 822 | } 823 | 824 | return nil, fmt.Errorf("%s %s giving up after %d attempt(s): %w", 825 | req.Method, redactURL(req.URL), attempt, err) 826 | } 827 | 828 | // Try to read the response body so we can reuse this connection. 829 | func (c *Client) drainBody(body io.ReadCloser) { 830 | defer body.Close() 831 | _, err := io.Copy(io.Discard, io.LimitReader(body, respReadLimit)) 832 | if err != nil { 833 | if c.logger() != nil { 834 | switch v := c.logger().(type) { 835 | case LeveledLogger: 836 | v.Error("error reading response body", "error", err) 837 | case Logger: 838 | v.Printf("[ERR] error reading response body: %v", err) 839 | } 840 | } 841 | } 842 | } 843 | 844 | // Get is a shortcut for doing a GET request without making a new client. 845 | func Get(url string) (*http.Response, error) { 846 | return defaultClient.Get(url) 847 | } 848 | 849 | // Get is a convenience helper for doing simple GET requests. 850 | func (c *Client) Get(url string) (*http.Response, error) { 851 | req, err := NewRequest("GET", url, nil) 852 | if err != nil { 853 | return nil, err 854 | } 855 | return c.Do(req) 856 | } 857 | 858 | // Head is a shortcut for doing a HEAD request without making a new client. 859 | func Head(url string) (*http.Response, error) { 860 | return defaultClient.Head(url) 861 | } 862 | 863 | // Head is a convenience method for doing simple HEAD requests. 864 | func (c *Client) Head(url string) (*http.Response, error) { 865 | req, err := NewRequest("HEAD", url, nil) 866 | if err != nil { 867 | return nil, err 868 | } 869 | return c.Do(req) 870 | } 871 | 872 | // Post is a shortcut for doing a POST request without making a new client. 873 | // The bodyType parameter sets the "Content-Type" header of the request. 874 | func Post(url, bodyType string, body interface{}) (*http.Response, error) { 875 | return defaultClient.Post(url, bodyType, body) 876 | } 877 | 878 | // Post is a convenience method for doing simple POST requests. 879 | // The bodyType parameter sets the "Content-Type" header of the request. 880 | func (c *Client) Post(url, bodyType string, body interface{}) (*http.Response, error) { 881 | req, err := NewRequest("POST", url, body) 882 | if err != nil { 883 | return nil, err 884 | } 885 | req.Header.Set("Content-Type", bodyType) 886 | return c.Do(req) 887 | } 888 | 889 | // PostForm is a shortcut to perform a POST with form data without creating 890 | // a new client. 891 | func PostForm(url string, data url.Values) (*http.Response, error) { 892 | return defaultClient.PostForm(url, data) 893 | } 894 | 895 | // PostForm is a convenience method for doing simple POST operations using 896 | // pre-filled url.Values form data. 897 | func (c *Client) PostForm(url string, data url.Values) (*http.Response, error) { 898 | return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) 899 | } 900 | 901 | // StandardClient returns a stdlib *http.Client with a custom Transport, which 902 | // shims in a *retryablehttp.Client for added retries. 903 | func (c *Client) StandardClient() *http.Client { 904 | return &http.Client{ 905 | Transport: &RoundTripper{Client: c}, 906 | } 907 | } 908 | 909 | // Taken from url.URL#Redacted() which was introduced in go 1.15. 910 | // We can switch to using it directly if we'll bump the minimum required go version. 911 | func redactURL(u *url.URL) string { 912 | if u == nil { 913 | return "" 914 | } 915 | 916 | ru := *u 917 | if _, has := ru.User.Password(); has { 918 | ru.User = url.UserPassword(ru.User.Username(), "xxxxx") 919 | } 920 | return ru.String() 921 | } 922 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package retryablehttp 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net" 13 | "net/http" 14 | "net/http/httptest" 15 | "net/http/httputil" 16 | "net/url" 17 | "strconv" 18 | "strings" 19 | "sync/atomic" 20 | "testing" 21 | "time" 22 | 23 | "github.com/hashicorp/go-hclog" 24 | ) 25 | 26 | func TestRequest(t *testing.T) { 27 | // Fails on invalid request 28 | _, err := NewRequest("GET", "://foo", nil) 29 | if err == nil { 30 | t.Fatalf("should error") 31 | } 32 | 33 | // Works with no request body 34 | _, err = NewRequest("GET", "http://foo", nil) 35 | if err != nil { 36 | t.Fatalf("err: %v", err) 37 | } 38 | 39 | // Works with request body 40 | body := bytes.NewReader([]byte("yo")) 41 | req, err := NewRequest("GET", "/", body) 42 | if err != nil { 43 | t.Fatalf("err: %v", err) 44 | } 45 | 46 | // Request allows typical HTTP request forming methods 47 | req.Header.Set("X-Test", "foo") 48 | if v, ok := req.Header["X-Test"]; !ok || len(v) != 1 || v[0] != "foo" { 49 | t.Fatalf("bad headers: %v", req.Header) 50 | } 51 | 52 | // Sets the Content-Length automatically for LenReaders 53 | if req.ContentLength != 2 { 54 | t.Fatalf("bad ContentLength: %d", req.ContentLength) 55 | } 56 | } 57 | 58 | func TestFromRequest(t *testing.T) { 59 | // Works with no request body 60 | httpReq, err := http.NewRequest("GET", "http://foo", nil) 61 | if err != nil { 62 | t.Fatalf("err: %v", err) 63 | } 64 | _, err = FromRequest(httpReq) 65 | if err != nil { 66 | t.Fatalf("err: %v", err) 67 | } 68 | 69 | // Works with request body 70 | body := bytes.NewReader([]byte("yo")) 71 | httpReq, err = http.NewRequest("GET", "/", body) 72 | if err != nil { 73 | t.Fatalf("err: %v", err) 74 | } 75 | req, err := FromRequest(httpReq) 76 | if err != nil { 77 | t.Fatalf("err: %v", err) 78 | } 79 | 80 | // Preserves headers 81 | httpReq.Header.Set("X-Test", "foo") 82 | if v, ok := req.Header["X-Test"]; !ok || len(v) != 1 || v[0] != "foo" { 83 | t.Fatalf("bad headers: %v", req.Header) 84 | } 85 | 86 | // Preserves the Content-Length automatically for LenReaders 87 | if req.ContentLength != 2 { 88 | t.Fatalf("bad ContentLength: %d", req.ContentLength) 89 | } 90 | } 91 | 92 | // Since normal ways we would generate a Reader have special cases, use a 93 | // custom type here 94 | type custReader struct { 95 | val string 96 | pos int 97 | } 98 | 99 | func (c *custReader) Read(p []byte) (n int, err error) { 100 | if c.val == "" { 101 | c.val = "hello" 102 | } 103 | if c.pos >= len(c.val) { 104 | return 0, io.EOF 105 | } 106 | var i int 107 | for i = 0; i < len(p) && i+c.pos < len(c.val); i++ { 108 | p[i] = c.val[i+c.pos] 109 | } 110 | c.pos += i 111 | return i, nil 112 | } 113 | 114 | func TestClient_Do(t *testing.T) { 115 | testBytes := []byte("hello") 116 | // Native func 117 | testClientDo(t, ReaderFunc(func() (io.Reader, error) { 118 | return bytes.NewReader(testBytes), nil 119 | })) 120 | // Native func, different Go type 121 | testClientDo(t, func() (io.Reader, error) { 122 | return bytes.NewReader(testBytes), nil 123 | }) 124 | // []byte 125 | testClientDo(t, testBytes) 126 | // *bytes.Buffer 127 | testClientDo(t, bytes.NewBuffer(testBytes)) 128 | // *bytes.Reader 129 | testClientDo(t, bytes.NewReader(testBytes)) 130 | // io.ReadSeeker 131 | testClientDo(t, strings.NewReader(string(testBytes))) 132 | // io.Reader 133 | testClientDo(t, &custReader{}) 134 | } 135 | 136 | func testClientDo(t *testing.T, body interface{}) { 137 | // Create a request 138 | req, err := NewRequest("PUT", "http://127.0.0.1:28934/v1/foo", body) 139 | if err != nil { 140 | t.Fatalf("err: %v", err) 141 | } 142 | req.Header.Set("foo", "bar") 143 | 144 | // Track the number of times the logging hook was called 145 | retryCount := -1 146 | 147 | // Create the client. Use short retry windows. 148 | client := NewClient() 149 | client.RetryWaitMin = 10 * time.Millisecond 150 | client.RetryWaitMax = 50 * time.Millisecond 151 | client.RetryMax = 50 152 | client.RequestLogHook = func(logger Logger, req *http.Request, retryNumber int) { 153 | retryCount = retryNumber 154 | 155 | if logger != client.Logger { 156 | t.Fatalf("Client logger was not passed to logging hook") 157 | } 158 | 159 | dumpBytes, err := httputil.DumpRequestOut(req, false) 160 | if err != nil { 161 | t.Fatal("Dumping requests failed") 162 | } 163 | 164 | dumpString := string(dumpBytes) 165 | if !strings.Contains(dumpString, "PUT /v1/foo") { 166 | t.Fatalf("Bad request dump:\n%s", dumpString) 167 | } 168 | } 169 | 170 | // Send the request 171 | var resp *http.Response 172 | doneCh := make(chan struct{}) 173 | errCh := make(chan error, 1) 174 | go func() { 175 | defer close(doneCh) 176 | defer close(errCh) 177 | var err error 178 | resp, err = client.Do(req) 179 | errCh <- err 180 | }() 181 | 182 | select { 183 | case <-doneCh: 184 | t.Fatalf("should retry on error") 185 | case <-time.After(200 * time.Millisecond): 186 | // Client should still be retrying due to connection failure. 187 | } 188 | 189 | // Create the mock handler. First we return a 500-range response to ensure 190 | // that we power through and keep retrying in the face of recoverable 191 | // errors. 192 | code := int64(500) 193 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 194 | // Check the request details 195 | if r.Method != "PUT" { 196 | t.Fatalf("bad method: %s", r.Method) 197 | } 198 | if r.RequestURI != "/v1/foo" { 199 | t.Fatalf("bad uri: %s", r.RequestURI) 200 | } 201 | 202 | // Check the headers 203 | if v := r.Header.Get("foo"); v != "bar" { 204 | t.Fatalf("bad header: expect foo=bar, got foo=%v", v) 205 | } 206 | 207 | // Check the payload 208 | body, err := io.ReadAll(r.Body) 209 | if err != nil { 210 | t.Fatalf("err: %s", err) 211 | } 212 | expected := []byte("hello") 213 | if !bytes.Equal(body, expected) { 214 | t.Fatalf("bad: %v", body) 215 | } 216 | 217 | w.WriteHeader(int(atomic.LoadInt64(&code))) 218 | }) 219 | 220 | // Create a test server 221 | list, err := net.Listen("tcp", ":28934") 222 | if err != nil { 223 | t.Fatalf("err: %v", err) 224 | } 225 | defer list.Close() 226 | errors := make(chan error, 1) 227 | go func() { 228 | err := http.Serve(list, handler) 229 | if err != nil { 230 | errors <- err 231 | return 232 | } 233 | }() 234 | 235 | // Wait again 236 | select { 237 | case <-doneCh: 238 | t.Fatalf("should retry on 500-range") 239 | case <-time.After(200 * time.Millisecond): 240 | // Client should still be retrying due to 500's. 241 | } 242 | 243 | // Start returning 200's 244 | atomic.StoreInt64(&code, 200) 245 | 246 | // Wait again 247 | select { 248 | case <-doneCh: 249 | case <-time.After(time.Second): 250 | t.Fatalf("timed out") 251 | } 252 | 253 | if resp.StatusCode != 200 { 254 | t.Fatalf("exected 200, got: %d", resp.StatusCode) 255 | } 256 | 257 | if retryCount < 0 { 258 | t.Fatal("request log hook was not called") 259 | } 260 | 261 | err = <-errCh 262 | if err != nil { 263 | t.Fatalf("err: %v", err) 264 | } 265 | } 266 | 267 | func TestClient_Do_WithResponseHandler(t *testing.T) { 268 | // Create the client. Use short retry windows so we fail faster. 269 | client := NewClient() 270 | client.RetryWaitMin = 10 * time.Millisecond 271 | client.RetryWaitMax = 10 * time.Millisecond 272 | client.RetryMax = 2 273 | 274 | var checks int 275 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 276 | checks++ 277 | if err != nil && strings.Contains(err.Error(), "nonretryable") { 278 | return false, nil 279 | } 280 | return DefaultRetryPolicy(context.TODO(), resp, err) 281 | } 282 | 283 | // Mock server which always responds 200. 284 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 285 | w.WriteHeader(200) 286 | })) 287 | defer ts.Close() 288 | 289 | var shouldSucceed bool 290 | tests := []struct { 291 | name string 292 | handler ResponseHandlerFunc 293 | expectedChecks int // often 2x number of attempts since we check twice 294 | err string 295 | }{ 296 | { 297 | name: "nil handler", 298 | handler: nil, 299 | expectedChecks: 1, 300 | }, 301 | { 302 | name: "handler always succeeds", 303 | handler: func(*http.Response) error { 304 | return nil 305 | }, 306 | expectedChecks: 2, 307 | }, 308 | { 309 | name: "handler always fails in a retryable way", 310 | handler: func(*http.Response) error { 311 | return errors.New("retryable failure") 312 | }, 313 | expectedChecks: 6, 314 | }, 315 | { 316 | name: "handler always fails in a nonretryable way", 317 | handler: func(*http.Response) error { 318 | return errors.New("nonretryable failure") 319 | }, 320 | expectedChecks: 2, 321 | }, 322 | { 323 | name: "handler succeeds on second attempt", 324 | handler: func(*http.Response) error { 325 | if shouldSucceed { 326 | return nil 327 | } 328 | shouldSucceed = true 329 | return errors.New("retryable failure") 330 | }, 331 | expectedChecks: 4, 332 | }, 333 | } 334 | 335 | for _, tt := range tests { 336 | t.Run(tt.name, func(t *testing.T) { 337 | checks = 0 338 | shouldSucceed = false 339 | // Create the request 340 | req, err := NewRequest("GET", ts.URL, nil) 341 | if err != nil { 342 | t.Fatalf("err: %v", err) 343 | } 344 | req.SetResponseHandler(tt.handler) 345 | 346 | // Send the request. 347 | _, err = client.Do(req) 348 | if err != nil && !strings.Contains(err.Error(), tt.err) { 349 | t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error()) 350 | } 351 | if err == nil && tt.err != "" { 352 | t.Fatalf("no error, expected: %s", tt.err) 353 | } 354 | 355 | if checks != tt.expectedChecks { 356 | t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks) 357 | } 358 | }) 359 | } 360 | } 361 | 362 | func TestClient_Do_WithPrepareRetry(t *testing.T) { 363 | // Create the client. Use short retry windows so we fail faster. 364 | client := NewClient() 365 | client.RetryWaitMin = 10 * time.Millisecond 366 | client.RetryWaitMax = 10 * time.Millisecond 367 | client.RetryMax = 2 368 | 369 | var checks int 370 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 371 | checks++ 372 | if err != nil && strings.Contains(err.Error(), "nonretryable") { 373 | return false, nil 374 | } 375 | return DefaultRetryPolicy(context.TODO(), resp, err) 376 | } 377 | 378 | var prepareChecks int 379 | client.PrepareRetry = func(req *http.Request) error { 380 | prepareChecks++ 381 | req.Header.Set("foo", strconv.Itoa(prepareChecks)) 382 | return nil 383 | } 384 | 385 | // Mock server which always responds 200. 386 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 387 | w.WriteHeader(200) 388 | })) 389 | defer ts.Close() 390 | 391 | var shouldSucceed bool 392 | tests := []struct { 393 | name string 394 | handler ResponseHandlerFunc 395 | expectedChecks int // often 2x number of attempts since we check twice 396 | expectedPrepareChecks int 397 | err string 398 | }{ 399 | { 400 | name: "nil handler", 401 | handler: nil, 402 | expectedChecks: 1, 403 | expectedPrepareChecks: 0, 404 | }, 405 | { 406 | name: "handler always succeeds", 407 | handler: func(*http.Response) error { 408 | return nil 409 | }, 410 | expectedChecks: 2, 411 | expectedPrepareChecks: 0, 412 | }, 413 | { 414 | name: "handler always fails in a retryable way", 415 | handler: func(*http.Response) error { 416 | return errors.New("retryable failure") 417 | }, 418 | expectedChecks: 6, 419 | expectedPrepareChecks: 2, 420 | }, 421 | { 422 | name: "handler always fails in a nonretryable way", 423 | handler: func(*http.Response) error { 424 | return errors.New("nonretryable failure") 425 | }, 426 | expectedChecks: 2, 427 | expectedPrepareChecks: 0, 428 | }, 429 | { 430 | name: "handler succeeds on second attempt", 431 | handler: func(*http.Response) error { 432 | if shouldSucceed { 433 | return nil 434 | } 435 | shouldSucceed = true 436 | return errors.New("retryable failure") 437 | }, 438 | expectedChecks: 4, 439 | expectedPrepareChecks: 1, 440 | }, 441 | } 442 | 443 | for _, tt := range tests { 444 | t.Run(tt.name, func(t *testing.T) { 445 | checks = 0 446 | prepareChecks = 0 447 | shouldSucceed = false 448 | // Create the request 449 | req, err := NewRequest("GET", ts.URL, nil) 450 | if err != nil { 451 | t.Fatalf("err: %v", err) 452 | } 453 | req.SetResponseHandler(tt.handler) 454 | 455 | // Send the request. 456 | _, err = client.Do(req) 457 | if err != nil && !strings.Contains(err.Error(), tt.err) { 458 | t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error()) 459 | } 460 | if err == nil && tt.err != "" { 461 | t.Fatalf("no error, expected: %s", tt.err) 462 | } 463 | 464 | if checks != tt.expectedChecks { 465 | t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks) 466 | } 467 | 468 | if prepareChecks != tt.expectedPrepareChecks { 469 | t.Fatalf("expected %d attempts of prepare check, got %d attempts", tt.expectedPrepareChecks, prepareChecks) 470 | } 471 | header := req.Request.Header.Get("foo") 472 | if tt.expectedPrepareChecks == 0 && header != "" { 473 | t.Fatalf("expected no changes to request header 'foo', but got '%s'", header) 474 | } 475 | expectedHeader := strconv.Itoa(tt.expectedPrepareChecks) 476 | if tt.expectedPrepareChecks != 0 && header != expectedHeader { 477 | t.Fatalf("expected changes in request header 'foo' '%s', but got '%s'", expectedHeader, header) 478 | } 479 | 480 | }) 481 | } 482 | } 483 | 484 | func TestClient_Do_fails(t *testing.T) { 485 | // Mock server which always responds 500. 486 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 487 | w.WriteHeader(500) 488 | })) 489 | defer ts.Close() 490 | 491 | serverUrlWithBasicAuth, err := url.Parse(ts.URL) 492 | if err != nil { 493 | t.Fatalf("failed parsing test server url: %s", ts.URL) 494 | } 495 | serverUrlWithBasicAuth.User = url.UserPassword("user", "pasten") 496 | 497 | tests := []struct { 498 | url string 499 | name string 500 | cr CheckRetry 501 | err string 502 | }{ 503 | { 504 | url: ts.URL, 505 | name: "default_retry_policy", 506 | cr: DefaultRetryPolicy, 507 | err: "giving up after 3 attempt(s)", 508 | }, 509 | { 510 | url: serverUrlWithBasicAuth.String(), 511 | name: "default_retry_policy_url_with_basic_auth", 512 | cr: DefaultRetryPolicy, 513 | err: redactURL(serverUrlWithBasicAuth) + " giving up after 3 attempt(s)", 514 | }, 515 | { 516 | url: ts.URL, 517 | name: "error_propagated_retry_policy", 518 | cr: ErrorPropagatedRetryPolicy, 519 | err: "giving up after 3 attempt(s): unexpected HTTP status 500 Internal Server Error", 520 | }, 521 | } 522 | 523 | for _, tt := range tests { 524 | t.Run(tt.name, func(t *testing.T) { 525 | // Create the client. Use short retry windows so we fail faster. 526 | client := NewClient() 527 | client.RetryWaitMin = 10 * time.Millisecond 528 | client.RetryWaitMax = 10 * time.Millisecond 529 | client.CheckRetry = tt.cr 530 | client.RetryMax = 2 531 | 532 | // Create the request 533 | req, err := NewRequest("POST", tt.url, nil) 534 | if err != nil { 535 | t.Fatalf("err: %v", err) 536 | } 537 | 538 | // Send the request. 539 | _, err = client.Do(req) 540 | if err == nil || !strings.HasSuffix(err.Error(), tt.err) { 541 | t.Fatalf("expected %#v, got: %#v", tt.err, err) 542 | } 543 | }) 544 | } 545 | } 546 | 547 | func TestClient_Get(t *testing.T) { 548 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 549 | if r.Method != "GET" { 550 | t.Fatalf("bad method: %s", r.Method) 551 | } 552 | if r.RequestURI != "/foo/bar" { 553 | t.Fatalf("bad uri: %s", r.RequestURI) 554 | } 555 | w.WriteHeader(200) 556 | })) 557 | defer ts.Close() 558 | 559 | // Make the request. 560 | resp, err := NewClient().Get(ts.URL + "/foo/bar") 561 | if err != nil { 562 | t.Fatalf("err: %v", err) 563 | } 564 | resp.Body.Close() 565 | } 566 | 567 | func TestClient_RequestLogHook(t *testing.T) { 568 | t.Run("RequestLogHook successfully called with default Logger", func(t *testing.T) { 569 | testClientRequestLogHook(t, defaultLogger) 570 | }) 571 | t.Run("RequestLogHook successfully called with nil Logger", func(t *testing.T) { 572 | testClientRequestLogHook(t, nil) 573 | }) 574 | t.Run("RequestLogHook successfully called with nil typed Logger", func(t *testing.T) { 575 | testClientRequestLogHook(t, Logger(nil)) 576 | }) 577 | t.Run("RequestLogHook successfully called with nil typed LeveledLogger", func(t *testing.T) { 578 | testClientRequestLogHook(t, LeveledLogger(nil)) 579 | }) 580 | } 581 | 582 | func testClientRequestLogHook(t *testing.T, logger interface{}) { 583 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 584 | if r.Method != "GET" { 585 | t.Fatalf("bad method: %s", r.Method) 586 | } 587 | if r.RequestURI != "/foo/bar" { 588 | t.Fatalf("bad uri: %s", r.RequestURI) 589 | } 590 | w.WriteHeader(200) 591 | })) 592 | defer ts.Close() 593 | 594 | retries := -1 595 | testURIPath := "/foo/bar" 596 | 597 | client := NewClient() 598 | client.Logger = logger 599 | client.RequestLogHook = func(logger Logger, req *http.Request, retry int) { 600 | retries = retry 601 | 602 | if logger != client.Logger { 603 | t.Fatalf("Client logger was not passed to logging hook") 604 | } 605 | 606 | dumpBytes, err := httputil.DumpRequestOut(req, false) 607 | if err != nil { 608 | t.Fatal("Dumping requests failed") 609 | } 610 | 611 | dumpString := string(dumpBytes) 612 | if !strings.Contains(dumpString, "GET "+testURIPath) { 613 | t.Fatalf("Bad request dump:\n%s", dumpString) 614 | } 615 | } 616 | 617 | // Make the request. 618 | resp, err := client.Get(ts.URL + testURIPath) 619 | if err != nil { 620 | t.Fatalf("err: %v", err) 621 | } 622 | resp.Body.Close() 623 | 624 | if retries < 0 { 625 | t.Fatal("Logging hook was not called") 626 | } 627 | } 628 | 629 | func TestClient_ResponseLogHook(t *testing.T) { 630 | t.Run("ResponseLogHook successfully called with hclog Logger", func(t *testing.T) { 631 | buf := new(bytes.Buffer) 632 | l := hclog.New(&hclog.LoggerOptions{ 633 | Output: buf, 634 | }) 635 | testClientResponseLogHook(t, l, buf) 636 | }) 637 | t.Run("ResponseLogHook successfully called with nil Logger", func(t *testing.T) { 638 | buf := new(bytes.Buffer) 639 | testClientResponseLogHook(t, nil, buf) 640 | }) 641 | t.Run("ResponseLogHook successfully called with nil typed Logger", func(t *testing.T) { 642 | buf := new(bytes.Buffer) 643 | testClientResponseLogHook(t, Logger(nil), buf) 644 | }) 645 | t.Run("ResponseLogHook successfully called with nil typed LeveledLogger", func(t *testing.T) { 646 | buf := new(bytes.Buffer) 647 | testClientResponseLogHook(t, LeveledLogger(nil), buf) 648 | }) 649 | } 650 | 651 | func testClientResponseLogHook(t *testing.T, l interface{}, buf *bytes.Buffer) { 652 | passAfter := time.Now().Add(100 * time.Millisecond) 653 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 654 | if time.Now().After(passAfter) { 655 | w.WriteHeader(200) 656 | if _, err := w.Write([]byte("test_200_body")); err != nil { 657 | t.Fatalf("failed to write: %v", err) 658 | } 659 | } else { 660 | w.WriteHeader(500) 661 | if _, err := w.Write([]byte("test_500_body")); err != nil { 662 | t.Fatalf("failed to write: %v", err) 663 | } 664 | } 665 | })) 666 | defer ts.Close() 667 | 668 | client := NewClient() 669 | 670 | client.Logger = l 671 | client.RetryWaitMin = 10 * time.Millisecond 672 | client.RetryWaitMax = 10 * time.Millisecond 673 | client.RetryMax = 15 674 | client.ResponseLogHook = func(logger Logger, resp *http.Response) { 675 | if resp.StatusCode == 200 { 676 | successLog := "test_log_pass" 677 | // Log something when we get a 200 678 | if logger != nil { 679 | logger.Printf(successLog) 680 | } else { 681 | buf.WriteString(successLog) 682 | } 683 | } else { 684 | // Log the response body when we get a 500 685 | body, err := io.ReadAll(resp.Body) 686 | if err != nil { 687 | t.Fatalf("err: %v", err) 688 | } 689 | failLog := string(body) 690 | if logger != nil { 691 | logger.Printf(failLog) 692 | } else { 693 | buf.WriteString(failLog) 694 | } 695 | } 696 | } 697 | 698 | // Perform the request. Exits when we finally get a 200. 699 | resp, err := client.Get(ts.URL) 700 | if err != nil { 701 | t.Fatalf("err: %v", err) 702 | } 703 | 704 | // Make sure we can read the response body still, since we did not 705 | // read or close it from the response log hook. 706 | body, err := io.ReadAll(resp.Body) 707 | if err != nil { 708 | t.Fatalf("err: %v", err) 709 | } 710 | if string(body) != "test_200_body" { 711 | t.Fatalf("expect %q, got %q", "test_200_body", string(body)) 712 | } 713 | 714 | // Make sure we wrote to the logger on callbacks. 715 | out := buf.String() 716 | if !strings.Contains(out, "test_log_pass") { 717 | t.Fatalf("expect response callback on 200: %q", out) 718 | } 719 | if !strings.Contains(out, "test_500_body") { 720 | t.Fatalf("expect response callback on 500: %q", out) 721 | } 722 | } 723 | 724 | func TestClient_NewRequestWithContext(t *testing.T) { 725 | ctx, cancel := context.WithCancel(context.Background()) 726 | defer cancel() 727 | r, err := NewRequestWithContext(ctx, http.MethodGet, "/abc", nil) 728 | if err != nil { 729 | t.Fatalf("err: %v", err) 730 | } 731 | if r.Context() != ctx { 732 | t.Fatal("Context must be set") 733 | } 734 | } 735 | 736 | func TestClient_RequestWithContext(t *testing.T) { 737 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 738 | w.WriteHeader(200) 739 | if _, err := w.Write([]byte("test_200_body")); err != nil { 740 | t.Fatalf("failed to write: %v", err) 741 | } 742 | })) 743 | defer ts.Close() 744 | 745 | req, err := NewRequest(http.MethodGet, ts.URL, nil) 746 | if err != nil { 747 | t.Fatalf("err: %v", err) 748 | } 749 | ctx, cancel := context.WithCancel(req.Request.Context()) 750 | reqCtx := req.WithContext(ctx) 751 | if reqCtx == req { 752 | t.Fatal("WithContext must return a new Request object") 753 | } 754 | 755 | client := NewClient() 756 | 757 | called := 0 758 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 759 | called++ 760 | return DefaultRetryPolicy(reqCtx.Request.Context(), resp, err) 761 | } 762 | 763 | cancel() 764 | _, err = client.Do(reqCtx) 765 | 766 | if called != 1 { 767 | t.Fatalf("CheckRetry called %d times, expected 1", called) 768 | } 769 | 770 | e := fmt.Sprintf("GET %s giving up after 1 attempt(s): %s", ts.URL, context.Canceled.Error()) 771 | 772 | if err.Error() != e { 773 | t.Fatalf("Expected err to contain %s, got: %v", e, err) 774 | } 775 | } 776 | 777 | func TestClient_CheckRetry(t *testing.T) { 778 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 779 | http.Error(w, "test_500_body", http.StatusInternalServerError) 780 | })) 781 | defer ts.Close() 782 | 783 | client := NewClient() 784 | 785 | retryErr := errors.New("retryError") 786 | called := 0 787 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 788 | if called < 1 { 789 | called++ 790 | return DefaultRetryPolicy(context.TODO(), resp, err) 791 | } 792 | 793 | return false, retryErr 794 | } 795 | 796 | // CheckRetry should return our retryErr value and stop the retry loop. 797 | _, err := client.Get(ts.URL) 798 | 799 | if called != 1 { 800 | t.Fatalf("CheckRetry called %d times, expected 1", called) 801 | } 802 | 803 | if err.Error() != fmt.Sprintf("GET %s giving up after 2 attempt(s): retryError", ts.URL) { 804 | t.Fatalf("Expected retryError, got:%v", err) 805 | } 806 | } 807 | 808 | func testStaticTime(t *testing.T) { 809 | timeNow = func() time.Time { 810 | now, err := time.Parse(time.RFC1123, "Fri, 31 Dec 1999 23:59:57 GMT") 811 | if err != nil { 812 | panic(err) 813 | } 814 | return now 815 | } 816 | t.Cleanup(func() { 817 | timeNow = time.Now 818 | }) 819 | } 820 | 821 | func TestParseRetryAfterHeader(t *testing.T) { 822 | testStaticTime(t) 823 | tests := []struct { 824 | name string 825 | headers []string 826 | sleep time.Duration 827 | ok bool 828 | }{ 829 | {"seconds", []string{"2"}, time.Second * 2, true}, 830 | {"date", []string{"Fri, 31 Dec 1999 23:59:59 GMT"}, time.Second * 2, true}, 831 | {"past-date", []string{"Fri, 31 Dec 1999 23:59:00 GMT"}, 0, true}, 832 | {"nil", nil, 0, false}, 833 | {"two-headers", []string{"2", "3"}, time.Second * 2, true}, 834 | {"empty", []string{""}, 0, false}, 835 | {"negative", []string{"-2"}, 0, false}, 836 | {"bad-date", []string{"Fri, 32 Dec 1999 23:59:59 GMT"}, 0, false}, 837 | {"bad-date-format", []string{"badbadbad"}, 0, false}, 838 | } 839 | for _, test := range tests { 840 | t.Run(test.name, func(t *testing.T) { 841 | sleep, ok := parseRetryAfterHeader(test.headers) 842 | if ok != test.ok { 843 | t.Fatalf("expected ok=%t, got ok=%t", test.ok, ok) 844 | } 845 | if sleep != test.sleep { 846 | t.Fatalf("expected sleep=%v, got sleep=%v", test.sleep, sleep) 847 | } 848 | }) 849 | } 850 | } 851 | 852 | func TestClient_DefaultBackoff(t *testing.T) { 853 | testStaticTime(t) 854 | tests := []struct { 855 | name string 856 | code int 857 | retryHeader string 858 | }{ 859 | {"http_429_seconds", http.StatusTooManyRequests, "2"}, 860 | {"http_429_date", http.StatusTooManyRequests, "Fri, 31 Dec 1999 23:59:59 GMT"}, 861 | {"http_503_seconds", http.StatusServiceUnavailable, "2"}, 862 | {"http_503_date", http.StatusServiceUnavailable, "Fri, 31 Dec 1999 23:59:59 GMT"}, 863 | } 864 | for _, test := range tests { 865 | t.Run(test.name, func(t *testing.T) { 866 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 867 | w.Header().Set("Retry-After", test.retryHeader) 868 | http.Error(w, fmt.Sprintf("test_%d_body", test.code), test.code) 869 | })) 870 | defer ts.Close() 871 | 872 | client := NewClient() 873 | 874 | var retryAfter time.Duration 875 | retryable := false 876 | 877 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 878 | retryable, _ = DefaultRetryPolicy(context.Background(), resp, err) 879 | retryAfter = DefaultBackoff(client.RetryWaitMin, client.RetryWaitMax, 1, resp) 880 | return false, nil 881 | } 882 | 883 | _, err := client.Get(ts.URL) 884 | if err != nil { 885 | t.Fatalf("expected no errors since retryable") 886 | } 887 | 888 | if !retryable { 889 | t.Fatal("Since the error is recoverable, the default policy shall return true") 890 | } 891 | 892 | if retryAfter != 2*time.Second { 893 | t.Fatalf("The header Retry-After specified 2 seconds, and shall not be %d seconds", retryAfter/time.Second) 894 | } 895 | }) 896 | } 897 | } 898 | 899 | func TestClient_DefaultRetryPolicy_TLS(t *testing.T) { 900 | ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 901 | w.WriteHeader(200) 902 | })) 903 | defer ts.Close() 904 | 905 | attempts := 0 906 | client := NewClient() 907 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 908 | attempts++ 909 | return DefaultRetryPolicy(context.TODO(), resp, err) 910 | } 911 | 912 | _, err := client.Get(ts.URL) 913 | if err == nil { 914 | t.Fatalf("expected x509 error, got nil") 915 | } 916 | if attempts != 1 { 917 | t.Fatalf("expected 1 attempt, got %d", attempts) 918 | } 919 | } 920 | 921 | func TestClient_DefaultRetryPolicy_redirects(t *testing.T) { 922 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 923 | http.Redirect(w, r, "/", http.StatusFound) 924 | })) 925 | defer ts.Close() 926 | 927 | attempts := 0 928 | client := NewClient() 929 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 930 | attempts++ 931 | return DefaultRetryPolicy(context.TODO(), resp, err) 932 | } 933 | 934 | _, err := client.Get(ts.URL) 935 | if err == nil { 936 | t.Fatalf("expected redirect error, got nil") 937 | } 938 | if attempts != 1 { 939 | t.Fatalf("expected 1 attempt, got %d", attempts) 940 | } 941 | } 942 | 943 | func TestClient_DefaultRetryPolicy_invalidscheme(t *testing.T) { 944 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 945 | w.WriteHeader(200) 946 | })) 947 | defer ts.Close() 948 | 949 | attempts := 0 950 | client := NewClient() 951 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 952 | attempts++ 953 | return DefaultRetryPolicy(context.TODO(), resp, err) 954 | } 955 | 956 | url := strings.Replace(ts.URL, "http", "ftp", 1) 957 | _, err := client.Get(url) 958 | if err == nil { 959 | t.Fatalf("expected scheme error, got nil") 960 | } 961 | if attempts != 1 { 962 | t.Fatalf("expected 1 attempt, got %d", attempts) 963 | } 964 | } 965 | 966 | func TestClient_DefaultRetryPolicy_invalidheadername(t *testing.T) { 967 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 968 | w.WriteHeader(200) 969 | })) 970 | defer ts.Close() 971 | 972 | attempts := 0 973 | client := NewClient() 974 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 975 | attempts++ 976 | return DefaultRetryPolicy(context.TODO(), resp, err) 977 | } 978 | 979 | req, err := http.NewRequest(http.MethodGet, ts.URL, nil) 980 | if err != nil { 981 | t.Fatalf("err: %v", err) 982 | } 983 | req.Header.Set("Header-Name-\033", "header value") 984 | _, err = client.StandardClient().Do(req) 985 | if err == nil { 986 | t.Fatalf("expected header error, got nil") 987 | } 988 | if attempts != 1 { 989 | t.Fatalf("expected 1 attempt, got %d", attempts) 990 | } 991 | } 992 | 993 | func TestClient_DefaultRetryPolicy_invalidheadervalue(t *testing.T) { 994 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 995 | w.WriteHeader(200) 996 | })) 997 | defer ts.Close() 998 | 999 | attempts := 0 1000 | client := NewClient() 1001 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 1002 | attempts++ 1003 | return DefaultRetryPolicy(context.TODO(), resp, err) 1004 | } 1005 | 1006 | req, err := http.NewRequest(http.MethodGet, ts.URL, nil) 1007 | if err != nil { 1008 | t.Fatalf("err: %v", err) 1009 | } 1010 | req.Header.Set("Header-Name", "bad header value \033") 1011 | _, err = client.StandardClient().Do(req) 1012 | if err == nil { 1013 | t.Fatalf("expected header value error, got nil") 1014 | } 1015 | if attempts != 1 { 1016 | t.Fatalf("expected 1 attempt, got %d", attempts) 1017 | } 1018 | } 1019 | 1020 | func TestClient_CheckRetryStop(t *testing.T) { 1021 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1022 | http.Error(w, "test_500_body", http.StatusInternalServerError) 1023 | })) 1024 | defer ts.Close() 1025 | 1026 | client := NewClient() 1027 | 1028 | // Verify that this stops retries on the first try, with no errors from the client. 1029 | called := 0 1030 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 1031 | called++ 1032 | return false, nil 1033 | } 1034 | 1035 | _, err := client.Get(ts.URL) 1036 | 1037 | if called != 1 { 1038 | t.Fatalf("CheckRetry called %d times, expected 1", called) 1039 | } 1040 | 1041 | if err != nil { 1042 | t.Fatalf("Expected no error, got:%v", err) 1043 | } 1044 | } 1045 | 1046 | func TestClient_Head(t *testing.T) { 1047 | // Mock server which always responds 200. 1048 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1049 | if r.Method != "HEAD" { 1050 | t.Fatalf("bad method: %s", r.Method) 1051 | } 1052 | if r.RequestURI != "/foo/bar" { 1053 | t.Fatalf("bad uri: %s", r.RequestURI) 1054 | } 1055 | w.WriteHeader(200) 1056 | })) 1057 | defer ts.Close() 1058 | 1059 | // Make the request. 1060 | resp, err := NewClient().Head(ts.URL + "/foo/bar") 1061 | if err != nil { 1062 | t.Fatalf("err: %v", err) 1063 | } 1064 | resp.Body.Close() 1065 | } 1066 | 1067 | func TestClient_Post(t *testing.T) { 1068 | // Mock server which always responds 200. 1069 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1070 | if r.Method != "POST" { 1071 | t.Fatalf("bad method: %s", r.Method) 1072 | } 1073 | if r.RequestURI != "/foo/bar" { 1074 | t.Fatalf("bad uri: %s", r.RequestURI) 1075 | } 1076 | if ct := r.Header.Get("Content-Type"); ct != "application/json" { 1077 | t.Fatalf("bad content-type: %s", ct) 1078 | } 1079 | 1080 | // Check the payload 1081 | body, err := io.ReadAll(r.Body) 1082 | if err != nil { 1083 | t.Fatalf("err: %s", err) 1084 | } 1085 | expected := []byte(`{"hello":"world"}`) 1086 | if !bytes.Equal(body, expected) { 1087 | t.Fatalf("bad: %v", body) 1088 | } 1089 | 1090 | w.WriteHeader(200) 1091 | })) 1092 | defer ts.Close() 1093 | 1094 | // Make the request. 1095 | resp, err := NewClient().Post( 1096 | ts.URL+"/foo/bar", 1097 | "application/json", 1098 | strings.NewReader(`{"hello":"world"}`)) 1099 | if err != nil { 1100 | t.Fatalf("err: %v", err) 1101 | } 1102 | resp.Body.Close() 1103 | } 1104 | 1105 | func TestClient_PostForm(t *testing.T) { 1106 | // Mock server which always responds 200. 1107 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1108 | if r.Method != "POST" { 1109 | t.Fatalf("bad method: %s", r.Method) 1110 | } 1111 | if r.RequestURI != "/foo/bar" { 1112 | t.Fatalf("bad uri: %s", r.RequestURI) 1113 | } 1114 | if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { 1115 | t.Fatalf("bad content-type: %s", ct) 1116 | } 1117 | 1118 | // Check the payload 1119 | body, err := io.ReadAll(r.Body) 1120 | if err != nil { 1121 | t.Fatalf("err: %s", err) 1122 | } 1123 | expected := []byte(`hello=world`) 1124 | if !bytes.Equal(body, expected) { 1125 | t.Fatalf("bad: %v", body) 1126 | } 1127 | 1128 | w.WriteHeader(200) 1129 | })) 1130 | defer ts.Close() 1131 | 1132 | // Create the form data. 1133 | form, err := url.ParseQuery("hello=world") 1134 | if err != nil { 1135 | t.Fatalf("err: %v", err) 1136 | } 1137 | 1138 | // Make the request. 1139 | resp, err := NewClient().PostForm(ts.URL+"/foo/bar", form) 1140 | if err != nil { 1141 | t.Fatalf("err: %v", err) 1142 | } 1143 | resp.Body.Close() 1144 | } 1145 | 1146 | func TestBackoff(t *testing.T) { 1147 | type tcase struct { 1148 | min time.Duration 1149 | max time.Duration 1150 | i int 1151 | expect time.Duration 1152 | } 1153 | cases := []tcase{ 1154 | { 1155 | time.Second, 1156 | 5 * time.Minute, 1157 | 0, 1158 | time.Second, 1159 | }, 1160 | { 1161 | time.Second, 1162 | 5 * time.Minute, 1163 | 1, 1164 | 2 * time.Second, 1165 | }, 1166 | { 1167 | time.Second, 1168 | 5 * time.Minute, 1169 | 2, 1170 | 4 * time.Second, 1171 | }, 1172 | { 1173 | time.Second, 1174 | 5 * time.Minute, 1175 | 3, 1176 | 8 * time.Second, 1177 | }, 1178 | { 1179 | time.Second, 1180 | 5 * time.Minute, 1181 | 63, 1182 | 5 * time.Minute, 1183 | }, 1184 | { 1185 | time.Second, 1186 | 5 * time.Minute, 1187 | 128, 1188 | 5 * time.Minute, 1189 | }, 1190 | } 1191 | 1192 | for _, tc := range cases { 1193 | if v := DefaultBackoff(tc.min, tc.max, tc.i, nil); v != tc.expect { 1194 | t.Fatalf("bad: %#v -> %s", tc, v) 1195 | } 1196 | } 1197 | } 1198 | 1199 | func TestClient_BackoffCustom(t *testing.T) { 1200 | var retries int32 1201 | 1202 | client := NewClient() 1203 | client.Backoff = func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { 1204 | atomic.AddInt32(&retries, 1) 1205 | return time.Millisecond * 1 1206 | } 1207 | 1208 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1209 | if atomic.LoadInt32(&retries) == int32(client.RetryMax) { 1210 | w.WriteHeader(200) 1211 | return 1212 | } 1213 | w.WriteHeader(500) 1214 | })) 1215 | defer ts.Close() 1216 | 1217 | // Make the request. 1218 | resp, err := client.Get(ts.URL + "/foo/bar") 1219 | if err != nil { 1220 | t.Fatalf("err: %v", err) 1221 | } 1222 | resp.Body.Close() 1223 | if retries != int32(client.RetryMax) { 1224 | t.Fatalf("expected retries: %d != %d", client.RetryMax, retries) 1225 | } 1226 | } 1227 | 1228 | func TestClient_StandardClient(t *testing.T) { 1229 | // Create a retryable HTTP client. 1230 | client := NewClient() 1231 | 1232 | // Get a standard client. 1233 | standard := client.StandardClient() 1234 | 1235 | // Ensure the underlying retrying client is set properly. 1236 | if v := standard.Transport.(*RoundTripper).Client; v != client { 1237 | t.Fatalf("expected %v, got %v", client, v) 1238 | } 1239 | } 1240 | 1241 | func TestClient_RedirectWithBody(t *testing.T) { 1242 | var redirects int32 1243 | // Mock server which always responds 200. 1244 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1245 | switch r.RequestURI { 1246 | case "/redirect": 1247 | w.Header().Set("Location", "/target") 1248 | w.WriteHeader(http.StatusTemporaryRedirect) 1249 | case "/target": 1250 | atomic.AddInt32(&redirects, 1) 1251 | w.WriteHeader(http.StatusCreated) 1252 | default: 1253 | t.Fatalf("bad uri: %s", r.RequestURI) 1254 | } 1255 | })) 1256 | defer ts.Close() 1257 | 1258 | client := NewClient() 1259 | client.RequestLogHook = func(logger Logger, req *http.Request, retryNumber int) { 1260 | if _, err := req.GetBody(); err != nil { 1261 | t.Fatalf("unexpected error with GetBody: %v", err) 1262 | } 1263 | } 1264 | // create a request with a body 1265 | req, err := NewRequest(http.MethodPost, ts.URL+"/redirect", strings.NewReader(`{"foo":"bar"}`)) 1266 | if err != nil { 1267 | t.Fatalf("err: %v", err) 1268 | } 1269 | 1270 | resp, err := client.Do(req) 1271 | if err != nil { 1272 | t.Fatalf("err: %v", err) 1273 | } 1274 | resp.Body.Close() 1275 | 1276 | if resp.StatusCode != http.StatusCreated { 1277 | t.Fatalf("expected status code 201, got: %d", resp.StatusCode) 1278 | } 1279 | 1280 | // now one without a body 1281 | if err := req.SetBody(nil); err != nil { 1282 | t.Fatalf("err: %v", err) 1283 | } 1284 | 1285 | resp, err = client.Do(req) 1286 | if err != nil { 1287 | t.Fatalf("err: %v", err) 1288 | } 1289 | resp.Body.Close() 1290 | 1291 | if resp.StatusCode != http.StatusCreated { 1292 | t.Fatalf("expected status code 201, got: %d", resp.StatusCode) 1293 | } 1294 | 1295 | if atomic.LoadInt32(&redirects) != 2 { 1296 | t.Fatalf("Expected the client to be redirected 2 times, got: %d", atomic.LoadInt32(&redirects)) 1297 | } 1298 | } 1299 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-retryablehttp 2 | 3 | require ( 4 | github.com/hashicorp/go-cleanhttp v0.5.2 5 | github.com/hashicorp/go-hclog v1.6.3 6 | ) 7 | 8 | require ( 9 | github.com/fatih/color v1.16.0 // indirect 10 | github.com/mattn/go-colorable v0.1.13 // indirect 11 | github.com/mattn/go-isatty v0.0.20 // indirect 12 | golang.org/x/sys v0.20.0 // indirect 13 | ) 14 | 15 | go 1.23 16 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= 5 | github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= 6 | github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= 7 | github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= 8 | github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= 9 | github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= 10 | github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= 11 | github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= 12 | github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= 13 | github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= 14 | github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= 15 | github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= 16 | github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= 17 | github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= 18 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 19 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 22 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 23 | github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= 24 | github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= 25 | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 26 | golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 27 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 28 | golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 29 | golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 30 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 31 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 32 | golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= 33 | golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 34 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 35 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 36 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 37 | -------------------------------------------------------------------------------- /roundtripper.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package retryablehttp 5 | 6 | import ( 7 | "errors" 8 | "net/http" 9 | "net/url" 10 | "sync" 11 | ) 12 | 13 | // RoundTripper implements the http.RoundTripper interface, using a retrying 14 | // HTTP client to execute requests. 15 | // 16 | // It is important to note that retryablehttp doesn't always act exactly as a 17 | // RoundTripper should. This is highly dependent on the retryable client's 18 | // configuration. 19 | type RoundTripper struct { 20 | // The client to use during requests. If nil, the default retryablehttp 21 | // client and settings will be used. 22 | Client *Client 23 | 24 | // once ensures that the logic to initialize the default client runs at 25 | // most once, in a single thread. 26 | once sync.Once 27 | } 28 | 29 | // init initializes the underlying retryable client. 30 | func (rt *RoundTripper) init() { 31 | if rt.Client == nil { 32 | rt.Client = NewClient() 33 | } 34 | } 35 | 36 | // RoundTrip satisfies the http.RoundTripper interface. 37 | func (rt *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 38 | rt.once.Do(rt.init) 39 | 40 | // Convert the request to be retryable. 41 | retryableReq, err := FromRequest(req) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | // Execute the request. 47 | resp, err := rt.Client.Do(retryableReq) 48 | // If we got an error returned by standard library's `Do` method, unwrap it 49 | // otherwise we will wind up erroneously re-nesting the error. 50 | if _, ok := err.(*url.Error); ok { 51 | return resp, errors.Unwrap(err) 52 | } 53 | 54 | return resp, err 55 | } 56 | -------------------------------------------------------------------------------- /roundtripper_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package retryablehttp 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "io" 10 | "net" 11 | "net/http" 12 | "net/http/httptest" 13 | "net/url" 14 | "reflect" 15 | "sync/atomic" 16 | "testing" 17 | ) 18 | 19 | func TestRoundTripper_implements(t *testing.T) { 20 | // Compile-time proof of interface satisfaction. 21 | var _ http.RoundTripper = &RoundTripper{} 22 | } 23 | 24 | func TestRoundTripper_init(t *testing.T) { 25 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 26 | w.WriteHeader(200) 27 | })) 28 | defer ts.Close() 29 | 30 | // Start with a new empty RoundTripper. 31 | rt := &RoundTripper{} 32 | 33 | // RoundTrip once. 34 | req, _ := http.NewRequest("GET", ts.URL, nil) 35 | if _, err := rt.RoundTrip(req); err != nil { 36 | t.Fatal(err) 37 | } 38 | 39 | // Check that the Client was initialized. 40 | if rt.Client == nil { 41 | t.Fatal("expected rt.Client to be initialized") 42 | } 43 | 44 | // Save the Client for later comparison. 45 | initialClient := rt.Client 46 | 47 | // RoundTrip again. 48 | req, _ = http.NewRequest("GET", ts.URL, nil) 49 | if _, err := rt.RoundTrip(req); err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | // Check that the underlying Client is unchanged. 54 | if rt.Client != initialClient { 55 | t.Fatalf("expected %v, got %v", initialClient, rt.Client) 56 | } 57 | } 58 | 59 | func TestRoundTripper_RoundTrip(t *testing.T) { 60 | var reqCount int32 = 0 61 | 62 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 63 | reqNo := atomic.AddInt32(&reqCount, 1) 64 | if reqNo < 3 { 65 | w.WriteHeader(404) 66 | } else { 67 | w.WriteHeader(200) 68 | if _, err := w.Write([]byte("success!")); err != nil { 69 | t.Fatalf("failed to write: %v", err) 70 | } 71 | } 72 | })) 73 | defer ts.Close() 74 | 75 | // Make a client with some custom settings to verify they are used. 76 | retryClient := NewClient() 77 | retryClient.CheckRetry = func(_ context.Context, resp *http.Response, _ error) (bool, error) { 78 | return resp.StatusCode == 404, nil 79 | } 80 | 81 | // Get the standard client and execute the request. 82 | client := retryClient.StandardClient() 83 | resp, err := client.Get(ts.URL) 84 | if err != nil { 85 | t.Fatal(err) 86 | } 87 | defer resp.Body.Close() 88 | 89 | // Check the response to ensure the client behaved as expected. 90 | if resp.StatusCode != 200 { 91 | t.Fatalf("expected 200, got %d", resp.StatusCode) 92 | } 93 | if v, err := io.ReadAll(resp.Body); err != nil { 94 | t.Fatal(err) 95 | } else if string(v) != "success!" { 96 | t.Fatalf("expected %q, got %q", "success!", v) 97 | } 98 | } 99 | 100 | func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) { 101 | // Make a client with some custom settings to verify they are used. 102 | retryClient := NewClient() 103 | retryClient.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 104 | if err != nil { 105 | return true, err 106 | } 107 | 108 | return false, nil 109 | } 110 | 111 | retryClient.ErrorHandler = PassthroughErrorHandler 112 | 113 | expectedError := &url.Error{ 114 | Op: "Get", 115 | URL: "http://999.999.999.999:999/", 116 | Err: &net.OpError{ 117 | Op: "dial", 118 | Net: "tcp", 119 | Err: &net.DNSError{ 120 | Name: "999.999.999.999", 121 | Err: "no such host", 122 | IsNotFound: true, 123 | }, 124 | }, 125 | } 126 | 127 | // Get the standard client and execute the request. 128 | client := retryClient.StandardClient() 129 | _, err := client.Get("http://999.999.999.999:999/") 130 | 131 | // assert expectations 132 | if !reflect.DeepEqual(expectedError, normalizeError(err)) { 133 | t.Fatalf("expected %q, got %q", expectedError, err) 134 | } 135 | } 136 | 137 | func normalizeError(err error) error { 138 | var dnsError *net.DNSError 139 | 140 | if errors.As(err, &dnsError) { 141 | // this field is populated with the DNS server on on CI, but not locally 142 | dnsError.Server = "" 143 | } 144 | 145 | return err 146 | } 147 | --------------------------------------------------------------------------------