├── .gitignore ├── go.mod ├── Makefile ├── go.sum ├── roundtripper.go ├── .circleci └── config.yml ├── README.md ├── roundtripper_test.go ├── LICENSE ├── client_test.go └── client.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.iml 3 | *.test 4 | .vscode/ -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-retryablehttp 2 | 3 | require ( 4 | github.com/hashicorp/go-cleanhttp v0.5.2 5 | github.com/hashicorp/go-hclog v0.9.2 6 | ) 7 | 8 | go 1.13 9 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | default: test 2 | 3 | test: 4 | go vet ./... 5 | go test -race ./... 6 | 7 | updatedeps: 8 | go get -f -t -u ./... 9 | go get -f -u ./... 10 | 11 | .PHONY: default test updatedeps 12 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= 4 | github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= 5 | github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= 6 | github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= 7 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 8 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 9 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 10 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 11 | -------------------------------------------------------------------------------- /roundtripper.go: -------------------------------------------------------------------------------- 1 | package retryablehttp 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/url" 7 | "sync" 8 | ) 9 | 10 | // RoundTripper implements the http.RoundTripper interface, using a retrying 11 | // HTTP client to execute requests. 12 | // 13 | // It is important to note that retryablehttp doesn't always act exactly as a 14 | // RoundTripper should. This is highly dependent on the retryable client's 15 | // configuration. 16 | type RoundTripper struct { 17 | // The client to use during requests. If nil, the default retryablehttp 18 | // client and settings will be used. 19 | Client *Client 20 | 21 | // once ensures that the logic to initialize the default client runs at 22 | // most once, in a single thread. 23 | once sync.Once 24 | } 25 | 26 | // init initializes the underlying retryable client. 27 | func (rt *RoundTripper) init() { 28 | if rt.Client == nil { 29 | rt.Client = NewClient() 30 | } 31 | } 32 | 33 | // RoundTrip satisfies the http.RoundTripper interface. 34 | func (rt *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 35 | rt.once.Do(rt.init) 36 | 37 | // Convert the request to be retryable. 38 | retryableReq, err := FromRequest(req) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | // Execute the request. 44 | resp, err := rt.Client.Do(retryableReq) 45 | // If we got an error returned by standard library's `Do` method, unwrap it 46 | // otherwise we will wind up erroneously re-nesting the error. 47 | if _, ok := err.(*url.Error); ok { 48 | return resp, errors.Unwrap(err) 49 | } 50 | 51 | return resp, err 52 | } 53 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | go: circleci/go@1.1.1 5 | 6 | references: 7 | environments: 8 | tmp: &TEST_RESULTS_PATH /tmp/test-results 9 | 10 | environment: &ENVIRONMENT 11 | TEST_RESULTS: "/tmp/test-results" 12 | 13 | jobs: 14 | run-tests: 15 | executor: 16 | name: go/default 17 | tag: << parameters.go-version >> 18 | parameters: 19 | go-version: 20 | type: string 21 | environment: 22 | TEST_RESULTS: *TEST_RESULTS_PATH 23 | steps: 24 | - checkout 25 | - run: mkdir -p $TEST_RESULTS/go-retryablyhttp 26 | - go/load-cache 27 | - go/mod-download 28 | - go/save-cache 29 | - run: 30 | name: Run go format 31 | command: | 32 | files=$(go fmt ./...) 33 | if [ -n "$files" ]; then 34 | echo "The following file(s) do not conform to go fmt:" 35 | echo "$files" 36 | exit 1 37 | fi 38 | - run: 39 | name: Run tests with gotestsum 40 | command: | 41 | PACKAGE_NAMES=$(go list ./...) 42 | gotestsum --format=short-verbose --junitfile $TEST_RESULTS/go-retryablyhttp/gotestsum-report.xml -- $PACKAGE_NAMES 43 | - store_test_results: 44 | path: *TEST_RESULTS_PATH 45 | - store_artifacts: 46 | path: *TEST_RESULTS_PATH 47 | 48 | workflows: 49 | go-retryablehttp: 50 | jobs: 51 | - run-tests: 52 | matrix: 53 | parameters: 54 | go-version: ["1.14.2"] 55 | name: test-go-<< matrix.go-version >> 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | go-retryablehttp 2 | ================ 3 | 4 | [![Build Status](http://img.shields.io/travis/hashicorp/go-retryablehttp.svg?style=flat-square)][travis] 5 | [![Go Documentation](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)][godocs] 6 | 7 | [travis]: http://travis-ci.org/hashicorp/go-retryablehttp 8 | [godocs]: http://godoc.org/github.com/hashicorp/go-retryablehttp 9 | 10 | The `retryablehttp` package provides a familiar HTTP client interface with 11 | automatic retries and exponential backoff. It is a thin wrapper over the 12 | standard `net/http` client library and exposes nearly the same public API. This 13 | makes `retryablehttp` very easy to drop into existing programs. 14 | 15 | `retryablehttp` performs automatic retries under certain conditions. Mainly, if 16 | an error is returned by the client (connection errors, etc.), or if a 500-range 17 | response code is received (except 501), then a retry is invoked after a wait 18 | period. Otherwise, the response is returned and left to the caller to 19 | interpret. 20 | 21 | The main difference from `net/http` is that requests which take a request body 22 | (POST/PUT et. al) can have the body provided in a number of ways (some more or 23 | less efficient) that allow "rewinding" the request body if the initial request 24 | fails so that the full request can be attempted again. See the 25 | [godoc](http://godoc.org/github.com/hashicorp/go-retryablehttp) for more 26 | details. 27 | 28 | Version 0.6.0 and before are compatible with Go prior to 1.12. From 0.6.1 onward, Go 1.12+ is required. 29 | From 0.6.7 onward, Go 1.13+ is required. 30 | 31 | Example Use 32 | =========== 33 | 34 | Using this library should look almost identical to what you would do with 35 | `net/http`. The most simple example of a GET request is shown below: 36 | 37 | ```go 38 | resp, err := retryablehttp.Get("/foo") 39 | if err != nil { 40 | panic(err) 41 | } 42 | ``` 43 | 44 | The returned response object is an `*http.Response`, the same thing you would 45 | usually get from `net/http`. Had the request failed one or more times, the above 46 | call would block and retry with exponential backoff. 47 | 48 | ## Getting a stdlib `*http.Client` with retries 49 | 50 | It's possible to convert a `*retryablehttp.Client` directly to a `*http.Client`. 51 | This makes use of retryablehttp broadly applicable with minimal effort. Simply 52 | configure a `*retryablehttp.Client` as you wish, and then call `StandardClient()`: 53 | 54 | ```go 55 | retryClient := retryablehttp.NewClient() 56 | retryClient.RetryMax = 10 57 | 58 | standardClient := retryClient.StandardClient() // *http.Client 59 | ``` 60 | 61 | For more usage and examples see the 62 | [godoc](http://godoc.org/github.com/hashicorp/go-retryablehttp). 63 | -------------------------------------------------------------------------------- /roundtripper_test.go: -------------------------------------------------------------------------------- 1 | package retryablehttp 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io/ioutil" 7 | "net" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "reflect" 12 | "sync/atomic" 13 | "testing" 14 | ) 15 | 16 | func TestRoundTripper_implements(t *testing.T) { 17 | // Compile-time proof of interface satisfaction. 18 | var _ http.RoundTripper = &RoundTripper{} 19 | } 20 | 21 | func TestRoundTripper_init(t *testing.T) { 22 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 23 | w.WriteHeader(200) 24 | })) 25 | defer ts.Close() 26 | 27 | // Start with a new empty RoundTripper. 28 | rt := &RoundTripper{} 29 | 30 | // RoundTrip once. 31 | req, _ := http.NewRequest("GET", ts.URL, nil) 32 | if _, err := rt.RoundTrip(req); err != nil { 33 | t.Fatal(err) 34 | } 35 | 36 | // Check that the Client was initialized. 37 | if rt.Client == nil { 38 | t.Fatal("expected rt.Client to be initialized") 39 | } 40 | 41 | // Save the Client for later comparison. 42 | initialClient := rt.Client 43 | 44 | // RoundTrip again. 45 | req, _ = http.NewRequest("GET", ts.URL, nil) 46 | if _, err := rt.RoundTrip(req); err != nil { 47 | t.Fatal(err) 48 | } 49 | 50 | // Check that the underlying Client is unchanged. 51 | if rt.Client != initialClient { 52 | t.Fatalf("expected %v, got %v", initialClient, rt.Client) 53 | } 54 | } 55 | 56 | func TestRoundTripper_RoundTrip(t *testing.T) { 57 | var reqCount int32 = 0 58 | 59 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 60 | reqNo := atomic.AddInt32(&reqCount, 1) 61 | if reqNo < 3 { 62 | w.WriteHeader(404) 63 | } else { 64 | w.WriteHeader(200) 65 | w.Write([]byte("success!")) 66 | } 67 | })) 68 | defer ts.Close() 69 | 70 | // Make a client with some custom settings to verify they are used. 71 | retryClient := NewClient() 72 | retryClient.CheckRetry = func(_ context.Context, resp *http.Response, _ error) (bool, error) { 73 | return resp.StatusCode == 404, nil 74 | } 75 | 76 | // Get the standard client and execute the request. 77 | client := retryClient.StandardClient() 78 | resp, err := client.Get(ts.URL) 79 | if err != nil { 80 | t.Fatal(err) 81 | } 82 | defer resp.Body.Close() 83 | 84 | // Check the response to ensure the client behaved as expected. 85 | if resp.StatusCode != 200 { 86 | t.Fatalf("expected 200, got %d", resp.StatusCode) 87 | } 88 | if v, err := ioutil.ReadAll(resp.Body); err != nil { 89 | t.Fatal(err) 90 | } else if string(v) != "success!" { 91 | t.Fatalf("expected %q, got %q", "success!", v) 92 | } 93 | } 94 | 95 | func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) { 96 | // Make a client with some custom settings to verify they are used. 97 | retryClient := NewClient() 98 | retryClient.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 99 | if err != nil { 100 | return true, err 101 | } 102 | 103 | return false, nil 104 | } 105 | 106 | retryClient.ErrorHandler = PassthroughErrorHandler 107 | 108 | expectedError := &url.Error{ 109 | Op: "Get", 110 | URL: "http://999.999.999.999:999/", 111 | Err: &net.OpError{ 112 | Op: "dial", 113 | Net: "tcp", 114 | Err: &net.DNSError{ 115 | Name: "999.999.999.999", 116 | Err: "no such host", 117 | IsNotFound: true, 118 | }, 119 | }, 120 | } 121 | 122 | // Get the standard client and execute the request. 123 | client := retryClient.StandardClient() 124 | _, err := client.Get("http://999.999.999.999:999/") 125 | 126 | // assert expectations 127 | if !reflect.DeepEqual(expectedError, normalizeError(err)) { 128 | t.Fatalf("expected %q, got %q", expectedError, err) 129 | } 130 | } 131 | 132 | func normalizeError(err error) error { 133 | var dnsError *net.DNSError 134 | 135 | if errors.As(err, &dnsError) { 136 | // this field is populated with the DNS server on on CI, but not locally 137 | dnsError.Server = "" 138 | } 139 | 140 | return err 141 | } 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 HashiCorp, Inc. 2 | 3 | Mozilla Public License, version 2.0 4 | 5 | 1. Definitions 6 | 7 | 1.1. "Contributor" 8 | 9 | means each individual or legal entity that creates, contributes to the 10 | creation of, or owns Covered Software. 11 | 12 | 1.2. "Contributor Version" 13 | 14 | means the combination of the Contributions of others (if any) used by a 15 | Contributor and that particular Contributor's Contribution. 16 | 17 | 1.3. "Contribution" 18 | 19 | means Covered Software of a particular Contributor. 20 | 21 | 1.4. "Covered Software" 22 | 23 | means Source Code Form to which the initial Contributor has attached the 24 | notice in Exhibit A, the Executable Form of such Source Code Form, and 25 | Modifications of such Source Code Form, in each case including portions 26 | thereof. 27 | 28 | 1.5. "Incompatible With Secondary Licenses" 29 | means 30 | 31 | a. that the initial Contributor has attached the notice described in 32 | Exhibit B to the Covered Software; or 33 | 34 | b. that the Covered Software was made available under the terms of 35 | version 1.1 or earlier of the License, but not also under the terms of 36 | a Secondary License. 37 | 38 | 1.6. "Executable Form" 39 | 40 | means any form of the work other than Source Code Form. 41 | 42 | 1.7. "Larger Work" 43 | 44 | means a work that combines Covered Software with other material, in a 45 | separate file or files, that is not Covered Software. 46 | 47 | 1.8. "License" 48 | 49 | means this document. 50 | 51 | 1.9. "Licensable" 52 | 53 | means having the right to grant, to the maximum extent possible, whether 54 | at the time of the initial grant or subsequently, any and all of the 55 | rights conveyed by this License. 56 | 57 | 1.10. "Modifications" 58 | 59 | means any of the following: 60 | 61 | a. any file in Source Code Form that results from an addition to, 62 | deletion from, or modification of the contents of Covered Software; or 63 | 64 | b. any new file in Source Code Form that contains any Covered Software. 65 | 66 | 1.11. "Patent Claims" of a Contributor 67 | 68 | means any patent claim(s), including without limitation, method, 69 | process, and apparatus claims, in any patent Licensable by such 70 | Contributor that would be infringed, but for the grant of the License, 71 | by the making, using, selling, offering for sale, having made, import, 72 | or transfer of either its Contributions or its Contributor Version. 73 | 74 | 1.12. "Secondary License" 75 | 76 | means either the GNU General Public License, Version 2.0, the GNU Lesser 77 | General Public License, Version 2.1, the GNU Affero General Public 78 | License, Version 3.0, or any later versions of those licenses. 79 | 80 | 1.13. "Source Code Form" 81 | 82 | means the form of the work preferred for making modifications. 83 | 84 | 1.14. "You" (or "Your") 85 | 86 | means an individual or a legal entity exercising rights under this 87 | License. For legal entities, "You" includes any entity that controls, is 88 | controlled by, or is under common control with You. For purposes of this 89 | definition, "control" means (a) the power, direct or indirect, to cause 90 | the direction or management of such entity, whether by contract or 91 | otherwise, or (b) ownership of more than fifty percent (50%) of the 92 | outstanding shares or beneficial ownership of such entity. 93 | 94 | 95 | 2. License Grants and Conditions 96 | 97 | 2.1. Grants 98 | 99 | Each Contributor hereby grants You a world-wide, royalty-free, 100 | non-exclusive license: 101 | 102 | a. under intellectual property rights (other than patent or trademark) 103 | Licensable by such Contributor to use, reproduce, make available, 104 | modify, display, perform, distribute, and otherwise exploit its 105 | Contributions, either on an unmodified basis, with Modifications, or 106 | as part of a Larger Work; and 107 | 108 | b. under Patent Claims of such Contributor to make, use, sell, offer for 109 | sale, have made, import, and otherwise transfer either its 110 | Contributions or its Contributor Version. 111 | 112 | 2.2. Effective Date 113 | 114 | The licenses granted in Section 2.1 with respect to any Contribution 115 | become effective for each Contribution on the date the Contributor first 116 | distributes such Contribution. 117 | 118 | 2.3. Limitations on Grant Scope 119 | 120 | The licenses granted in this Section 2 are the only rights granted under 121 | this License. No additional rights or licenses will be implied from the 122 | distribution or licensing of Covered Software under this License. 123 | Notwithstanding Section 2.1(b) above, no patent license is granted by a 124 | Contributor: 125 | 126 | a. for any code that a Contributor has removed from Covered Software; or 127 | 128 | b. for infringements caused by: (i) Your and any other third party's 129 | modifications of Covered Software, or (ii) the combination of its 130 | Contributions with other software (except as part of its Contributor 131 | Version); or 132 | 133 | c. under Patent Claims infringed by Covered Software in the absence of 134 | its Contributions. 135 | 136 | This License does not grant any rights in the trademarks, service marks, 137 | or logos of any Contributor (except as may be necessary to comply with 138 | the notice requirements in Section 3.4). 139 | 140 | 2.4. Subsequent Licenses 141 | 142 | No Contributor makes additional grants as a result of Your choice to 143 | distribute the Covered Software under a subsequent version of this 144 | License (see Section 10.2) or under the terms of a Secondary License (if 145 | permitted under the terms of Section 3.3). 146 | 147 | 2.5. Representation 148 | 149 | Each Contributor represents that the Contributor believes its 150 | Contributions are its original creation(s) or it has sufficient rights to 151 | grant the rights to its Contributions conveyed by this License. 152 | 153 | 2.6. Fair Use 154 | 155 | This License is not intended to limit any rights You have under 156 | applicable copyright doctrines of fair use, fair dealing, or other 157 | equivalents. 158 | 159 | 2.7. Conditions 160 | 161 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in 162 | Section 2.1. 163 | 164 | 165 | 3. Responsibilities 166 | 167 | 3.1. Distribution of Source Form 168 | 169 | All distribution of Covered Software in Source Code Form, including any 170 | Modifications that You create or to which You contribute, must be under 171 | the terms of this License. You must inform recipients that the Source 172 | Code Form of the Covered Software is governed by the terms of this 173 | License, and how they can obtain a copy of this License. You may not 174 | attempt to alter or restrict the recipients' rights in the Source Code 175 | Form. 176 | 177 | 3.2. Distribution of Executable Form 178 | 179 | If You distribute Covered Software in Executable Form then: 180 | 181 | a. such Covered Software must also be made available in Source Code Form, 182 | as described in Section 3.1, and You must inform recipients of the 183 | Executable Form how they can obtain a copy of such Source Code Form by 184 | reasonable means in a timely manner, at a charge no more than the cost 185 | of distribution to the recipient; and 186 | 187 | b. You may distribute such Executable Form under the terms of this 188 | License, or sublicense it under different terms, provided that the 189 | license for the Executable Form does not attempt to limit or alter the 190 | recipients' rights in the Source Code Form under this License. 191 | 192 | 3.3. Distribution of a Larger Work 193 | 194 | You may create and distribute a Larger Work under terms of Your choice, 195 | provided that You also comply with the requirements of this License for 196 | the Covered Software. If the Larger Work is a combination of Covered 197 | Software with a work governed by one or more Secondary Licenses, and the 198 | Covered Software is not Incompatible With Secondary Licenses, this 199 | License permits You to additionally distribute such Covered Software 200 | under the terms of such Secondary License(s), so that the recipient of 201 | the Larger Work may, at their option, further distribute the Covered 202 | Software under the terms of either this License or such Secondary 203 | License(s). 204 | 205 | 3.4. Notices 206 | 207 | You may not remove or alter the substance of any license notices 208 | (including copyright notices, patent notices, disclaimers of warranty, or 209 | limitations of liability) contained within the Source Code Form of the 210 | Covered Software, except that You may alter any license notices to the 211 | extent required to remedy known factual inaccuracies. 212 | 213 | 3.5. Application of Additional Terms 214 | 215 | You may choose to offer, and to charge a fee for, warranty, support, 216 | indemnity or liability obligations to one or more recipients of Covered 217 | Software. However, You may do so only on Your own behalf, and not on 218 | behalf of any Contributor. You must make it absolutely clear that any 219 | such warranty, support, indemnity, or liability obligation is offered by 220 | You alone, and You hereby agree to indemnify every Contributor for any 221 | liability incurred by such Contributor as a result of warranty, support, 222 | indemnity or liability terms You offer. You may include additional 223 | disclaimers of warranty and limitations of liability specific to any 224 | jurisdiction. 225 | 226 | 4. Inability to Comply Due to Statute or Regulation 227 | 228 | If it is impossible for You to comply with any of the terms of this License 229 | with respect to some or all of the Covered Software due to statute, 230 | judicial order, or regulation then You must: (a) comply with the terms of 231 | this License to the maximum extent possible; and (b) describe the 232 | limitations and the code they affect. Such description must be placed in a 233 | text file included with all distributions of the Covered Software under 234 | this License. Except to the extent prohibited by statute or regulation, 235 | such description must be sufficiently detailed for a recipient of ordinary 236 | skill to be able to understand it. 237 | 238 | 5. Termination 239 | 240 | 5.1. The rights granted under this License will terminate automatically if You 241 | fail to comply with any of its terms. However, if You become compliant, 242 | then the rights granted under this License from a particular Contributor 243 | are reinstated (a) provisionally, unless and until such Contributor 244 | explicitly and finally terminates Your grants, and (b) on an ongoing 245 | basis, if such Contributor fails to notify You of the non-compliance by 246 | some reasonable means prior to 60 days after You have come back into 247 | compliance. Moreover, Your grants from a particular Contributor are 248 | reinstated on an ongoing basis if such Contributor notifies You of the 249 | non-compliance by some reasonable means, this is the first time You have 250 | received notice of non-compliance with this License from such 251 | Contributor, and You become compliant prior to 30 days after Your receipt 252 | of the notice. 253 | 254 | 5.2. If You initiate litigation against any entity by asserting a patent 255 | infringement claim (excluding declaratory judgment actions, 256 | counter-claims, and cross-claims) alleging that a Contributor Version 257 | directly or indirectly infringes any patent, then the rights granted to 258 | You by any and all Contributors for the Covered Software under Section 259 | 2.1 of this License shall terminate. 260 | 261 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user 262 | license agreements (excluding distributors and resellers) which have been 263 | validly granted by You or Your distributors under this License prior to 264 | termination shall survive termination. 265 | 266 | 6. Disclaimer of Warranty 267 | 268 | Covered Software is provided under this License on an "as is" basis, 269 | without warranty of any kind, either expressed, implied, or statutory, 270 | including, without limitation, warranties that the Covered Software is free 271 | of defects, merchantable, fit for a particular purpose or non-infringing. 272 | The entire risk as to the quality and performance of the Covered Software 273 | is with You. Should any Covered Software prove defective in any respect, 274 | You (not any Contributor) assume the cost of any necessary servicing, 275 | repair, or correction. This disclaimer of warranty constitutes an essential 276 | part of this License. No use of any Covered Software is authorized under 277 | this License except under this disclaimer. 278 | 279 | 7. Limitation of Liability 280 | 281 | Under no circumstances and under no legal theory, whether tort (including 282 | negligence), contract, or otherwise, shall any Contributor, or anyone who 283 | distributes Covered Software as permitted above, be liable to You for any 284 | direct, indirect, special, incidental, or consequential damages of any 285 | character including, without limitation, damages for lost profits, loss of 286 | goodwill, work stoppage, computer failure or malfunction, or any and all 287 | other commercial damages or losses, even if such party shall have been 288 | informed of the possibility of such damages. This limitation of liability 289 | shall not apply to liability for death or personal injury resulting from 290 | such party's negligence to the extent applicable law prohibits such 291 | limitation. Some jurisdictions do not allow the exclusion or limitation of 292 | incidental or consequential damages, so this exclusion and limitation may 293 | not apply to You. 294 | 295 | 8. Litigation 296 | 297 | Any litigation relating to this License may be brought only in the courts 298 | of a jurisdiction where the defendant maintains its principal place of 299 | business and such litigation shall be governed by laws of that 300 | jurisdiction, without reference to its conflict-of-law provisions. Nothing 301 | in this Section shall prevent a party's ability to bring cross-claims or 302 | counter-claims. 303 | 304 | 9. Miscellaneous 305 | 306 | This License represents the complete agreement concerning the subject 307 | matter hereof. If any provision of this License is held to be 308 | unenforceable, such provision shall be reformed only to the extent 309 | necessary to make it enforceable. Any law or regulation which provides that 310 | the language of a contract shall be construed against the drafter shall not 311 | be used to construe this License against a Contributor. 312 | 313 | 314 | 10. Versions of the License 315 | 316 | 10.1. New Versions 317 | 318 | Mozilla Foundation is the license steward. Except as provided in Section 319 | 10.3, no one other than the license steward has the right to modify or 320 | publish new versions of this License. Each version will be given a 321 | distinguishing version number. 322 | 323 | 10.2. Effect of New Versions 324 | 325 | You may distribute the Covered Software under the terms of the version 326 | of the License under which You originally received the Covered Software, 327 | or under the terms of any subsequent version published by the license 328 | steward. 329 | 330 | 10.3. Modified Versions 331 | 332 | If you create software not governed by this License, and you want to 333 | create a new license for such software, you may create and use a 334 | modified version of this License if you rename the license and remove 335 | any references to the name of the license steward (except to note that 336 | such modified license differs from this License). 337 | 338 | 10.4. Distributing Source Code Form that is Incompatible With Secondary 339 | Licenses If You choose to distribute Source Code Form that is 340 | Incompatible With Secondary Licenses under the terms of this version of 341 | the License, the notice described in Exhibit B of this License must be 342 | attached. 343 | 344 | Exhibit A - Source Code Form License Notice 345 | 346 | This Source Code Form is subject to the 347 | terms of the Mozilla Public License, v. 348 | 2.0. If a copy of the MPL was not 349 | distributed with this file, You can 350 | obtain one at 351 | http://mozilla.org/MPL/2.0/. 352 | 353 | If it is not possible or desirable to put the notice in a particular file, 354 | then You may include the notice in a location (such as a LICENSE file in a 355 | relevant directory) where a recipient would be likely to look for such a 356 | notice. 357 | 358 | You may add additional accurate notices of copyright ownership. 359 | 360 | Exhibit B - "Incompatible With Secondary Licenses" Notice 361 | 362 | This Source Code Form is "Incompatible 363 | With Secondary Licenses", as defined by 364 | the Mozilla Public License, v. 2.0. 365 | 366 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | package retryablehttp 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "io/ioutil" 10 | "net" 11 | "net/http" 12 | "net/http/httptest" 13 | "net/http/httputil" 14 | "net/url" 15 | "strings" 16 | "sync/atomic" 17 | "testing" 18 | "time" 19 | 20 | "github.com/hashicorp/go-hclog" 21 | ) 22 | 23 | func TestRequest(t *testing.T) { 24 | // Fails on invalid request 25 | _, err := NewRequest("GET", "://foo", nil) 26 | if err == nil { 27 | t.Fatalf("should error") 28 | } 29 | 30 | // Works with no request body 31 | _, err = NewRequest("GET", "http://foo", nil) 32 | if err != nil { 33 | t.Fatalf("err: %v", err) 34 | } 35 | 36 | // Works with request body 37 | body := bytes.NewReader([]byte("yo")) 38 | req, err := NewRequest("GET", "/", body) 39 | if err != nil { 40 | t.Fatalf("err: %v", err) 41 | } 42 | 43 | // Request allows typical HTTP request forming methods 44 | req.Header.Set("X-Test", "foo") 45 | if v, ok := req.Header["X-Test"]; !ok || len(v) != 1 || v[0] != "foo" { 46 | t.Fatalf("bad headers: %v", req.Header) 47 | } 48 | 49 | // Sets the Content-Length automatically for LenReaders 50 | if req.ContentLength != 2 { 51 | t.Fatalf("bad ContentLength: %d", req.ContentLength) 52 | } 53 | } 54 | 55 | func TestFromRequest(t *testing.T) { 56 | // Works with no request body 57 | httpReq, err := http.NewRequest("GET", "http://foo", nil) 58 | if err != nil { 59 | t.Fatalf("err: %v", err) 60 | } 61 | _, err = FromRequest(httpReq) 62 | if err != nil { 63 | t.Fatalf("err: %v", err) 64 | } 65 | 66 | // Works with request body 67 | body := bytes.NewReader([]byte("yo")) 68 | httpReq, err = http.NewRequest("GET", "/", body) 69 | if err != nil { 70 | t.Fatalf("err: %v", err) 71 | } 72 | req, err := FromRequest(httpReq) 73 | if err != nil { 74 | t.Fatalf("err: %v", err) 75 | } 76 | 77 | // Preserves headers 78 | httpReq.Header.Set("X-Test", "foo") 79 | if v, ok := req.Header["X-Test"]; !ok || len(v) != 1 || v[0] != "foo" { 80 | t.Fatalf("bad headers: %v", req.Header) 81 | } 82 | 83 | // Preserves the Content-Length automatically for LenReaders 84 | if req.ContentLength != 2 { 85 | t.Fatalf("bad ContentLength: %d", req.ContentLength) 86 | } 87 | } 88 | 89 | // Since normal ways we would generate a Reader have special cases, use a 90 | // custom type here 91 | type custReader struct { 92 | val string 93 | pos int 94 | } 95 | 96 | func (c *custReader) Read(p []byte) (n int, err error) { 97 | if c.val == "" { 98 | c.val = "hello" 99 | } 100 | if c.pos >= len(c.val) { 101 | return 0, io.EOF 102 | } 103 | var i int 104 | for i = 0; i < len(p) && i+c.pos < len(c.val); i++ { 105 | p[i] = c.val[i+c.pos] 106 | } 107 | c.pos += i 108 | return i, nil 109 | } 110 | 111 | func TestClient_Do(t *testing.T) { 112 | testBytes := []byte("hello") 113 | // Native func 114 | testClientDo(t, ReaderFunc(func() (io.Reader, error) { 115 | return bytes.NewReader(testBytes), nil 116 | })) 117 | // Native func, different Go type 118 | testClientDo(t, func() (io.Reader, error) { 119 | return bytes.NewReader(testBytes), nil 120 | }) 121 | // []byte 122 | testClientDo(t, testBytes) 123 | // *bytes.Buffer 124 | testClientDo(t, bytes.NewBuffer(testBytes)) 125 | // *bytes.Reader 126 | testClientDo(t, bytes.NewReader(testBytes)) 127 | // io.ReadSeeker 128 | testClientDo(t, strings.NewReader(string(testBytes))) 129 | // io.Reader 130 | testClientDo(t, &custReader{}) 131 | } 132 | 133 | func testClientDo(t *testing.T, body interface{}) { 134 | // Create a request 135 | req, err := NewRequest("PUT", "http://127.0.0.1:28934/v1/foo", body) 136 | if err != nil { 137 | t.Fatalf("err: %v", err) 138 | } 139 | req.Header.Set("foo", "bar") 140 | 141 | // Track the number of times the logging hook was called 142 | retryCount := -1 143 | 144 | // Create the client. Use short retry windows. 145 | client := NewClient() 146 | client.RetryWaitMin = 10 * time.Millisecond 147 | client.RetryWaitMax = 50 * time.Millisecond 148 | client.RetryMax = 50 149 | client.RequestLogHook = func(logger Logger, req *http.Request, retryNumber int) { 150 | retryCount = retryNumber 151 | 152 | if logger != client.Logger { 153 | t.Fatalf("Client logger was not passed to logging hook") 154 | } 155 | 156 | dumpBytes, err := httputil.DumpRequestOut(req, false) 157 | if err != nil { 158 | t.Fatal("Dumping requests failed") 159 | } 160 | 161 | dumpString := string(dumpBytes) 162 | if !strings.Contains(dumpString, "PUT /v1/foo") { 163 | t.Fatalf("Bad request dump:\n%s", dumpString) 164 | } 165 | } 166 | 167 | // Send the request 168 | var resp *http.Response 169 | doneCh := make(chan struct{}) 170 | errCh := make(chan error, 1) 171 | go func() { 172 | defer close(doneCh) 173 | defer close(errCh) 174 | var err error 175 | resp, err = client.Do(req) 176 | errCh <- err 177 | }() 178 | 179 | select { 180 | case <-doneCh: 181 | t.Fatalf("should retry on error") 182 | case <-time.After(200 * time.Millisecond): 183 | // Client should still be retrying due to connection failure. 184 | } 185 | 186 | // Create the mock handler. First we return a 500-range response to ensure 187 | // that we power through and keep retrying in the face of recoverable 188 | // errors. 189 | code := int64(500) 190 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 191 | // Check the request details 192 | if r.Method != "PUT" { 193 | t.Fatalf("bad method: %s", r.Method) 194 | } 195 | if r.RequestURI != "/v1/foo" { 196 | t.Fatalf("bad uri: %s", r.RequestURI) 197 | } 198 | 199 | // Check the headers 200 | if v := r.Header.Get("foo"); v != "bar" { 201 | t.Fatalf("bad header: expect foo=bar, got foo=%v", v) 202 | } 203 | 204 | // Check the payload 205 | body, err := ioutil.ReadAll(r.Body) 206 | if err != nil { 207 | t.Fatalf("err: %s", err) 208 | } 209 | expected := []byte("hello") 210 | if !bytes.Equal(body, expected) { 211 | t.Fatalf("bad: %v", body) 212 | } 213 | 214 | w.WriteHeader(int(atomic.LoadInt64(&code))) 215 | }) 216 | 217 | // Create a test server 218 | list, err := net.Listen("tcp", ":28934") 219 | if err != nil { 220 | t.Fatalf("err: %v", err) 221 | } 222 | defer list.Close() 223 | go http.Serve(list, handler) 224 | 225 | // Wait again 226 | select { 227 | case <-doneCh: 228 | t.Fatalf("should retry on 500-range") 229 | case <-time.After(200 * time.Millisecond): 230 | // Client should still be retrying due to 500's. 231 | } 232 | 233 | // Start returning 200's 234 | atomic.StoreInt64(&code, 200) 235 | 236 | // Wait again 237 | select { 238 | case <-doneCh: 239 | case <-time.After(time.Second): 240 | t.Fatalf("timed out") 241 | } 242 | 243 | if resp.StatusCode != 200 { 244 | t.Fatalf("exected 200, got: %d", resp.StatusCode) 245 | } 246 | 247 | if retryCount < 0 { 248 | t.Fatal("request log hook was not called") 249 | } 250 | 251 | err = <-errCh 252 | if err != nil { 253 | t.Fatalf("err: %v", err) 254 | } 255 | } 256 | 257 | func TestClient_Do_WithResponseHandler(t *testing.T) { 258 | // Create the client. Use short retry windows so we fail faster. 259 | client := NewClient() 260 | client.RetryWaitMin = 10 * time.Millisecond 261 | client.RetryWaitMax = 10 * time.Millisecond 262 | client.RetryMax = 2 263 | 264 | var checks int 265 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 266 | checks++ 267 | if err != nil && strings.Contains(err.Error(), "nonretryable") { 268 | return false, nil 269 | } 270 | return DefaultRetryPolicy(context.TODO(), resp, err) 271 | } 272 | 273 | // Mock server which always responds 200. 274 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 275 | w.WriteHeader(200) 276 | })) 277 | defer ts.Close() 278 | 279 | var shouldSucceed bool 280 | tests := []struct { 281 | name string 282 | handler ResponseHandlerFunc 283 | expectedChecks int // often 2x number of attempts since we check twice 284 | err string 285 | }{ 286 | { 287 | name: "nil handler", 288 | handler: nil, 289 | expectedChecks: 1, 290 | }, 291 | { 292 | name: "handler always succeeds", 293 | handler: func(*http.Response) error { 294 | return nil 295 | }, 296 | expectedChecks: 2, 297 | }, 298 | { 299 | name: "handler always fails in a retryable way", 300 | handler: func(*http.Response) error { 301 | return errors.New("retryable failure") 302 | }, 303 | expectedChecks: 6, 304 | }, 305 | { 306 | name: "handler always fails in a nonretryable way", 307 | handler: func(*http.Response) error { 308 | return errors.New("nonretryable failure") 309 | }, 310 | expectedChecks: 2, 311 | }, 312 | { 313 | name: "handler succeeds on second attempt", 314 | handler: func(*http.Response) error { 315 | if shouldSucceed { 316 | return nil 317 | } 318 | shouldSucceed = true 319 | return errors.New("retryable failure") 320 | }, 321 | expectedChecks: 4, 322 | }, 323 | } 324 | 325 | for _, tt := range tests { 326 | t.Run(tt.name, func(t *testing.T) { 327 | checks = 0 328 | shouldSucceed = false 329 | // Create the request 330 | req, err := NewRequest("GET", ts.URL, nil) 331 | if err != nil { 332 | t.Fatalf("err: %v", err) 333 | } 334 | req.SetResponseHandler(tt.handler) 335 | 336 | // Send the request. 337 | _, err = client.Do(req) 338 | if err != nil && !strings.Contains(err.Error(), tt.err) { 339 | t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error()) 340 | } 341 | if err == nil && tt.err != "" { 342 | t.Fatalf("no error, expected: %s", tt.err) 343 | } 344 | 345 | if checks != tt.expectedChecks { 346 | t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks) 347 | } 348 | }) 349 | } 350 | } 351 | 352 | func TestClient_Do_fails(t *testing.T) { 353 | // Mock server which always responds 500. 354 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 355 | w.WriteHeader(500) 356 | })) 357 | defer ts.Close() 358 | 359 | tests := []struct { 360 | name string 361 | cr CheckRetry 362 | err string 363 | }{ 364 | { 365 | name: "default_retry_policy", 366 | cr: DefaultRetryPolicy, 367 | err: "giving up after 3 attempt(s)", 368 | }, 369 | { 370 | name: "error_propagated_retry_policy", 371 | cr: ErrorPropagatedRetryPolicy, 372 | err: "giving up after 3 attempt(s): unexpected HTTP status 500 Internal Server Error", 373 | }, 374 | } 375 | 376 | for _, tt := range tests { 377 | t.Run(tt.name, func(t *testing.T) { 378 | // Create the client. Use short retry windows so we fail faster. 379 | client := NewClient() 380 | client.RetryWaitMin = 10 * time.Millisecond 381 | client.RetryWaitMax = 10 * time.Millisecond 382 | client.CheckRetry = tt.cr 383 | client.RetryMax = 2 384 | 385 | // Create the request 386 | req, err := NewRequest("POST", ts.URL, nil) 387 | if err != nil { 388 | t.Fatalf("err: %v", err) 389 | } 390 | 391 | // Send the request. 392 | _, err = client.Do(req) 393 | if err == nil || !strings.HasSuffix(err.Error(), tt.err) { 394 | t.Fatalf("expected giving up error, got: %#v", err) 395 | } 396 | }) 397 | } 398 | } 399 | 400 | func TestClient_Get(t *testing.T) { 401 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 402 | if r.Method != "GET" { 403 | t.Fatalf("bad method: %s", r.Method) 404 | } 405 | if r.RequestURI != "/foo/bar" { 406 | t.Fatalf("bad uri: %s", r.RequestURI) 407 | } 408 | w.WriteHeader(200) 409 | })) 410 | defer ts.Close() 411 | 412 | // Make the request. 413 | resp, err := NewClient().Get(ts.URL + "/foo/bar") 414 | if err != nil { 415 | t.Fatalf("err: %v", err) 416 | } 417 | resp.Body.Close() 418 | } 419 | 420 | func TestClient_RequestLogHook(t *testing.T) { 421 | t.Run("RequestLogHook successfully called with default Logger", func(t *testing.T) { 422 | testClientRequestLogHook(t, defaultLogger) 423 | }) 424 | t.Run("RequestLogHook successfully called with nil Logger", func(t *testing.T) { 425 | testClientRequestLogHook(t, nil) 426 | }) 427 | } 428 | 429 | func testClientRequestLogHook(t *testing.T, logger interface{}) { 430 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 431 | if r.Method != "GET" { 432 | t.Fatalf("bad method: %s", r.Method) 433 | } 434 | if r.RequestURI != "/foo/bar" { 435 | t.Fatalf("bad uri: %s", r.RequestURI) 436 | } 437 | w.WriteHeader(200) 438 | })) 439 | defer ts.Close() 440 | 441 | retries := -1 442 | testURIPath := "/foo/bar" 443 | 444 | client := NewClient() 445 | client.Logger = logger 446 | client.RequestLogHook = func(logger Logger, req *http.Request, retry int) { 447 | retries = retry 448 | 449 | if logger != client.Logger { 450 | t.Fatalf("Client logger was not passed to logging hook") 451 | } 452 | 453 | dumpBytes, err := httputil.DumpRequestOut(req, false) 454 | if err != nil { 455 | t.Fatal("Dumping requests failed") 456 | } 457 | 458 | dumpString := string(dumpBytes) 459 | if !strings.Contains(dumpString, "GET "+testURIPath) { 460 | t.Fatalf("Bad request dump:\n%s", dumpString) 461 | } 462 | } 463 | 464 | // Make the request. 465 | resp, err := client.Get(ts.URL + testURIPath) 466 | if err != nil { 467 | t.Fatalf("err: %v", err) 468 | } 469 | resp.Body.Close() 470 | 471 | if retries < 0 { 472 | t.Fatal("Logging hook was not called") 473 | } 474 | } 475 | 476 | func TestClient_ResponseLogHook(t *testing.T) { 477 | t.Run("ResponseLogHook successfully called with hclog Logger", func(t *testing.T) { 478 | buf := new(bytes.Buffer) 479 | l := hclog.New(&hclog.LoggerOptions{ 480 | Output: buf, 481 | }) 482 | testClientResponseLogHook(t, l, buf) 483 | }) 484 | t.Run("ResponseLogHook successfully called with nil Logger", func(t *testing.T) { 485 | buf := new(bytes.Buffer) 486 | testClientResponseLogHook(t, nil, buf) 487 | }) 488 | } 489 | 490 | func testClientResponseLogHook(t *testing.T, l interface{}, buf *bytes.Buffer) { 491 | passAfter := time.Now().Add(100 * time.Millisecond) 492 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 493 | if time.Now().After(passAfter) { 494 | w.WriteHeader(200) 495 | w.Write([]byte("test_200_body")) 496 | } else { 497 | w.WriteHeader(500) 498 | w.Write([]byte("test_500_body")) 499 | } 500 | })) 501 | defer ts.Close() 502 | 503 | client := NewClient() 504 | 505 | client.Logger = l 506 | client.RetryWaitMin = 10 * time.Millisecond 507 | client.RetryWaitMax = 10 * time.Millisecond 508 | client.RetryMax = 15 509 | client.ResponseLogHook = func(logger Logger, resp *http.Response) { 510 | if resp.StatusCode == 200 { 511 | successLog := "test_log_pass" 512 | // Log something when we get a 200 513 | if logger != nil { 514 | logger.Printf(successLog) 515 | } else { 516 | buf.WriteString(successLog) 517 | } 518 | } else { 519 | // Log the response body when we get a 500 520 | body, err := ioutil.ReadAll(resp.Body) 521 | if err != nil { 522 | t.Fatalf("err: %v", err) 523 | } 524 | failLog := string(body) 525 | if logger != nil { 526 | logger.Printf(failLog) 527 | } else { 528 | buf.WriteString(failLog) 529 | } 530 | } 531 | } 532 | 533 | // Perform the request. Exits when we finally get a 200. 534 | resp, err := client.Get(ts.URL) 535 | if err != nil { 536 | t.Fatalf("err: %v", err) 537 | } 538 | 539 | // Make sure we can read the response body still, since we did not 540 | // read or close it from the response log hook. 541 | body, err := ioutil.ReadAll(resp.Body) 542 | if err != nil { 543 | t.Fatalf("err: %v", err) 544 | } 545 | if string(body) != "test_200_body" { 546 | t.Fatalf("expect %q, got %q", "test_200_body", string(body)) 547 | } 548 | 549 | // Make sure we wrote to the logger on callbacks. 550 | out := buf.String() 551 | if !strings.Contains(out, "test_log_pass") { 552 | t.Fatalf("expect response callback on 200: %q", out) 553 | } 554 | if !strings.Contains(out, "test_500_body") { 555 | t.Fatalf("expect response callback on 500: %q", out) 556 | } 557 | } 558 | 559 | func TestClient_NewRequestWithContext(t *testing.T) { 560 | ctx, cancel := context.WithCancel(context.Background()) 561 | defer cancel() 562 | r, err := NewRequestWithContext(ctx, http.MethodGet, "/abc", nil) 563 | if err != nil { 564 | t.Fatalf("err: %v", err) 565 | } 566 | if r.Context() != ctx { 567 | t.Fatal("Context must be set") 568 | } 569 | } 570 | 571 | func TestClient_RequestWithContext(t *testing.T) { 572 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 573 | w.WriteHeader(200) 574 | w.Write([]byte("test_200_body")) 575 | })) 576 | defer ts.Close() 577 | 578 | req, err := NewRequest(http.MethodGet, ts.URL, nil) 579 | if err != nil { 580 | t.Fatalf("err: %v", err) 581 | } 582 | ctx, cancel := context.WithCancel(req.Request.Context()) 583 | reqCtx := req.WithContext(ctx) 584 | if reqCtx == req { 585 | t.Fatal("WithContext must return a new Request object") 586 | } 587 | 588 | client := NewClient() 589 | 590 | called := 0 591 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 592 | called++ 593 | return DefaultRetryPolicy(reqCtx.Request.Context(), resp, err) 594 | } 595 | 596 | cancel() 597 | _, err = client.Do(reqCtx) 598 | 599 | if called != 1 { 600 | t.Fatalf("CheckRetry called %d times, expected 1", called) 601 | } 602 | 603 | e := fmt.Sprintf("GET %s giving up after 1 attempt(s): %s", ts.URL, context.Canceled.Error()) 604 | 605 | if err.Error() != e { 606 | t.Fatalf("Expected err to contain %s, got: %v", e, err) 607 | } 608 | } 609 | 610 | func TestClient_CheckRetry(t *testing.T) { 611 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 612 | http.Error(w, "test_500_body", http.StatusInternalServerError) 613 | })) 614 | defer ts.Close() 615 | 616 | client := NewClient() 617 | 618 | retryErr := errors.New("retryError") 619 | called := 0 620 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 621 | if called < 1 { 622 | called++ 623 | return DefaultRetryPolicy(context.TODO(), resp, err) 624 | } 625 | 626 | return false, retryErr 627 | } 628 | 629 | // CheckRetry should return our retryErr value and stop the retry loop. 630 | _, err := client.Get(ts.URL) 631 | 632 | if called != 1 { 633 | t.Fatalf("CheckRetry called %d times, expected 1", called) 634 | } 635 | 636 | if err.Error() != fmt.Sprintf("GET %s giving up after 2 attempt(s): retryError", ts.URL) { 637 | t.Fatalf("Expected retryError, got:%v", err) 638 | } 639 | } 640 | 641 | func TestClient_DefaultBackoff(t *testing.T) { 642 | for _, code := range []int{http.StatusTooManyRequests, http.StatusServiceUnavailable} { 643 | t.Run(fmt.Sprintf("http_%d", code), func(t *testing.T) { 644 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 645 | w.Header().Set("Retry-After", "2") 646 | http.Error(w, fmt.Sprintf("test_%d_body", code), code) 647 | })) 648 | defer ts.Close() 649 | 650 | client := NewClient() 651 | 652 | var retryAfter time.Duration 653 | retryable := false 654 | 655 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 656 | retryable, _ = DefaultRetryPolicy(context.Background(), resp, err) 657 | retryAfter = DefaultBackoff(client.RetryWaitMin, client.RetryWaitMax, 1, resp) 658 | return false, nil 659 | } 660 | 661 | _, err := client.Get(ts.URL) 662 | if err != nil { 663 | t.Fatalf("expected no errors since retryable") 664 | } 665 | 666 | if !retryable { 667 | t.Fatal("Since the error is recoverable, the default policy shall return true") 668 | } 669 | 670 | if retryAfter != 2*time.Second { 671 | t.Fatalf("The header Retry-After specified 2 seconds, and shall not be %d seconds", retryAfter/time.Second) 672 | } 673 | }) 674 | } 675 | } 676 | 677 | func TestClient_DefaultRetryPolicy_TLS(t *testing.T) { 678 | ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 679 | w.WriteHeader(200) 680 | })) 681 | defer ts.Close() 682 | 683 | attempts := 0 684 | client := NewClient() 685 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 686 | attempts++ 687 | return DefaultRetryPolicy(context.TODO(), resp, err) 688 | } 689 | 690 | _, err := client.Get(ts.URL) 691 | if err == nil { 692 | t.Fatalf("expected x509 error, got nil") 693 | } 694 | if attempts != 1 { 695 | t.Fatalf("expected 1 attempt, got %d", attempts) 696 | } 697 | } 698 | 699 | func TestClient_DefaultRetryPolicy_redirects(t *testing.T) { 700 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 701 | http.Redirect(w, r, "/", http.StatusFound) 702 | })) 703 | defer ts.Close() 704 | 705 | attempts := 0 706 | client := NewClient() 707 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 708 | attempts++ 709 | return DefaultRetryPolicy(context.TODO(), resp, err) 710 | } 711 | 712 | _, err := client.Get(ts.URL) 713 | if err == nil { 714 | t.Fatalf("expected redirect error, got nil") 715 | } 716 | if attempts != 1 { 717 | t.Fatalf("expected 1 attempt, got %d", attempts) 718 | } 719 | } 720 | 721 | func TestClient_DefaultRetryPolicy_invalidscheme(t *testing.T) { 722 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 723 | w.WriteHeader(200) 724 | })) 725 | defer ts.Close() 726 | 727 | attempts := 0 728 | client := NewClient() 729 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 730 | attempts++ 731 | return DefaultRetryPolicy(context.TODO(), resp, err) 732 | } 733 | 734 | url := strings.Replace(ts.URL, "http", "ftp", 1) 735 | _, err := client.Get(url) 736 | if err == nil { 737 | t.Fatalf("expected scheme error, got nil") 738 | } 739 | if attempts != 1 { 740 | t.Fatalf("expected 1 attempt, got %d", attempts) 741 | } 742 | } 743 | 744 | func TestClient_CheckRetryStop(t *testing.T) { 745 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 746 | http.Error(w, "test_500_body", http.StatusInternalServerError) 747 | })) 748 | defer ts.Close() 749 | 750 | client := NewClient() 751 | 752 | // Verify that this stops retries on the first try, with no errors from the client. 753 | called := 0 754 | client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { 755 | called++ 756 | return false, nil 757 | } 758 | 759 | _, err := client.Get(ts.URL) 760 | 761 | if called != 1 { 762 | t.Fatalf("CheckRetry called %d times, expected 1", called) 763 | } 764 | 765 | if err != nil { 766 | t.Fatalf("Expected no error, got:%v", err) 767 | } 768 | } 769 | 770 | func TestClient_Head(t *testing.T) { 771 | // Mock server which always responds 200. 772 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 773 | if r.Method != "HEAD" { 774 | t.Fatalf("bad method: %s", r.Method) 775 | } 776 | if r.RequestURI != "/foo/bar" { 777 | t.Fatalf("bad uri: %s", r.RequestURI) 778 | } 779 | w.WriteHeader(200) 780 | })) 781 | defer ts.Close() 782 | 783 | // Make the request. 784 | resp, err := NewClient().Head(ts.URL + "/foo/bar") 785 | if err != nil { 786 | t.Fatalf("err: %v", err) 787 | } 788 | resp.Body.Close() 789 | } 790 | 791 | func TestClient_Post(t *testing.T) { 792 | // Mock server which always responds 200. 793 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 794 | if r.Method != "POST" { 795 | t.Fatalf("bad method: %s", r.Method) 796 | } 797 | if r.RequestURI != "/foo/bar" { 798 | t.Fatalf("bad uri: %s", r.RequestURI) 799 | } 800 | if ct := r.Header.Get("Content-Type"); ct != "application/json" { 801 | t.Fatalf("bad content-type: %s", ct) 802 | } 803 | 804 | // Check the payload 805 | body, err := ioutil.ReadAll(r.Body) 806 | if err != nil { 807 | t.Fatalf("err: %s", err) 808 | } 809 | expected := []byte(`{"hello":"world"}`) 810 | if !bytes.Equal(body, expected) { 811 | t.Fatalf("bad: %v", body) 812 | } 813 | 814 | w.WriteHeader(200) 815 | })) 816 | defer ts.Close() 817 | 818 | // Make the request. 819 | resp, err := NewClient().Post( 820 | ts.URL+"/foo/bar", 821 | "application/json", 822 | strings.NewReader(`{"hello":"world"}`)) 823 | if err != nil { 824 | t.Fatalf("err: %v", err) 825 | } 826 | resp.Body.Close() 827 | } 828 | 829 | func TestClient_PostForm(t *testing.T) { 830 | // Mock server which always responds 200. 831 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 832 | if r.Method != "POST" { 833 | t.Fatalf("bad method: %s", r.Method) 834 | } 835 | if r.RequestURI != "/foo/bar" { 836 | t.Fatalf("bad uri: %s", r.RequestURI) 837 | } 838 | if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { 839 | t.Fatalf("bad content-type: %s", ct) 840 | } 841 | 842 | // Check the payload 843 | body, err := ioutil.ReadAll(r.Body) 844 | if err != nil { 845 | t.Fatalf("err: %s", err) 846 | } 847 | expected := []byte(`hello=world`) 848 | if !bytes.Equal(body, expected) { 849 | t.Fatalf("bad: %v", body) 850 | } 851 | 852 | w.WriteHeader(200) 853 | })) 854 | defer ts.Close() 855 | 856 | // Create the form data. 857 | form, err := url.ParseQuery("hello=world") 858 | if err != nil { 859 | t.Fatalf("err: %v", err) 860 | } 861 | 862 | // Make the request. 863 | resp, err := NewClient().PostForm(ts.URL+"/foo/bar", form) 864 | if err != nil { 865 | t.Fatalf("err: %v", err) 866 | } 867 | resp.Body.Close() 868 | } 869 | 870 | func TestBackoff(t *testing.T) { 871 | type tcase struct { 872 | min time.Duration 873 | max time.Duration 874 | i int 875 | expect time.Duration 876 | } 877 | cases := []tcase{ 878 | { 879 | time.Second, 880 | 5 * time.Minute, 881 | 0, 882 | time.Second, 883 | }, 884 | { 885 | time.Second, 886 | 5 * time.Minute, 887 | 1, 888 | 2 * time.Second, 889 | }, 890 | { 891 | time.Second, 892 | 5 * time.Minute, 893 | 2, 894 | 4 * time.Second, 895 | }, 896 | { 897 | time.Second, 898 | 5 * time.Minute, 899 | 3, 900 | 8 * time.Second, 901 | }, 902 | { 903 | time.Second, 904 | 5 * time.Minute, 905 | 63, 906 | 5 * time.Minute, 907 | }, 908 | { 909 | time.Second, 910 | 5 * time.Minute, 911 | 128, 912 | 5 * time.Minute, 913 | }, 914 | } 915 | 916 | for _, tc := range cases { 917 | if v := DefaultBackoff(tc.min, tc.max, tc.i, nil); v != tc.expect { 918 | t.Fatalf("bad: %#v -> %s", tc, v) 919 | } 920 | } 921 | } 922 | 923 | func TestClient_BackoffCustom(t *testing.T) { 924 | var retries int32 925 | 926 | client := NewClient() 927 | client.Backoff = func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { 928 | atomic.AddInt32(&retries, 1) 929 | return time.Millisecond * 1 930 | } 931 | 932 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 933 | if atomic.LoadInt32(&retries) == int32(client.RetryMax) { 934 | w.WriteHeader(200) 935 | return 936 | } 937 | w.WriteHeader(500) 938 | })) 939 | defer ts.Close() 940 | 941 | // Make the request. 942 | resp, err := client.Get(ts.URL + "/foo/bar") 943 | if err != nil { 944 | t.Fatalf("err: %v", err) 945 | } 946 | resp.Body.Close() 947 | if retries != int32(client.RetryMax) { 948 | t.Fatalf("expected retries: %d != %d", client.RetryMax, retries) 949 | } 950 | } 951 | 952 | func TestClient_StandardClient(t *testing.T) { 953 | // Create a retryable HTTP client. 954 | client := NewClient() 955 | 956 | // Get a standard client. 957 | standard := client.StandardClient() 958 | 959 | // Ensure the underlying retrying client is set properly. 960 | if v := standard.Transport.(*RoundTripper).Client; v != client { 961 | t.Fatalf("expected %v, got %v", client, v) 962 | } 963 | } 964 | 965 | func TestClient_RedirectWithBody(t *testing.T) { 966 | // Mock server which always responds 200. 967 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 968 | switch r.RequestURI { 969 | case "/foo/redirect": 970 | w.Header().Set("Location", "/foo/redirected") 971 | w.WriteHeader(307) 972 | case "/foo/redirected": 973 | w.WriteHeader(200) 974 | default: 975 | t.Fatalf("bad uri: %s", r.RequestURI) 976 | } 977 | })) 978 | defer ts.Close() 979 | 980 | client := NewClient() 981 | 982 | // has body 983 | req, err := NewRequest(http.MethodPost, ts.URL+"/foo/redirect", strings.NewReader(`{}`)) 984 | if err != nil { 985 | t.Fatalf("err: %v", err) 986 | } 987 | 988 | resp, err := client.Do(req) 989 | if err != nil { 990 | t.Fatalf("err: %v", err) 991 | } 992 | resp.Body.Close() 993 | 994 | // no body 995 | if err := req.SetBody(nil); err != nil { 996 | t.Fatalf("err: %v", err) 997 | } 998 | 999 | resp, err = client.Do(req) 1000 | if err != nil { 1001 | t.Fatalf("err: %v", err) 1002 | } 1003 | resp.Body.Close() 1004 | } 1005 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // Package retryablehttp provides a familiar HTTP client interface with 2 | // automatic retries and exponential backoff. It is a thin wrapper over the 3 | // standard net/http client library and exposes nearly the same public API. 4 | // This makes retryablehttp very easy to drop into existing programs. 5 | // 6 | // retryablehttp performs automatic retries under certain conditions. Mainly, if 7 | // an error is returned by the client (connection errors etc), or if a 500-range 8 | // response is received, then a retry is invoked. Otherwise, the response is 9 | // returned and left to the caller to interpret. 10 | // 11 | // Requests which take a request body should provide a non-nil function 12 | // parameter. The best choice is to provide either a function satisfying 13 | // ReaderFunc which provides multiple io.Readers in an efficient manner, a 14 | // *bytes.Buffer (the underlying raw byte slice will be used) or a raw byte 15 | // slice. As it is a reference type, and we will wrap it as needed by readers, 16 | // we can efficiently re-use the request body without needing to copy it. If an 17 | // io.Reader (such as a *bytes.Reader) is provided, the full body will be read 18 | // prior to the first request, and will be efficiently re-used for any retries. 19 | // ReadSeeker can be used, but some users have observed occasional data races 20 | // between the net/http library and the Seek functionality of some 21 | // implementations of ReadSeeker, so should be avoided if possible. 22 | package retryablehttp 23 | 24 | import ( 25 | "bytes" 26 | "context" 27 | "crypto/x509" 28 | "fmt" 29 | "io" 30 | "io/ioutil" 31 | "log" 32 | "math" 33 | "math/rand" 34 | "net/http" 35 | "net/url" 36 | "os" 37 | "regexp" 38 | "strconv" 39 | "strings" 40 | "sync" 41 | "time" 42 | 43 | cleanhttp "github.com/hashicorp/go-cleanhttp" 44 | ) 45 | 46 | var ( 47 | // Default retry configuration 48 | defaultRetryWaitMin = 1 * time.Second 49 | defaultRetryWaitMax = 30 * time.Second 50 | defaultRetryMax = 4 51 | 52 | // defaultLogger is the logger provided with defaultClient 53 | defaultLogger = log.New(os.Stderr, "", log.LstdFlags) 54 | 55 | // defaultClient is used for performing requests without explicitly making 56 | // a new client. It is purposely private to avoid modifications. 57 | defaultClient = NewClient() 58 | 59 | // We need to consume response bodies to maintain http connections, but 60 | // limit the size we consume to respReadLimit. 61 | respReadLimit = int64(4096) 62 | 63 | // A regular expression to match the error returned by net/http when the 64 | // configured number of redirects is exhausted. This error isn't typed 65 | // specifically so we resort to matching on the error string. 66 | redirectsErrorRe = regexp.MustCompile(`stopped after \d+ redirects\z`) 67 | 68 | // A regular expression to match the error returned by net/http when the 69 | // scheme specified in the URL is invalid. This error isn't typed 70 | // specifically so we resort to matching on the error string. 71 | schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`) 72 | 73 | // A regular expression to match the error returned by net/http when the 74 | // TLS certificate is not trusted. This error isn't typed 75 | // specifically so we resort to matching on the error string. 76 | notTrustedErrorRe = regexp.MustCompile(`certificate is not trusted`) 77 | ) 78 | 79 | // ReaderFunc is the type of function that can be given natively to NewRequest 80 | type ReaderFunc func() (io.Reader, error) 81 | 82 | // ResponseHandlerFunc is a type of function that takes in a Response, and does something with it. 83 | // The ResponseHandlerFunc is called when the HTTP client successfully receives a response and the 84 | // CheckRetry function indicates that a retry of the base request is not necessary. 85 | // If an error is returned from this function, the CheckRetry policy will be used to determine 86 | // whether to retry the whole request (including this handler). 87 | // 88 | // Make sure to check status codes! Even if the request was completed it may have a non-2xx status code. 89 | // 90 | // The response body is not automatically closed. It must be closed either by the ResponseHandlerFunc or 91 | // by the caller out-of-band. Failure to do so will result in a memory leak. 92 | type ResponseHandlerFunc func(*http.Response) error 93 | 94 | // LenReader is an interface implemented by many in-memory io.Reader's. Used 95 | // for automatically sending the right Content-Length header when possible. 96 | type LenReader interface { 97 | Len() int 98 | } 99 | 100 | // Request wraps the metadata needed to create HTTP requests. 101 | type Request struct { 102 | // body is a seekable reader over the request body payload. This is 103 | // used to rewind the request data in between retries. 104 | body ReaderFunc 105 | 106 | responseHandler ResponseHandlerFunc 107 | 108 | // Embed an HTTP request directly. This makes a *Request act exactly 109 | // like an *http.Request so that all meta methods are supported. 110 | *http.Request 111 | } 112 | 113 | // WithContext returns wrapped Request with a shallow copy of underlying *http.Request 114 | // with its context changed to ctx. The provided ctx must be non-nil. 115 | func (r *Request) WithContext(ctx context.Context) *Request { 116 | return &Request{ 117 | body: r.body, 118 | responseHandler: r.responseHandler, 119 | Request: r.Request.WithContext(ctx), 120 | } 121 | } 122 | 123 | // SetResponseHandler allows setting the response handler. 124 | func (r *Request) SetResponseHandler(fn ResponseHandlerFunc) { 125 | r.responseHandler = fn 126 | } 127 | 128 | // BodyBytes allows accessing the request body. It is an analogue to 129 | // http.Request's Body variable, but it returns a copy of the underlying data 130 | // rather than consuming it. 131 | // 132 | // This function is not thread-safe; do not call it at the same time as another 133 | // call, or at the same time this request is being used with Client.Do. 134 | func (r *Request) BodyBytes() ([]byte, error) { 135 | if r.body == nil { 136 | return nil, nil 137 | } 138 | body, err := r.body() 139 | if err != nil { 140 | return nil, err 141 | } 142 | buf := new(bytes.Buffer) 143 | _, err = buf.ReadFrom(body) 144 | if err != nil { 145 | return nil, err 146 | } 147 | return buf.Bytes(), nil 148 | } 149 | 150 | // SetBody allows setting the request body. 151 | // 152 | // It is useful if a new body needs to be set without constructing a new Request. 153 | func (r *Request) SetBody(rawBody interface{}) error { 154 | bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody) 155 | if err != nil { 156 | return err 157 | } 158 | r.body = bodyReader 159 | r.ContentLength = contentLength 160 | if bodyReader != nil { 161 | r.GetBody = func() (io.ReadCloser, error) { 162 | body, err := bodyReader() 163 | if err != nil { 164 | return nil, err 165 | } 166 | if rc, ok := body.(io.ReadCloser); ok { 167 | return rc, nil 168 | } 169 | return ioutil.NopCloser(body), nil 170 | } 171 | } else { 172 | r.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil } 173 | } 174 | return nil 175 | } 176 | 177 | // WriteTo allows copying the request body into a writer. 178 | // 179 | // It writes data to w until there's no more data to write or 180 | // when an error occurs. The return int64 value is the number of bytes 181 | // written. Any error encountered during the write is also returned. 182 | // The signature matches io.WriterTo interface. 183 | func (r *Request) WriteTo(w io.Writer) (int64, error) { 184 | body, err := r.body() 185 | if err != nil { 186 | return 0, err 187 | } 188 | if c, ok := body.(io.Closer); ok { 189 | defer c.Close() 190 | } 191 | return io.Copy(w, body) 192 | } 193 | 194 | func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, error) { 195 | var bodyReader ReaderFunc 196 | var contentLength int64 197 | 198 | switch body := rawBody.(type) { 199 | // If they gave us a function already, great! Use it. 200 | case ReaderFunc: 201 | bodyReader = body 202 | tmp, err := body() 203 | if err != nil { 204 | return nil, 0, err 205 | } 206 | if lr, ok := tmp.(LenReader); ok { 207 | contentLength = int64(lr.Len()) 208 | } 209 | if c, ok := tmp.(io.Closer); ok { 210 | c.Close() 211 | } 212 | 213 | case func() (io.Reader, error): 214 | bodyReader = body 215 | tmp, err := body() 216 | if err != nil { 217 | return nil, 0, err 218 | } 219 | if lr, ok := tmp.(LenReader); ok { 220 | contentLength = int64(lr.Len()) 221 | } 222 | if c, ok := tmp.(io.Closer); ok { 223 | c.Close() 224 | } 225 | 226 | // If a regular byte slice, we can read it over and over via new 227 | // readers 228 | case []byte: 229 | buf := body 230 | bodyReader = func() (io.Reader, error) { 231 | return bytes.NewReader(buf), nil 232 | } 233 | contentLength = int64(len(buf)) 234 | 235 | // If a bytes.Buffer we can read the underlying byte slice over and 236 | // over 237 | case *bytes.Buffer: 238 | buf := body 239 | bodyReader = func() (io.Reader, error) { 240 | return bytes.NewReader(buf.Bytes()), nil 241 | } 242 | contentLength = int64(buf.Len()) 243 | 244 | // We prioritize *bytes.Reader here because we don't really want to 245 | // deal with it seeking so want it to match here instead of the 246 | // io.ReadSeeker case. 247 | case *bytes.Reader: 248 | buf, err := ioutil.ReadAll(body) 249 | if err != nil { 250 | return nil, 0, err 251 | } 252 | bodyReader = func() (io.Reader, error) { 253 | return bytes.NewReader(buf), nil 254 | } 255 | contentLength = int64(len(buf)) 256 | 257 | // Compat case 258 | case io.ReadSeeker: 259 | raw := body 260 | bodyReader = func() (io.Reader, error) { 261 | _, err := raw.Seek(0, 0) 262 | return ioutil.NopCloser(raw), err 263 | } 264 | if lr, ok := raw.(LenReader); ok { 265 | contentLength = int64(lr.Len()) 266 | } 267 | 268 | // Read all in so we can reset 269 | case io.Reader: 270 | buf, err := ioutil.ReadAll(body) 271 | if err != nil { 272 | return nil, 0, err 273 | } 274 | bodyReader = func() (io.Reader, error) { 275 | return bytes.NewReader(buf), nil 276 | } 277 | contentLength = int64(len(buf)) 278 | 279 | // No body provided, nothing to do 280 | case nil: 281 | 282 | // Unrecognized type 283 | default: 284 | return nil, 0, fmt.Errorf("cannot handle type %T", rawBody) 285 | } 286 | return bodyReader, contentLength, nil 287 | } 288 | 289 | // FromRequest wraps an http.Request in a retryablehttp.Request 290 | func FromRequest(r *http.Request) (*Request, error) { 291 | bodyReader, _, err := getBodyReaderAndContentLength(r.Body) 292 | if err != nil { 293 | return nil, err 294 | } 295 | // Could assert contentLength == r.ContentLength 296 | return &Request{body: bodyReader, Request: r}, nil 297 | } 298 | 299 | // NewRequest creates a new wrapped request. 300 | func NewRequest(method, url string, rawBody interface{}) (*Request, error) { 301 | return NewRequestWithContext(context.Background(), method, url, rawBody) 302 | } 303 | 304 | // NewRequestWithContext creates a new wrapped request with the provided context. 305 | // 306 | // The context controls the entire lifetime of a request and its response: 307 | // obtaining a connection, sending the request, and reading the response headers and body. 308 | func NewRequestWithContext(ctx context.Context, method, url string, rawBody interface{}) (*Request, error) { 309 | httpReq, err := http.NewRequestWithContext(ctx, method, url, nil) 310 | if err != nil { 311 | return nil, err 312 | } 313 | 314 | req := &Request{Request: httpReq} 315 | if err := req.SetBody(rawBody); err != nil { 316 | return nil, err 317 | } 318 | 319 | return req, nil 320 | } 321 | 322 | // Logger interface allows to use other loggers than 323 | // standard log.Logger. 324 | type Logger interface { 325 | Printf(string, ...interface{}) 326 | } 327 | 328 | // LeveledLogger is an interface that can be implemented by any logger or a 329 | // logger wrapper to provide leveled logging. The methods accept a message 330 | // string and a variadic number of key-value pairs. For log.Printf style 331 | // formatting where message string contains a format specifier, use Logger 332 | // interface. 333 | type LeveledLogger interface { 334 | Error(msg string, keysAndValues ...interface{}) 335 | Info(msg string, keysAndValues ...interface{}) 336 | Debug(msg string, keysAndValues ...interface{}) 337 | Warn(msg string, keysAndValues ...interface{}) 338 | } 339 | 340 | // hookLogger adapts an LeveledLogger to Logger for use by the existing hook functions 341 | // without changing the API. 342 | type hookLogger struct { 343 | LeveledLogger 344 | } 345 | 346 | func (h hookLogger) Printf(s string, args ...interface{}) { 347 | h.Info(fmt.Sprintf(s, args...)) 348 | } 349 | 350 | // RequestLogHook allows a function to run before each retry. The HTTP 351 | // request which will be made, and the retry number (0 for the initial 352 | // request) are available to users. The internal logger is exposed to 353 | // consumers. 354 | type RequestLogHook func(Logger, *http.Request, int) 355 | 356 | // ResponseLogHook is like RequestLogHook, but allows running a function 357 | // on each HTTP response. This function will be invoked at the end of 358 | // every HTTP request executed, regardless of whether a subsequent retry 359 | // needs to be performed or not. If the response body is read or closed 360 | // from this method, this will affect the response returned from Do(). 361 | type ResponseLogHook func(Logger, *http.Response) 362 | 363 | // CheckRetry specifies a policy for handling retries. It is called 364 | // following each request with the response and error values returned by 365 | // the http.Client. If CheckRetry returns false, the Client stops retrying 366 | // and returns the response to the caller. If CheckRetry returns an error, 367 | // that error value is returned in lieu of the error from the request. The 368 | // Client will close any response body when retrying, but if the retry is 369 | // aborted it is up to the CheckRetry callback to properly close any 370 | // response body before returning. 371 | type CheckRetry func(ctx context.Context, resp *http.Response, err error) (bool, error) 372 | 373 | // Backoff specifies a policy for how long to wait between retries. 374 | // It is called after a failing request to determine the amount of time 375 | // that should pass before trying again. 376 | type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration 377 | 378 | // ErrorHandler is called if retries are expired, containing the last status 379 | // from the http library. If not specified, default behavior for the library is 380 | // to close the body and return an error indicating how many tries were 381 | // attempted. If overriding this, be sure to close the body if needed. 382 | type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error) 383 | 384 | // Client is used to make HTTP requests. It adds additional functionality 385 | // like automatic retries to tolerate minor outages. 386 | type Client struct { 387 | HTTPClient *http.Client // Internal HTTP client. 388 | Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger 389 | 390 | RetryWaitMin time.Duration // Minimum time to wait 391 | RetryWaitMax time.Duration // Maximum time to wait 392 | RetryMax int // Maximum number of retries 393 | 394 | // RequestLogHook allows a user-supplied function to be called 395 | // before each retry. 396 | RequestLogHook RequestLogHook 397 | 398 | // ResponseLogHook allows a user-supplied function to be called 399 | // with the response from each HTTP request executed. 400 | ResponseLogHook ResponseLogHook 401 | 402 | // CheckRetry specifies the policy for handling retries, and is called 403 | // after each request. The default policy is DefaultRetryPolicy. 404 | CheckRetry CheckRetry 405 | 406 | // Backoff specifies the policy for how long to wait between retries 407 | Backoff Backoff 408 | 409 | // ErrorHandler specifies the custom error handler to use, if any 410 | ErrorHandler ErrorHandler 411 | 412 | loggerInit sync.Once 413 | clientInit sync.Once 414 | } 415 | 416 | // NewClient creates a new Client with default settings. 417 | func NewClient() *Client { 418 | return &Client{ 419 | HTTPClient: cleanhttp.DefaultPooledClient(), 420 | Logger: defaultLogger, 421 | RetryWaitMin: defaultRetryWaitMin, 422 | RetryWaitMax: defaultRetryWaitMax, 423 | RetryMax: defaultRetryMax, 424 | CheckRetry: DefaultRetryPolicy, 425 | Backoff: DefaultBackoff, 426 | } 427 | } 428 | 429 | func (c *Client) logger() interface{} { 430 | c.loggerInit.Do(func() { 431 | if c.Logger == nil { 432 | return 433 | } 434 | 435 | switch c.Logger.(type) { 436 | case Logger, LeveledLogger: 437 | // ok 438 | default: 439 | // This should happen in dev when they are setting Logger and work on code, not in prod. 440 | panic(fmt.Sprintf("invalid logger type passed, must be Logger or LeveledLogger, was %T", c.Logger)) 441 | } 442 | }) 443 | 444 | return c.Logger 445 | } 446 | 447 | // DefaultRetryPolicy provides a default callback for Client.CheckRetry, which 448 | // will retry on connection errors and server errors. 449 | func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { 450 | // do not retry on context.Canceled or context.DeadlineExceeded 451 | if ctx.Err() != nil { 452 | return false, ctx.Err() 453 | } 454 | 455 | // don't propagate other errors 456 | shouldRetry, _ := baseRetryPolicy(resp, err) 457 | return shouldRetry, nil 458 | } 459 | 460 | // ErrorPropagatedRetryPolicy is the same as DefaultRetryPolicy, except it 461 | // propagates errors back instead of returning nil. This allows you to inspect 462 | // why it decided to retry or not. 463 | func ErrorPropagatedRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { 464 | // do not retry on context.Canceled or context.DeadlineExceeded 465 | if ctx.Err() != nil { 466 | return false, ctx.Err() 467 | } 468 | 469 | return baseRetryPolicy(resp, err) 470 | } 471 | 472 | func baseRetryPolicy(resp *http.Response, err error) (bool, error) { 473 | if err != nil { 474 | if v, ok := err.(*url.Error); ok { 475 | // Don't retry if the error was due to too many redirects. 476 | if redirectsErrorRe.MatchString(v.Error()) { 477 | return false, v 478 | } 479 | 480 | // Don't retry if the error was due to an invalid protocol scheme. 481 | if schemeErrorRe.MatchString(v.Error()) { 482 | return false, v 483 | } 484 | 485 | // Don't retry if the error was due to TLS cert verification failure. 486 | if notTrustedErrorRe.MatchString(v.Error()) { 487 | return false, v 488 | } 489 | if _, ok := v.Err.(x509.UnknownAuthorityError); ok { 490 | return false, v 491 | } 492 | } 493 | 494 | // The error is likely recoverable so retry. 495 | return true, nil 496 | } 497 | 498 | // 429 Too Many Requests is recoverable. Sometimes the server puts 499 | // a Retry-After response header to indicate when the server is 500 | // available to start processing request from client. 501 | if resp.StatusCode == http.StatusTooManyRequests { 502 | return true, nil 503 | } 504 | 505 | // Check the response code. We retry on 500-range responses to allow 506 | // the server time to recover, as 500's are typically not permanent 507 | // errors and may relate to outages on the server side. This will catch 508 | // invalid response codes as well, like 0 and 999. 509 | if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != http.StatusNotImplemented) { 510 | return true, fmt.Errorf("unexpected HTTP status %s", resp.Status) 511 | } 512 | 513 | return false, nil 514 | } 515 | 516 | // DefaultBackoff provides a default callback for Client.Backoff which 517 | // will perform exponential backoff based on the attempt number and limited 518 | // by the provided minimum and maximum durations. 519 | // 520 | // It also tries to parse Retry-After response header when a http.StatusTooManyRequests 521 | // (HTTP Code 429) is found in the resp parameter. Hence it will return the number of 522 | // seconds the server states it may be ready to process more requests from this client. 523 | func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { 524 | if resp != nil { 525 | if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { 526 | if s, ok := resp.Header["Retry-After"]; ok { 527 | if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil { 528 | return time.Second * time.Duration(sleep) 529 | } 530 | } 531 | } 532 | } 533 | 534 | mult := math.Pow(2, float64(attemptNum)) * float64(min) 535 | sleep := time.Duration(mult) 536 | if float64(sleep) != mult || sleep > max { 537 | sleep = max 538 | } 539 | return sleep 540 | } 541 | 542 | // LinearJitterBackoff provides a callback for Client.Backoff which will 543 | // perform linear backoff based on the attempt number and with jitter to 544 | // prevent a thundering herd. 545 | // 546 | // min and max here are *not* absolute values. The number to be multiplied by 547 | // the attempt number will be chosen at random from between them, thus they are 548 | // bounding the jitter. 549 | // 550 | // For instance: 551 | // * To get strictly linear backoff of one second increasing each retry, set 552 | // both to one second (1s, 2s, 3s, 4s, ...) 553 | // * To get a small amount of jitter centered around one second increasing each 554 | // retry, set to around one second, such as a min of 800ms and max of 1200ms 555 | // (892ms, 2102ms, 2945ms, 4312ms, ...) 556 | // * To get extreme jitter, set to a very wide spread, such as a min of 100ms 557 | // and a max of 20s (15382ms, 292ms, 51321ms, 35234ms, ...) 558 | func LinearJitterBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { 559 | // attemptNum always starts at zero but we want to start at 1 for multiplication 560 | attemptNum++ 561 | 562 | if max <= min { 563 | // Unclear what to do here, or they are the same, so return min * 564 | // attemptNum 565 | return min * time.Duration(attemptNum) 566 | } 567 | 568 | // Seed rand; doing this every time is fine 569 | rand := rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) 570 | 571 | // Pick a random number that lies somewhere between the min and max and 572 | // multiply by the attemptNum. attemptNum starts at zero so we always 573 | // increment here. We first get a random percentage, then apply that to the 574 | // difference between min and max, and add to min. 575 | jitter := rand.Float64() * float64(max-min) 576 | jitterMin := int64(jitter) + int64(min) 577 | return time.Duration(jitterMin * int64(attemptNum)) 578 | } 579 | 580 | // PassthroughErrorHandler is an ErrorHandler that directly passes through the 581 | // values from the net/http library for the final request. The body is not 582 | // closed. 583 | func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Response, error) { 584 | return resp, err 585 | } 586 | 587 | // Do wraps calling an HTTP method with retries. 588 | func (c *Client) Do(req *Request) (*http.Response, error) { 589 | c.clientInit.Do(func() { 590 | if c.HTTPClient == nil { 591 | c.HTTPClient = cleanhttp.DefaultPooledClient() 592 | } 593 | }) 594 | 595 | logger := c.logger() 596 | 597 | if logger != nil { 598 | switch v := logger.(type) { 599 | case LeveledLogger: 600 | v.Debug("performing request", "method", req.Method, "url", req.URL) 601 | case Logger: 602 | v.Printf("[DEBUG] %s %s", req.Method, req.URL) 603 | } 604 | } 605 | 606 | var resp *http.Response 607 | var attempt int 608 | var shouldRetry bool 609 | var doErr, respErr, checkErr error 610 | 611 | for i := 0; ; i++ { 612 | doErr, respErr = nil, nil 613 | attempt++ 614 | 615 | // Always rewind the request body when non-nil. 616 | if req.body != nil { 617 | body, err := req.body() 618 | if err != nil { 619 | c.HTTPClient.CloseIdleConnections() 620 | return resp, err 621 | } 622 | if c, ok := body.(io.ReadCloser); ok { 623 | req.Body = c 624 | } else { 625 | req.Body = ioutil.NopCloser(body) 626 | } 627 | } 628 | 629 | if c.RequestLogHook != nil { 630 | switch v := logger.(type) { 631 | case LeveledLogger: 632 | c.RequestLogHook(hookLogger{v}, req.Request, i) 633 | case Logger: 634 | c.RequestLogHook(v, req.Request, i) 635 | default: 636 | c.RequestLogHook(nil, req.Request, i) 637 | } 638 | } 639 | 640 | // Attempt the request 641 | resp, doErr = c.HTTPClient.Do(req.Request) 642 | 643 | // Check if we should continue with retries. 644 | shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) 645 | if !shouldRetry && doErr == nil && req.responseHandler != nil { 646 | respErr = req.responseHandler(resp) 647 | shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, respErr) 648 | } 649 | 650 | err := doErr 651 | if respErr != nil { 652 | err = respErr 653 | } 654 | if err != nil { 655 | switch v := logger.(type) { 656 | case LeveledLogger: 657 | v.Error("request failed", "error", err, "method", req.Method, "url", req.URL) 658 | case Logger: 659 | v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, err) 660 | } 661 | } else { 662 | // Call this here to maintain the behavior of logging all requests, 663 | // even if CheckRetry signals to stop. 664 | if c.ResponseLogHook != nil { 665 | // Call the response logger function if provided. 666 | switch v := logger.(type) { 667 | case LeveledLogger: 668 | c.ResponseLogHook(hookLogger{v}, resp) 669 | case Logger: 670 | c.ResponseLogHook(v, resp) 671 | default: 672 | c.ResponseLogHook(nil, resp) 673 | } 674 | } 675 | } 676 | 677 | if !shouldRetry { 678 | break 679 | } 680 | 681 | // We do this before drainBody because there's no need for the I/O if 682 | // we're breaking out 683 | remain := c.RetryMax - i 684 | if remain <= 0 { 685 | break 686 | } 687 | 688 | // We're going to retry, consume any response to reuse the connection. 689 | if doErr == nil { 690 | c.drainBody(resp.Body) 691 | } 692 | 693 | wait := c.Backoff(c.RetryWaitMin, c.RetryWaitMax, i, resp) 694 | if logger != nil { 695 | desc := fmt.Sprintf("%s %s", req.Method, req.URL) 696 | if resp != nil { 697 | desc = fmt.Sprintf("%s (status: %d)", desc, resp.StatusCode) 698 | } 699 | switch v := logger.(type) { 700 | case LeveledLogger: 701 | v.Debug("retrying request", "request", desc, "timeout", wait, "remaining", remain) 702 | case Logger: 703 | v.Printf("[DEBUG] %s: retrying in %s (%d left)", desc, wait, remain) 704 | } 705 | } 706 | timer := time.NewTimer(wait) 707 | select { 708 | case <-req.Context().Done(): 709 | timer.Stop() 710 | c.HTTPClient.CloseIdleConnections() 711 | return nil, req.Context().Err() 712 | case <-timer.C: 713 | } 714 | 715 | // Make shallow copy of http Request so that we can modify its body 716 | // without racing against the closeBody call in persistConn.writeLoop. 717 | httpreq := *req.Request 718 | req.Request = &httpreq 719 | } 720 | 721 | // this is the closest we have to success criteria 722 | if doErr == nil && respErr == nil && checkErr == nil && !shouldRetry { 723 | return resp, nil 724 | } 725 | 726 | defer c.HTTPClient.CloseIdleConnections() 727 | 728 | var err error 729 | if checkErr != nil { 730 | err = checkErr 731 | } else if respErr != nil { 732 | err = respErr 733 | } else { 734 | err = doErr 735 | } 736 | 737 | if c.ErrorHandler != nil { 738 | return c.ErrorHandler(resp, err, attempt) 739 | } 740 | 741 | // By default, we close the response body and return an error without 742 | // returning the response 743 | if resp != nil { 744 | c.drainBody(resp.Body) 745 | } 746 | 747 | // this means CheckRetry thought the request was a failure, but didn't 748 | // communicate why 749 | if err == nil { 750 | return nil, fmt.Errorf("%s %s giving up after %d attempt(s)", 751 | req.Method, req.URL, attempt) 752 | } 753 | 754 | return nil, fmt.Errorf("%s %s giving up after %d attempt(s): %w", 755 | req.Method, req.URL, attempt, err) 756 | } 757 | 758 | // Try to read the response body so we can reuse this connection. 759 | func (c *Client) drainBody(body io.ReadCloser) { 760 | defer body.Close() 761 | _, err := io.Copy(ioutil.Discard, io.LimitReader(body, respReadLimit)) 762 | if err != nil { 763 | if c.logger() != nil { 764 | switch v := c.logger().(type) { 765 | case LeveledLogger: 766 | v.Error("error reading response body", "error", err) 767 | case Logger: 768 | v.Printf("[ERR] error reading response body: %v", err) 769 | } 770 | } 771 | } 772 | } 773 | 774 | // Get is a shortcut for doing a GET request without making a new client. 775 | func Get(url string) (*http.Response, error) { 776 | return defaultClient.Get(url) 777 | } 778 | 779 | // Get is a convenience helper for doing simple GET requests. 780 | func (c *Client) Get(url string) (*http.Response, error) { 781 | req, err := NewRequest("GET", url, nil) 782 | if err != nil { 783 | return nil, err 784 | } 785 | return c.Do(req) 786 | } 787 | 788 | // Head is a shortcut for doing a HEAD request without making a new client. 789 | func Head(url string) (*http.Response, error) { 790 | return defaultClient.Head(url) 791 | } 792 | 793 | // Head is a convenience method for doing simple HEAD requests. 794 | func (c *Client) Head(url string) (*http.Response, error) { 795 | req, err := NewRequest("HEAD", url, nil) 796 | if err != nil { 797 | return nil, err 798 | } 799 | return c.Do(req) 800 | } 801 | 802 | // Post is a shortcut for doing a POST request without making a new client. 803 | func Post(url, bodyType string, body interface{}) (*http.Response, error) { 804 | return defaultClient.Post(url, bodyType, body) 805 | } 806 | 807 | // Post is a convenience method for doing simple POST requests. 808 | func (c *Client) Post(url, bodyType string, body interface{}) (*http.Response, error) { 809 | req, err := NewRequest("POST", url, body) 810 | if err != nil { 811 | return nil, err 812 | } 813 | req.Header.Set("Content-Type", bodyType) 814 | return c.Do(req) 815 | } 816 | 817 | // PostForm is a shortcut to perform a POST with form data without creating 818 | // a new client. 819 | func PostForm(url string, data url.Values) (*http.Response, error) { 820 | return defaultClient.PostForm(url, data) 821 | } 822 | 823 | // PostForm is a convenience method for doing simple POST operations using 824 | // pre-filled url.Values form data. 825 | func (c *Client) PostForm(url string, data url.Values) (*http.Response, error) { 826 | return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) 827 | } 828 | 829 | // StandardClient returns a stdlib *http.Client with a custom Transport, which 830 | // shims in a *retryablehttp.Client for added retries. 831 | func (c *Client) StandardClient() *http.Client { 832 | return &http.Client{ 833 | Transport: &RoundTripper{Client: c}, 834 | } 835 | } 836 | --------------------------------------------------------------------------------