├── .github ├── FUNDING.yml └── workflows │ ├── ci.yml │ └── label-actions.yml ├── .gitignore ├── .testdata ├── cert.pem ├── key.pem ├── sample-root.pem ├── test-img.png └── text-file.txt ├── BUILD.bazel ├── LICENSE ├── README.md ├── WORKSPACE ├── benchmark_test.go ├── cert_watcher_test.go ├── circuit_breaker.go ├── client.go ├── client_test.go ├── context_test.go ├── curl.go ├── curl_test.go ├── debug.go ├── digest.go ├── digest_test.go ├── go.mod ├── go.sum ├── load_balancer.go ├── load_balancer_test.go ├── middleware.go ├── middleware_test.go ├── multipart.go ├── multipart_test.go ├── redirect.go ├── request.go ├── request_test.go ├── response.go ├── resty.go ├── resty_test.go ├── retry.go ├── retry_test.go ├── sse.go ├── sse_test.go ├── stream.go ├── trace.go ├── transport_dial.go ├── transport_dial_wasm.go ├── util.go └── util_test.go /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [jeevatkm] 2 | custom: ["https://www.paypal.com/donate/?cmd=_donations&business=QWMZG74FW4QYC&lc=US&item_name=Resty+Library+for+Go¤cy_code=USD"] 3 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - v3 7 | - v2 8 | paths-ignore: 9 | - '**.md' 10 | - '**.bazel' 11 | - 'WORKSPACE' 12 | pull_request: 13 | branches: 14 | - main 15 | - v3 16 | - v2 17 | paths-ignore: 18 | - '**.md' 19 | - '**.bazel' 20 | - 'WORKSPACE' 21 | 22 | # Allows you to run this workflow manually from the Actions tab 23 | workflow_dispatch: 24 | 25 | jobs: 26 | build: 27 | name: Build 28 | strategy: 29 | matrix: 30 | go: [ 'stable', '1.21.x' ] 31 | os: [ ubuntu-latest ] 32 | 33 | runs-on: ${{ matrix.os }} 34 | 35 | steps: 36 | - name: Checkout 37 | uses: actions/checkout@v4 38 | with: 39 | fetch-depth: 0 40 | 41 | - name: Setup Go 42 | uses: actions/setup-go@v5 43 | with: 44 | go-version: ${{ matrix.go }} 45 | cache: true 46 | cache-dependency-path: go.sum 47 | 48 | - name: Format 49 | run: diff -u <(echo -n) <(go fmt $(go list ./...)) 50 | 51 | - name: Test 52 | run: go run gotest.tools/gotestsum@latest -f testname -- ./... -race -count=1 -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... -shuffle=on 53 | 54 | - name: Upload coverage to Codecov 55 | if: ${{ matrix.os == 'ubuntu-latest' && matrix.go == 'stable' }} 56 | uses: codecov/codecov-action@v4 57 | with: 58 | token: ${{ secrets.CODECOV_TOKEN }} 59 | file: ./coverage.txt 60 | flags: unittests 61 | -------------------------------------------------------------------------------- /.github/workflows/label-actions.yml: -------------------------------------------------------------------------------- 1 | name: 'Label' 2 | 3 | on: 4 | pull_request: 5 | types: [labeled] 6 | paths-ignore: 7 | - '**.md' 8 | - '**.bazel' 9 | - 'WORKSPACE' 10 | 11 | jobs: 12 | build: 13 | strategy: 14 | matrix: 15 | go: [ 'stable', '1.21.x' ] 16 | os: [ ubuntu-latest ] 17 | 18 | name: Run Build 19 | if: ${{ github.event.label.name == 'run-build' }} 20 | runs-on: ${{ matrix.os }} 21 | 22 | steps: 23 | - name: Checkout 24 | uses: actions/checkout@v4 25 | with: 26 | fetch-depth: 0 27 | 28 | - name: Setup Go 29 | uses: actions/setup-go@v5 30 | with: 31 | go-version: ${{ matrix.go }} 32 | cache: true 33 | cache-dependency-path: go.sum 34 | 35 | - name: Format 36 | run: diff -u <(echo -n) <(go fmt $(go list ./...)) 37 | 38 | - name: Test 39 | run: go run gotest.tools/gotestsum@latest -f testname -- ./... -race -count=1 -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... -shuffle=on 40 | 41 | - name: Upload coverage to Codecov 42 | if: ${{ matrix.os == 'ubuntu-latest' && matrix.go == 'stable' }} 43 | uses: codecov/codecov-action@v4 44 | with: 45 | token: ${{ secrets.CODECOV_TOKEN }} 46 | file: ./coverage.txt 47 | flags: unittests 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | coverage.out 27 | coverage.txt 28 | 29 | # Exclude IDE folders 30 | .idea/* 31 | .vscode/* 32 | -------------------------------------------------------------------------------- /.testdata/cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIC+jCCAeKgAwIBAgIRAJce5ewsoW44j0qvSABmq7owDQYJKoZIhvcNAQELBQAw 3 | EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0yNTAxMDQwNzA3MTNaFw0yNjAxMDQwNzA3 4 | MTNaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw 5 | ggEKAoIBAQCYkTN1g/0Z3KkS3w0lX9yhZkwiA0obXCeFs7hpRP0p4WlW3uADyXQ5 6 | h2MaYx8OCA7oGU7/dWOPhtE3rgFEz7IwLxcP5d02ukLGlFD69D6KLyTXwCFmvOWQ 7 | 5fbOq4s73WTNDfYSTYNzeujDCjeu/Bk0OVhdxbyZdyrpdm+UBfH8uIDoGeCRXnji 8 | nqG9HNOQx6r/S6FqC5j/7PrVl1i66WlqRzKEJB94uejfujrHq8RjQm/wzEutU5df 9 | C39zEEEx75qQt7Jc0asm1AqAKSq34xn4rVajWrBZ/WudUUizHfaBDP61uPFvPyKW 10 | JDvTSdeoM9TPX0y0cjo6AwSrdLl7flrRAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF 11 | oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC 12 | CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAdHvPQe3EJ4/X6K/bklJUhIfM 13 | KBauH8VMBfri7xLawleKssm7GdiFivSA0g1pArkl8SALBlPqhrx7rwlyyivLTZaR 14 | VFvXaQ9eU0zGnSnDnKVz6CX/zn3TKfcgZPEBclayh0ldm7A8xSJWaWbRZ+s9e9x1 15 | XcQTn2KkMZfBDMnGEWQ3KZrClvO5ZfkqSiyzEm9+eF0m0E7ujTyfSVMsPdyldA6U 16 | pHG8omQTyOzJl2I4z7DlS0AEsL0TJHV4iKr9rDei2xQz/wtful5qU/taYp2Y6zMH 17 | 8ytnDldJhmcCwmvtqvK5p6CbkatE7TFyw2CxQJHnQef+Y4W94sSZWg9CGRKDIQ== 18 | -----END CERTIFICATE----- 19 | -------------------------------------------------------------------------------- /.testdata/key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCYkTN1g/0Z3KkS 3 | 3w0lX9yhZkwiA0obXCeFs7hpRP0p4WlW3uADyXQ5h2MaYx8OCA7oGU7/dWOPhtE3 4 | rgFEz7IwLxcP5d02ukLGlFD69D6KLyTXwCFmvOWQ5fbOq4s73WTNDfYSTYNzeujD 5 | Cjeu/Bk0OVhdxbyZdyrpdm+UBfH8uIDoGeCRXnjinqG9HNOQx6r/S6FqC5j/7PrV 6 | l1i66WlqRzKEJB94uejfujrHq8RjQm/wzEutU5dfC39zEEEx75qQt7Jc0asm1AqA 7 | KSq34xn4rVajWrBZ/WudUUizHfaBDP61uPFvPyKWJDvTSdeoM9TPX0y0cjo6AwSr 8 | dLl7flrRAgMBAAECggEAJPTPNUEilxgncGXNZmdBJ2uDN536XoRFIpL1MbK/bFyo 9 | yp00QFaVK7ZK4EJwbFKxYbF3vFOwKT0sAsPIlOWGsTtG59fzbOVTdYzJzPBLEef3 10 | kbd9n8hUB3RdA5T0Ji0r1Kv0FlzmYZu9NDmOYXm5lTfq2tQiKj5+i4zf3EhQZLng 11 | 4wVxBT7yQUQcstJv5K1L6HVzunSYtbHx8ZVxmw+tJ4lMCK23KPlvncZZTT8chWdT 12 | 3GOp5nYIHk9E5jQnBnj7p73sxZUCZlb8uhLtdcgAXc4scptEVO+7n5zOaXIv40Oz 13 | yfkESgHcZWAMDvnkxdySHlD38Z2LIKDGbqR6O9wcwQKBgQDBO6fFPXO41nsxdVCB 14 | nhCgL2hsGjaxJzGBLaOJNVMMFRASN3Yqvs4N1Hn7lawRI/FRRffxjLkZfNGEBSF2 15 | OipdvX19Oe2hCZxvwHPoe5sb/Dh6KE7If1hRLOCXg/8E7ADBtAp94dam1WF4Kh6N 16 | Va6+n2YKif2rqye1YtRoUU46iQKBgQDKH/eMcMRUe9IySxHLogidOUwa0X7WrxF/ 17 | PkXGpPbHQtMOJF5cVzh+L+foUKXNM60lgmCH0438GKU7kirC/dVtD/bwE598/XFZ 18 | vnjPV7Adf9vBz9NN8cS/4uEfQYbvTRmrnrQK+ZhOe8hmwjapxqdWrVHNUtvx18vL 19 | qBwR4YjsCQKBgCycMx1MFJ1FludSKCXkcf4pM7hRTPMVE065VJnmn6eYbT9nYnZ3 20 | 2mZC+W5lnXXPkHSs7JLtZAZIVK5f6Nu8je9aQdBZQUz+RQlfquKvNp39WqSJDbcn 21 | /yGudKNGK+fc/Ee74vgw3Tdi57+wKaGDeHY1on8oYFHzj5VGnbb/nknRAoGBAK2Z 22 | hyQ4NmfZcU+A6mfbY0qmS5c9F5OMCZsgAQ374XiDDIK4+dKVlw/KVYRSwBTerXfp 23 | 4r7GFMzQ3hmsEM4o9YYWkCDiubjAdPp/fYOX7MtpZXWw6euoGzQzyObvgNVHgyTD 24 | yh8jAI1oA1c+t3RaCp+HfRq8b+vnTEI+wN0auF8BAoGBAJmw+GgHCZGpw2XPNu+X 25 | 8kuVGbQYAjTOXhBM4WzZyhfH1TWKLGn7C9YixhE2AW0UWKDvy+6OqPhe8q3KVms3 26 | 8YZ1W+vbUNEZNGE0XrB5ZMXfePiqisCz0jgP9OAuT+ii4aI3MAm3zgCEC6UTMvLq 27 | gNBu3Tcy6udxnUf7czzJDRtE 28 | -----END PRIVATE KEY----- 29 | -------------------------------------------------------------------------------- /.testdata/sample-root.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIEBDCCAuygAwIBAgIDAjppMA0GCSqGSIb3DQEBBQUAMEIxCzAJBgNVBAYTAlVT 3 | MRYwFAYDVQQKEw1HZW9UcnVzdCBJbmMuMRswGQYDVQQDExJHZW9UcnVzdCBHbG9i 4 | YWwgQ0EwHhcNMTMwNDA1MTUxNTU1WhcNMTUwNDA0MTUxNTU1WjBJMQswCQYDVQQG 5 | EwJVUzETMBEGA1UEChMKR29vZ2xlIEluYzElMCMGA1UEAxMcR29vZ2xlIEludGVy 6 | bmV0IEF1dGhvcml0eSBHMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB 7 | AJwqBHdc2FCROgajguDYUEi8iT/xGXAaiEZ+4I/F8YnOIe5a/mENtzJEiaB0C1NP 8 | VaTOgmKV7utZX8bhBYASxF6UP7xbSDj0U/ck5vuR6RXEz/RTDfRK/J9U3n2+oGtv 9 | h8DQUB8oMANA2ghzUWx//zo8pzcGjr1LEQTrfSTe5vn8MXH7lNVg8y5Kr0LSy+rE 10 | ahqyzFPdFUuLH8gZYR/Nnag+YyuENWllhMgZxUYi+FOVvuOAShDGKuy6lyARxzmZ 11 | EASg8GF6lSWMTlJ14rbtCMoU/M4iarNOz0YDl5cDfsCx3nuvRTPPuj5xt970JSXC 12 | DTWJnZ37DhF5iR43xa+OcmkCAwEAAaOB+zCB+DAfBgNVHSMEGDAWgBTAephojYn7 13 | qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wEgYD 14 | VR0TAQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAQYwOgYDVR0fBDMwMTAvoC2g 15 | K4YpaHR0cDovL2NybC5nZW90cnVzdC5jb20vY3Jscy9ndGdsb2JhbC5jcmwwPQYI 16 | KwYBBQUHAQEEMTAvMC0GCCsGAQUFBzABhiFodHRwOi8vZ3RnbG9iYWwtb2NzcC5n 17 | ZW90cnVzdC5jb20wFwYDVR0gBBAwDjAMBgorBgEEAdZ5AgUBMA0GCSqGSIb3DQEB 18 | BQUAA4IBAQA21waAESetKhSbOHezI6B1WLuxfoNCunLaHtiONgaX4PCVOzf9G0JY 19 | /iLIa704XtE7JW4S615ndkZAkNoUyHgN7ZVm2o6Gb4ChulYylYbc3GrKBIxbf/a/ 20 | zG+FA1jDaFETzf3I93k9mTXwVqO94FntT0QJo544evZG0R0SnU++0ED8Vf4GXjza 21 | HFa9llF7b1cq26KqltyMdMKVvvBulRP/F/A8rLIQjcxz++iPAsbw+zOzlTvjwsto 22 | WHPbqCRiOwY1nQ2pM714A5AuTHhdUDqB1O6gyHA43LL5Z/qHQF1hwFGPa4NrzQU6 23 | yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx 24 | -----END CERTIFICATE----- 25 | -------------------------------------------------------------------------------- /.testdata/test-img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/go-resty/resty/66256efbe00678090075a7862bf4113c8e183c8d/.testdata/test-img.png -------------------------------------------------------------------------------- /.testdata/text-file.txt: -------------------------------------------------------------------------------- 1 | THIS IS TEXT FILE FOR MULTIPART UPLOAD TEST :) 2 | 3 | - go-resty 4 | -------------------------------------------------------------------------------- /BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@bazel_gazelle//:def.bzl", "gazelle") 2 | load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") 3 | 4 | # gazelle:prefix resty.dev/v3 5 | # gazelle:go_naming_convention import_alias 6 | gazelle(name = "gazelle") 7 | 8 | go_library( 9 | name = "resty", 10 | srcs = [ 11 | "circuit_breaker.go", 12 | "client.go", 13 | "curl.go", 14 | "debug.go", 15 | "digest.go", 16 | "load_balancer.go", 17 | "middleware.go", 18 | "multipart.go", 19 | "redirect.go", 20 | "request.go", 21 | "response.go", 22 | "resty.go", 23 | "retry.go", 24 | "sse.go", 25 | "stream.go", 26 | "trace.go", 27 | "transport_dial.go", 28 | "transport_dial_wasm.go", 29 | "util.go", 30 | ], 31 | importpath = "resty.dev/v3", 32 | visibility = ["//visibility:public"], 33 | deps = ["@org_golang_x_net//publicsuffix:go_default_library"], 34 | ) 35 | 36 | go_test( 37 | name = "resty_test", 38 | srcs = [ 39 | "benchmark_test.go", 40 | "cert_watcher_test.go", 41 | "client_test.go", 42 | "context_test.go", 43 | "curl_test.go", 44 | "digest_test.go", 45 | "load_balancer_test.go", 46 | "middleware_test.go", 47 | "multipart_test.go", 48 | "request_test.go", 49 | "resty_test.go", 50 | "retry_test.go", 51 | "sse_test.go", 52 | "util_test.go", 53 | ], 54 | data = glob([".testdata/*"]), 55 | embed = [":resty"], 56 | ) 57 | 58 | alias( 59 | name = "go_default_library", 60 | actual = ":resty", 61 | visibility = ["//visibility:public"], 62 | ) 63 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015-present Jeevanandam M., https://myjeeva.com 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Resty Logo 3 |

4 |

Simple HTTP, REST, and SSE client library for Go

5 | 6 |

Resty Build Status 7 | Resty Code Coverage 8 | Go Report Card 9 | Resty GoDoc 10 | License Mentioned in Awesome Go

11 |

12 |

13 | 14 | 15 | ## Documentation 16 | 17 | Go to https://resty.dev and refer to godoc. 18 | 19 | ## Minimum Go Version 20 | 21 | Use `go1.21` and above. 22 | 23 | ## Support & Donate 24 | 25 | * Sponsor via [GitHub](https://github.com/sponsors/jeevatkm) 26 | * Donate via [PayPal](https://www.paypal.com/donate/?cmd=_donations&business=QWMZG74FW4QYC&lc=US&item_name=Resty+Library+for+Go¤cy_code=USD) 27 | 28 | ## Versioning 29 | 30 | Resty releases versions according to [Semantic Versioning](http://semver.org) 31 | 32 | * Resty v3 provides Go Vanity URL `resty.dev/v3`. 33 | * Resty v2 migrated away from `gopkg.in` service, `github.com/go-resty/resty/v2`. 34 | * Resty fully adapted to `go mod` capabilities since `v1.10.0` release. 35 | * Resty v1 series was using `gopkg.in` to provide versioning. `gopkg.in/resty.vX` points to appropriate tagged versions; `X` denotes version series number and it's a stable release for production use. For e.g. `gopkg.in/resty.v0`. 36 | 37 | ## Contribution 38 | 39 | I would welcome your contribution! 40 | 41 | * If you find any improvement or issue you want to fix, feel free to send a pull request. 42 | * The pull requests must include test cases for feature/fix/enhancement with patch coverage of 100%. 43 | * I have done my best to bring pretty good coverage. I would request contributors to do the same for their contribution. 44 | 45 | I always look forward to hearing feedback, appreciation, and real-world usage stories from Resty users on [GitHub Discussions](https://github.com/go-resty/resty/discussions). It means a lot to me. 46 | 47 | ## Creator 48 | 49 | [Jeevanandam M.](https://github.com/jeevatkm) (jeeva@myjeeva.com) 50 | 51 | 52 | ## Contributors 53 | 54 | Have a look on [Contributors](https://github.com/go-resty/resty/graphs/contributors) page. 55 | 56 | ## License Info 57 | 58 | Resty released under MIT [LICENSE](LICENSE). 59 | 60 | Resty [Documentation](https://github.com/go-resty/docs) and website released under Apache-2.0 [LICENSE](https://github.com/go-resty/docs/blob/main/LICENSE). 61 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "resty") 2 | 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | 5 | http_archive( 6 | name = "io_bazel_rules_go", 7 | sha256 = "80a98277ad1311dacd837f9b16db62887702e9f1d1c4c9f796d0121a46c8e184", 8 | urls = [ 9 | "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.46.0/rules_go-v0.46.0.zip", 10 | "https://github.com/bazelbuild/rules_go/releases/download/v0.46.0/rules_go-v0.46.0.zip", 11 | ], 12 | ) 13 | 14 | http_archive( 15 | name = "bazel_gazelle", 16 | sha256 = "62ca106be173579c0a167deb23358fdfe71ffa1e4cfdddf5582af26520f1c66f", 17 | urls = [ 18 | "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.23.0/bazel-gazelle-v0.23.0.tar.gz", 19 | "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.23.0/bazel-gazelle-v0.23.0.tar.gz", 20 | ], 21 | ) 22 | 23 | load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") 24 | 25 | go_rules_dependencies() 26 | 27 | go_register_toolchains(version = "1.21") 28 | 29 | load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies") 30 | 31 | gazelle_dependencies() 32 | -------------------------------------------------------------------------------- /benchmark_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "bytes" 10 | "strings" 11 | "testing" 12 | ) 13 | 14 | func Benchmark_parseRequestURL_PathParams(b *testing.B) { 15 | c := New().SetPathParams(map[string]string{ 16 | "foo": "1", 17 | "bar": "2", 18 | }).SetRawPathParams(map[string]string{ 19 | "foo": "3", 20 | "xyz": "4", 21 | }) 22 | r := c.R().SetPathParams(map[string]string{ 23 | "foo": "5", 24 | "qwe": "6", 25 | }).SetRawPathParams(map[string]string{ 26 | "foo": "7", 27 | "asd": "8", 28 | }) 29 | b.ResetTimer() 30 | for i := 0; i < b.N; i++ { 31 | r.URL = "https://example.com/{foo}/{bar}/{xyz}/{qwe}/{asd}" 32 | if err := parseRequestURL(c, r); err != nil { 33 | b.Errorf("parseRequestURL() error = %v", err) 34 | } 35 | } 36 | } 37 | 38 | func Benchmark_parseRequestURL_QueryParams(b *testing.B) { 39 | c := New().SetQueryParams(map[string]string{ 40 | "foo": "1", 41 | "bar": "2", 42 | }) 43 | r := c.R().SetQueryParams(map[string]string{ 44 | "foo": "5", 45 | "qwe": "6", 46 | }) 47 | b.ResetTimer() 48 | for i := 0; i < b.N; i++ { 49 | r.URL = "https://example.com/" 50 | if err := parseRequestURL(c, r); err != nil { 51 | b.Errorf("parseRequestURL() error = %v", err) 52 | } 53 | } 54 | } 55 | 56 | func Benchmark_parseRequestHeader(b *testing.B) { 57 | c := New() 58 | r := c.R() 59 | c.SetHeaders(map[string]string{ 60 | "foo": "1", // ignored, because of the same header in the request 61 | "bar": "2", 62 | }) 63 | r.SetHeaders(map[string]string{ 64 | "foo": "3", 65 | "xyz": "4", 66 | }) 67 | b.ResetTimer() 68 | for i := 0; i < b.N; i++ { 69 | if err := parseRequestHeader(c, r); err != nil { 70 | b.Errorf("parseRequestHeader() error = %v", err) 71 | } 72 | } 73 | } 74 | 75 | func Benchmark_parseRequestBody_string(b *testing.B) { 76 | c := New() 77 | r := c.R() 78 | r.SetBody("foo").SetContentLength(true) 79 | b.ResetTimer() 80 | for i := 0; i < b.N; i++ { 81 | if err := parseRequestBody(c, r); err != nil { 82 | b.Errorf("parseRequestBody() error = %v", err) 83 | } 84 | } 85 | } 86 | 87 | func Benchmark_parseRequestBody_byte(b *testing.B) { 88 | c := New() 89 | r := c.R() 90 | r.SetBody([]byte("foo")).SetContentLength(true) 91 | b.ResetTimer() 92 | for i := 0; i < b.N; i++ { 93 | if err := parseRequestBody(c, r); err != nil { 94 | b.Errorf("parseRequestBody() error = %v", err) 95 | } 96 | } 97 | } 98 | 99 | func Benchmark_parseRequestBody_reader(b *testing.B) { 100 | c := New() 101 | r := c.R() 102 | r.SetBody(bytes.NewBufferString("foo")) 103 | b.ResetTimer() 104 | for i := 0; i < b.N; i++ { 105 | if err := parseRequestBody(c, r); err != nil { 106 | b.Errorf("parseRequestBody() error = %v", err) 107 | } 108 | } 109 | } 110 | 111 | func Benchmark_parseRequestBody_struct(b *testing.B) { 112 | type FooBar struct { 113 | Foo string `json:"foo"` 114 | Bar string `json:"bar"` 115 | } 116 | c := New() 117 | r := c.R() 118 | r.SetBody(FooBar{Foo: "1", Bar: "2"}).SetContentLength(true).SetHeader(hdrContentTypeKey, jsonContentType) 119 | b.ResetTimer() 120 | for i := 0; i < b.N; i++ { 121 | if err := parseRequestBody(c, r); err != nil { 122 | b.Errorf("parseRequestBody() error = %v", err) 123 | } 124 | } 125 | } 126 | 127 | func Benchmark_parseRequestBody_struct_xml(b *testing.B) { 128 | type FooBar struct { 129 | Foo string `xml:"foo"` 130 | Bar string `xml:"bar"` 131 | } 132 | c := New() 133 | r := c.R() 134 | r.SetBody(FooBar{Foo: "1", Bar: "2"}).SetContentLength(true).SetHeader(hdrContentTypeKey, "text/xml") 135 | b.ResetTimer() 136 | for i := 0; i < b.N; i++ { 137 | if err := parseRequestBody(c, r); err != nil { 138 | b.Errorf("parseRequestBody() error = %v", err) 139 | } 140 | } 141 | } 142 | 143 | func Benchmark_parseRequestBody_map(b *testing.B) { 144 | c := New() 145 | r := c.R() 146 | r.SetBody(map[string]string{ 147 | "foo": "1", 148 | "bar": "2", 149 | }).SetContentLength(true).SetHeader(hdrContentTypeKey, jsonContentType) 150 | b.ResetTimer() 151 | for i := 0; i < b.N; i++ { 152 | if err := parseRequestBody(c, r); err != nil { 153 | b.Errorf("parseRequestBody() error = %v", err) 154 | } 155 | } 156 | } 157 | 158 | func Benchmark_parseRequestBody_slice(b *testing.B) { 159 | c := New() 160 | r := c.R() 161 | r.SetBody([]string{"1", "2"}).SetContentLength(true).SetHeader(hdrContentTypeKey, jsonContentType) 162 | b.ResetTimer() 163 | for i := 0; i < b.N; i++ { 164 | if err := parseRequestBody(c, r); err != nil { 165 | b.Errorf("parseRequestBody() error = %v", err) 166 | } 167 | } 168 | } 169 | 170 | func Benchmark_parseRequestBody_FormData(b *testing.B) { 171 | c := New() 172 | r := c.R() 173 | c.SetFormData(map[string]string{"foo": "1", "bar": "2"}) 174 | r.SetFormData(map[string]string{"foo": "3", "baz": "4"}).SetContentLength(true) 175 | b.ResetTimer() 176 | for i := 0; i < b.N; i++ { 177 | if err := parseRequestBody(c, r); err != nil { 178 | b.Errorf("parseRequestBody() error = %v", err) 179 | } 180 | } 181 | } 182 | 183 | func Benchmark_parseRequestBody_MultiPart(b *testing.B) { 184 | c := New() 185 | r := c.R() 186 | c.SetFormData(map[string]string{"foo": "1", "bar": "2"}) 187 | r.SetFormData(map[string]string{"foo": "3", "baz": "4"}). 188 | SetMultipartFormData(map[string]string{"foo": "5", "xyz": "6"}). 189 | SetFileReader("qwe", "qwe.txt", strings.NewReader("7")). 190 | SetMultipartFields( 191 | &MultipartField{ 192 | Name: "sdj", 193 | ContentType: "text/plain", 194 | Reader: strings.NewReader("8"), 195 | }, 196 | ). 197 | SetContentLength(true). 198 | SetMethod(MethodPost) 199 | b.ResetTimer() 200 | for i := 0; i < b.N; i++ { 201 | if err := parseRequestBody(c, r); err != nil { 202 | b.Errorf("parseRequestBody() error = %v", err) 203 | } 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /cert_watcher_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "crypto/rand" 10 | "crypto/rsa" 11 | "crypto/x509" 12 | "crypto/x509/pkix" 13 | "encoding/pem" 14 | "math/big" 15 | "net" 16 | "net/http" 17 | "os" 18 | "path/filepath" 19 | "strings" 20 | "testing" 21 | "time" 22 | ) 23 | 24 | type certPaths struct { 25 | RootCAKey string 26 | RootCACert string 27 | TLSKey string 28 | TLSCert string 29 | } 30 | 31 | func TestClient_SetRootCertificateWatcher(t *testing.T) { 32 | // For this test, we want to: 33 | // - Generate root CA 34 | // - Generate TLS cert signed with root CA 35 | // - Start a Test HTTPS server 36 | // - Create a Resty client with SetRootCertificateWatcher and SetClientRootCertificateWatcher 37 | // - Send multiple requests and re-generate the certs periodically to reproduce renewal 38 | 39 | certDir := t.TempDir() 40 | paths := certPaths{ 41 | RootCAKey: filepath.Join(certDir, "root-ca.key"), 42 | RootCACert: filepath.Join(certDir, "root-ca.crt"), 43 | TLSKey: filepath.Join(certDir, "tls.key"), 44 | TLSCert: filepath.Join(certDir, "tls.crt"), 45 | } 46 | 47 | generateCerts(t, paths) 48 | 49 | ts := createTestTLSServer(func(w http.ResponseWriter, r *http.Request) { 50 | w.WriteHeader(http.StatusOK) 51 | }, paths.TLSCert, paths.TLSKey) 52 | defer ts.Close() 53 | 54 | poolingInterval := 100 * time.Millisecond 55 | 56 | client := NewWithTransportSettings(&TransportSettings{ 57 | // Make sure that TLS handshake happens for all request 58 | // (otherwise, test may succeed because 1st TLS session is re-used) 59 | DisableKeepAlives: true, 60 | }).SetRootCertificatesWatcher( 61 | &CertWatcherOptions{PoolInterval: poolingInterval}, 62 | paths.RootCACert, 63 | ).SetClientRootCertificatesWatcher( 64 | &CertWatcherOptions{PoolInterval: poolingInterval}, 65 | paths.RootCACert, 66 | ).SetDebug(false) 67 | 68 | url := strings.Replace(ts.URL, "127.0.0.1", "localhost", 1) 69 | t.Log("Test URL:", url) 70 | 71 | t.Run("Cert Watcher should handle certs rotation", func(t *testing.T) { 72 | for i := 0; i < 5; i++ { 73 | res, err := client.R().Get(url) 74 | if err != nil { 75 | t.Fatal(err) 76 | } 77 | 78 | assertEqual(t, res.StatusCode(), http.StatusOK) 79 | 80 | if i%2 == 1 { 81 | // Re-generate certs to simulate renewal scenario 82 | generateCerts(t, paths) 83 | time.Sleep(50 * time.Millisecond) 84 | } 85 | 86 | } 87 | }) 88 | 89 | t.Run("Cert Watcher should recover on failure", func(t *testing.T) { 90 | // Delete root cert and re-create it to ensure that cert watcher is able to recover 91 | 92 | // Re-generate certs to invalidate existing cert 93 | generateCerts(t, paths) 94 | // Delete root cert so that Cert Watcher will fail 95 | err := os.RemoveAll(paths.RootCACert) 96 | assertNil(t, err) 97 | 98 | // Reset TLS config to ensure that previous root cert is not re-used 99 | tr, err := client.HTTPTransport() 100 | assertNil(t, err) 101 | tr.TLSClientConfig = nil 102 | client.SetTransport(tr) 103 | 104 | time.Sleep(50 * time.Millisecond) 105 | 106 | _, err = client.R().Get(url) 107 | // We expect an error since root cert has been deleted 108 | assertNotNil(t, err) 109 | 110 | // Re-generate certs. We except cert watcher to reload the new root cert. 111 | generateCerts(t, paths) 112 | time.Sleep(50 * time.Millisecond) 113 | _, err = client.R().Get(url) 114 | assertNil(t, err) 115 | }) 116 | 117 | err := client.Close() 118 | assertNil(t, err) 119 | } 120 | 121 | func generateCerts(t *testing.T, paths certPaths) { 122 | rootKey, rootCert, err := generateRootCA(paths.RootCAKey, paths.RootCACert) 123 | if err != nil { 124 | t.Fatal(err) 125 | } 126 | 127 | if err := generateTLSCert(paths.TLSKey, paths.TLSCert, rootKey, rootCert); err != nil { 128 | t.Fatal(err) 129 | } 130 | } 131 | 132 | // Generate a Root Certificate Authority (CA) 133 | func generateRootCA(keyPath, certPath string) (*rsa.PrivateKey, []byte, error) { 134 | // Generate the key for the Root CA 135 | rootKey, err := generateKey() 136 | if err != nil { 137 | return nil, nil, err 138 | } 139 | 140 | // Define the maximum value you want for the random big integer 141 | max := new(big.Int).Lsh(big.NewInt(1), 256) // Example: 256 bits 142 | 143 | // Generate a random big.Int 144 | randomBigInt, err := rand.Int(rand.Reader, max) 145 | if err != nil { 146 | return nil, nil, err 147 | } 148 | 149 | // Create the root certificate template 150 | rootCertTemplate := &x509.Certificate{ 151 | SerialNumber: randomBigInt, 152 | Subject: pkix.Name{ 153 | Organization: []string{"YourOrg"}, 154 | Country: []string{"US"}, 155 | Province: []string{"State"}, 156 | Locality: []string{"City"}, 157 | CommonName: "YourRootCA", 158 | }, 159 | NotBefore: time.Now(), 160 | NotAfter: time.Now().Add(time.Hour * 10), 161 | KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, 162 | IsCA: true, 163 | BasicConstraintsValid: true, 164 | } 165 | 166 | // Self-sign the root certificate 167 | rootCert, err := x509.CreateCertificate(rand.Reader, rootCertTemplate, rootCertTemplate, &rootKey.PublicKey, rootKey) 168 | if err != nil { 169 | return nil, nil, err 170 | } 171 | 172 | // Save the Root CA key and certificate 173 | if err := savePEMKey(keyPath, rootKey); err != nil { 174 | return nil, nil, err 175 | } 176 | if err := savePEMCert(certPath, rootCert); err != nil { 177 | return nil, nil, err 178 | } 179 | 180 | return rootKey, rootCert, nil 181 | } 182 | 183 | // Generate a TLS Certificate signed by the Root CA 184 | func generateTLSCert(keyPath, certPath string, rootKey *rsa.PrivateKey, rootCert []byte) error { 185 | // Generate a key for the server 186 | serverKey, err := generateKey() 187 | if err != nil { 188 | return err 189 | } 190 | 191 | // Parse the Root CA certificate 192 | parsedRootCert, err := x509.ParseCertificate(rootCert) 193 | if err != nil { 194 | return err 195 | } 196 | 197 | // Create the server certificate template 198 | serverCertTemplate := &x509.Certificate{ 199 | SerialNumber: big.NewInt(2), 200 | Subject: pkix.Name{ 201 | Organization: []string{"YourOrg"}, 202 | CommonName: "localhost", 203 | }, 204 | NotBefore: time.Now(), 205 | NotAfter: time.Now().Add(time.Hour * 10), 206 | KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, 207 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 208 | IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, 209 | DNSNames: []string{"localhost"}, 210 | } 211 | 212 | // Sign the server certificate with the Root CA 213 | serverCert, err := x509.CreateCertificate(rand.Reader, serverCertTemplate, parsedRootCert, &serverKey.PublicKey, rootKey) 214 | if err != nil { 215 | return err 216 | } 217 | 218 | // Save the server key and certificate 219 | if err := savePEMKey(keyPath, serverKey); err != nil { 220 | return err 221 | } 222 | if err := savePEMCert(certPath, serverCert); err != nil { 223 | return err 224 | } 225 | 226 | return nil 227 | } 228 | 229 | func generateKey() (*rsa.PrivateKey, error) { 230 | return rsa.GenerateKey(rand.Reader, 2048) 231 | } 232 | 233 | func savePEMKey(fileName string, key *rsa.PrivateKey) error { 234 | keyFile, err := os.Create(fileName) 235 | if err != nil { 236 | return err 237 | } 238 | defer keyFile.Close() 239 | 240 | privateKeyPEM := &pem.Block{ 241 | Type: "RSA PRIVATE KEY", 242 | Bytes: x509.MarshalPKCS1PrivateKey(key), 243 | } 244 | 245 | return pem.Encode(keyFile, privateKeyPEM) 246 | } 247 | 248 | func savePEMCert(fileName string, cert []byte) error { 249 | certFile, err := os.Create(fileName) 250 | if err != nil { 251 | return err 252 | } 253 | defer certFile.Close() 254 | 255 | certPEM := &pem.Block{ 256 | Type: "CERTIFICATE", 257 | Bytes: cert, 258 | } 259 | 260 | return pem.Encode(certFile, certPEM) 261 | } 262 | -------------------------------------------------------------------------------- /circuit_breaker.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "errors" 10 | "net/http" 11 | "sync/atomic" 12 | "time" 13 | ) 14 | 15 | // CircuitBreaker struct implements a state machine to monitor and manage the 16 | // states of circuit breakers. The three states are: 17 | // - Closed: requests are allowed 18 | // - Open: requests are blocked 19 | // - Half-Open: a single request is allowed to determine 20 | // 21 | // Transitions 22 | // - To Closed State: when the success count reaches the success threshold. 23 | // - To Open State: when the failure count reaches the failure threshold. 24 | // - Half-Open Check: when the specified timeout reaches, a single request is allowed 25 | // to determine the transition state; if failed, it goes back to the open state. 26 | type CircuitBreaker struct { 27 | policies []CircuitBreakerPolicy 28 | timeout time.Duration 29 | failureThreshold uint32 30 | successThreshold uint32 31 | state atomic.Value // circuitBreakerState 32 | failureCount atomic.Uint32 33 | successCount atomic.Uint32 34 | lastFailureAt time.Time 35 | } 36 | 37 | // NewCircuitBreaker method creates a new [CircuitBreaker] with default settings. 38 | // 39 | // The default settings are: 40 | // - Timeout: 10 seconds 41 | // - FailThreshold: 3 42 | // - SuccessThreshold: 1 43 | // - Policies: CircuitBreaker5xxPolicy 44 | func NewCircuitBreaker() *CircuitBreaker { 45 | cb := &CircuitBreaker{ 46 | policies: []CircuitBreakerPolicy{CircuitBreaker5xxPolicy}, 47 | timeout: 10 * time.Second, 48 | failureThreshold: 3, 49 | successThreshold: 1, 50 | } 51 | cb.state.Store(circuitBreakerStateClosed) 52 | return cb 53 | } 54 | 55 | // SetPolicies method sets the one or more given CircuitBreakerPolicy(s) into 56 | // [CircuitBreaker], which will be used to determine whether a request is failed 57 | // or successful by evaluating the response instance. 58 | // 59 | // // set one policy 60 | // cb.SetPolicies(CircuitBreaker5xxPolicy) 61 | // 62 | // // set multiple polices 63 | // cb.SetPolicies(policy1, policy2, policy3) 64 | // 65 | // // if you have slice, do 66 | // cb.SetPolicies(policies...) 67 | // 68 | // NOTE: This method overwrites the policies with the given new ones. See [CircuitBreaker.AddPolicies] 69 | func (cb *CircuitBreaker) SetPolicies(policies ...CircuitBreakerPolicy) *CircuitBreaker { 70 | cb.policies = policies 71 | return cb 72 | } 73 | 74 | // SetTimeout method sets the timeout duration for the [CircuitBreaker]. When the 75 | // timeout reaches, a single request is allowed to determine the state. 76 | func (cb *CircuitBreaker) SetTimeout(timeout time.Duration) *CircuitBreaker { 77 | cb.timeout = timeout 78 | return cb 79 | } 80 | 81 | // SetFailureThreshold method sets the number of failures that must occur within the 82 | // timeout duration for the [CircuitBreaker] to transition to the Open state. 83 | func (cb *CircuitBreaker) SetFailureThreshold(threshold uint32) *CircuitBreaker { 84 | cb.failureThreshold = threshold 85 | return cb 86 | } 87 | 88 | // SetSuccessThreshold method sets the number of successes that must occur to transition 89 | // the [CircuitBreaker] from the Half-Open state to the Closed state. 90 | func (cb *CircuitBreaker) SetSuccessThreshold(threshold uint32) *CircuitBreaker { 91 | cb.successThreshold = threshold 92 | return cb 93 | } 94 | 95 | // CircuitBreakerPolicy is a function type that determines whether a response should 96 | // trip the [CircuitBreaker]. 97 | type CircuitBreakerPolicy func(resp *http.Response) bool 98 | 99 | // CircuitBreaker5xxPolicy is a [CircuitBreakerPolicy] that trips the [CircuitBreaker] if 100 | // the response status code is 500 or greater. 101 | func CircuitBreaker5xxPolicy(resp *http.Response) bool { 102 | return resp.StatusCode > 499 103 | } 104 | 105 | var ErrCircuitBreakerOpen = errors.New("resty: circuit breaker open") 106 | 107 | type circuitBreakerState uint32 108 | 109 | const ( 110 | circuitBreakerStateClosed circuitBreakerState = iota 111 | circuitBreakerStateOpen 112 | circuitBreakerStateHalfOpen 113 | ) 114 | 115 | func (cb *CircuitBreaker) getState() circuitBreakerState { 116 | return cb.state.Load().(circuitBreakerState) 117 | } 118 | 119 | func (cb *CircuitBreaker) allow() error { 120 | if cb == nil { 121 | return nil 122 | } 123 | 124 | if cb.getState() == circuitBreakerStateOpen { 125 | return ErrCircuitBreakerOpen 126 | } 127 | 128 | return nil 129 | } 130 | 131 | func (cb *CircuitBreaker) applyPolicies(resp *http.Response) { 132 | if cb == nil { 133 | return 134 | } 135 | 136 | failed := false 137 | for _, policy := range cb.policies { 138 | if policy(resp) { 139 | failed = true 140 | break 141 | } 142 | } 143 | 144 | if failed { 145 | if cb.failureCount.Load() > 0 && time.Since(cb.lastFailureAt) > cb.timeout { 146 | cb.failureCount.Store(0) 147 | } 148 | 149 | switch cb.getState() { 150 | case circuitBreakerStateClosed: 151 | failCount := cb.failureCount.Add(1) 152 | if failCount >= cb.failureThreshold { 153 | cb.open() 154 | } else { 155 | cb.lastFailureAt = time.Now() 156 | } 157 | case circuitBreakerStateHalfOpen: 158 | cb.open() 159 | } 160 | } else { 161 | switch cb.getState() { 162 | case circuitBreakerStateClosed: 163 | return 164 | case circuitBreakerStateHalfOpen: 165 | successCount := cb.successCount.Add(1) 166 | if successCount >= cb.successThreshold { 167 | cb.changeState(circuitBreakerStateClosed) 168 | } 169 | } 170 | } 171 | } 172 | 173 | func (cb *CircuitBreaker) open() { 174 | cb.changeState(circuitBreakerStateOpen) 175 | go func() { 176 | time.Sleep(cb.timeout) 177 | cb.changeState(circuitBreakerStateHalfOpen) 178 | }() 179 | } 180 | 181 | func (cb *CircuitBreaker) changeState(state circuitBreakerState) { 182 | cb.failureCount.Store(0) 183 | cb.successCount.Store(0) 184 | cb.state.Store(state) 185 | } 186 | -------------------------------------------------------------------------------- /context_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // 2016 Andrew Grigorev (https://github.com/ei-grad) 3 | // resty source code and usage is governed by a MIT style 4 | // license that can be found in the LICENSE file. 5 | // SPDX-License-Identifier: MIT 6 | 7 | package resty 8 | 9 | import ( 10 | "context" 11 | "errors" 12 | "net/http" 13 | "sync/atomic" 14 | "testing" 15 | "time" 16 | ) 17 | 18 | func TestClientSetContext(t *testing.T) { 19 | ts := createGetServer(t) 20 | defer ts.Close() 21 | 22 | c := dcnl() 23 | 24 | assertNil(t, c.Context()) 25 | 26 | c.SetContext(context.Background()) 27 | 28 | resp, err := c.R().Get(ts.URL + "/") 29 | 30 | assertError(t, err) 31 | assertEqual(t, http.StatusOK, resp.StatusCode()) 32 | assertEqual(t, "200 OK", resp.Status()) 33 | assertEqual(t, "TestGet: text response", resp.String()) 34 | 35 | logResponse(t, resp) 36 | } 37 | 38 | func TestRequestSetContext(t *testing.T) { 39 | ts := createGetServer(t) 40 | defer ts.Close() 41 | 42 | resp, err := dcnl().R(). 43 | SetContext(context.Background()). 44 | Get(ts.URL + "/") 45 | 46 | assertError(t, err) 47 | assertEqual(t, http.StatusOK, resp.StatusCode()) 48 | assertEqual(t, "200 OK", resp.Status()) 49 | assertEqual(t, "TestGet: text response", resp.String()) 50 | 51 | logResponse(t, resp) 52 | } 53 | 54 | func TestSetContextWithError(t *testing.T) { 55 | ts := createGetServer(t) 56 | defer ts.Close() 57 | 58 | resp, err := dcnlr(). 59 | SetContext(context.Background()). 60 | Get(ts.URL + "/mypage") 61 | 62 | assertError(t, err) 63 | assertEqual(t, http.StatusBadRequest, resp.StatusCode()) 64 | assertEqual(t, "", resp.String()) 65 | 66 | logResponse(t, resp) 67 | } 68 | 69 | func TestSetContextCancel(t *testing.T) { 70 | ch := make(chan struct{}) 71 | ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { 72 | defer func() { 73 | ch <- struct{}{} // tell test request is finished 74 | }() 75 | t.Logf("Server: %v %v", r.Method, r.URL.Path) 76 | ch <- struct{}{} 77 | <-ch // wait for client to finish request 78 | n, err := w.Write([]byte("TestSetContextCancel: response")) 79 | // FIXME? test server doesn't handle request cancellation 80 | t.Logf("Server: wrote %d bytes", n) 81 | t.Logf("Server: err is %v ", err) 82 | }) 83 | defer ts.Close() 84 | 85 | ctx, cancel := context.WithCancel(context.Background()) 86 | 87 | go func() { 88 | <-ch // wait for server to start request handling 89 | cancel() 90 | }() 91 | 92 | _, err := dcnl().R(). 93 | SetContext(ctx). 94 | Get(ts.URL + "/") 95 | 96 | ch <- struct{}{} // tell server to continue request handling 97 | 98 | <-ch // wait for server to finish request handling 99 | 100 | t.Logf("Error: %v", err) 101 | if !errIsContextCanceled(err) { 102 | t.Errorf("Got unexpected error: %v", err) 103 | } 104 | } 105 | 106 | func TestSetContextCancelRetry(t *testing.T) { 107 | reqCount := 0 108 | ch := make(chan struct{}) 109 | ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { 110 | reqCount++ 111 | defer func() { 112 | ch <- struct{}{} // tell test request is finished 113 | }() 114 | t.Logf("Server: %v %v", r.Method, r.URL.Path) 115 | ch <- struct{}{} 116 | <-ch // wait for client to finish request 117 | n, err := w.Write([]byte("TestSetContextCancel: response")) 118 | // FIXME? test server doesn't handle request cancellation 119 | t.Logf("Server: wrote %d bytes", n) 120 | t.Logf("Server: err is %v ", err) 121 | }) 122 | defer ts.Close() 123 | 124 | ctx, cancel := context.WithCancel(context.Background()) 125 | 126 | go func() { 127 | <-ch // wait for server to start request handling 128 | cancel() 129 | }() 130 | 131 | c := dcnl(). 132 | SetTimeout(time.Second * 3). 133 | SetRetryCount(3) 134 | 135 | _, err := c.R(). 136 | SetContext(ctx). 137 | Get(ts.URL + "/") 138 | 139 | ch <- struct{}{} // tell server to continue request handling 140 | 141 | <-ch // wait for server to finish request handling 142 | 143 | t.Logf("Error: %v", err) 144 | if !errIsContextCanceled(err) { 145 | t.Errorf("Got unexpected error: %v", err) 146 | } 147 | 148 | if reqCount != 1 { 149 | t.Errorf("Request was retried %d times instead of 1", reqCount) 150 | } 151 | } 152 | 153 | func TestSetContextCancelWithError(t *testing.T) { 154 | ch := make(chan struct{}) 155 | ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { 156 | defer func() { 157 | ch <- struct{}{} // tell test request is finished 158 | }() 159 | t.Logf("Server: %v %v", r.Method, r.URL.Path) 160 | t.Log("Server: sending StatusBadRequest response") 161 | w.WriteHeader(http.StatusBadRequest) 162 | ch <- struct{}{} 163 | <-ch // wait for client to finish request 164 | n, err := w.Write([]byte("TestSetContextCancelWithError: response")) 165 | // FIXME? test server doesn't handle request cancellation 166 | t.Logf("Server: wrote %d bytes", n) 167 | t.Logf("Server: err is %v ", err) 168 | }) 169 | defer ts.Close() 170 | 171 | ctx, cancel := context.WithCancel(context.Background()) 172 | 173 | go func() { 174 | <-ch // wait for server to start request handling 175 | cancel() 176 | }() 177 | 178 | _, err := dcnl().R(). 179 | SetContext(ctx). 180 | Get(ts.URL + "/") 181 | 182 | ch <- struct{}{} // tell server to continue request handling 183 | 184 | <-ch // wait for server to finish request handling 185 | 186 | t.Logf("Error: %v", err) 187 | if !errIsContextCanceled(err) { 188 | t.Errorf("Got unexpected error: %v", err) 189 | } 190 | } 191 | 192 | func TestClientRetryWithSetContext(t *testing.T) { 193 | var attempt int32 194 | ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { 195 | t.Logf("Method: %v", r.Method) 196 | t.Logf("Path: %v", r.URL.Path) 197 | if atomic.AddInt32(&attempt, 1) <= 4 { 198 | time.Sleep(100 * time.Millisecond) 199 | } 200 | _, _ = w.Write([]byte("TestClientRetry page")) 201 | }) 202 | defer ts.Close() 203 | 204 | c := dcnl(). 205 | SetTimeout(50 * time.Millisecond). 206 | SetRetryCount(3) 207 | 208 | _, err := c.R(). 209 | SetContext(context.Background()). 210 | Get(ts.URL + "/") 211 | 212 | assertNotNil(t, ts) 213 | assertNotNil(t, err) 214 | assertEqual(t, true, errors.Is(err, context.DeadlineExceeded)) 215 | } 216 | 217 | func TestRequestContext(t *testing.T) { 218 | client := dcnl() 219 | r := client.NewRequest() 220 | assertNotNil(t, r.Context()) 221 | 222 | r.SetContext(context.Background()) 223 | assertNotNil(t, r.Context()) 224 | } 225 | 226 | func errIsContextCanceled(err error) bool { 227 | return errors.Is(err, context.Canceled) 228 | } 229 | -------------------------------------------------------------------------------- /curl.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "bytes" 10 | "io" 11 | "net/http" 12 | "regexp" 13 | 14 | "net/url" 15 | "strings" 16 | ) 17 | 18 | func buildCurlCmd(req *Request) string { 19 | // generate curl raw headers 20 | var curl = "curl -X " + req.Method + " " 21 | headers := dumpCurlHeaders(req.RawRequest) 22 | for _, kv := range *headers { 23 | curl += "-H " + cmdQuote(kv[0]+": "+kv[1]) + " " 24 | } 25 | 26 | // generate curl cookies 27 | if cookieJar := req.client.CookieJar(); cookieJar != nil { 28 | if cookies := cookieJar.Cookies(req.RawRequest.URL); len(cookies) > 0 { 29 | curl += "-H " + cmdQuote(dumpCurlCookies(cookies)) + " " 30 | } 31 | } 32 | 33 | // generate curl body except for io.Reader and multipart request flow 34 | if req.RawRequest.GetBody != nil { 35 | body, err := req.RawRequest.GetBody() 36 | if err == nil { 37 | buf, _ := io.ReadAll(body) 38 | curl += "-d " + cmdQuote(string(bytes.TrimRight(buf, "\n"))) + " " 39 | } else { 40 | req.log.Errorf("curl: %v", err) 41 | curl += "-d ''" 42 | } 43 | } 44 | 45 | urlString := cmdQuote(req.RawRequest.URL.String()) 46 | if urlString == "''" { 47 | urlString = "'http://unexecuted-request'" 48 | } 49 | curl += urlString 50 | return curl 51 | } 52 | 53 | // dumpCurlCookies dumps cookies to curl format 54 | func dumpCurlCookies(cookies []*http.Cookie) string { 55 | sb := strings.Builder{} 56 | sb.WriteString("Cookie: ") 57 | for _, cookie := range cookies { 58 | sb.WriteString(cookie.Name + "=" + url.QueryEscape(cookie.Value) + "&") 59 | } 60 | return strings.TrimRight(sb.String(), "&") 61 | } 62 | 63 | // dumpCurlHeaders dumps headers to curl format 64 | func dumpCurlHeaders(req *http.Request) *[][2]string { 65 | headers := [][2]string{} 66 | for k, vs := range req.Header { 67 | for _, v := range vs { 68 | headers = append(headers, [2]string{k, v}) 69 | } 70 | } 71 | n := len(headers) 72 | for i := 0; i < n; i++ { 73 | for j := n - 1; j > i; j-- { 74 | jj := j - 1 75 | h1, h2 := headers[j], headers[jj] 76 | if h1[0] < h2[0] { 77 | headers[jj], headers[j] = headers[j], headers[jj] 78 | } 79 | } 80 | } 81 | return &headers 82 | } 83 | 84 | var regexCmdQuote = regexp.MustCompile(`[^\w@%+=:,./-]`) 85 | 86 | // cmdQuote method to escape arbitrary strings for a safe use as 87 | // command line arguments in the most common POSIX shells. 88 | // 89 | // The original Python package which this work was inspired by can be found 90 | // at https://pypi.python.org/pypi/shellescape. 91 | func cmdQuote(s string) string { 92 | if len(s) == 0 { 93 | return "''" 94 | } 95 | 96 | if regexCmdQuote.MatchString(s) { 97 | return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'" 98 | } 99 | 100 | return s 101 | } 102 | -------------------------------------------------------------------------------- /curl_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "bytes" 10 | "errors" 11 | "io" 12 | "net/http" 13 | "net/http/cookiejar" 14 | "strings" 15 | "testing" 16 | ) 17 | 18 | func TestCurlGenerateUnexecutedRequest(t *testing.T) { 19 | req := dcnldr(). 20 | SetBody(map[string]string{ 21 | "name": "Resty", 22 | }). 23 | SetCookies( 24 | []*http.Cookie{ 25 | {Name: "count", Value: "1"}, 26 | }, 27 | ). 28 | SetMethod(MethodPost) 29 | 30 | assertEqual(t, "", req.CurlCmd()) 31 | 32 | curlCmdUnexecuted := req.EnableGenerateCurlCmd().CurlCmd() 33 | req.DisableGenerateCurlCmd() 34 | 35 | if !strings.Contains(curlCmdUnexecuted, "Cookie: count=1") || 36 | !strings.Contains(curlCmdUnexecuted, "curl -X POST") || 37 | !strings.Contains(curlCmdUnexecuted, `-d '{"name":"Resty"}'`) { 38 | t.Fatal("Incomplete curl:", curlCmdUnexecuted) 39 | } else { 40 | t.Log("curlCmdUnexecuted: \n", curlCmdUnexecuted) 41 | } 42 | 43 | } 44 | 45 | func TestCurlGenerateExecutedRequest(t *testing.T) { 46 | ts := createPostServer(t) 47 | defer ts.Close() 48 | 49 | data := map[string]string{ 50 | "name": "Resty", 51 | } 52 | c := dcnl().EnableDebug() 53 | req := c.R(). 54 | SetBody(data). 55 | SetCookies( 56 | []*http.Cookie{ 57 | {Name: "count", Value: "1"}, 58 | }, 59 | ) 60 | 61 | url := ts.URL + "/curl-cmd-post" 62 | resp, err := req. 63 | EnableGenerateCurlCmd(). 64 | Post(url) 65 | if err != nil { 66 | t.Fatal(err) 67 | } 68 | curlCmdExecuted := resp.Request.CurlCmd() 69 | 70 | c.DisableGenerateCurlCmd() 71 | req.DisableGenerateCurlCmd() 72 | if !strings.Contains(curlCmdExecuted, "Cookie: count=1") || 73 | !strings.Contains(curlCmdExecuted, "curl -X POST") || 74 | !strings.Contains(curlCmdExecuted, `-d '{"name":"Resty"}'`) || 75 | !strings.Contains(curlCmdExecuted, url) { 76 | t.Fatal("Incomplete curl:", curlCmdExecuted) 77 | } else { 78 | t.Log("curlCmdExecuted: \n", curlCmdExecuted) 79 | } 80 | } 81 | 82 | func TestCurlCmdDebugMode(t *testing.T) { 83 | ts := createPostServer(t) 84 | defer ts.Close() 85 | 86 | c, logBuf := dcldb() 87 | c.EnableGenerateCurlCmd(). 88 | SetDebugLogCurlCmd(true) 89 | 90 | // Build request 91 | req := c.R(). 92 | SetBody(map[string]string{ 93 | "name": "Resty", 94 | }). 95 | SetCookies( 96 | []*http.Cookie{ 97 | {Name: "count", Value: "1"}, 98 | }, 99 | ). 100 | SetDebugLogCurlCmd(true) 101 | 102 | // Execute request: set debug mode 103 | url := ts.URL + "/curl-cmd-post" 104 | _, err := req.SetDebug(true).Post(url) 105 | if err != nil { 106 | t.Fatal(err) 107 | } 108 | 109 | c.DisableGenerateCurlCmd() 110 | req.DisableGenerateCurlCmd() 111 | 112 | // test logContent curl cmd 113 | logContent := logBuf.String() 114 | if !strings.Contains(logContent, "Cookie: count=1") || 115 | !strings.Contains(logContent, `-d '{"name":"Resty"}'`) { 116 | t.Fatal("Incomplete debug curl info:", logContent) 117 | } 118 | } 119 | 120 | func TestCurl_buildCurlCmd(t *testing.T) { 121 | tests := []struct { 122 | name string 123 | method string 124 | url string 125 | headers map[string]string 126 | body string 127 | cookies []*http.Cookie 128 | expected string 129 | }{ 130 | { 131 | name: "With Headers", 132 | method: "GET", 133 | url: "http://example.com", 134 | headers: map[string]string{"Content-Type": "application/json", "Authorization": "Bearer token"}, 135 | expected: "curl -X GET -H 'Authorization: Bearer token' -H 'Content-Type: application/json' http://example.com", 136 | }, 137 | { 138 | name: "With Body", 139 | method: "POST", 140 | url: "http://example.com", 141 | headers: map[string]string{"Content-Type": "application/json"}, 142 | body: `{"key":"value"}`, 143 | expected: "curl -X POST -H 'Content-Type: application/json' -d '{\"key\":\"value\"}' http://example.com", 144 | }, 145 | { 146 | name: "With Empty Body", 147 | method: "POST", 148 | url: "http://example.com", 149 | headers: map[string]string{"Content-Type": "application/json"}, 150 | expected: "curl -X POST -H 'Content-Type: application/json' http://example.com", 151 | }, 152 | { 153 | name: "With Query Params", 154 | method: "GET", 155 | url: "http://example.com?param1=value1¶m2=value2", 156 | expected: "curl -X GET 'http://example.com?param1=value1¶m2=value2'", 157 | }, 158 | { 159 | name: "With Special Characters in URL", 160 | method: "GET", 161 | url: "http://example.com/path with spaces", 162 | expected: "curl -X GET http://example.com/path%20with%20spaces", 163 | }, 164 | { 165 | name: "With Cookies", 166 | method: "GET", 167 | url: "http://example.com", 168 | cookies: []*http.Cookie{{Name: "session_id", Value: "abc123"}}, 169 | expected: "curl -X GET -H 'Cookie: session_id=abc123' http://example.com", 170 | }, 171 | { 172 | name: "Without Cookies", 173 | method: "GET", 174 | url: "http://example.com", 175 | expected: "curl -X GET http://example.com", 176 | }, 177 | { 178 | name: "With Multiple Cookies", 179 | method: "GET", 180 | url: "http://example.com", 181 | cookies: []*http.Cookie{{Name: "session_id", Value: "abc123"}, {Name: "user_id", Value: "user456"}}, 182 | expected: "curl -X GET -H 'Cookie: session_id=abc123&user_id=user456' http://example.com", 183 | }, 184 | { 185 | name: "With Empty Cookie Jar", 186 | method: "GET", 187 | url: "http://example.com", 188 | expected: "curl -X GET http://example.com", 189 | }, 190 | } 191 | for _, tt := range tests { 192 | t.Run(tt.name, func(t *testing.T) { 193 | c := dcnl() 194 | req := c.R().SetMethod(tt.method).SetURL(tt.url) 195 | 196 | if !isStringEmpty(tt.body) { 197 | req.SetBody(bytes.NewBufferString(tt.body)) 198 | } 199 | 200 | for k, v := range tt.headers { 201 | req.SetHeader(k, v) 202 | } 203 | 204 | err := createRawRequest(c, req) 205 | assertNil(t, err) 206 | 207 | if len(tt.cookies) > 0 { 208 | cookieJar, _ := cookiejar.New(nil) 209 | cookieJar.SetCookies(req.RawRequest.URL, tt.cookies) 210 | c.SetCookieJar(cookieJar) 211 | } 212 | 213 | curlCmd := buildCurlCmd(req) 214 | assertEqual(t, tt.expected, curlCmd) 215 | }) 216 | } 217 | } 218 | 219 | func TestCurlRequestGetBodyError(t *testing.T) { 220 | c := dcnl(). 221 | EnableDebug(). 222 | SetRequestMiddlewares( 223 | PrepareRequestMiddleware, 224 | func(_ *Client, r *Request) error { 225 | r.RawRequest.GetBody = func() (io.ReadCloser, error) { 226 | return nil, errors.New("test case error") 227 | } 228 | return nil 229 | }, 230 | ) 231 | 232 | req := c.R(). 233 | SetBody(map[string]string{ 234 | "name": "Resty", 235 | }). 236 | SetCookies( 237 | []*http.Cookie{ 238 | {Name: "count", Value: "1"}, 239 | }, 240 | ). 241 | SetMethod(MethodPost) 242 | 243 | assertEqual(t, "", req.CurlCmd()) 244 | 245 | curlCmdUnexecuted := req.EnableGenerateCurlCmd().CurlCmd() 246 | req.DisableGenerateCurlCmd() 247 | 248 | if !strings.Contains(curlCmdUnexecuted, "Cookie: count=1") || 249 | !strings.Contains(curlCmdUnexecuted, "curl -X POST") || 250 | !strings.Contains(curlCmdUnexecuted, `-d ''`) { 251 | t.Fatal("Incomplete curl:", curlCmdUnexecuted) 252 | } else { 253 | t.Log("curlCmdUnexecuted: \n", curlCmdUnexecuted) 254 | } 255 | } 256 | 257 | func TestCurlRequestMiddlewaresError(t *testing.T) { 258 | errMsg := "middleware error" 259 | c := dcnl().EnableDebug(). 260 | SetRequestMiddlewares( 261 | func(c *Client, r *Request) error { 262 | return errors.New(errMsg) 263 | }, 264 | PrepareRequestMiddleware, 265 | ) 266 | 267 | curlCmdUnexecuted := c.R().EnableGenerateCurlCmd().CurlCmd() 268 | assertEqual(t, "", curlCmdUnexecuted) 269 | } 270 | 271 | func TestCurlMiscTestCoverage(t *testing.T) { 272 | cookieStr := dumpCurlCookies([]*http.Cookie{ 273 | {Name: "count", Value: "1"}, 274 | }) 275 | assertEqual(t, "Cookie: count=1", cookieStr) 276 | } 277 | -------------------------------------------------------------------------------- /debug.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "fmt" 10 | "net/http" 11 | "time" 12 | ) 13 | 14 | type ( 15 | // DebugLogCallbackFunc function type is for request and response debug log callback purposes. 16 | // It gets called before Resty logs it 17 | DebugLogCallbackFunc func(*DebugLog) 18 | 19 | // DebugLogFormatterFunc function type is used to implement debug log formatting. 20 | // See out of the box [DebugLogStringFormatter], [DebugLogJSONFormatter] 21 | DebugLogFormatterFunc func(*DebugLog) string 22 | 23 | // DebugLog struct is used to collect details from Resty request and response 24 | // for debug logging callback purposes. 25 | DebugLog struct { 26 | Request *DebugLogRequest `json:"request"` 27 | Response *DebugLogResponse `json:"response"` 28 | TraceInfo *TraceInfo `json:"trace_info"` 29 | } 30 | 31 | // DebugLogRequest type used to capture debug info about the [Request]. 32 | DebugLogRequest struct { 33 | Host string `json:"host"` 34 | URI string `json:"uri"` 35 | Method string `json:"method"` 36 | Proto string `json:"proto"` 37 | Header http.Header `json:"header"` 38 | CurlCmd string `json:"curl_cmd"` 39 | RetryTraceID string `json:"retry_trace_id"` 40 | Attempt int `json:"attempt"` 41 | Body string `json:"body"` 42 | } 43 | 44 | // DebugLogResponse type used to capture debug info about the [Response]. 45 | DebugLogResponse struct { 46 | StatusCode int `json:"status_code"` 47 | Status string `json:"status"` 48 | Proto string `json:"proto"` 49 | ReceivedAt time.Time `json:"received_at"` 50 | Duration time.Duration `json:"duration"` 51 | Size int64 `json:"size"` 52 | Header http.Header `json:"header"` 53 | Body string `json:"body"` 54 | } 55 | ) 56 | 57 | // DebugLogFormatter function formats the given debug log info in human readable 58 | // format. 59 | // 60 | // This is the default debug log formatter in the Resty. 61 | func DebugLogFormatter(dl *DebugLog) string { 62 | debugLog := "\n==============================================================================\n" 63 | 64 | req := dl.Request 65 | if len(req.CurlCmd) > 0 { 66 | debugLog += "~~~ REQUEST(CURL) ~~~\n" + 67 | fmt.Sprintf(" %v\n", req.CurlCmd) 68 | } 69 | debugLog += "~~~ REQUEST ~~~\n" + 70 | fmt.Sprintf("%s %s %s\n", req.Method, req.URI, req.Proto) + 71 | fmt.Sprintf("HOST : %s\n", req.Host) + 72 | fmt.Sprintf("HEADERS:\n%s\n", composeHeaders(req.Header)) + 73 | fmt.Sprintf("BODY :\n%v\n", req.Body) + 74 | "------------------------------------------------------------------------------\n" 75 | if len(req.RetryTraceID) > 0 { 76 | debugLog += fmt.Sprintf("RETRY TRACE ID: %s\n", req.RetryTraceID) + 77 | fmt.Sprintf("ATTEMPT : %d\n", req.Attempt) + 78 | "------------------------------------------------------------------------------\n" 79 | } 80 | 81 | res := dl.Response 82 | debugLog += "~~~ RESPONSE ~~~\n" + 83 | fmt.Sprintf("STATUS : %s\n", res.Status) + 84 | fmt.Sprintf("PROTO : %s\n", res.Proto) + 85 | fmt.Sprintf("RECEIVED AT : %v\n", res.ReceivedAt.Format(time.RFC3339Nano)) + 86 | fmt.Sprintf("DURATION : %v\n", res.Duration) + 87 | "HEADERS :\n" + 88 | composeHeaders(res.Header) + "\n" + 89 | fmt.Sprintf("BODY :\n%v\n", res.Body) 90 | if dl.TraceInfo != nil { 91 | debugLog += "------------------------------------------------------------------------------\n" 92 | debugLog += fmt.Sprintf("%v\n", dl.TraceInfo) 93 | } 94 | debugLog += "==============================================================================\n" 95 | 96 | return debugLog 97 | } 98 | 99 | // DebugLogJSONFormatter function formats the given debug log info in JSON format. 100 | func DebugLogJSONFormatter(dl *DebugLog) string { 101 | return toJSON(dl) 102 | } 103 | 104 | func debugLogger(c *Client, res *Response) { 105 | req := res.Request 106 | if !req.Debug { 107 | return 108 | } 109 | 110 | rdl := &DebugLogResponse{ 111 | StatusCode: res.StatusCode(), 112 | Status: res.Status(), 113 | Proto: res.Proto(), 114 | ReceivedAt: res.ReceivedAt(), 115 | Duration: res.Duration(), 116 | Size: res.Size(), 117 | Header: sanitizeHeaders(res.Header().Clone()), 118 | Body: res.fmtBodyString(res.Request.DebugBodyLimit), 119 | } 120 | 121 | dl := &DebugLog{ 122 | Request: req.values[debugRequestLogKey].(*DebugLogRequest), 123 | Response: rdl, 124 | } 125 | 126 | if res.Request.IsTrace { 127 | ti := req.TraceInfo() 128 | dl.TraceInfo = &ti 129 | } 130 | 131 | dblCallback := c.debugLogCallbackFunc() 132 | if dblCallback != nil { 133 | dblCallback(dl) 134 | } 135 | 136 | formatterFunc := c.debugLogFormatterFunc() 137 | if formatterFunc != nil { 138 | debugLog := formatterFunc(dl) 139 | req.log.Debugf("%s", debugLog) 140 | } 141 | } 142 | 143 | const debugRequestLogKey = "__restyDebugRequestLog" 144 | 145 | func prepareRequestDebugInfo(c *Client, r *Request) { 146 | if !r.Debug { 147 | return 148 | } 149 | 150 | rr := r.RawRequest 151 | rh := rr.Header.Clone() 152 | if c.Client().Jar != nil { 153 | for _, cookie := range c.Client().Jar.Cookies(r.RawRequest.URL) { 154 | s := fmt.Sprintf("%s=%s", cookie.Name, cookie.Value) 155 | if c := rh.Get(hdrCookieKey); isStringEmpty(c) { 156 | rh.Set(hdrCookieKey, s) 157 | } else { 158 | rh.Set(hdrCookieKey, c+"; "+s) 159 | } 160 | } 161 | } 162 | 163 | rdl := &DebugLogRequest{ 164 | Host: rr.URL.Host, 165 | URI: rr.URL.RequestURI(), 166 | Method: r.Method, 167 | Proto: rr.Proto, 168 | Header: sanitizeHeaders(rh), 169 | Body: r.fmtBodyString(r.DebugBodyLimit), 170 | } 171 | if r.generateCurlCmd && r.debugLogCurlCmd { 172 | rdl.CurlCmd = r.resultCurlCmd 173 | } 174 | if len(r.RetryTraceID) > 0 { 175 | rdl.Attempt = r.Attempt 176 | rdl.RetryTraceID = r.RetryTraceID 177 | } 178 | 179 | r.initValuesMap() 180 | r.values[debugRequestLogKey] = rdl 181 | } 182 | -------------------------------------------------------------------------------- /digest.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // 2023 Segev Dagan (https://github.com/segevda) 3 | // 2024 Philipp Wolfer (https://github.com/phw) 4 | // resty source code and usage is governed by a MIT style 5 | // license that can be found in the LICENSE file. 6 | // SPDX-License-Identifier: MIT 7 | 8 | package resty 9 | 10 | import ( 11 | "bytes" 12 | "crypto/md5" 13 | "crypto/rand" 14 | "crypto/sha256" 15 | "crypto/sha512" 16 | "encoding/hex" 17 | "errors" 18 | "fmt" 19 | "hash" 20 | "io" 21 | "net/http" 22 | "strconv" 23 | "strings" 24 | ) 25 | 26 | var ( 27 | ErrDigestBadChallenge = errors.New("resty: digest: challenge is bad") 28 | ErrDigestInvalidCharset = errors.New("resty: digest: invalid charset") 29 | ErrDigestAlgNotSupported = errors.New("resty: digest: algorithm is not supported") 30 | ErrDigestQopNotSupported = errors.New("resty: digest: qop is not supported") 31 | ) 32 | 33 | // Reference: https://datatracker.ietf.org/doc/html/rfc7616#section-6.1 34 | var digestHashFuncs = map[string]func() hash.Hash{ 35 | "": md5.New, 36 | "MD5": md5.New, 37 | "MD5-sess": md5.New, 38 | "SHA-256": sha256.New, 39 | "SHA-256-sess": sha256.New, 40 | "SHA-512": sha512.New, 41 | "SHA-512-sess": sha512.New, 42 | "SHA-512-256": sha512.New512_256, 43 | "SHA-512-256-sess": sha512.New512_256, 44 | } 45 | 46 | const ( 47 | qopAuth = "auth" 48 | qopAuthInt = "auth-int" 49 | ) 50 | 51 | type digestTransport struct { 52 | *credentials 53 | transport http.RoundTripper 54 | } 55 | 56 | func (dt *digestTransport) RoundTrip(req *http.Request) (*http.Response, error) { 57 | // first request without body for all HTTP verbs 58 | req1 := dt.cloneReq(req, true) 59 | 60 | // make a request to get the 401 that contains the challenge. 61 | res, err := dt.transport.RoundTrip(req1) 62 | if err != nil { 63 | return nil, err 64 | } 65 | if res.StatusCode != http.StatusUnauthorized { 66 | return res, nil 67 | } 68 | _, _ = ioCopy(io.Discard, res.Body) 69 | closeq(res.Body) 70 | 71 | chaHdrValue := strings.TrimSpace(res.Header.Get(hdrWwwAuthenticateKey)) 72 | if chaHdrValue == "" { 73 | return nil, ErrDigestBadChallenge 74 | } 75 | 76 | cha, err := dt.parseChallenge(chaHdrValue) 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | // prepare second request 82 | req2 := dt.cloneReq(req, false) 83 | cred, err := dt.createCredentials(cha, req2) 84 | if err != nil { 85 | return nil, err 86 | } 87 | 88 | auth, err := cred.digest(cha) 89 | if err != nil { 90 | return nil, err 91 | } 92 | 93 | req2.Header.Set(hdrAuthorizationKey, auth) 94 | return dt.transport.RoundTrip(req2) 95 | } 96 | 97 | func (dt *digestTransport) cloneReq(r *http.Request, first bool) *http.Request { 98 | r1 := r.Clone(r.Context()) 99 | if first { 100 | r1.Body = http.NoBody 101 | r1.ContentLength = 0 102 | r1.GetBody = nil 103 | } 104 | return r1 105 | } 106 | 107 | func (dt *digestTransport) parseChallenge(input string) (*digestChallenge, error) { 108 | const ws = " \n\r\t" 109 | s := strings.Trim(input, ws) 110 | if !strings.HasPrefix(s, "Digest ") { 111 | return nil, ErrDigestBadChallenge 112 | } 113 | 114 | s = strings.Trim(s[7:], ws) 115 | c := &digestChallenge{} 116 | b := strings.Builder{} 117 | key := "" 118 | quoted := false 119 | for _, r := range s { 120 | switch r { 121 | case '"': 122 | quoted = !quoted 123 | case ',': 124 | if quoted { 125 | b.WriteRune(r) 126 | } else { 127 | val := strings.Trim(b.String(), ws) 128 | b.Reset() 129 | if err := c.setValue(key, val); err != nil { 130 | return nil, err 131 | } 132 | key = "" 133 | } 134 | case '=': 135 | if quoted { 136 | b.WriteRune(r) 137 | } else { 138 | key = strings.Trim(b.String(), ws) 139 | b.Reset() 140 | } 141 | default: 142 | b.WriteRune(r) 143 | } 144 | } 145 | 146 | key = strings.TrimSpace(key) 147 | if quoted || (key == "" && b.Len() > 0) { 148 | return nil, ErrDigestBadChallenge 149 | } 150 | 151 | if key != "" { 152 | val := strings.Trim(b.String(), ws) 153 | if err := c.setValue(key, val); err != nil { 154 | return nil, err 155 | } 156 | } 157 | 158 | return c, nil 159 | } 160 | 161 | func (dt *digestTransport) createCredentials(cha *digestChallenge, req *http.Request) (*digestCredentials, error) { 162 | cred := &digestCredentials{ 163 | username: dt.Username, 164 | password: dt.Password, 165 | uri: req.URL.RequestURI(), 166 | method: req.Method, 167 | realm: cha.realm, 168 | nonce: cha.nonce, 169 | nc: cha.nc, 170 | algorithm: cha.algorithm, 171 | sessAlgorithm: strings.HasSuffix(cha.algorithm, "-sess"), 172 | opaque: cha.opaque, 173 | userHash: cha.userHash, 174 | } 175 | 176 | if cha.isQopSupported(qopAuthInt) { 177 | if err := dt.prepareBody(req); err != nil { 178 | return nil, fmt.Errorf("resty: digest: failed to prepare body for auth-int: %w", err) 179 | } 180 | body, err := req.GetBody() 181 | if err != nil { 182 | return nil, fmt.Errorf("resty: digest: failed to get body for auth-int: %w", err) 183 | } 184 | if body != http.NoBody { 185 | defer closeq(body) 186 | h := newHashFunc(cha.algorithm) 187 | if _, err := ioCopy(h, body); err != nil { 188 | return nil, err 189 | } 190 | cred.bodyHash = hex.EncodeToString(h.Sum(nil)) 191 | } 192 | } 193 | 194 | return cred, nil 195 | } 196 | 197 | func (dt *digestTransport) prepareBody(req *http.Request) error { 198 | if req.GetBody != nil { 199 | return nil 200 | } 201 | 202 | if req.Body == nil || req.Body == http.NoBody { 203 | req.GetBody = func() (io.ReadCloser, error) { 204 | return http.NoBody, nil 205 | } 206 | return nil 207 | } 208 | 209 | b, err := ioReadAll(req.Body) 210 | if err != nil { 211 | return err 212 | } 213 | closeq(req.Body) 214 | req.Body = io.NopCloser(bytes.NewReader(b)) 215 | req.GetBody = func() (io.ReadCloser, error) { 216 | return io.NopCloser(bytes.NewReader(b)), nil 217 | } 218 | 219 | return nil 220 | } 221 | 222 | type digestChallenge struct { 223 | realm string 224 | domain string 225 | nonce string 226 | opaque string 227 | stale string 228 | algorithm string 229 | qop []string 230 | nc int 231 | userHash string 232 | } 233 | 234 | func (dc *digestChallenge) isQopSupported(qop string) bool { 235 | for _, v := range dc.qop { 236 | if v == qop { 237 | return true 238 | } 239 | } 240 | return false 241 | } 242 | 243 | func (dc *digestChallenge) setValue(k, v string) error { 244 | switch k { 245 | case "realm": 246 | dc.realm = v 247 | case "domain": 248 | dc.domain = v 249 | case "nonce": 250 | dc.nonce = v 251 | case "opaque": 252 | dc.opaque = v 253 | case "stale": 254 | dc.stale = v 255 | case "algorithm": 256 | dc.algorithm = v 257 | case "qop": 258 | if !isStringEmpty(v) { 259 | dc.qop = strings.Split(v, ",") 260 | } 261 | case "charset": 262 | if strings.ToUpper(v) != "UTF-8" { 263 | return ErrDigestInvalidCharset 264 | } 265 | case "nc": 266 | nc, err := strconv.ParseInt(v, 16, 32) 267 | if err != nil { 268 | return fmt.Errorf("resty: digest: invalid nc: %w", err) 269 | } 270 | dc.nc = int(nc) 271 | case "userhash": 272 | dc.userHash = v 273 | default: 274 | return ErrDigestBadChallenge 275 | } 276 | return nil 277 | } 278 | 279 | type digestCredentials struct { 280 | username string 281 | password string 282 | userHash string 283 | method string 284 | uri string 285 | realm string 286 | nonce string 287 | algorithm string 288 | sessAlgorithm bool 289 | cnonce string 290 | opaque string 291 | qop string 292 | nc int 293 | response string 294 | bodyHash string 295 | } 296 | 297 | func (dc *digestCredentials) parseQop(cha *digestChallenge) error { 298 | if len(cha.qop) == 0 { 299 | return nil 300 | } 301 | 302 | if cha.isQopSupported(qopAuth) { 303 | dc.qop = qopAuth 304 | return nil 305 | } 306 | 307 | if cha.isQopSupported(qopAuthInt) { 308 | dc.qop = qopAuthInt 309 | return nil 310 | } 311 | 312 | return ErrDigestQopNotSupported 313 | } 314 | 315 | func (dc *digestCredentials) h(data string) string { 316 | h := newHashFunc(dc.algorithm) 317 | _, _ = h.Write([]byte(data)) 318 | return hex.EncodeToString(h.Sum(nil)) 319 | } 320 | 321 | func (dc *digestCredentials) digest(cha *digestChallenge) (string, error) { 322 | if _, ok := digestHashFuncs[dc.algorithm]; !ok { 323 | return "", ErrDigestAlgNotSupported 324 | } 325 | 326 | if err := dc.parseQop(cha); err != nil { 327 | return "", err 328 | } 329 | 330 | dc.nc++ 331 | 332 | b := make([]byte, 16) 333 | _, _ = io.ReadFull(rand.Reader, b) 334 | dc.cnonce = hex.EncodeToString(b) 335 | 336 | ha1 := dc.ha1() 337 | ha2 := dc.ha2() 338 | 339 | var resp string 340 | switch dc.qop { 341 | case "": 342 | resp = fmt.Sprintf("%s:%s:%s", ha1, dc.nonce, ha2) 343 | case qopAuth, qopAuthInt: 344 | resp = fmt.Sprintf("%s:%s:%08x:%s:%s:%s", 345 | ha1, dc.nonce, dc.nc, dc.cnonce, dc.qop, ha2) 346 | } 347 | dc.response = dc.h(resp) 348 | 349 | return "Digest " + dc.String(), nil 350 | } 351 | 352 | // https://datatracker.ietf.org/doc/html/rfc7616#section-3.4.2 353 | func (dc *digestCredentials) ha1() string { 354 | a1 := dc.h(fmt.Sprintf("%s:%s:%s", dc.username, dc.realm, dc.password)) 355 | if dc.sessAlgorithm { 356 | return dc.h(fmt.Sprintf("%s:%s:%s", a1, dc.nonce, dc.cnonce)) 357 | } 358 | return a1 359 | } 360 | 361 | // https://datatracker.ietf.org/doc/html/rfc7616#section-3.4.3 362 | func (dc *digestCredentials) ha2() string { 363 | if dc.qop == qopAuthInt { 364 | return dc.h(fmt.Sprintf("%s:%s:%s", dc.method, dc.uri, dc.bodyHash)) 365 | } 366 | return dc.h(fmt.Sprintf("%s:%s", dc.method, dc.uri)) 367 | } 368 | 369 | func (dc *digestCredentials) String() string { 370 | sl := make([]string, 0, 10) 371 | // https://datatracker.ietf.org/doc/html/rfc7616#section-3.4.4 372 | if dc.userHash == "true" { 373 | dc.username = dc.h(fmt.Sprintf("%s:%s", dc.username, dc.realm)) 374 | } 375 | sl = append(sl, fmt.Sprintf(`username="%s"`, dc.username)) 376 | sl = append(sl, fmt.Sprintf(`realm="%s"`, dc.realm)) 377 | sl = append(sl, fmt.Sprintf(`nonce="%s"`, dc.nonce)) 378 | sl = append(sl, fmt.Sprintf(`uri="%s"`, dc.uri)) 379 | if dc.algorithm != "" { 380 | sl = append(sl, fmt.Sprintf(`algorithm=%s`, dc.algorithm)) 381 | } 382 | if dc.opaque != "" { 383 | sl = append(sl, fmt.Sprintf(`opaque="%s"`, dc.opaque)) 384 | } 385 | if dc.qop != "" { 386 | sl = append(sl, fmt.Sprintf("qop=%s", dc.qop)) 387 | sl = append(sl, fmt.Sprintf("nc=%08x", dc.nc)) 388 | sl = append(sl, fmt.Sprintf(`cnonce="%s"`, dc.cnonce)) 389 | } 390 | sl = append(sl, fmt.Sprintf(`userhash=%s`, dc.userHash)) 391 | sl = append(sl, fmt.Sprintf(`response="%s"`, dc.response)) 392 | 393 | return strings.Join(sl, ", ") 394 | } 395 | 396 | func newHashFunc(algorithm string) hash.Hash { 397 | hf := digestHashFuncs[algorithm] 398 | h := hf() 399 | h.Reset() 400 | return h 401 | } 402 | -------------------------------------------------------------------------------- /digest_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "errors" 10 | "io" 11 | "net/http" 12 | "strings" 13 | "testing" 14 | ) 15 | 16 | type digestServerConfig struct { 17 | realm, qop, nonce, opaque, algo, uri, charset, username, password, nc string 18 | } 19 | 20 | func defaultDigestServerConf() *digestServerConfig { 21 | return &digestServerConfig{ 22 | realm: "testrealm@host.com", 23 | qop: "auth", 24 | nonce: "dcd98b7102dd2f0e8b11d0f600bfb0c093", 25 | opaque: "5ccc069c403ebaf9f0171e9517f40e41", 26 | algo: "MD5", 27 | uri: "/dir/index.html", 28 | charset: "utf-8", 29 | username: "Mufasa", 30 | password: "Circle Of Life", 31 | nc: "00000001", 32 | } 33 | } 34 | 35 | func TestClientDigestAuth(t *testing.T) { 36 | conf := *defaultDigestServerConf() 37 | ts := createDigestServer(t, &conf) 38 | defer ts.Close() 39 | 40 | c := dcnl(). 41 | SetBaseURL(ts.URL+"/"). 42 | SetDigestAuth(conf.username, conf.password) 43 | 44 | resp, err := c.R(). 45 | SetResult(&AuthSuccess{}). 46 | Get(conf.uri) 47 | assertError(t, err) 48 | assertEqual(t, http.StatusOK, resp.StatusCode()) 49 | } 50 | 51 | func TestClientDigestAuthSession(t *testing.T) { 52 | conf := *defaultDigestServerConf() 53 | conf.algo = "MD5-sess" 54 | conf.qop = "auth, auth-int" 55 | ts := createDigestServer(t, &conf) 56 | defer ts.Close() 57 | 58 | c := dcnl(). 59 | SetBaseURL(ts.URL+"/"). 60 | SetDigestAuth(conf.username, conf.password) 61 | 62 | resp, err := c.R(). 63 | SetResult(&AuthSuccess{}). 64 | Get(conf.uri) 65 | assertError(t, err) 66 | assertEqual(t, http.StatusOK, resp.StatusCode()) 67 | } 68 | 69 | func TestClientDigestAuthErrors(t *testing.T) { 70 | type test struct { 71 | mutateConf func(*digestServerConfig) 72 | expect error 73 | } 74 | tests := []test{ 75 | {mutateConf: func(c *digestServerConfig) { c.algo = "BAD_ALGO" }, expect: ErrDigestAlgNotSupported}, 76 | {mutateConf: func(c *digestServerConfig) { c.qop = "bad-qop" }, expect: ErrDigestQopNotSupported}, 77 | {mutateConf: func(c *digestServerConfig) { c.charset = "utf-16" }, expect: ErrDigestInvalidCharset}, 78 | {mutateConf: func(c *digestServerConfig) { c.uri = "/bad" }, expect: ErrDigestBadChallenge}, 79 | {mutateConf: func(c *digestServerConfig) { c.uri = "/unknown_param" }, expect: ErrDigestBadChallenge}, 80 | {mutateConf: func(c *digestServerConfig) { c.uri = "/missing_value" }, expect: ErrDigestBadChallenge}, 81 | {mutateConf: func(c *digestServerConfig) { c.uri = "/unclosed_quote" }, expect: ErrDigestBadChallenge}, 82 | {mutateConf: func(c *digestServerConfig) { c.uri = "/no_challenge" }, expect: ErrDigestBadChallenge}, 83 | {mutateConf: func(c *digestServerConfig) { c.uri = "/status_500" }, expect: nil}, 84 | } 85 | 86 | for _, tc := range tests { 87 | conf := *defaultDigestServerConf() 88 | tc.mutateConf(&conf) 89 | ts := createDigestServer(t, &conf) 90 | 91 | c := dcnl(). 92 | SetBaseURL(ts.URL+"/"). 93 | SetDigestAuth(conf.username, conf.password) 94 | 95 | _, err := c.R().Get(conf.uri) 96 | assertErrorIs(t, tc.expect, err) 97 | ts.Close() 98 | } 99 | } 100 | 101 | func TestClientDigestAuthWithBody(t *testing.T) { 102 | conf := *defaultDigestServerConf() 103 | ts := createDigestServer(t, &conf) 104 | defer ts.Close() 105 | 106 | c := dcnl().SetDigestAuth(conf.username, conf.password) 107 | 108 | resp, err := c.R(). 109 | SetResult(&AuthSuccess{}). 110 | SetHeader(hdrContentTypeKey, "application/json"). 111 | SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). 112 | Post(ts.URL + conf.uri) 113 | 114 | assertError(t, err) 115 | assertEqual(t, http.StatusOK, resp.StatusCode()) 116 | } 117 | 118 | func TestClientDigestAuthWithBodyQopAuthInt(t *testing.T) { 119 | conf := *defaultDigestServerConf() 120 | conf.qop = "auth-int" 121 | ts := createDigestServer(t, &conf) 122 | defer ts.Close() 123 | 124 | c := dcnl().SetDigestAuth(conf.username, conf.password) 125 | 126 | resp, err := c.R(). 127 | SetResult(&AuthSuccess{}). 128 | SetHeader(hdrContentTypeKey, "application/json"). 129 | SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). 130 | Post(ts.URL + conf.uri) 131 | 132 | assertError(t, err) 133 | assertEqual(t, http.StatusOK, resp.StatusCode()) 134 | } 135 | 136 | func TestClientDigestAuthWithBodyQopAuthIntIoCopyError(t *testing.T) { 137 | conf := *defaultDigestServerConf() 138 | conf.qop = "auth-int" 139 | ts := createDigestServer(t, &conf) 140 | defer ts.Close() 141 | 142 | c := dcnl().SetDigestAuth(conf.username, conf.password) 143 | 144 | errCopyMsg := "test copy error" 145 | ioCopy = func(dst io.Writer, src io.Reader) (written int64, err error) { 146 | return 0, errors.New(errCopyMsg) 147 | } 148 | t.Cleanup(func() { 149 | ioCopy = io.Copy 150 | }) 151 | 152 | resp, err := c.R(). 153 | SetResult(&AuthSuccess{}). 154 | SetHeader(hdrContentTypeKey, "application/json"). 155 | SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). 156 | Post(ts.URL + conf.uri) 157 | 158 | assertNotNil(t, err) 159 | assertEqual(t, true, strings.Contains(err.Error(), errCopyMsg)) 160 | assertEqual(t, 0, resp.StatusCode()) 161 | } 162 | 163 | func TestClientDigestAuthRoundTripError(t *testing.T) { 164 | conf := *defaultDigestServerConf() 165 | ts := createDigestServer(t, &conf) 166 | defer ts.Close() 167 | 168 | c := dcnl().SetTransport(&CustomRoundTripper2{returnErr: true}) 169 | c.SetDigestAuth(conf.username, conf.password) 170 | 171 | _, err := c.R(). 172 | SetResult(&AuthSuccess{}). 173 | SetHeader(hdrContentTypeKey, "application/json"). 174 | SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). 175 | Post(ts.URL + conf.uri) 176 | 177 | assertNotNil(t, err) 178 | assertEqual(t, true, strings.Contains(err.Error(), "test req mock error")) 179 | } 180 | 181 | func TestClientDigestAuthWithBodyQopAuthIntGetBodyNil(t *testing.T) { 182 | conf := *defaultDigestServerConf() 183 | conf.qop = "auth-int" 184 | ts := createDigestServer(t, &conf) 185 | defer ts.Close() 186 | 187 | c := dcnl().SetDigestAuth(conf.username, conf.password) 188 | c.SetRequestMiddlewares( 189 | PrepareRequestMiddleware, 190 | func(c *Client, r *Request) error { 191 | r.RawRequest.GetBody = nil 192 | return nil 193 | }, 194 | ) 195 | 196 | resp, err := c.R(). 197 | SetResult(&AuthSuccess{}). 198 | SetHeader(hdrContentTypeKey, "application/json"). 199 | SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). 200 | Post(ts.URL + conf.uri) 201 | 202 | assertError(t, err) 203 | assertEqual(t, http.StatusOK, resp.StatusCode()) 204 | } 205 | 206 | func TestClientDigestAuthWithGetBodyError(t *testing.T) { 207 | conf := *defaultDigestServerConf() 208 | conf.qop = "auth-int" 209 | ts := createDigestServer(t, &conf) 210 | defer ts.Close() 211 | 212 | c := dcnl().SetDigestAuth(conf.username, conf.password) 213 | c.SetRequestMiddlewares( 214 | PrepareRequestMiddleware, 215 | func(c *Client, r *Request) error { 216 | r.RawRequest.GetBody = func() (_ io.ReadCloser, _ error) { 217 | return nil, errors.New("get body test error") 218 | } 219 | return nil 220 | }, 221 | ) 222 | 223 | resp, err := c.R(). 224 | SetResult(&AuthSuccess{}). 225 | SetHeader(hdrContentTypeKey, "application/json"). 226 | SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). 227 | Post(ts.URL + conf.uri) 228 | 229 | assertNotNil(t, err) 230 | assertEqual(t, true, strings.Contains(err.Error(), "resty: digest: failed to get body for auth-int: get body test error")) 231 | assertEqual(t, 0, resp.StatusCode()) 232 | } 233 | 234 | func TestClientDigestAuthWithGetBodyNilReadError(t *testing.T) { 235 | conf := *defaultDigestServerConf() 236 | conf.qop = "auth-int" 237 | ts := createDigestServer(t, &conf) 238 | defer ts.Close() 239 | 240 | c := dcnl().SetDigestAuth(conf.username, conf.password) 241 | c.SetRequestMiddlewares( 242 | PrepareRequestMiddleware, 243 | func(c *Client, r *Request) error { 244 | r.RawRequest.GetBody = nil 245 | return nil 246 | }, 247 | ) 248 | 249 | resp, err := c.R(). 250 | SetResult(&AuthSuccess{}). 251 | SetHeader(hdrContentTypeKey, "application/json"). 252 | SetBody(&brokenReadCloser{}). 253 | Post(ts.URL + conf.uri) 254 | 255 | assertNotNil(t, err) 256 | assertEqual(t, true, strings.Contains(err.Error(), "resty: digest: failed to prepare body for auth-int: read error")) 257 | assertEqual(t, 0, resp.StatusCode()) 258 | } 259 | 260 | func TestClientDigestAuthWithNoBodyQopAuthInt(t *testing.T) { 261 | conf := *defaultDigestServerConf() 262 | conf.qop = "auth-int" 263 | ts := createDigestServer(t, &conf) 264 | defer ts.Close() 265 | 266 | c := dcnl().SetDigestAuth(conf.username, conf.password) 267 | 268 | resp, err := c.R().Get(ts.URL + conf.uri) 269 | 270 | assertError(t, err) 271 | assertEqual(t, http.StatusOK, resp.StatusCode()) 272 | } 273 | 274 | func TestClientDigestAuthNoQop(t *testing.T) { 275 | conf := *defaultDigestServerConf() 276 | conf.qop = "" 277 | 278 | ts := createDigestServer(t, &conf) 279 | defer ts.Close() 280 | 281 | c := dcnl().SetDigestAuth(conf.username, conf.password) 282 | 283 | resp, err := c.R(). 284 | SetResult(&AuthSuccess{}). 285 | SetHeader(hdrContentTypeKey, "application/json"). 286 | SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). 287 | Post(ts.URL + conf.uri) 288 | 289 | assertNil(t, err) 290 | assertEqual(t, "200 OK", resp.Status()) 291 | } 292 | 293 | func TestClientDigestAuthWithIncorrectNcValue(t *testing.T) { 294 | conf := *defaultDigestServerConf() 295 | conf.nc = "1234567890" 296 | 297 | ts := createDigestServer(t, &conf) 298 | defer ts.Close() 299 | 300 | c := dcnl().SetDigestAuth(conf.username, conf.password) 301 | 302 | resp, err := c.R(). 303 | SetResult(&AuthSuccess{}). 304 | SetHeader(hdrContentTypeKey, "application/json"). 305 | SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). 306 | Post(ts.URL + conf.uri) 307 | 308 | assertNotNil(t, err) 309 | assertEqual(t, true, strings.Contains(err.Error(), `parsing "1234567890": value out of range`)) 310 | assertEqual(t, "", resp.Status()) 311 | } 312 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module resty.dev/v3 2 | 3 | go 1.21 4 | 5 | require golang.org/x/net v0.33.0 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= 2 | golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= 3 | -------------------------------------------------------------------------------- /load_balancer.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "net" 12 | "net/url" 13 | "strings" 14 | "sync" 15 | "time" 16 | ) 17 | 18 | // LoadBalancer is the interface that wraps the HTTP client load-balancing 19 | // algorithm that returns the "Next" Base URL for the request to target 20 | type LoadBalancer interface { 21 | Next() (string, error) 22 | Feedback(*RequestFeedback) 23 | Close() error 24 | } 25 | 26 | // RequestFeedback struct is used to send the request feedback to load balancing 27 | // algorithm 28 | type RequestFeedback struct { 29 | BaseURL string 30 | Success bool 31 | Attempt int 32 | } 33 | 34 | // NewRoundRobin method creates the new Round-Robin(RR) request load balancer 35 | // instance with given base URLs 36 | func NewRoundRobin(baseURLs ...string) (*RoundRobin, error) { 37 | rr := &RoundRobin{lock: new(sync.Mutex)} 38 | if err := rr.Refresh(baseURLs...); err != nil { 39 | return rr, err 40 | } 41 | return rr, nil 42 | } 43 | 44 | var _ LoadBalancer = (*RoundRobin)(nil) 45 | 46 | // RoundRobin struct used to implement the Round-Robin(RR) request 47 | // load balancer algorithm 48 | type RoundRobin struct { 49 | lock *sync.Mutex 50 | baseURLs []string 51 | current int 52 | } 53 | 54 | // Next method returns the next Base URL based on the Round-Robin(RR) algorithm 55 | func (rr *RoundRobin) Next() (string, error) { 56 | rr.lock.Lock() 57 | defer rr.lock.Unlock() 58 | 59 | baseURL := rr.baseURLs[rr.current] 60 | rr.current = (rr.current + 1) % len(rr.baseURLs) 61 | return baseURL, nil 62 | } 63 | 64 | // Feedback method does nothing in Round-Robin(RR) request load balancer 65 | func (rr *RoundRobin) Feedback(_ *RequestFeedback) {} 66 | 67 | // Close method does nothing in Round-Robin(RR) request load balancer 68 | func (rr *RoundRobin) Close() error { return nil } 69 | 70 | // Refresh method reset the existing Base URLs with the given Base URLs slice to refresh it 71 | func (rr *RoundRobin) Refresh(baseURLs ...string) error { 72 | rr.lock.Lock() 73 | defer rr.lock.Unlock() 74 | result := make([]string, 0) 75 | for _, u := range baseURLs { 76 | baseURL, err := extractBaseURL(u) 77 | if err != nil { 78 | return err 79 | } 80 | result = append(result, baseURL) 81 | } 82 | 83 | // after processing, assign the updates 84 | rr.baseURLs = result 85 | return nil 86 | } 87 | 88 | // Host struct used to represent the host information and its weight 89 | // to load balance the requests 90 | type Host struct { 91 | // BaseURL represents the targeted host base URL 92 | // https://resty.dev 93 | BaseURL string 94 | 95 | // Weight represents the host weight to determine 96 | // the percentage of requests to send 97 | Weight int 98 | 99 | // MaxFailures represents the value to mark the host as 100 | // not usable until it reaches the Recovery duration 101 | // Default value is 5 102 | MaxFailures int 103 | 104 | state HostState 105 | currentWeight int 106 | failedRequests int 107 | } 108 | 109 | func (h *Host) addWeight() { 110 | h.currentWeight += h.Weight 111 | } 112 | 113 | func (h *Host) resetWeight(totalWeight int) { 114 | h.currentWeight -= totalWeight 115 | } 116 | 117 | type HostState int 118 | 119 | // Host transition states 120 | const ( 121 | HostStateInActive HostState = iota 122 | HostStateActive 123 | ) 124 | 125 | // HostStateChangeFunc type provides feedback on host state transitions 126 | type HostStateChangeFunc func(baseURL string, from, to HostState) 127 | 128 | // ErrNoActiveHost error returned when all hosts are inactive on the load balancer 129 | var ErrNoActiveHost = errors.New("resty: no active host") 130 | 131 | // NewWeightedRoundRobin method creates the new Weighted Round-Robin(WRR) 132 | // request load balancer instance with given recovery duration and hosts slice 133 | func NewWeightedRoundRobin(recovery time.Duration, hosts ...*Host) (*WeightedRoundRobin, error) { 134 | if recovery == 0 { 135 | recovery = 120 * time.Second // defaults to 120 seconds 136 | } 137 | wrr := &WeightedRoundRobin{ 138 | lock: new(sync.Mutex), 139 | hosts: make([]*Host, 0), 140 | tick: time.NewTicker(recovery), 141 | recovery: recovery, 142 | } 143 | 144 | err := wrr.Refresh(hosts...) 145 | 146 | go wrr.ticker() 147 | 148 | return wrr, err 149 | } 150 | 151 | var _ LoadBalancer = (*WeightedRoundRobin)(nil) 152 | 153 | // WeightedRoundRobin struct used to represent the host details for 154 | // Weighted Round-Robin(WRR) algorithm implementation 155 | type WeightedRoundRobin struct { 156 | lock *sync.Mutex 157 | hosts []*Host 158 | totalWeight int 159 | tick *time.Ticker 160 | onStateChange HostStateChangeFunc 161 | 162 | // Recovery duration is used to set the timer to put 163 | // the host back in the pool for the next turn and 164 | // reset the failed request count for the segment 165 | recovery time.Duration 166 | } 167 | 168 | // Next method returns the next Base URL based on Weighted Round-Robin(WRR) 169 | func (wrr *WeightedRoundRobin) Next() (string, error) { 170 | wrr.lock.Lock() 171 | defer wrr.lock.Unlock() 172 | 173 | var best *Host 174 | total := 0 175 | for _, h := range wrr.hosts { 176 | if h.state == HostStateInActive { 177 | continue 178 | } 179 | 180 | h.addWeight() 181 | total += h.Weight 182 | 183 | if best == nil || h.currentWeight > best.currentWeight { 184 | best = h 185 | } 186 | } 187 | 188 | if best == nil { 189 | return "", ErrNoActiveHost 190 | } 191 | 192 | best.resetWeight(total) 193 | return best.BaseURL, nil 194 | } 195 | 196 | // Feedback method process the request feedback for Weighted Round-Robin(WRR) 197 | // request load balancer 198 | func (wrr *WeightedRoundRobin) Feedback(f *RequestFeedback) { 199 | wrr.lock.Lock() 200 | defer wrr.lock.Unlock() 201 | 202 | for _, host := range wrr.hosts { 203 | if host.BaseURL == f.BaseURL { 204 | if !f.Success { 205 | host.failedRequests++ 206 | } 207 | if host.failedRequests >= host.MaxFailures { 208 | host.state = HostStateInActive 209 | if wrr.onStateChange != nil { 210 | wrr.onStateChange(host.BaseURL, HostStateActive, HostStateInActive) 211 | } 212 | } 213 | break 214 | } 215 | } 216 | } 217 | 218 | // Close method does the cleanup by stopping the [time.Ticker] on 219 | // Weighted Round-Robin(WRR) request load balancer 220 | func (wrr *WeightedRoundRobin) Close() error { 221 | wrr.lock.Lock() 222 | defer wrr.lock.Unlock() 223 | wrr.tick.Stop() 224 | return nil 225 | } 226 | 227 | // Refresh method reset the existing values with the given [Host] slice to refresh it 228 | func (wrr *WeightedRoundRobin) Refresh(hosts ...*Host) error { 229 | if hosts == nil { 230 | return nil 231 | } 232 | 233 | wrr.lock.Lock() 234 | defer wrr.lock.Unlock() 235 | newTotalWeight := 0 236 | for _, h := range hosts { 237 | baseURL, err := extractBaseURL(h.BaseURL) 238 | if err != nil { 239 | return err 240 | } 241 | 242 | h.BaseURL = baseURL 243 | h.state = HostStateActive 244 | newTotalWeight += h.Weight 245 | 246 | // assign defaults if not provided 247 | if h.MaxFailures == 0 { 248 | h.MaxFailures = 5 // default value is 5 249 | } 250 | } 251 | 252 | // after processing, assign the updates 253 | wrr.hosts = hosts 254 | wrr.totalWeight = newTotalWeight 255 | return nil 256 | } 257 | 258 | // SetOnStateChange method used to set a callback for the host transition state 259 | func (wrr *WeightedRoundRobin) SetOnStateChange(fn HostStateChangeFunc) { 260 | wrr.lock.Lock() 261 | defer wrr.lock.Unlock() 262 | wrr.onStateChange = fn 263 | } 264 | 265 | // SetRecoveryDuration method is used to change the existing recovery duration for the host 266 | func (wrr *WeightedRoundRobin) SetRecoveryDuration(d time.Duration) { 267 | wrr.lock.Lock() 268 | defer wrr.lock.Unlock() 269 | wrr.recovery = d 270 | wrr.tick.Reset(d) 271 | } 272 | 273 | func (wrr *WeightedRoundRobin) ticker() { 274 | for range wrr.tick.C { 275 | wrr.lock.Lock() 276 | for _, host := range wrr.hosts { 277 | if host.state == HostStateInActive { 278 | host.state = HostStateActive 279 | host.failedRequests = 0 280 | 281 | if wrr.onStateChange != nil { 282 | wrr.onStateChange(host.BaseURL, HostStateInActive, HostStateActive) 283 | } 284 | } 285 | } 286 | wrr.lock.Unlock() 287 | } 288 | } 289 | 290 | // NewSRVWeightedRoundRobin method creates a new Weighted Round-Robin(WRR) load balancer instance 291 | // with given SRV values 292 | func NewSRVWeightedRoundRobin(service, proto, domainName, httpScheme string) (*SRVWeightedRoundRobin, error) { 293 | if isStringEmpty(proto) { 294 | proto = "tcp" 295 | } 296 | if isStringEmpty(httpScheme) { 297 | httpScheme = "https" 298 | } 299 | 300 | wrr, _ := NewWeightedRoundRobin(0) // with this input error will not occur 301 | swrr := &SRVWeightedRoundRobin{ 302 | Service: service, 303 | Proto: proto, 304 | DomainName: domainName, 305 | HttpScheme: httpScheme, 306 | wrr: wrr, 307 | tick: time.NewTicker(180 * time.Second), // default is 180 seconds 308 | lock: new(sync.Mutex), 309 | lookupSRV: func() ([]*net.SRV, error) { 310 | _, addrs, err := net.LookupSRV(service, proto, domainName) 311 | return addrs, err 312 | }, 313 | } 314 | 315 | err := swrr.Refresh() 316 | 317 | go swrr.ticker() 318 | 319 | return swrr, err 320 | } 321 | 322 | var _ LoadBalancer = (*SRVWeightedRoundRobin)(nil) 323 | 324 | // SRVWeightedRoundRobin struct used to implement SRV Weighted Round-Robin(RR) algorithm 325 | type SRVWeightedRoundRobin struct { 326 | Service string 327 | Proto string 328 | DomainName string 329 | HttpScheme string 330 | 331 | wrr *WeightedRoundRobin 332 | tick *time.Ticker 333 | lock *sync.Mutex 334 | lookupSRV func() ([]*net.SRV, error) 335 | } 336 | 337 | // Next method returns the next SRV Base URL based on Weighted Round-Robin(RR) 338 | func (swrr *SRVWeightedRoundRobin) Next() (string, error) { 339 | return swrr.wrr.Next() 340 | } 341 | 342 | // Feedback method does nothing in SRV Base URL based on Weighted Round-Robin(WRR) 343 | // request load balancer 344 | func (swrr *SRVWeightedRoundRobin) Feedback(f *RequestFeedback) { 345 | swrr.wrr.Feedback(f) 346 | } 347 | 348 | // Close method does the cleanup by stopping the [time.Ticker] SRV Base URL based 349 | // on Weighted Round-Robin(WRR) request load balancer 350 | func (swrr *SRVWeightedRoundRobin) Close() error { 351 | swrr.lock.Lock() 352 | defer swrr.lock.Unlock() 353 | swrr.wrr.Close() 354 | swrr.tick.Stop() 355 | return nil 356 | } 357 | 358 | // Refresh method reset the values based [net.LookupSRV] values to refresh it 359 | func (swrr *SRVWeightedRoundRobin) Refresh() error { 360 | swrr.lock.Lock() 361 | defer swrr.lock.Unlock() 362 | addrs, err := swrr.lookupSRV() 363 | if err != nil { 364 | return err 365 | } 366 | 367 | hosts := make([]*Host, len(addrs)) 368 | for idx, addr := range addrs { 369 | domain := strings.TrimRight(addr.Target, ".") 370 | baseURL := fmt.Sprintf("%s://%s:%d", swrr.HttpScheme, domain, addr.Port) 371 | hosts[idx] = &Host{BaseURL: baseURL, Weight: int(addr.Weight)} 372 | } 373 | 374 | return swrr.wrr.Refresh(hosts...) 375 | } 376 | 377 | // SetRefreshDuration method assists in changing the default (180 seconds) refresh duration 378 | func (swrr *SRVWeightedRoundRobin) SetRefreshDuration(d time.Duration) { 379 | swrr.lock.Lock() 380 | defer swrr.lock.Unlock() 381 | swrr.tick.Reset(d) 382 | } 383 | 384 | // SetOnStateChange method used to set a callback for the host transition state 385 | func (swrr *SRVWeightedRoundRobin) SetOnStateChange(fn HostStateChangeFunc) { 386 | swrr.wrr.SetOnStateChange(fn) 387 | } 388 | 389 | // SetRecoveryDuration method is used to change the existing recovery duration for the host 390 | func (swrr *SRVWeightedRoundRobin) SetRecoveryDuration(d time.Duration) { 391 | swrr.wrr.SetRecoveryDuration(d) 392 | } 393 | 394 | func (swrr *SRVWeightedRoundRobin) ticker() { 395 | for range swrr.tick.C { 396 | swrr.Refresh() 397 | } 398 | } 399 | 400 | func extractBaseURL(u string) (string, error) { 401 | baseURL, err := url.Parse(u) 402 | if err != nil { 403 | return "", err 404 | } 405 | 406 | // we only require base URL LB 407 | baseURL.Path = "" 408 | baseURL.RawQuery = "" 409 | 410 | return strings.TrimRight(baseURL.String(), "/"), nil 411 | } 412 | -------------------------------------------------------------------------------- /load_balancer_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "errors" 10 | "net" 11 | "net/http" 12 | "net/url" 13 | "sync/atomic" 14 | "testing" 15 | "time" 16 | ) 17 | 18 | func TestRoundRobin(t *testing.T) { 19 | 20 | t.Run("2 base urls", func(t *testing.T) { 21 | rr, err := NewRoundRobin("https://example1.com", "https://example2.com") 22 | assertNil(t, err) 23 | 24 | runCount := 5 25 | var result []string 26 | for i := 0; i < runCount; i++ { 27 | baseURL, _ := rr.Next() 28 | result = append(result, baseURL) 29 | } 30 | 31 | expected := []string{ 32 | "https://example1.com", "https://example2.com", "https://example1.com", 33 | "https://example2.com", "https://example1.com", 34 | } 35 | 36 | assertEqual(t, runCount, len(expected)) 37 | assertEqual(t, runCount, len(result)) 38 | assertEqual(t, expected, result) 39 | 40 | rr.Feedback(&RequestFeedback{}) 41 | rr.Close() 42 | }) 43 | 44 | t.Run("5 base urls", func(t *testing.T) { 45 | input := []string{"https://example1.com", "https://example2.com", 46 | "https://example3.com", "https://example4.com", "https://example5.com"} 47 | rr, err := NewRoundRobin(input...) 48 | assertNil(t, err) 49 | 50 | runCount := 30 51 | var result []string 52 | for i := 0; i < runCount; i++ { 53 | baseURL, _ := rr.Next() 54 | result = append(result, baseURL) 55 | } 56 | 57 | var expected []string 58 | for i := 0; i < runCount/len(input); i++ { 59 | expected = append(expected, input...) 60 | } 61 | 62 | assertEqual(t, runCount, len(expected)) 63 | assertEqual(t, runCount, len(result)) 64 | assertEqual(t, expected, result) 65 | 66 | rr.Feedback(&RequestFeedback{}) 67 | rr.Close() 68 | }) 69 | 70 | t.Run("2 base urls with refresh", func(t *testing.T) { 71 | rr, err := NewRoundRobin("https://example1.com", "https://example2.com") 72 | assertNil(t, err) 73 | 74 | err = rr.Refresh("https://example3.com", "https://example4.com") 75 | assertNil(t, err) 76 | 77 | runCount := 5 78 | var result []string 79 | for i := 0; i < runCount; i++ { 80 | baseURL, _ := rr.Next() 81 | result = append(result, baseURL) 82 | } 83 | 84 | expected := []string{ 85 | "https://example3.com", "https://example4.com", "https://example3.com", 86 | "https://example4.com", "https://example3.com", 87 | } 88 | 89 | assertEqual(t, runCount, len(expected)) 90 | assertEqual(t, runCount, len(result)) 91 | assertEqual(t, expected, result) 92 | 93 | rr.Feedback(&RequestFeedback{}) 94 | rr.Close() 95 | }) 96 | } 97 | 98 | func TestWeightedRoundRobin(t *testing.T) { 99 | t.Run("3 hosts with weight {5,2,1}", func(t *testing.T) { 100 | hosts := []*Host{ 101 | {BaseURL: "https://example1.com", Weight: 5}, 102 | {BaseURL: "https://example2.com", Weight: 2}, 103 | {BaseURL: "https://example3.com", Weight: 1}, 104 | } 105 | 106 | wrr, err := NewWeightedRoundRobin(200*time.Millisecond, hosts...) 107 | assertNil(t, err) 108 | defer wrr.Close() 109 | 110 | runCount := 5 111 | var result []string 112 | for i := 0; i < runCount; i++ { 113 | baseURL, err := wrr.Next() 114 | assertNil(t, err) 115 | result = append(result, baseURL) 116 | } 117 | 118 | expected := []string{ 119 | "https://example1.com", "https://example2.com", "https://example1.com", 120 | "https://example1.com", "https://example3.com", 121 | } 122 | 123 | assertEqual(t, runCount, len(expected)) 124 | assertEqual(t, runCount, len(result)) 125 | assertEqual(t, expected, result) 126 | }) 127 | 128 | t.Run("3 hosts with weight {2,1,10}", func(t *testing.T) { 129 | hosts := []*Host{ 130 | {BaseURL: "https://example1.com", Weight: 2}, 131 | {BaseURL: "https://example2.com", Weight: 1}, 132 | {BaseURL: "https://example3.com", Weight: 10, MaxFailures: 3}, 133 | } 134 | 135 | wrr, err := NewWeightedRoundRobin(200*time.Millisecond, hosts...) 136 | assertNil(t, err) 137 | defer wrr.Close() 138 | 139 | var stateChangeCalled int32 140 | wrr.SetOnStateChange(func(baseURL string, from, to HostState) { 141 | atomic.AddInt32(&stateChangeCalled, 1) 142 | }) 143 | 144 | runCount := 10 145 | var result []string 146 | for i := 0; i < runCount; i++ { 147 | baseURL, err := wrr.Next() 148 | assertNil(t, err) 149 | result = append(result, baseURL) 150 | if baseURL == "https://example3.com" && i%2 != 0 { 151 | wrr.Feedback(&RequestFeedback{BaseURL: baseURL, Success: false, Attempt: 1}) 152 | } else { 153 | wrr.Feedback(&RequestFeedback{BaseURL: baseURL, Success: true, Attempt: 1}) 154 | } 155 | } 156 | 157 | expected := []string{ 158 | "https://example3.com", "https://example3.com", "https://example1.com", 159 | "https://example3.com", "https://example3.com", "https://example3.com", 160 | "https://example2.com", "https://example2.com", "https://example1.com", 161 | "https://example1.com", 162 | } 163 | 164 | assertEqual(t, int32(1), stateChangeCalled) 165 | assertEqual(t, runCount, len(expected)) 166 | assertEqual(t, runCount, len(result)) 167 | assertEqual(t, expected, result) 168 | }) 169 | 170 | t.Run("2 hosts with weight {5,5} and refresh", func(t *testing.T) { 171 | wrr, err := NewWeightedRoundRobin( 172 | 200*time.Millisecond, 173 | &Host{BaseURL: "https://example1.com", Weight: 5}, 174 | &Host{BaseURL: "https://example2.com", Weight: 5}, 175 | ) 176 | assertNil(t, err) 177 | defer wrr.Close() 178 | 179 | err = wrr.Refresh( 180 | &Host{BaseURL: "https://example3.com", Weight: 5}, 181 | &Host{BaseURL: "https://example4.com", Weight: 5}, 182 | ) 183 | assertNil(t, err) 184 | 185 | runCount := 5 186 | var result []string 187 | for i := 0; i < runCount; i++ { 188 | baseURL, err := wrr.Next() 189 | assertNil(t, err) 190 | result = append(result, baseURL) 191 | } 192 | 193 | expected := []string{ 194 | "https://example3.com", "https://example4.com", "https://example3.com", 195 | "https://example4.com", "https://example3.com", 196 | } 197 | 198 | assertEqual(t, runCount, len(expected)) 199 | assertEqual(t, runCount, len(result)) 200 | assertEqual(t, expected, result) 201 | }) 202 | 203 | t.Run("no active hosts error", func(t *testing.T) { 204 | wrr, err := NewWeightedRoundRobin(200 * time.Millisecond) 205 | assertNil(t, err) 206 | defer wrr.Close() 207 | 208 | _, err = wrr.Next() 209 | assertErrorIs(t, ErrNoActiveHost, err) 210 | }) 211 | } 212 | 213 | func TestSRVWeightedRoundRobin(t *testing.T) { 214 | t.Run("3 records with weight {50,30,20}", func(t *testing.T) { 215 | srv, err := NewSRVWeightedRoundRobin("_sample-server", "", "example.com", "") 216 | assertNotNil(t, err) 217 | assertNotNil(t, srv) 218 | var dnsErr *net.DNSError 219 | assertEqual(t, true, errors.As(err, &dnsErr)) 220 | 221 | // mock net.LookupSRV call 222 | srv.lookupSRV = func() ([]*net.SRV, error) { 223 | return []*net.SRV{ 224 | {Target: "service1.example.com.", Port: 443, Priority: 10, Weight: 50}, 225 | {Target: "service2.example.com.", Port: 443, Priority: 20, Weight: 30}, 226 | {Target: "service3.example.com.", Port: 443, Priority: 20, Weight: 20}, 227 | }, nil 228 | } 229 | err = srv.Refresh() 230 | assertNil(t, err) 231 | 232 | srv.SetRecoveryDuration(200 * time.Millisecond) 233 | 234 | runCount := 5 235 | var result []string 236 | for i := 0; i < runCount; i++ { 237 | baseURL, err := srv.Next() 238 | assertNil(t, err) 239 | result = append(result, baseURL) 240 | } 241 | 242 | expected := []string{ 243 | "https://service1.example.com:443", "https://service2.example.com:443", 244 | "https://service3.example.com:443", "https://service1.example.com:443", 245 | "https://service1.example.com:443", 246 | } 247 | 248 | assertEqual(t, runCount, len(expected)) 249 | assertEqual(t, runCount, len(result)) 250 | assertEqual(t, expected, result) 251 | }) 252 | 253 | t.Run("2 records with weight {50,50}", func(t *testing.T) { 254 | srv, err := NewSRVWeightedRoundRobin("_sample-server", "", "example.com", "") 255 | assertNotNil(t, err) 256 | assertNotNil(t, srv) 257 | var dnsErr *net.DNSError 258 | assertEqual(t, true, errors.As(err, &dnsErr)) 259 | 260 | // mock net.LookupSRV call 261 | srv.lookupSRV = func() ([]*net.SRV, error) { 262 | return []*net.SRV{ 263 | {Target: "service1.example.com.", Port: 443, Priority: 10, Weight: 50}, 264 | {Target: "service2.example.com.", Port: 443, Priority: 20, Weight: 50}, 265 | }, nil 266 | } 267 | err = srv.Refresh() 268 | assertNil(t, err) 269 | 270 | srv.SetRecoveryDuration(200 * time.Millisecond) 271 | 272 | runCount := 5 273 | var result []string 274 | for i := 0; i < runCount; i++ { 275 | baseURL, err := srv.Next() 276 | assertNil(t, err) 277 | result = append(result, baseURL) 278 | } 279 | 280 | expected := []string{ 281 | "https://service1.example.com:443", "https://service2.example.com:443", 282 | "https://service1.example.com:443", "https://service2.example.com:443", 283 | "https://service1.example.com:443", 284 | } 285 | 286 | assertEqual(t, runCount, len(expected)) 287 | assertEqual(t, runCount, len(result)) 288 | assertEqual(t, expected, result) 289 | }) 290 | 291 | t.Run("3 records with weight {60,20,20}", func(t *testing.T) { 292 | srv, err := NewSRVWeightedRoundRobin("_sample-server", "", "example.com", "") 293 | assertNotNil(t, err) 294 | assertNotNil(t, srv) 295 | var dnsErr *net.DNSError 296 | assertEqual(t, true, errors.As(err, &dnsErr)) 297 | 298 | // mock net.LookupSRV call 299 | srv.lookupSRV = func() ([]*net.SRV, error) { 300 | return []*net.SRV{ 301 | {Target: "service1.example.com.", Port: 443, Priority: 10, Weight: 60}, 302 | {Target: "service2.example.com.", Port: 443, Priority: 20, Weight: 20}, 303 | {Target: "service3.example.com.", Port: 443, Priority: 20, Weight: 20}, 304 | }, nil 305 | } 306 | err = srv.Refresh() 307 | assertNil(t, err) 308 | 309 | var stateChangeCalled int32 310 | srv.SetOnStateChange(func(baseURL string, from, to HostState) { 311 | atomic.AddInt32(&stateChangeCalled, 1) 312 | }) 313 | 314 | srv.SetRecoveryDuration(200 * time.Millisecond) 315 | 316 | runCount := 20 317 | var result []string 318 | for i := 0; i < runCount; i++ { 319 | baseURL, err := srv.Next() 320 | assertNil(t, err) 321 | result = append(result, baseURL) 322 | 323 | if baseURL == "https://service1.example.com:443" { 324 | srv.Feedback(&RequestFeedback{BaseURL: baseURL, Success: false, Attempt: 1}) 325 | } else { 326 | srv.Feedback(&RequestFeedback{BaseURL: baseURL, Success: true, Attempt: 1}) 327 | } 328 | } 329 | 330 | expected := []string{ 331 | "https://service1.example.com:443", "https://service2.example.com:443", "https://service1.example.com:443", 332 | "https://service3.example.com:443", "https://service1.example.com:443", "https://service1.example.com:443", 333 | "https://service2.example.com:443", "https://service1.example.com:443", "https://service3.example.com:443", 334 | "https://service3.example.com:443", "https://service3.example.com:443", "https://service2.example.com:443", 335 | "https://service3.example.com:443", "https://service2.example.com:443", "https://service3.example.com:443", 336 | "https://service2.example.com:443", "https://service3.example.com:443", "https://service2.example.com:443", 337 | "https://service3.example.com:443", "https://service2.example.com:443", 338 | } 339 | 340 | assertEqual(t, runCount, len(expected)) 341 | assertEqual(t, runCount, len(result)) 342 | assertEqual(t, expected, result) 343 | }) 344 | 345 | t.Run("srv record with refresh duration 100ms", func(t *testing.T) { 346 | srv, err := NewSRVWeightedRoundRobin("_sample-server", "", "example.com", "") 347 | assertNotNil(t, err) 348 | assertNotNil(t, srv) 349 | var dnsErr *net.DNSError 350 | assertEqual(t, true, errors.As(err, &dnsErr)) 351 | 352 | // mock net.LookupSRV call 353 | srv.lookupSRV = func() ([]*net.SRV, error) { 354 | return []*net.SRV{ 355 | {Target: "service1.example.com.", Port: 443, Priority: 10, Weight: 50}, 356 | {Target: "service2.example.com.", Port: 443, Priority: 20, Weight: 50}, 357 | }, nil 358 | } 359 | err = srv.Refresh() 360 | assertNil(t, err) 361 | 362 | srv.SetRecoveryDuration(200 * time.Millisecond) 363 | 364 | go func() { 365 | for i := 0; i < 10; i++ { 366 | baseURL, _ := srv.Next() 367 | assertNotNil(t, baseURL) 368 | time.Sleep(15 * time.Millisecond) 369 | } 370 | }() 371 | 372 | srv.SetRefreshDuration(150 * time.Millisecond) 373 | time.Sleep(320 * time.Millisecond) 374 | srv.Close() 375 | }) 376 | 377 | t.Run("srv record with error on default lookupSRV", func(t *testing.T) { 378 | srv, err := NewSRVWeightedRoundRobin("_sample-server", "", "example.com", "") 379 | assertNotNil(t, err) 380 | assertNotNil(t, srv) 381 | var dnsErr *net.DNSError 382 | assertEqual(t, true, errors.As(err, &dnsErr)) 383 | 384 | // default error flow 385 | err = srv.Refresh() 386 | assertNotNil(t, err) 387 | assertEqual(t, true, errors.As(err, &dnsErr)) 388 | 389 | // replace with mock error flow 390 | errMockTest := errors.New("network error") 391 | srv.lookupSRV = func() ([]*net.SRV, error) { return nil, errMockTest } 392 | err = srv.Refresh() 393 | assertNotNil(t, err) 394 | assertErrorIs(t, errMockTest, err) 395 | 396 | }) 397 | 398 | } 399 | 400 | func TestLoadBalancerRequest(t *testing.T) { 401 | ts1 := createGetServer(t) 402 | defer ts1.Close() 403 | 404 | ts2 := createGetServer(t) 405 | defer ts2.Close() 406 | 407 | rr, err := NewRoundRobin(ts1.URL, ts2.URL) 408 | assertNil(t, err) 409 | 410 | c := dcnl() 411 | defer c.Close() 412 | 413 | c.SetLoadBalancer(rr) 414 | 415 | ts1URL, ts2URL := 0, 0 416 | for i := 0; i < 20; i++ { 417 | resp, err := c.R().Get("/") 418 | assertNil(t, err) 419 | switch resp.Request.baseURL { 420 | case ts1.URL: 421 | ts1URL++ 422 | case ts2.URL: 423 | ts2URL++ 424 | } 425 | } 426 | assertEqual(t, ts1URL, ts2URL) 427 | } 428 | 429 | func TestLoadBalancerRequestFlowError(t *testing.T) { 430 | 431 | t.Run("obtain next url error", func(t *testing.T) { 432 | wrr, err := NewWeightedRoundRobin(0) 433 | assertNil(t, err) 434 | 435 | c := dcnl() 436 | defer c.Close() 437 | 438 | c.SetLoadBalancer(wrr) 439 | 440 | resp, err := c.R().Get("/") 441 | assertEqual(t, ErrNoActiveHost, err) 442 | assertNil(t, resp) 443 | }) 444 | 445 | t.Run("round-robin invalid url input", func(t *testing.T) { 446 | rr, err := NewRoundRobin("://example.com") 447 | assertType(t, url.Error{}, err) 448 | assertNotNil(t, rr) 449 | 450 | wrr, err := NewWeightedRoundRobin(0, &Host{BaseURL: "://example.com"}) 451 | assertType(t, url.Error{}, err) 452 | assertNotNil(t, wrr) 453 | }) 454 | 455 | t.Run("weighted round-robin invalid url input", func(t *testing.T) { 456 | wrr, err := NewWeightedRoundRobin(0, &Host{BaseURL: "://example.com"}) 457 | assertType(t, url.Error{}, err) 458 | assertNotNil(t, wrr) 459 | }) 460 | } 461 | 462 | func Test_extractBaseURL(t *testing.T) { 463 | for _, tt := range []struct { 464 | name string 465 | inputURL string 466 | expectedURL string 467 | expectedErr error 468 | }{ 469 | { 470 | name: "simple relative path", 471 | inputURL: "https://resty.dev/welcome", 472 | expectedURL: "https://resty.dev", 473 | }, 474 | { 475 | name: "longer relative path with file extension", 476 | inputURL: "https://resty.dev/welcome/path/to/remove.html", 477 | expectedURL: "https://resty.dev", 478 | }, 479 | { 480 | name: "longer relative path with file extension and query params", 481 | inputURL: "https://resty.dev/welcome/path/to/remove.html?a=1&b=2", 482 | expectedURL: "https://resty.dev", 483 | }, 484 | { 485 | name: "invalid url input", 486 | inputURL: "://resty.dev/welcome", 487 | expectedURL: "", 488 | expectedErr: &url.Error{Op: "parse", URL: "://resty.dev/welcome", Err: errors.New("missing protocol scheme")}, 489 | }, 490 | } { 491 | t.Run(tt.name, func(t *testing.T) { 492 | outputURL, err := extractBaseURL(tt.inputURL) 493 | if tt.expectedErr != nil { 494 | assertEqual(t, tt.expectedErr, err) 495 | } 496 | assertEqual(t, tt.expectedURL, outputURL) 497 | }) 498 | } 499 | } 500 | 501 | func TestLoadBalancerRequestFailures(t *testing.T) { 502 | ts1 := createGetServer(t) 503 | ts1.Close() 504 | 505 | ts2 := createGetServer(t) 506 | defer ts2.Close() 507 | 508 | rr, err := NewWeightedRoundRobin(200*time.Millisecond, 509 | &Host{BaseURL: ts1.URL, Weight: 50, MaxFailures: 3}, &Host{BaseURL: ts2.URL, Weight: 50}) 510 | assertNil(t, err) 511 | 512 | c := dcnl() 513 | defer c.Close() 514 | 515 | c.SetLoadBalancer(rr) 516 | 517 | ts1URL, ts2URL := 0, 0 518 | for i := 0; i < 10; i++ { 519 | resp, _ := c.R().Get("/") 520 | switch resp.Request.baseURL { 521 | case ts1.URL: 522 | ts1URL++ 523 | case ts2.URL: 524 | assertError(t, err) 525 | ts2URL++ 526 | } 527 | } 528 | assertEqual(t, 3, ts1URL) 529 | assertEqual(t, 7, ts2URL) 530 | } 531 | 532 | type mockTimeoutErr struct{} 533 | 534 | func (e *mockTimeoutErr) Error() string { return "i/o timeout" } 535 | func (e *mockTimeoutErr) Timeout() bool { return true } 536 | 537 | func TestLoadBalancerCoverage(t *testing.T) { 538 | t.Run("mock net op timeout error", func(t *testing.T) { 539 | wrr, err := NewWeightedRoundRobin(0) 540 | assertNil(t, err) 541 | 542 | c := dcnl() 543 | defer c.Close() 544 | 545 | c.SetLoadBalancer(wrr) 546 | 547 | req := c.R() 548 | 549 | netOpErr := &net.OpError{Op: "mock", Net: "mock", Err: &mockTimeoutErr{}} 550 | req.sendLoadBalancerFeedback(&Response{}, netOpErr) 551 | 552 | req.sendLoadBalancerFeedback(&Response{RawResponse: &http.Response{ 553 | StatusCode: http.StatusInternalServerError, 554 | }}, nil) 555 | }) 556 | } 557 | -------------------------------------------------------------------------------- /middleware.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "bytes" 10 | "fmt" 11 | "io" 12 | "mime" 13 | "mime/multipart" 14 | "net/http" 15 | "net/textproto" 16 | "net/url" 17 | "path" 18 | "path/filepath" 19 | "reflect" 20 | "strconv" 21 | "strings" 22 | ) 23 | 24 | //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ 25 | // Request Middleware(s) 26 | //_______________________________________________________________________ 27 | 28 | // PrepareRequestMiddleware method is used to prepare HTTP requests from 29 | // user provides request values. Request preparation fails if any error occurs 30 | func PrepareRequestMiddleware(c *Client, r *Request) (err error) { 31 | if err = parseRequestURL(c, r); err != nil { 32 | return err 33 | } 34 | 35 | // no error returned 36 | parseRequestHeader(c, r) 37 | 38 | if err = parseRequestBody(c, r); err != nil { 39 | return err 40 | } 41 | 42 | // at this point, possible error from `http.NewRequestWithContext` 43 | // is URL-related, and those get caught up in the `parseRequestURL` 44 | createRawRequest(c, r) 45 | 46 | addCredentials(c, r) 47 | 48 | _ = r.generateCurlCommand() 49 | 50 | return nil 51 | } 52 | 53 | func parseRequestURL(c *Client, r *Request) error { 54 | if len(c.PathParams())+len(r.PathParams) > 0 { 55 | // GitHub #103 Path Params, #663 Raw Path Params 56 | for p, v := range c.PathParams() { 57 | if _, ok := r.PathParams[p]; ok { 58 | continue 59 | } 60 | r.PathParams[p] = v 61 | } 62 | 63 | var prev int 64 | buf := acquireBuffer() 65 | defer releaseBuffer(buf) 66 | // search for the next or first opened curly bracket 67 | for curr := strings.Index(r.URL, "{"); curr == 0 || curr > prev; curr = prev + strings.Index(r.URL[prev:], "{") { 68 | // write everything from the previous position up to the current 69 | if curr > prev { 70 | buf.WriteString(r.URL[prev:curr]) 71 | } 72 | // search for the closed curly bracket from current position 73 | next := curr + strings.Index(r.URL[curr:], "}") 74 | // if not found, then write the remainder and exit 75 | if next < curr { 76 | buf.WriteString(r.URL[curr:]) 77 | prev = len(r.URL) 78 | break 79 | } 80 | // special case for {}, without parameter's name 81 | if next == curr+1 { 82 | buf.WriteString("{}") 83 | } else { 84 | // check for the replacement 85 | key := r.URL[curr+1 : next] 86 | value, ok := r.PathParams[key] 87 | // keep the original string if the replacement not found 88 | if !ok { 89 | value = r.URL[curr : next+1] 90 | } 91 | buf.WriteString(value) 92 | } 93 | 94 | // set the previous position after the closed curly bracket 95 | prev = next + 1 96 | if prev >= len(r.URL) { 97 | break 98 | } 99 | } 100 | if buf.Len() > 0 { 101 | // write remainder 102 | if prev < len(r.URL) { 103 | buf.WriteString(r.URL[prev:]) 104 | } 105 | r.URL = buf.String() 106 | } 107 | } 108 | 109 | // Parsing request URL 110 | reqURL, err := url.Parse(r.URL) 111 | if err != nil { 112 | return &invalidRequestError{Err: err} 113 | } 114 | 115 | // If [Request.URL] is a relative path, then the following 116 | // gets evaluated in the order 117 | // 1. [Client.LoadBalancer] is used to obtain the base URL if not nil 118 | // 2. [Client.BaseURL] is used to obtain the base URL 119 | // 3. Otherwise [Request.URL] is used as-is 120 | if !reqURL.IsAbs() { 121 | r.URL = reqURL.String() 122 | if len(r.URL) > 0 && r.URL[0] != '/' { 123 | r.URL = "/" + r.URL 124 | } 125 | 126 | if r.client.LoadBalancer() != nil { 127 | r.baseURL, err = r.client.LoadBalancer().Next() 128 | if err != nil { 129 | return &invalidRequestError{Err: err} 130 | } 131 | } 132 | 133 | reqURL, err = url.Parse(r.baseURL + r.URL) 134 | if err != nil { 135 | return &invalidRequestError{Err: err} 136 | } 137 | } 138 | 139 | // GH #407 && #318 140 | if reqURL.Scheme == "" && len(c.Scheme()) > 0 { 141 | reqURL.Scheme = c.Scheme() 142 | } 143 | 144 | // Adding Query Param 145 | if len(c.QueryParams())+len(r.QueryParams) > 0 { 146 | for k, v := range c.QueryParams() { 147 | if _, ok := r.QueryParams[k]; ok { 148 | continue 149 | } 150 | r.QueryParams[k] = v[:] 151 | } 152 | 153 | // GitHub #123 Preserve query string order partially. 154 | // Since not feasible in `SetQuery*` resty methods, because 155 | // standard package `url.Encode(...)` sorts the query params 156 | // alphabetically 157 | if isStringEmpty(reqURL.RawQuery) { 158 | reqURL.RawQuery = r.QueryParams.Encode() 159 | } else { 160 | reqURL.RawQuery = reqURL.RawQuery + "&" + r.QueryParams.Encode() 161 | } 162 | } 163 | 164 | // GH#797 Unescape query parameters (non-standard - not recommended) 165 | if r.unescapeQueryParams && len(reqURL.RawQuery) > 0 { 166 | // at this point, all errors caught up in the above operations 167 | // so ignore the return error on query unescape; I realized 168 | // while writing the unit test 169 | unescapedQuery, _ := url.QueryUnescape(reqURL.RawQuery) 170 | reqURL.RawQuery = strings.ReplaceAll(unescapedQuery, " ", "+") // otherwise request becomes bad request 171 | } 172 | 173 | r.URL = reqURL.String() 174 | 175 | return nil 176 | } 177 | 178 | func parseRequestHeader(c *Client, r *Request) error { 179 | for k, v := range c.Header() { 180 | if _, ok := r.Header[k]; ok { 181 | continue 182 | } 183 | r.Header[k] = v[:] 184 | } 185 | 186 | if !r.isHeaderExists(hdrUserAgentKey) { 187 | r.Header.Set(hdrUserAgentKey, hdrUserAgentValue) 188 | } 189 | 190 | if !r.isHeaderExists(hdrAcceptEncodingKey) { 191 | r.Header.Set(hdrAcceptEncodingKey, r.client.ContentDecompresserKeys()) 192 | } 193 | 194 | return nil 195 | } 196 | 197 | func parseRequestBody(c *Client, r *Request) error { 198 | if r.isMultiPart && !(r.Method == MethodPost || r.Method == MethodPut || r.Method == MethodPatch) { 199 | err := fmt.Errorf("resty: multipart is not allowed in HTTP verb: %v", r.Method) 200 | return &invalidRequestError{Err: err} 201 | } 202 | 203 | if r.isPayloadSupported() { 204 | switch { 205 | case r.isMultiPart: // Handling Multipart 206 | if err := handleMultipart(c, r); err != nil { 207 | return &invalidRequestError{Err: err} 208 | } 209 | case len(c.FormData()) > 0 || len(r.FormData) > 0: // Handling Form Data 210 | handleFormData(c, r) 211 | case r.Body != nil: // Handling Request body 212 | if err := handleRequestBody(c, r); err != nil { 213 | return &invalidRequestError{Err: err} 214 | } 215 | } 216 | } else { 217 | r.Body = nil // if the payload is not supported by HTTP verb, set explicit nil 218 | } 219 | 220 | // by default resty won't set content length, but user can opt-in 221 | if r.setContentLength { 222 | cntLen := 0 223 | if r.bodyBuf != nil { 224 | cntLen = r.bodyBuf.Len() 225 | } else if b, ok := r.Body.(*bytes.Reader); ok { 226 | cntLen = b.Len() 227 | } 228 | r.Header.Set(hdrContentLengthKey, strconv.Itoa(cntLen)) 229 | } 230 | 231 | return nil 232 | } 233 | 234 | func createRawRequest(c *Client, r *Request) (err error) { 235 | // init client trace if enabled 236 | r.initTraceIfEnabled() 237 | 238 | if r.bodyBuf == nil { 239 | if reader, ok := r.Body.(io.Reader); ok { 240 | r.RawRequest, err = http.NewRequestWithContext(r.Context(), r.Method, r.URL, reader) 241 | } else { 242 | r.RawRequest, err = http.NewRequestWithContext(r.Context(), r.Method, r.URL, nil) 243 | } 244 | } else { 245 | r.RawRequest, err = http.NewRequestWithContext(r.Context(), r.Method, r.URL, r.bodyBuf) 246 | } 247 | 248 | if err != nil { 249 | return &invalidRequestError{Err: err} 250 | } 251 | 252 | // get the context reference back from underlying RawRequest 253 | r.ctx = r.RawRequest.Context() 254 | 255 | // Assign close connection option 256 | r.RawRequest.Close = r.CloseConnection 257 | 258 | // Add headers into http request 259 | r.RawRequest.Header = r.Header 260 | 261 | // Add cookies from client instance into http request 262 | for _, cookie := range c.Cookies() { 263 | r.RawRequest.AddCookie(cookie) 264 | } 265 | 266 | // Add cookies from request instance into http request 267 | for _, cookie := range r.Cookies { 268 | r.RawRequest.AddCookie(cookie) 269 | } 270 | 271 | return 272 | } 273 | 274 | func addCredentials(c *Client, r *Request) error { 275 | credentialsAdded := false 276 | // Basic Auth 277 | if r.credentials != nil { 278 | credentialsAdded = true 279 | r.RawRequest.SetBasicAuth(r.credentials.Username, r.credentials.Password) 280 | } 281 | 282 | // Build the token Auth header 283 | if !isStringEmpty(r.AuthToken) { 284 | credentialsAdded = true 285 | r.RawRequest.Header.Set(r.HeaderAuthorizationKey, strings.TrimSpace(r.AuthScheme+" "+r.AuthToken)) 286 | } 287 | 288 | if !c.IsDisableWarn() && credentialsAdded { 289 | if r.RawRequest.URL.Scheme == "http" { 290 | r.log.Warnf("Using sensitive credentials in HTTP mode is not secure. Use HTTPS") 291 | } 292 | } 293 | 294 | return nil 295 | } 296 | 297 | func handleMultipart(c *Client, r *Request) error { 298 | for k, v := range c.FormData() { 299 | if _, ok := r.FormData[k]; ok { 300 | continue 301 | } 302 | r.FormData[k] = v[:] 303 | } 304 | 305 | mfLen := len(r.multipartFields) 306 | if mfLen == 0 { 307 | r.bodyBuf = acquireBuffer() 308 | mw := multipart.NewWriter(r.bodyBuf) 309 | 310 | // set boundary if it is provided by the user 311 | if !isStringEmpty(r.multipartBoundary) { 312 | if err := mw.SetBoundary(r.multipartBoundary); err != nil { 313 | return err 314 | } 315 | } 316 | 317 | if err := r.writeFormData(mw); err != nil { 318 | return err 319 | } 320 | 321 | r.Header.Set(hdrContentTypeKey, mw.FormDataContentType()) 322 | closeq(mw) 323 | 324 | return nil 325 | } 326 | 327 | // multipart streaming 328 | bodyReader, bodyWriter := io.Pipe() 329 | mw := multipart.NewWriter(bodyWriter) 330 | r.Body = bodyReader 331 | r.multipartErrChan = make(chan error, 1) 332 | 333 | // set boundary if it is provided by the user 334 | if !isStringEmpty(r.multipartBoundary) { 335 | if err := mw.SetBoundary(r.multipartBoundary); err != nil { 336 | return err 337 | } 338 | } 339 | 340 | go func() { 341 | defer close(r.multipartErrChan) 342 | if err := createMultipart(mw, r); err != nil { 343 | r.multipartErrChan <- err 344 | } 345 | closeq(mw) 346 | closeq(bodyWriter) 347 | }() 348 | 349 | r.Header.Set(hdrContentTypeKey, mw.FormDataContentType()) 350 | return nil 351 | } 352 | 353 | var mpCreatePart = func(w *multipart.Writer, h textproto.MIMEHeader) (io.Writer, error) { 354 | return w.CreatePart(h) 355 | } 356 | 357 | func createMultipart(w *multipart.Writer, r *Request) error { 358 | if err := r.writeFormData(w); err != nil { 359 | return err 360 | } 361 | 362 | for _, mf := range r.multipartFields { 363 | if len(mf.Values) > 0 { 364 | for _, v := range mf.Values { 365 | w.WriteField(mf.Name, v) 366 | } 367 | continue 368 | } 369 | 370 | if err := mf.openFileIfRequired(); err != nil { 371 | return err 372 | } 373 | 374 | p := make([]byte, 512) 375 | size, err := mf.Reader.Read(p) 376 | if err != nil && err != io.EOF { 377 | return err 378 | } 379 | // auto detect content type if empty 380 | if isStringEmpty(mf.ContentType) { 381 | mf.ContentType = http.DetectContentType(p[:size]) 382 | } 383 | 384 | partWriter, err := mpCreatePart(w, mf.createHeader()) 385 | if err != nil { 386 | return err 387 | } 388 | 389 | partWriter = mf.wrapProgressCallbackIfPresent(partWriter) 390 | partWriter.Write(p[:size]) 391 | 392 | if _, err = ioCopy(partWriter, mf.Reader); err != nil { 393 | return err 394 | } 395 | } 396 | 397 | return nil 398 | } 399 | 400 | func handleFormData(c *Client, r *Request) { 401 | for k, v := range c.FormData() { 402 | if _, ok := r.FormData[k]; ok { 403 | continue 404 | } 405 | r.FormData[k] = v[:] 406 | } 407 | 408 | r.bodyBuf = acquireBuffer() 409 | r.bodyBuf.WriteString(r.FormData.Encode()) 410 | r.Header.Set(hdrContentTypeKey, formContentType) 411 | r.isFormData = true 412 | } 413 | 414 | func handleRequestBody(c *Client, r *Request) error { 415 | contentType := r.Header.Get(hdrContentTypeKey) 416 | if isStringEmpty(contentType) { 417 | // it is highly recommended that the user provide a request content-type 418 | // so that we can minimize memory allocation and compute. 419 | contentType = detectContentType(r.Body) 420 | } 421 | if !r.isHeaderExists(hdrContentTypeKey) { 422 | r.Header.Set(hdrContentTypeKey, contentType) 423 | } 424 | 425 | r.bodyBuf = acquireBuffer() 426 | 427 | switch body := r.Body.(type) { 428 | case io.Reader: 429 | // Resty v3 onwards io.Reader used as-is with the request body. 430 | releaseBuffer(r.bodyBuf) 431 | r.bodyBuf = nil 432 | 433 | // enable multiple reads if body is *bytes.Buffer 434 | if b, ok := r.Body.(*bytes.Buffer); ok { 435 | v := b.Bytes() 436 | r.Body = bytes.NewReader(v) 437 | } 438 | 439 | // do seek start for retry attempt if io.ReadSeeker 440 | // interface supported 441 | if r.Attempt > 1 { 442 | if rs, ok := r.Body.(io.ReadSeeker); ok { 443 | _, _ = rs.Seek(0, io.SeekStart) 444 | } 445 | } 446 | return nil 447 | case []byte: 448 | r.bodyBuf.Write(body) 449 | case string: 450 | r.bodyBuf.Write([]byte(body)) 451 | default: 452 | encKey := inferContentTypeMapKey(contentType) 453 | if jsonKey == encKey { 454 | if !r.jsonEscapeHTML { 455 | return encodeJSONEscapeHTML(r.bodyBuf, r.Body, r.jsonEscapeHTML) 456 | } 457 | } else if xmlKey == encKey { 458 | if inferKind(r.Body) != reflect.Struct { 459 | releaseBuffer(r.bodyBuf) 460 | r.bodyBuf = nil 461 | return ErrUnsupportedRequestBodyKind 462 | } 463 | } 464 | 465 | // user registered encoders with resty fallback key 466 | encFunc, found := c.inferContentTypeEncoder(contentType, encKey) 467 | if !found { 468 | releaseBuffer(r.bodyBuf) 469 | r.bodyBuf = nil 470 | return fmt.Errorf("resty: content-type encoder not found for %s", contentType) 471 | } 472 | if err := encFunc(r.bodyBuf, r.Body); err != nil { 473 | releaseBuffer(r.bodyBuf) 474 | r.bodyBuf = nil 475 | return err 476 | } 477 | } 478 | 479 | return nil 480 | } 481 | 482 | //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ 483 | // Response Middleware(s) 484 | //_______________________________________________________________________ 485 | 486 | // AutoParseResponseMiddleware method is used to parse the response body automatically 487 | // based on registered HTTP response `Content-Type` decoder, see [Client.AddContentTypeDecoder]; 488 | // if [Request.SetResult], [Request.SetError], or [Client.SetError] is used 489 | func AutoParseResponseMiddleware(c *Client, res *Response) (err error) { 490 | if res.Err != nil || res.Request.DoNotParseResponse { 491 | return // move on 492 | } 493 | 494 | if res.StatusCode() == http.StatusNoContent { 495 | res.Request.Error = nil 496 | return 497 | } 498 | 499 | rct := firstNonEmpty( 500 | res.Request.ForceResponseContentType, 501 | res.Header().Get(hdrContentTypeKey), 502 | res.Request.ExpectResponseContentType, 503 | ) 504 | decKey := inferContentTypeMapKey(rct) 505 | decFunc, found := c.inferContentTypeDecoder(rct, decKey) 506 | if !found { 507 | // the Content-Type decoder is not found; just read all the body bytes 508 | err = res.readAll() 509 | return 510 | } 511 | 512 | // HTTP status code > 199 and < 300, considered as Result 513 | if res.IsSuccess() && res.Request.Result != nil { 514 | res.Request.Error = nil 515 | defer closeq(res.Body) 516 | err = decFunc(res.Body, res.Request.Result) 517 | res.IsRead = true 518 | return 519 | } 520 | 521 | // HTTP status code > 399, considered as Error 522 | if res.IsError() { 523 | // global error type registered at client-instance 524 | if res.Request.Error == nil { 525 | res.Request.Error = c.newErrorInterface() 526 | } 527 | 528 | if res.Request.Error != nil { 529 | defer closeq(res.Body) 530 | err = decFunc(res.Body, res.Request.Error) 531 | res.IsRead = true 532 | return 533 | } 534 | } 535 | 536 | return 537 | } 538 | 539 | var hostnameReplacer = strings.NewReplacer(":", "_", ".", "_") 540 | 541 | // SaveToFileResponseMiddleware method used to write HTTP response body into 542 | // file. The filename is determined in the following order - 543 | // - [Request.SetOutputFileName] 544 | // - Content-Disposition header 545 | // - Request URL using [path.Base] 546 | func SaveToFileResponseMiddleware(c *Client, res *Response) error { 547 | if res.Err != nil || !res.Request.IsSaveResponse { 548 | return nil 549 | } 550 | 551 | file := res.Request.OutputFileName 552 | if isStringEmpty(file) { 553 | cntDispositionValue := res.Header().Get(hdrContentDisposition) 554 | if len(cntDispositionValue) > 0 { 555 | if _, params, err := mime.ParseMediaType(cntDispositionValue); err == nil { 556 | file = params["filename"] 557 | } 558 | } 559 | if isStringEmpty(file) { 560 | rURL, _ := url.Parse(res.Request.URL) 561 | if isStringEmpty(rURL.Path) || rURL.Path == "/" { 562 | file = hostnameReplacer.Replace(rURL.Host) 563 | } else { 564 | file = path.Base(rURL.Path) 565 | } 566 | } 567 | } 568 | 569 | if len(c.OutputDirectory()) > 0 && !filepath.IsAbs(file) { 570 | file = filepath.Join(c.OutputDirectory(), string(filepath.Separator), file) 571 | } 572 | 573 | file = filepath.Clean(file) 574 | if err := createDirectory(filepath.Dir(file)); err != nil { 575 | return err 576 | } 577 | 578 | outFile, err := createFile(file) 579 | if err != nil { 580 | return err 581 | } 582 | 583 | defer func() { 584 | closeq(outFile) 585 | closeq(res.Body) 586 | }() 587 | 588 | // io.Copy reads maximum 32kb size, it is perfect for large file download too 589 | res.size, err = ioCopy(outFile, res.Body) 590 | 591 | return err 592 | } 593 | -------------------------------------------------------------------------------- /multipart.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "fmt" 10 | "io" 11 | "net/textproto" 12 | "os" 13 | "path/filepath" 14 | "strings" 15 | ) 16 | 17 | var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") 18 | 19 | func escapeQuotes(s string) string { 20 | return quoteEscaper.Replace(s) 21 | } 22 | 23 | // MultipartField struct represents the multipart field to compose 24 | // all [io.Reader] capable input for multipart form request 25 | type MultipartField struct { 26 | // Name of the multipart field name that the server expects it 27 | Name string 28 | 29 | // FileName is used to set the file name we have to send to the server 30 | FileName string 31 | 32 | // ContentType is a multipart file content-type value. It is highly 33 | // recommended setting it if you know the content-type so that Resty 34 | // don't have to do additional computing to auto-detect (Optional) 35 | ContentType string 36 | 37 | // Reader is an input of [io.Reader] for multipart upload. It 38 | // is optional if you set the FilePath value 39 | Reader io.Reader 40 | 41 | // FilePath is a file path for multipart upload. It 42 | // is optional if you set the Reader value 43 | FilePath string 44 | 45 | // FileSize in bytes is used just for the information purpose of 46 | // sharing via [MultipartFieldCallbackFunc] (Optional) 47 | FileSize int64 48 | 49 | // ProgressCallback function is used to provide live progress details 50 | // during a multipart upload (Optional) 51 | // 52 | // NOTE: It is recommended to set the FileSize value when using `MultipartField.Reader` 53 | // with `ProgressCallback` feature so that Resty sends the FileSize 54 | // value via [MultipartFieldProgress] 55 | ProgressCallback MultipartFieldCallbackFunc 56 | 57 | // Values field is used to provide form field value. (Optional, unless it's a form-data field) 58 | // 59 | // It is primarily added for ordered multipart form-data field use cases 60 | Values []string 61 | } 62 | 63 | // Clone method returns the deep copy of m except [io.Reader]. 64 | func (mf *MultipartField) Clone() *MultipartField { 65 | mf2 := new(MultipartField) 66 | *mf2 = *mf 67 | return mf2 68 | } 69 | 70 | func (mf *MultipartField) resetReader() error { 71 | if rs, ok := mf.Reader.(io.ReadSeeker); ok { 72 | _, err := rs.Seek(0, io.SeekStart) 73 | return err 74 | } 75 | return nil 76 | } 77 | 78 | func (mf *MultipartField) close() { 79 | closeq(mf.Reader) 80 | } 81 | 82 | func (mf *MultipartField) createHeader() textproto.MIMEHeader { 83 | h := make(textproto.MIMEHeader) 84 | if isStringEmpty(mf.FileName) { 85 | h.Set(hdrContentDisposition, 86 | fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(mf.Name))) 87 | } else { 88 | h.Set(hdrContentDisposition, 89 | fmt.Sprintf(`form-data; name="%s"; filename="%s"`, 90 | escapeQuotes(mf.Name), escapeQuotes(mf.FileName))) 91 | } 92 | if !isStringEmpty(mf.ContentType) { 93 | h.Set(hdrContentTypeKey, mf.ContentType) 94 | } 95 | return h 96 | } 97 | 98 | func (mf *MultipartField) openFileIfRequired() error { 99 | if isStringEmpty(mf.FilePath) || mf.Reader != nil { 100 | return nil 101 | } 102 | 103 | file, err := os.Open(mf.FilePath) 104 | if err != nil { 105 | return err 106 | } 107 | 108 | if isStringEmpty(mf.FileName) { 109 | mf.FileName = filepath.Base(mf.FilePath) 110 | } 111 | 112 | // if file open is success, stat will succeed 113 | fileStat, _ := file.Stat() 114 | 115 | mf.Reader = file 116 | mf.FileSize = fileStat.Size() 117 | 118 | return nil 119 | } 120 | 121 | func (mf *MultipartField) wrapProgressCallbackIfPresent(pw io.Writer) io.Writer { 122 | if mf.ProgressCallback == nil { 123 | return pw 124 | } 125 | 126 | return &multipartProgressWriter{ 127 | w: pw, 128 | f: func(pb int64) { 129 | mf.ProgressCallback(MultipartFieldProgress{ 130 | Name: mf.Name, 131 | FileName: mf.FileName, 132 | FileSize: mf.FileSize, 133 | Written: pb, 134 | }) 135 | }, 136 | } 137 | } 138 | 139 | // MultipartFieldCallbackFunc function used to transmit live multipart upload 140 | // progress in bytes count 141 | type MultipartFieldCallbackFunc func(MultipartFieldProgress) 142 | 143 | // MultipartFieldProgress struct used to provide multipart field upload progress 144 | // details via callback function 145 | type MultipartFieldProgress struct { 146 | Name string 147 | FileName string 148 | FileSize int64 149 | Written int64 150 | } 151 | 152 | // String method creates the string representation of [MultipartFieldProgress] 153 | func (mfp MultipartFieldProgress) String() string { 154 | return fmt.Sprintf("FieldName: %s, FileName: %s, FileSize: %v, Written: %v", 155 | mfp.Name, mfp.FileName, mfp.FileSize, mfp.Written) 156 | } 157 | 158 | type multipartProgressWriter struct { 159 | w io.Writer 160 | pb int64 161 | f func(int64) 162 | } 163 | 164 | func (mpw *multipartProgressWriter) Write(p []byte) (n int, err error) { 165 | n, err = mpw.w.Write(p) 166 | if n <= 0 { 167 | return 168 | } 169 | mpw.pb += int64(n) 170 | mpw.f(mpw.pb) 171 | return 172 | } 173 | -------------------------------------------------------------------------------- /redirect.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "net" 12 | "net/http" 13 | "strings" 14 | ) 15 | 16 | type ( 17 | // RedirectPolicy to regulate the redirects in the Resty client. 18 | // Objects implementing the [RedirectPolicy] interface can be registered as 19 | // 20 | // Apply function should return nil to continue the redirect journey; otherwise 21 | // return error to stop the redirect. 22 | RedirectPolicy interface { 23 | Apply(*http.Request, []*http.Request) error 24 | } 25 | 26 | // The [RedirectPolicyFunc] type is an adapter to allow the use of ordinary 27 | // functions as [RedirectPolicy]. If `f` is a function with the appropriate 28 | // signature, RedirectPolicyFunc(f) is a RedirectPolicy object that calls `f`. 29 | RedirectPolicyFunc func(*http.Request, []*http.Request) error 30 | 31 | // RedirectInfo struct is used to capture the URL and status code for the redirect history 32 | RedirectInfo struct { 33 | URL string 34 | StatusCode int 35 | } 36 | ) 37 | 38 | // Apply calls f(req, via). 39 | func (f RedirectPolicyFunc) Apply(req *http.Request, via []*http.Request) error { 40 | return f(req, via) 41 | } 42 | 43 | // NoRedirectPolicy is used to disable the redirects in the Resty client 44 | // 45 | // resty.SetRedirectPolicy(resty.NoRedirectPolicy()) 46 | func NoRedirectPolicy() RedirectPolicy { 47 | return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { 48 | return http.ErrUseLastResponse 49 | }) 50 | } 51 | 52 | // FlexibleRedirectPolicy method is convenient for creating several redirect policies for Resty clients. 53 | // 54 | // resty.SetRedirectPolicy(FlexibleRedirectPolicy(20)) 55 | func FlexibleRedirectPolicy(noOfRedirect int) RedirectPolicy { 56 | return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { 57 | if len(via) >= noOfRedirect { 58 | return fmt.Errorf("resty: stopped after %d redirects", noOfRedirect) 59 | } 60 | checkHostAndAddHeaders(req, via[0]) 61 | return nil 62 | }) 63 | } 64 | 65 | // DomainCheckRedirectPolicy method is convenient for defining domain name redirect rules in Resty clients. 66 | // Redirect is allowed only for the host mentioned in the policy. 67 | // 68 | // resty.SetRedirectPolicy(resty.DomainCheckRedirectPolicy("host1.com", "host2.org", "host3.net")) 69 | func DomainCheckRedirectPolicy(hostnames ...string) RedirectPolicy { 70 | hosts := make(map[string]bool) 71 | for _, h := range hostnames { 72 | hosts[strings.ToLower(h)] = true 73 | } 74 | 75 | return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { 76 | if ok := hosts[getHostname(req.URL.Host)]; !ok { 77 | return errors.New("redirect is not allowed as per DomainCheckRedirectPolicy") 78 | } 79 | checkHostAndAddHeaders(req, via[0]) 80 | return nil 81 | }) 82 | } 83 | 84 | func getHostname(host string) (hostname string) { 85 | if strings.Index(host, ":") > 0 { 86 | host, _, _ = net.SplitHostPort(host) 87 | } 88 | hostname = strings.ToLower(host) 89 | return 90 | } 91 | 92 | // By default, Golang will not redirect request headers. 93 | // After reading through the various discussion comments from the thread - 94 | // https://github.com/golang/go/issues/4800 95 | // Resty will add all the headers during a redirect for the same host and 96 | // adds library user-agent if the Host is different. 97 | func checkHostAndAddHeaders(cur *http.Request, pre *http.Request) { 98 | curHostname := getHostname(cur.URL.Host) 99 | preHostname := getHostname(pre.URL.Host) 100 | if strings.EqualFold(curHostname, preHostname) { 101 | for key, val := range pre.Header { 102 | cur.Header[key] = val 103 | } 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /response.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "bytes" 10 | "encoding/json" 11 | "fmt" 12 | "io" 13 | "net/http" 14 | "strings" 15 | "time" 16 | ) 17 | 18 | //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ 19 | // Response struct and methods 20 | //_______________________________________________________________________ 21 | 22 | // Response struct holds response values of executed requests. 23 | type Response struct { 24 | Request *Request 25 | Body io.ReadCloser 26 | RawResponse *http.Response 27 | IsRead bool 28 | 29 | // Err field used to cascade the response middleware error 30 | // in the chain 31 | Err error 32 | 33 | bodyBytes []byte 34 | size int64 35 | receivedAt time.Time 36 | } 37 | 38 | // Status method returns the HTTP status string for the executed request. 39 | // 40 | // Example: 200 OK 41 | func (r *Response) Status() string { 42 | if r.RawResponse == nil { 43 | return "" 44 | } 45 | return r.RawResponse.Status 46 | } 47 | 48 | // StatusCode method returns the HTTP status code for the executed request. 49 | // 50 | // Example: 200 51 | func (r *Response) StatusCode() int { 52 | if r.RawResponse == nil { 53 | return 0 54 | } 55 | return r.RawResponse.StatusCode 56 | } 57 | 58 | // Proto method returns the HTTP response protocol used for the request. 59 | func (r *Response) Proto() string { 60 | if r.RawResponse == nil { 61 | return "" 62 | } 63 | return r.RawResponse.Proto 64 | } 65 | 66 | // Result method returns the response value as an object if it has one 67 | // 68 | // See [Request.SetResult] 69 | func (r *Response) Result() any { 70 | return r.Request.Result 71 | } 72 | 73 | // Error method returns the error object if it has one 74 | // 75 | // See [Request.SetError], [Client.SetError] 76 | func (r *Response) Error() any { 77 | return r.Request.Error 78 | } 79 | 80 | // Header method returns the response headers 81 | func (r *Response) Header() http.Header { 82 | if r.RawResponse == nil { 83 | return http.Header{} 84 | } 85 | return r.RawResponse.Header 86 | } 87 | 88 | // Cookies method to returns all the response cookies 89 | func (r *Response) Cookies() []*http.Cookie { 90 | if r.RawResponse == nil { 91 | return make([]*http.Cookie, 0) 92 | } 93 | return r.RawResponse.Cookies() 94 | } 95 | 96 | // String method returns the body of the HTTP response as a `string`. 97 | // It returns an empty string if it is nil or the body is zero length. 98 | // 99 | // NOTE: 100 | // - Returns an empty string on auto-unmarshal scenarios, unless 101 | // [Client.SetResponseBodyUnlimitedReads] or [Request.SetResponseBodyUnlimitedReads] set. 102 | // - Returns an empty string when [Client.SetDoNotParseResponse] or [Request.SetDoNotParseResponse] set. 103 | func (r *Response) String() string { 104 | r.readIfRequired() 105 | return strings.TrimSpace(string(r.bodyBytes)) 106 | } 107 | 108 | // Bytes method returns the body of the HTTP response as a byte slice. 109 | // It returns an empty byte slice if it is nil or the body is zero length. 110 | // 111 | // NOTE: 112 | // - Returns an empty byte slice on auto-unmarshal scenarios, unless 113 | // [Client.SetResponseBodyUnlimitedReads] or [Request.SetResponseBodyUnlimitedReads] set. 114 | // - Returns an empty byte slice when [Client.SetDoNotParseResponse] or [Request.SetDoNotParseResponse] set. 115 | func (r *Response) Bytes() []byte { 116 | r.readIfRequired() 117 | return r.bodyBytes 118 | } 119 | 120 | // Duration method returns the duration of HTTP response time from the request we sent 121 | // and received a request. 122 | // 123 | // See [Response.ReceivedAt] to know when the client received a response and see 124 | // `Response.Request.Time` to know when the client sent a request. 125 | func (r *Response) Duration() time.Duration { 126 | if r.Request.trace != nil { 127 | return r.Request.TraceInfo().TotalTime 128 | } 129 | return r.receivedAt.Sub(r.Request.Time) 130 | } 131 | 132 | // ReceivedAt method returns the time we received a response from the server for the request. 133 | func (r *Response) ReceivedAt() time.Time { 134 | return r.receivedAt 135 | } 136 | 137 | // Size method returns the HTTP response size in bytes. Yeah, you can rely on HTTP `Content-Length` 138 | // header, however it won't be available for chucked transfer/compressed response. 139 | // Since Resty captures response size details when processing the response body 140 | // when possible. So that users get the actual size of response bytes. 141 | func (r *Response) Size() int64 { 142 | r.readIfRequired() 143 | return r.size 144 | } 145 | 146 | // IsSuccess method returns true if HTTP status `code >= 200 and <= 299` otherwise false. 147 | func (r *Response) IsSuccess() bool { 148 | return r.StatusCode() > 199 && r.StatusCode() < 300 149 | } 150 | 151 | // IsError method returns true if HTTP status `code >= 400` otherwise false. 152 | func (r *Response) IsError() bool { 153 | return r.StatusCode() > 399 154 | } 155 | 156 | // RedirectHistory method returns a redirect history slice with the URL and status code 157 | func (r *Response) RedirectHistory() []*RedirectInfo { 158 | if r.RawResponse == nil { 159 | return nil 160 | } 161 | 162 | redirects := make([]*RedirectInfo, 0) 163 | res := r.RawResponse 164 | for res != nil { 165 | req := res.Request 166 | redirects = append(redirects, &RedirectInfo{ 167 | StatusCode: res.StatusCode, 168 | URL: req.URL.String(), 169 | }) 170 | res = req.Response 171 | } 172 | 173 | return redirects 174 | } 175 | 176 | func (r *Response) setReceivedAt() { 177 | r.receivedAt = time.Now() 178 | if r.Request.trace != nil { 179 | r.Request.trace.endTime = r.receivedAt 180 | } 181 | } 182 | 183 | func (r *Response) fmtBodyString(sl int) string { 184 | if r.Request.DoNotParseResponse { 185 | return "***** DO NOT PARSE RESPONSE - Enabled *****" 186 | } 187 | 188 | if r.Request.IsSaveResponse { 189 | return "***** RESPONSE WRITTEN INTO FILE *****" 190 | } 191 | 192 | bl := len(r.bodyBytes) 193 | if r.IsRead && bl == 0 { 194 | return "***** RESPONSE BODY IS ALREADY READ - see Response.{Result()/Error()} *****" 195 | } 196 | 197 | if bl > 0 { 198 | if bl > sl { 199 | return fmt.Sprintf("***** RESPONSE TOO LARGE (size - %d) *****", bl) 200 | } 201 | 202 | ct := r.Header().Get(hdrContentTypeKey) 203 | ctKey := inferContentTypeMapKey(ct) 204 | if jsonKey == ctKey { 205 | out := acquireBuffer() 206 | defer releaseBuffer(out) 207 | err := json.Indent(out, r.bodyBytes, "", " ") 208 | if err != nil { 209 | r.Request.log.Errorf("DebugLog: Response.fmtBodyString: %v", err) 210 | return "" 211 | } 212 | return out.String() 213 | } 214 | return r.String() 215 | } 216 | 217 | return "***** NO CONTENT *****" 218 | } 219 | 220 | func (r *Response) readIfRequired() { 221 | if len(r.bodyBytes) == 0 && !r.Request.DoNotParseResponse { 222 | _ = r.readAll() 223 | } 224 | } 225 | 226 | var ioReadAll = io.ReadAll 227 | 228 | // auto-unmarshal didn't happen, so fallback to 229 | // old behavior of reading response as body bytes 230 | func (r *Response) readAll() (err error) { 231 | if r.Body == nil || r.IsRead { 232 | return nil 233 | } 234 | 235 | if _, ok := r.Body.(*copyReadCloser); ok { 236 | _, err = ioReadAll(r.Body) 237 | } else { 238 | r.bodyBytes, err = ioReadAll(r.Body) 239 | closeq(r.Body) 240 | r.Body = &nopReadCloser{r: bytes.NewReader(r.bodyBytes)} 241 | } 242 | if err == io.ErrUnexpectedEOF { 243 | // content-encoding scenario's - empty/no response body from server 244 | err = nil 245 | } 246 | 247 | r.IsRead = true 248 | return 249 | } 250 | 251 | func (r *Response) wrapLimitReadCloser() { 252 | r.Body = &limitReadCloser{ 253 | r: r.Body, 254 | l: r.Request.ResponseBodyLimit, 255 | f: func(s int64) { 256 | r.size = s 257 | }, 258 | } 259 | } 260 | 261 | func (r *Response) wrapCopyReadCloser() { 262 | r.Body = ©ReadCloser{ 263 | s: r.Body, 264 | t: acquireBuffer(), 265 | f: func(b *bytes.Buffer) { 266 | r.bodyBytes = append([]byte{}, b.Bytes()...) 267 | closeq(r.Body) 268 | r.Body = &nopReadCloser{r: bytes.NewReader(r.bodyBytes)} 269 | releaseBuffer(b) 270 | }, 271 | } 272 | } 273 | 274 | func (r *Response) wrapContentDecompresser() error { 275 | ce := r.Header().Get(hdrContentEncodingKey) 276 | if isStringEmpty(ce) { 277 | return nil 278 | } 279 | 280 | if decFunc, f := r.Request.client.ContentDecompressers()[ce]; f { 281 | dec, err := decFunc(r.Body) 282 | if err != nil { 283 | if err == io.EOF { 284 | // empty/no response body from server 285 | err = nil 286 | } 287 | return err 288 | } 289 | 290 | r.Body = dec 291 | r.Header().Del(hdrContentEncodingKey) 292 | r.Header().Del(hdrContentLengthKey) 293 | r.RawResponse.ContentLength = -1 294 | } else { 295 | return ErrContentDecompresserNotFound 296 | } 297 | 298 | return nil 299 | } 300 | -------------------------------------------------------------------------------- /resty.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | // Package resty provides Simple HTTP, REST, and SSE client library for Go. 7 | package resty // import "resty.dev/v3" 8 | 9 | import ( 10 | "math" 11 | "net" 12 | "net/http" 13 | "net/http/cookiejar" 14 | "net/url" 15 | "runtime" 16 | "sync" 17 | "time" 18 | 19 | "golang.org/x/net/publicsuffix" 20 | ) 21 | 22 | // Version # of resty 23 | const Version = "3.0.0-beta.1" 24 | 25 | // New method creates a new Resty client. 26 | func New() *Client { 27 | return NewWithTransportSettings(nil) 28 | } 29 | 30 | // NewWithTransportSettings method creates a new Resty client with provided 31 | // timeout values. 32 | func NewWithTransportSettings(transportSettings *TransportSettings) *Client { 33 | return NewWithDialerAndTransportSettings(nil, transportSettings) 34 | } 35 | 36 | // NewWithClient method creates a new Resty client with given [http.Client]. 37 | func NewWithClient(hc *http.Client) *Client { 38 | return createClient(hc) 39 | } 40 | 41 | // NewWithDialer method creates a new Resty client with given Local Address 42 | // to dial from. 43 | func NewWithDialer(dialer *net.Dialer) *Client { 44 | return NewWithDialerAndTransportSettings(dialer, nil) 45 | } 46 | 47 | // NewWithLocalAddr method creates a new Resty client with the given Local Address. 48 | func NewWithLocalAddr(localAddr net.Addr) *Client { 49 | return NewWithDialerAndTransportSettings( 50 | &net.Dialer{LocalAddr: localAddr}, 51 | nil, 52 | ) 53 | } 54 | 55 | // NewWithDialerAndTransportSettings method creates a new Resty client with given Local Address 56 | // to dial from. 57 | func NewWithDialerAndTransportSettings(dialer *net.Dialer, transportSettings *TransportSettings) *Client { 58 | return createClient(&http.Client{ 59 | Jar: createCookieJar(), 60 | Transport: createTransport(dialer, transportSettings), 61 | }) 62 | } 63 | 64 | //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ 65 | // Unexported methods 66 | //_______________________________________________________________________ 67 | 68 | func createTransport(dialer *net.Dialer, transportSettings *TransportSettings) *http.Transport { 69 | if transportSettings == nil { 70 | transportSettings = &TransportSettings{} 71 | } 72 | 73 | // Dialer 74 | 75 | if dialer == nil { 76 | dialer = &net.Dialer{} 77 | } 78 | 79 | if transportSettings.DialerTimeout > 0 { 80 | dialer.Timeout = transportSettings.DialerTimeout 81 | } else { 82 | dialer.Timeout = 30 * time.Second 83 | } 84 | 85 | if transportSettings.DialerKeepAlive > 0 { 86 | dialer.KeepAlive = transportSettings.DialerKeepAlive 87 | } else { 88 | dialer.KeepAlive = 30 * time.Second 89 | } 90 | 91 | // Transport 92 | t := &http.Transport{ 93 | Proxy: http.ProxyFromEnvironment, 94 | DialContext: transportDialContext(dialer), 95 | DisableKeepAlives: transportSettings.DisableKeepAlives, 96 | DisableCompression: true, // Resty handles it, see [Client.AddContentDecoder] 97 | ForceAttemptHTTP2: true, 98 | } 99 | 100 | if transportSettings.IdleConnTimeout > 0 { 101 | t.IdleConnTimeout = transportSettings.IdleConnTimeout 102 | } else { 103 | t.IdleConnTimeout = 90 * time.Second 104 | } 105 | 106 | if transportSettings.TLSHandshakeTimeout > 0 { 107 | t.TLSHandshakeTimeout = transportSettings.TLSHandshakeTimeout 108 | } else { 109 | t.TLSHandshakeTimeout = 10 * time.Second 110 | } 111 | 112 | if transportSettings.ExpectContinueTimeout > 0 { 113 | t.ExpectContinueTimeout = transportSettings.ExpectContinueTimeout 114 | } else { 115 | t.ExpectContinueTimeout = 1 * time.Second 116 | } 117 | 118 | if transportSettings.MaxIdleConns > 0 { 119 | t.MaxIdleConns = transportSettings.MaxIdleConns 120 | } else { 121 | t.MaxIdleConns = 100 122 | } 123 | 124 | if transportSettings.MaxIdleConnsPerHost > 0 { 125 | t.MaxIdleConnsPerHost = transportSettings.MaxIdleConnsPerHost 126 | } else { 127 | t.MaxIdleConnsPerHost = runtime.GOMAXPROCS(0) + 1 128 | } 129 | 130 | // 131 | // No default value in Resty for following settings, added to 132 | // provide ability to set value otherwise the Go HTTP client 133 | // default value applies. 134 | // 135 | 136 | if transportSettings.ResponseHeaderTimeout > 0 { 137 | t.ResponseHeaderTimeout = transportSettings.ResponseHeaderTimeout 138 | } 139 | 140 | if transportSettings.MaxResponseHeaderBytes > 0 { 141 | t.MaxResponseHeaderBytes = transportSettings.MaxResponseHeaderBytes 142 | } 143 | 144 | if transportSettings.WriteBufferSize > 0 { 145 | t.WriteBufferSize = transportSettings.WriteBufferSize 146 | } 147 | 148 | if transportSettings.ReadBufferSize > 0 { 149 | t.ReadBufferSize = transportSettings.ReadBufferSize 150 | } 151 | 152 | return t 153 | } 154 | 155 | func createCookieJar() *cookiejar.Jar { 156 | cookieJar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) 157 | return cookieJar 158 | } 159 | 160 | func createClient(hc *http.Client) *Client { 161 | c := &Client{ // not setting language default values 162 | lock: &sync.RWMutex{}, 163 | queryParams: url.Values{}, 164 | formData: url.Values{}, 165 | header: http.Header{}, 166 | authScheme: defaultAuthScheme, 167 | cookies: make([]*http.Cookie, 0), 168 | retryWaitTime: defaultWaitTime, 169 | retryMaxWaitTime: defaultMaxWaitTime, 170 | isRetryDefaultConditions: true, 171 | pathParams: make(map[string]string), 172 | headerAuthorizationKey: hdrAuthorizationKey, 173 | jsonEscapeHTML: true, 174 | httpClient: hc, 175 | debugBodyLimit: math.MaxInt32, 176 | contentTypeEncoders: make(map[string]ContentTypeEncoder), 177 | contentTypeDecoders: make(map[string]ContentTypeDecoder), 178 | contentDecompresserKeys: make([]string, 0), 179 | contentDecompressers: make(map[string]ContentDecompresser), 180 | certWatcherStopChan: make(chan bool), 181 | } 182 | 183 | // Logger 184 | c.SetLogger(createLogger()) 185 | c.SetDebugLogFormatter(DebugLogFormatter) 186 | 187 | c.AddContentTypeEncoder(jsonKey, encodeJSON) 188 | c.AddContentTypeEncoder(xmlKey, encodeXML) 189 | 190 | c.AddContentTypeDecoder(jsonKey, decodeJSON) 191 | c.AddContentTypeDecoder(xmlKey, decodeXML) 192 | 193 | // Order matter, giving priority to gzip 194 | c.AddContentDecompresser("deflate", decompressDeflate) 195 | c.AddContentDecompresser("gzip", decompressGzip) 196 | 197 | // request middlewares 198 | c.SetRequestMiddlewares( 199 | PrepareRequestMiddleware, 200 | ) 201 | 202 | // response middlewares 203 | c.SetResponseMiddlewares( 204 | AutoParseResponseMiddleware, 205 | SaveToFileResponseMiddleware, 206 | ) 207 | 208 | return c 209 | } 210 | -------------------------------------------------------------------------------- /retry.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "crypto/tls" 10 | "math" 11 | "math/rand" 12 | "net/http" 13 | "net/url" 14 | "regexp" 15 | "strconv" 16 | "sync" 17 | "time" 18 | ) 19 | 20 | const ( 21 | defaultWaitTime = time.Duration(100) * time.Millisecond 22 | defaultMaxWaitTime = time.Duration(2000) * time.Millisecond 23 | ) 24 | 25 | type ( 26 | // RetryConditionFunc type is for the retry condition function 27 | // input: non-nil Response OR request execution error 28 | RetryConditionFunc func(*Response, error) bool 29 | 30 | // RetryHookFunc is for side-effecting functions triggered on retry 31 | RetryHookFunc func(*Response, error) 32 | 33 | // RetryStrategyFunc type is for custom retry strategy implementation 34 | // By default Resty uses the capped exponential backoff with a jitter strategy 35 | RetryStrategyFunc func(*Response, error) (time.Duration, error) 36 | ) 37 | 38 | var ( 39 | regexErrTooManyRedirects = regexp.MustCompile(`stopped after \d+ redirects\z`) 40 | regexErrScheme = regexp.MustCompile("unsupported protocol scheme") 41 | regexErrInvalidHeader = regexp.MustCompile("invalid header") 42 | ) 43 | 44 | func applyRetryDefaultConditions(res *Response, err error) bool { 45 | // no retry on TLS error 46 | if _, ok := err.(*tls.CertificateVerificationError); ok { 47 | return false 48 | } 49 | 50 | // validate url error, so we can decide to retry or not 51 | if u, ok := err.(*url.Error); ok { 52 | if regexErrTooManyRedirects.MatchString(u.Error()) { 53 | return false 54 | } 55 | if regexErrScheme.MatchString(u.Error()) { 56 | return false 57 | } 58 | if regexErrInvalidHeader.MatchString(u.Error()) { 59 | return false 60 | } 61 | return u.Temporary() // possible retry if it's true 62 | } 63 | 64 | if res == nil { 65 | return false 66 | } 67 | 68 | // certain HTTP status codes are temporary so that we can retry 69 | // - 429 Too Many Requests 70 | // - 500 or above (it's better to ignore 501 Not Implemented) 71 | // - 0 No status code received 72 | if res.StatusCode() == http.StatusTooManyRequests || 73 | (res.StatusCode() >= 500 && res.StatusCode() != http.StatusNotImplemented) || 74 | res.StatusCode() == 0 { 75 | return true 76 | } 77 | 78 | return false 79 | } 80 | 81 | func newBackoffWithJitter(min, max time.Duration) *backoffWithJitter { 82 | if min <= 0 { 83 | min = defaultWaitTime 84 | } 85 | if max == 0 { 86 | max = defaultMaxWaitTime 87 | } 88 | 89 | return &backoffWithJitter{ 90 | lock: new(sync.Mutex), 91 | rnd: rand.New(rand.NewSource(time.Now().UnixNano())), 92 | min: min, 93 | max: max, 94 | } 95 | } 96 | 97 | type backoffWithJitter struct { 98 | lock *sync.Mutex 99 | rnd *rand.Rand 100 | min time.Duration 101 | max time.Duration 102 | } 103 | 104 | func (b *backoffWithJitter) NextWaitDuration(c *Client, res *Response, err error, attempt int) (time.Duration, error) { 105 | if res != nil { 106 | if res.StatusCode() == http.StatusTooManyRequests || res.StatusCode() == http.StatusServiceUnavailable { 107 | if delay, ok := parseRetryAfterHeader(res.Header().Get(hdrRetryAfterKey)); ok { 108 | return delay, nil 109 | } 110 | } 111 | } 112 | 113 | const maxInt = 1<<31 - 1 // max int for arch 386 114 | if b.max < 0 { 115 | b.max = maxInt 116 | } 117 | 118 | var retryStrategyFunc RetryStrategyFunc 119 | if c != nil { 120 | retryStrategyFunc = c.RetryStrategy() 121 | } 122 | if res == nil || retryStrategyFunc == nil { 123 | return b.balanceMinMax(b.defaultStrategy(attempt)), nil 124 | } 125 | 126 | delay, rsErr := retryStrategyFunc(res, err) 127 | if rsErr != nil { 128 | return 0, rsErr 129 | } 130 | return b.balanceMinMax(delay), nil 131 | } 132 | 133 | // Return capped exponential backoff with jitter 134 | // https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ 135 | func (b *backoffWithJitter) defaultStrategy(attempt int) time.Duration { 136 | temp := math.Min(float64(b.max), float64(b.min)*math.Exp2(float64(attempt))) 137 | ri := time.Duration(temp / 2) 138 | if ri <= 0 { 139 | ri = time.Nanosecond 140 | } 141 | return b.randDuration(ri) 142 | } 143 | 144 | func (b *backoffWithJitter) randDuration(center time.Duration) time.Duration { 145 | b.lock.Lock() 146 | defer b.lock.Unlock() 147 | 148 | var ri = int64(center) 149 | var jitter = b.rnd.Int63n(ri) 150 | return time.Duration(math.Abs(float64(ri + jitter))) 151 | } 152 | 153 | func (b *backoffWithJitter) balanceMinMax(delay time.Duration) time.Duration { 154 | if delay <= 0 || b.max < delay { 155 | return b.max 156 | } 157 | if delay < b.min { 158 | return b.min 159 | } 160 | return delay 161 | } 162 | 163 | var timeNow = time.Now 164 | 165 | // parseRetryAfterHeader parses the Retry-After header and returns the 166 | // delay duration according to the spec: https://httpwg.org/specs/rfc7231.html#header.retry-after 167 | // The bool returned will be true if the header was successfully parsed. 168 | // Otherwise, the header was either not present, or was not parseable according to the spec. 169 | // 170 | // Retry-After headers come in two flavors: Seconds or HTTP-Date 171 | // 172 | // Examples: 173 | // - Retry-After: Fri, 31 Dec 1999 23:59:59 GMT 174 | // - Retry-After: 120 175 | func parseRetryAfterHeader(v string) (time.Duration, bool) { 176 | if isStringEmpty(v) { 177 | return 0, false 178 | } 179 | 180 | // Retry-After: 120 181 | if delay, err := strconv.ParseInt(v, 10, 64); err == nil { 182 | if delay < 0 { // a negative delay doesn't make sense 183 | return 0, false 184 | } 185 | return time.Second * time.Duration(delay), true 186 | } 187 | 188 | // Retry-After: Fri, 31 Dec 1999 23:59:59 GMT 189 | retryTime, err := time.Parse(time.RFC1123, v) 190 | if err != nil { 191 | return 0, false 192 | } 193 | if until := retryTime.Sub(timeNow()); until > 0 { 194 | return until, true 195 | } 196 | 197 | // date is in the past 198 | return 0, true 199 | } 200 | -------------------------------------------------------------------------------- /sse.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "bufio" 10 | "bytes" 11 | "errors" 12 | "fmt" 13 | "io" 14 | "net/http" 15 | "slices" 16 | "strconv" 17 | "strings" 18 | "sync" 19 | "time" 20 | ) 21 | 22 | // Spec: https://html.spec.whatwg.org/multipage/server-sent-events.html 23 | 24 | var ( 25 | defaultSseMaxBufSize = 1 << 15 // 32kb 26 | defaultEventName = "message" 27 | defaultHTTPMethod = MethodGet 28 | 29 | headerID = []byte("id:") 30 | headerData = []byte("data:") 31 | headerEvent = []byte("event:") 32 | headerRetry = []byte("retry:") 33 | 34 | hdrCacheControlKey = http.CanonicalHeaderKey("Cache-Control") 35 | hdrConnectionKey = http.CanonicalHeaderKey("Connection") 36 | hdrLastEvevntID = http.CanonicalHeaderKey("Last-Event-ID") 37 | ) 38 | 39 | type ( 40 | // EventOpenFunc is a callback function type used to receive notification 41 | // when Resty establishes a connection with the server for the 42 | // Server-Sent Events(SSE) 43 | EventOpenFunc func(url string) 44 | 45 | // EventMessageFunc is a callback function type used to receive event details 46 | // from the Server-Sent Events(SSE) stream 47 | EventMessageFunc func(any) 48 | 49 | // EventErrorFunc is a callback function type used to receive notification 50 | // when an error occurs with [EventSource] processing 51 | EventErrorFunc func(error) 52 | 53 | // Event struct represents the event details from the Server-Sent Events(SSE) stream 54 | Event struct { 55 | ID string 56 | Name string 57 | Data string 58 | } 59 | 60 | // EventSource struct implements the Server-Sent Events(SSE) [specification] to receive 61 | // stream from the server 62 | // 63 | // [specification]: https://html.spec.whatwg.org/multipage/server-sent-events.html 64 | EventSource struct { 65 | lock *sync.RWMutex 66 | url string 67 | method string 68 | header http.Header 69 | body io.Reader 70 | lastEventID string 71 | retryCount int 72 | retryWaitTime time.Duration 73 | retryMaxWaitTime time.Duration 74 | serverSentRetry time.Duration 75 | maxBufSize int 76 | onOpen EventOpenFunc 77 | onError EventErrorFunc 78 | onEvent map[string]*callback 79 | log Logger 80 | closed bool 81 | httpClient *http.Client 82 | } 83 | 84 | callback struct { 85 | Func EventMessageFunc 86 | Result any 87 | } 88 | ) 89 | 90 | // NewEventSource method creates a new instance of [EventSource] 91 | // with default values for Server-Sent Events(SSE) 92 | // 93 | // es := NewEventSource(). 94 | // SetURL("https://sse.dev/test"). 95 | // OnMessage( 96 | // func(e any) { 97 | // e = e.(*Event) 98 | // fmt.Println(e) 99 | // }, 100 | // nil, // see method godoc 101 | // ) 102 | // 103 | // err := es.Connect() 104 | // fmt.Println(err) 105 | // 106 | // See [EventSource.OnMessage], [EventSource.AddEventListener] 107 | func NewEventSource() *EventSource { 108 | es := &EventSource{ 109 | lock: new(sync.RWMutex), 110 | header: make(http.Header), 111 | retryCount: 3, 112 | retryWaitTime: defaultWaitTime, 113 | retryMaxWaitTime: defaultMaxWaitTime, 114 | maxBufSize: defaultSseMaxBufSize, 115 | onEvent: make(map[string]*callback), 116 | httpClient: &http.Client{ 117 | Jar: createCookieJar(), 118 | Transport: createTransport(nil, nil), 119 | }, 120 | } 121 | return es 122 | } 123 | 124 | // SetURL method sets a [EventSource] connection URL in the instance 125 | // 126 | // es.SetURL("https://sse.dev/test") 127 | func (es *EventSource) SetURL(url string) *EventSource { 128 | es.url = url 129 | return es 130 | } 131 | 132 | // SetMethod method sets a [EventSource] connection HTTP method in the instance 133 | // 134 | // es.SetMethod("POST"), or es.SetMethod(resty.MethodPost) 135 | func (es *EventSource) SetMethod(method string) *EventSource { 136 | es.method = method 137 | return es 138 | } 139 | 140 | // SetHeader method sets a header and its value to the [EventSource] instance. 141 | // It overwrites the header value if the key already exists. These headers will be sent in 142 | // the request while establishing a connection to the event source 143 | // 144 | // es.SetHeader("Authorization", "token here"). 145 | // SetHeader("X-Header", "value") 146 | func (es *EventSource) SetHeader(header, value string) *EventSource { 147 | es.lock.Lock() 148 | defer es.lock.Unlock() 149 | es.header.Set(header, value) 150 | return es 151 | } 152 | 153 | // SetBody method sets body value to the [EventSource] instance 154 | // 155 | // Example: 156 | // es.SetBody(bytes.NewReader([]byte(`{"test":"put_data"}`))) 157 | func (es *EventSource) SetBody(body io.Reader) *EventSource { 158 | es.body = body 159 | return es 160 | } 161 | 162 | // AddHeader method adds a header and its value to the [EventSource] instance. 163 | // If the header key already exists, it appends. These headers will be sent in 164 | // the request while establishing a connection to the event source 165 | // 166 | // es.AddHeader("Authorization", "token here"). 167 | // AddHeader("X-Header", "value") 168 | func (es *EventSource) AddHeader(header, value string) *EventSource { 169 | es.lock.Lock() 170 | defer es.lock.Unlock() 171 | es.header.Add(header, value) 172 | return es 173 | } 174 | 175 | // SetRetryCount method enables retry attempts on the SSE client while establishing 176 | // connection with the server 177 | // 178 | // first attempt + retry count = total attempts 179 | // 180 | // Default is 3 181 | // 182 | // es.SetRetryCount(10) 183 | func (es *EventSource) SetRetryCount(count int) *EventSource { 184 | es.lock.Lock() 185 | defer es.lock.Unlock() 186 | es.retryCount = count 187 | return es 188 | } 189 | 190 | // SetRetryWaitTime method sets the default wait time for sleep before retrying 191 | // the request 192 | // 193 | // Default is 100 milliseconds. 194 | // 195 | // NOTE: The server-sent retry value takes precedence if present. 196 | // 197 | // es.SetRetryWaitTime(1 * time.Second) 198 | func (es *EventSource) SetRetryWaitTime(waitTime time.Duration) *EventSource { 199 | es.lock.Lock() 200 | defer es.lock.Unlock() 201 | es.retryWaitTime = waitTime 202 | return es 203 | } 204 | 205 | // SetRetryMaxWaitTime method sets the max wait time for sleep before retrying 206 | // the request 207 | // 208 | // Default is 2 seconds. 209 | // 210 | // NOTE: The server-sent retry value takes precedence if present. 211 | // 212 | // es.SetRetryMaxWaitTime(3 * time.Second) 213 | func (es *EventSource) SetRetryMaxWaitTime(maxWaitTime time.Duration) *EventSource { 214 | es.lock.Lock() 215 | defer es.lock.Unlock() 216 | es.retryMaxWaitTime = maxWaitTime 217 | return es 218 | } 219 | 220 | // SetMaxBufSize method sets the given buffer size into the SSE client 221 | // 222 | // Default is 32kb 223 | // 224 | // es.SetMaxBufSize(64 * 1024) // 64kb 225 | func (es *EventSource) SetMaxBufSize(bufSize int) *EventSource { 226 | es.lock.Lock() 227 | defer es.lock.Unlock() 228 | es.maxBufSize = bufSize 229 | return es 230 | } 231 | 232 | // SetLogger method sets given writer for logging 233 | // 234 | // Compliant to interface [resty.Logger] 235 | func (es *EventSource) SetLogger(l Logger) *EventSource { 236 | es.lock.Lock() 237 | defer es.lock.Unlock() 238 | es.log = l 239 | return es 240 | } 241 | 242 | // just an internal helper method for test case 243 | func (es *EventSource) outputLogTo(w io.Writer) *EventSource { 244 | es.lock.Lock() 245 | defer es.lock.Unlock() 246 | es.log.(*logger).l.SetOutput(w) 247 | return es 248 | } 249 | 250 | // OnOpen registered callback gets triggered when the connection is 251 | // established with the server 252 | // 253 | // es.OnOpen(func(url string) { 254 | // fmt.Println("I'm connected:", url) 255 | // }) 256 | func (es *EventSource) OnOpen(ef EventOpenFunc) *EventSource { 257 | es.lock.Lock() 258 | defer es.lock.Unlock() 259 | if es.onOpen != nil { 260 | es.log.Warnf("Overwriting an existing OnOpen callback from=%s to=%s", 261 | functionName(es.onOpen), functionName(ef)) 262 | } 263 | es.onOpen = ef 264 | return es 265 | } 266 | 267 | // OnError registered callback gets triggered when the error occurred 268 | // in the process 269 | // 270 | // es.OnError(func(err error) { 271 | // fmt.Println("Error occurred:", err) 272 | // }) 273 | func (es *EventSource) OnError(ef EventErrorFunc) *EventSource { 274 | es.lock.Lock() 275 | defer es.lock.Unlock() 276 | if es.onError != nil { 277 | es.log.Warnf("Overwriting an existing OnError callback from=%s to=%s", 278 | functionName(es.OnError), functionName(ef)) 279 | } 280 | es.onError = ef 281 | return es 282 | } 283 | 284 | // OnMessage method registers a callback to emit every SSE event message 285 | // from the server. The second result argument is optional; it can be used 286 | // to register the data type for JSON data. 287 | // 288 | // es.OnMessage( 289 | // func(e any) { 290 | // e = e.(*Event) 291 | // fmt.Println("Event message", e) 292 | // }, 293 | // nil, 294 | // ) 295 | // 296 | // // Receiving JSON data from the server, you can set result type 297 | // // to do auto-unmarshal 298 | // es.OnMessage( 299 | // func(e any) { 300 | // e = e.(*MyData) 301 | // fmt.Println(e) 302 | // }, 303 | // MyData{}, 304 | // ) 305 | func (es *EventSource) OnMessage(ef EventMessageFunc, result any) *EventSource { 306 | return es.AddEventListener(defaultEventName, ef, result) 307 | } 308 | 309 | // AddEventListener method registers a callback to consume a specific event type 310 | // messages from the server. The second result argument is optional; it can be used 311 | // to register the data type for JSON data. 312 | // 313 | // es.AddEventListener( 314 | // "friend_logged_in", 315 | // func(e any) { 316 | // e = e.(*Event) 317 | // fmt.Println(e) 318 | // }, 319 | // nil, 320 | // ) 321 | // 322 | // // Receiving JSON data from the server, you can set result type 323 | // // to do auto-unmarshal 324 | // es.AddEventListener( 325 | // "friend_logged_in", 326 | // func(e any) { 327 | // e = e.(*UserLoggedIn) 328 | // fmt.Println(e) 329 | // }, 330 | // UserLoggedIn{}, 331 | // ) 332 | func (es *EventSource) AddEventListener(eventName string, ef EventMessageFunc, result any) *EventSource { 333 | es.lock.Lock() 334 | defer es.lock.Unlock() 335 | if e, found := es.onEvent[eventName]; found { 336 | es.log.Warnf("Overwriting an existing OnEvent callback from=%s to=%s", 337 | functionName(e), functionName(ef)) 338 | } 339 | cb := &callback{Func: ef, Result: nil} 340 | if result != nil { 341 | cb.Result = getPointer(result) 342 | } 343 | es.onEvent[eventName] = cb 344 | return es 345 | } 346 | 347 | // Get method establishes the connection with the server. 348 | // 349 | // es := NewEventSource(). 350 | // SetURL("https://sse.dev/test"). 351 | // OnMessage( 352 | // func(e any) { 353 | // e = e.(*Event) 354 | // fmt.Println(e) 355 | // }, 356 | // nil, // see method godoc 357 | // ) 358 | // 359 | // err := es.Get() 360 | // fmt.Println(err) 361 | func (es *EventSource) Get() error { 362 | // Validate required values 363 | if isStringEmpty(es.url) { 364 | return fmt.Errorf("resty:sse: event source URL is required") 365 | } 366 | 367 | if isStringEmpty(es.method) { 368 | // It is up to the user to choose which http method to use, depending on the specific code implementation. No restrictions are imposed here. 369 | // Ensure compatibility, use GET as default http method 370 | es.method = defaultHTTPMethod 371 | } 372 | 373 | if len(es.onEvent) == 0 { 374 | return fmt.Errorf("resty:sse: At least one OnMessage/AddEventListener func is required") 375 | } 376 | 377 | // reset to begin 378 | es.enableConnect() 379 | 380 | for { 381 | if es.isClosed() { 382 | return nil 383 | } 384 | res, err := es.connect() 385 | if err != nil { 386 | return err 387 | } 388 | es.triggerOnOpen() 389 | if err := es.listenStream(res); err != nil { 390 | return err 391 | } 392 | } 393 | } 394 | 395 | // Close method used to close SSE connection explicitly 396 | func (es *EventSource) Close() { 397 | es.lock.Lock() 398 | defer es.lock.Unlock() 399 | es.closed = true 400 | } 401 | 402 | func (es *EventSource) enableConnect() { 403 | es.lock.Lock() 404 | defer es.lock.Unlock() 405 | es.closed = false 406 | } 407 | 408 | func (es *EventSource) isClosed() bool { 409 | es.lock.RLock() 410 | defer es.lock.RUnlock() 411 | return es.closed 412 | } 413 | 414 | func (es *EventSource) triggerOnOpen() { 415 | es.lock.RLock() 416 | defer es.lock.RUnlock() 417 | if es.onOpen != nil { 418 | es.onOpen(strings.Clone(es.url)) 419 | } 420 | } 421 | 422 | func (es *EventSource) triggerOnError(err error) { 423 | es.lock.RLock() 424 | defer es.lock.RUnlock() 425 | if es.onError != nil { 426 | es.onError(err) 427 | } 428 | } 429 | 430 | func (es *EventSource) createRequest() (*http.Request, error) { 431 | req, err := http.NewRequest(es.method, es.url, es.body) 432 | if err != nil { 433 | return nil, err 434 | } 435 | 436 | req.Header = es.header.Clone() 437 | req.Header.Set(hdrAcceptKey, "text/event-stream") 438 | req.Header.Set(hdrCacheControlKey, "no-cache") 439 | req.Header.Set(hdrConnectionKey, "keep-alive") 440 | if len(es.lastEventID) > 0 { 441 | req.Header.Set(hdrLastEvevntID, es.lastEventID) 442 | } 443 | 444 | return req, nil 445 | } 446 | 447 | func (es *EventSource) connect() (*http.Response, error) { 448 | es.lock.RLock() 449 | defer es.lock.RUnlock() 450 | 451 | var backoff *backoffWithJitter 452 | if es.serverSentRetry > 0 { 453 | backoff = newBackoffWithJitter(es.serverSentRetry, es.serverSentRetry) 454 | } else { 455 | backoff = newBackoffWithJitter(es.retryWaitTime, es.retryMaxWaitTime) 456 | } 457 | 458 | var ( 459 | err error 460 | attempt int 461 | ) 462 | for i := 0; i <= es.retryCount; i++ { 463 | attempt++ 464 | req, reqErr := es.createRequest() 465 | if reqErr != nil { 466 | err = reqErr 467 | break 468 | } 469 | 470 | resp, doErr := es.httpClient.Do(req) 471 | if resp != nil && resp.StatusCode == http.StatusOK { 472 | return resp, nil 473 | } 474 | 475 | // we have reached the maximum no. of requests 476 | // first attempt + retry count = total attempts 477 | if attempt-1 == es.retryCount { 478 | err = doErr 479 | break 480 | } 481 | 482 | rRes := wrapResponse(resp) 483 | needsRetry := applyRetryDefaultConditions(rRes, doErr) 484 | 485 | // retry not required stop here 486 | if !needsRetry { 487 | if rRes != nil { 488 | err = wrapErrors(fmt.Errorf("resty:sse: %v", rRes.Status()), doErr) 489 | } else { 490 | err = doErr 491 | } 492 | break 493 | } 494 | 495 | // let's drain the response body, before retry wait 496 | drainBody(rRes) 497 | 498 | waitDuration, _ := backoff.NextWaitDuration(nil, rRes, doErr, attempt) 499 | timer := time.NewTimer(waitDuration) 500 | <-timer.C 501 | timer.Stop() 502 | } 503 | 504 | if err != nil { 505 | return nil, err 506 | } 507 | 508 | return nil, fmt.Errorf("resty:sse: unable to connect stream") 509 | } 510 | 511 | func (es *EventSource) listenStream(res *http.Response) error { 512 | defer closeq(res.Body) 513 | 514 | scanner := bufio.NewScanner(res.Body) 515 | scanner.Buffer(make([]byte, slices.Min([]int{4096, es.maxBufSize})), es.maxBufSize) 516 | scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { 517 | if atEOF && len(data) == 0 { 518 | return 0, nil, nil 519 | } 520 | if i := bytes.Index(data, []byte{'\n', '\n'}); i >= 0 { 521 | // We have a full double newline-terminated line. 522 | return i + 1, data[0:i], nil 523 | } 524 | // If we're at EOF, we have a final, non-terminated line. Return it. 525 | if atEOF { 526 | return len(data), data, nil 527 | } 528 | // Request more data. 529 | return 0, nil, nil 530 | }) 531 | 532 | for { 533 | if es.isClosed() { 534 | return nil 535 | } 536 | 537 | if err := es.processEvent(scanner); err != nil { 538 | return err 539 | } 540 | } 541 | } 542 | 543 | func (es *EventSource) processEvent(scanner *bufio.Scanner) error { 544 | e, err := readEvent(scanner) 545 | if err != nil { 546 | if err == io.EOF { 547 | return err 548 | } 549 | es.triggerOnError(err) 550 | return err 551 | } 552 | 553 | ed, err := parseEvent(e) 554 | if err != nil { 555 | es.triggerOnError(err) 556 | return nil // parsing errors, will not return error. 557 | } 558 | defer putRawEvent(ed) 559 | 560 | if len(ed.ID) > 0 { 561 | es.lock.Lock() 562 | es.lastEventID = string(ed.ID) 563 | es.lock.Unlock() 564 | } 565 | 566 | if len(ed.Retry) > 0 { 567 | if retry, err := strconv.Atoi(string(ed.Retry)); err == nil { 568 | es.lock.Lock() 569 | es.serverSentRetry = time.Millisecond * time.Duration(retry) 570 | es.lock.Unlock() 571 | } else { 572 | es.triggerOnError(err) 573 | } 574 | } 575 | 576 | if len(ed.Data) > 0 { 577 | es.handleCallback(&Event{ 578 | ID: string(ed.ID), 579 | Name: string(ed.Event), 580 | Data: string(ed.Data), 581 | }) 582 | } 583 | 584 | return nil 585 | } 586 | 587 | func (es *EventSource) handleCallback(e *Event) { 588 | es.lock.RLock() 589 | defer es.lock.RUnlock() 590 | 591 | eventName := e.Name 592 | if len(eventName) == 0 { 593 | eventName = defaultEventName 594 | } 595 | if cb, found := es.onEvent[eventName]; found { 596 | if cb.Result == nil { 597 | cb.Func(e) 598 | return 599 | } 600 | r := newInterface(cb.Result) 601 | if err := decodeJSON(strings.NewReader(e.Data), r); err != nil { 602 | es.triggerOnError(err) 603 | return 604 | } 605 | cb.Func(r) 606 | } 607 | } 608 | 609 | var readEvent = readEventFunc 610 | 611 | func readEventFunc(scanner *bufio.Scanner) ([]byte, error) { 612 | if scanner.Scan() { 613 | event := scanner.Bytes() 614 | return event, nil 615 | } 616 | if err := scanner.Err(); err != nil { 617 | return nil, err 618 | } 619 | return nil, io.EOF 620 | } 621 | 622 | func wrapResponse(res *http.Response) *Response { 623 | if res == nil { 624 | return nil 625 | } 626 | return &Response{RawResponse: res} 627 | } 628 | 629 | type rawEvent struct { 630 | ID []byte 631 | Data []byte 632 | Event []byte 633 | Retry []byte 634 | } 635 | 636 | var parseEvent = parseEventFunc 637 | 638 | // event value parsing logic obtained and modified for Resty processing flow. 639 | // https://github.com/r3labs/sse/blob/c6d5381ee3ca63828b321c16baa008fd6c0b4564/client.go#L322 640 | func parseEventFunc(msg []byte) (*rawEvent, error) { 641 | if len(msg) < 1 { 642 | return nil, errors.New("resty:sse: event message was empty") 643 | } 644 | 645 | e := newRawEvent() 646 | 647 | // Split the line by "\n" 648 | for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' }) { 649 | switch { 650 | case bytes.HasPrefix(line, headerID): 651 | e.ID = append([]byte(nil), trimHeader(len(headerID), line)...) 652 | case bytes.HasPrefix(line, headerData): 653 | // The spec allows for multiple data fields per event, concatenated them with "\n" 654 | e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...) 655 | // The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body. 656 | case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))): 657 | e.Data = append(e.Data, byte('\n')) 658 | case bytes.HasPrefix(line, headerEvent): 659 | e.Event = append([]byte(nil), trimHeader(len(headerEvent), line)...) 660 | case bytes.HasPrefix(line, headerRetry): 661 | e.Retry = append([]byte(nil), trimHeader(len(headerRetry), line)...) 662 | default: 663 | // Ignore anything that doesn't match the header 664 | } 665 | } 666 | 667 | // Trim the last "\n" per the spec 668 | e.Data = bytes.TrimSuffix(e.Data, []byte("\n")) 669 | 670 | return e, nil 671 | } 672 | 673 | func trimHeader(size int, data []byte) []byte { 674 | if data == nil || len(data) < size { 675 | return data 676 | } 677 | data = data[size:] 678 | data = bytes.TrimSpace(data) 679 | data = bytes.TrimSuffix(data, []byte("\n")) 680 | return data 681 | } 682 | 683 | var rawEventPool = &sync.Pool{New: func() any { return new(rawEvent) }} 684 | 685 | func newRawEvent() *rawEvent { 686 | e := rawEventPool.Get().(*rawEvent) 687 | e.ID = e.ID[:0] 688 | e.Data = e.Data[:0] 689 | e.Event = e.Event[:0] 690 | e.Retry = e.Retry[:0] 691 | return e 692 | } 693 | 694 | func putRawEvent(e *rawEvent) { 695 | rawEventPool.Put(e) 696 | } 697 | -------------------------------------------------------------------------------- /sse_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "bufio" 10 | "bytes" 11 | "errors" 12 | "fmt" 13 | "io" 14 | "net/http" 15 | "net/http/httptest" 16 | "strconv" 17 | "strings" 18 | "testing" 19 | "time" 20 | ) 21 | 22 | func TestEventSourceSimpleFlow(t *testing.T) { 23 | messageCounter := 0 24 | messageFunc := func(e any) { 25 | event := e.(*Event) 26 | assertEqual(t, strconv.Itoa(messageCounter), event.ID) 27 | assertEqual(t, true, strings.HasPrefix(event.Data, "The time is")) 28 | messageCounter++ 29 | } 30 | 31 | counter := 0 32 | es := createEventSource(t, "", messageFunc, nil) 33 | ts := createSSETestServer( 34 | t, 35 | 10*time.Millisecond, 36 | func(w io.Writer) error { 37 | if counter == 100 { 38 | es.Close() 39 | return fmt.Errorf("stop sending events") 40 | } 41 | _, err := fmt.Fprintf(w, "id: %v\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate)) 42 | counter++ 43 | return err 44 | }, 45 | ) 46 | defer ts.Close() 47 | 48 | es.SetURL(ts.URL) 49 | es.SetMethod(MethodPost) 50 | err := es.Get() 51 | assertNil(t, err) 52 | assertEqual(t, counter, messageCounter) 53 | } 54 | 55 | func TestEventSourceMultipleEventTypes(t *testing.T) { 56 | type userEvent struct { 57 | UserName string `json:"username"` 58 | Message string `json:"msg"` 59 | Time time.Time `json:"time"` 60 | } 61 | 62 | tm := time.Now().Add(-1 * time.Minute) 63 | userConnectCounter := 0 64 | userConnectFunc := func(e any) { 65 | data := e.(*userEvent) 66 | assertEqual(t, "username"+strconv.Itoa(userConnectCounter), data.UserName) 67 | assertEqual(t, true, data.Time.After(tm)) 68 | userConnectCounter++ 69 | } 70 | 71 | userMessageCounter := 0 72 | userMessageFunc := func(e any) { 73 | data := e.(*userEvent) 74 | assertEqual(t, "username"+strconv.Itoa(userConnectCounter), data.UserName) 75 | assertEqual(t, "Hello, how are you?", data.Message) 76 | assertEqual(t, true, data.Time.After(tm)) 77 | userMessageCounter++ 78 | } 79 | 80 | counter := 0 81 | es := createEventSource(t, "", func(any) {}, nil) 82 | ts := createSSETestServer( 83 | t, 84 | 10*time.Millisecond, 85 | func(w io.Writer) error { 86 | if counter == 100 { 87 | es.Close() 88 | return fmt.Errorf("stop sending events") 89 | } 90 | 91 | id := counter / 2 92 | if counter%2 == 0 { 93 | event := fmt.Sprintf("id: %v\n"+ 94 | "event: user_message\n"+ 95 | `data: {"username": "%v", "time": "%v", "msg": "Hello, how are you?"}`+"\n\n", 96 | id, 97 | "username"+strconv.Itoa(id), 98 | time.Now().Format(time.RFC3339), 99 | ) 100 | fmt.Fprint(w, event) 101 | } else { 102 | event := fmt.Sprintf("id: %v\n"+ 103 | "event: user_connect\n"+ 104 | `data: {"username": "%v", "time": "%v"}`+"\n\n", 105 | int(id), 106 | "username"+strconv.Itoa(int(id)), 107 | time.Now().Format(time.RFC3339), 108 | ) 109 | fmt.Fprint(w, event) 110 | } 111 | 112 | counter++ 113 | return nil 114 | }, 115 | ) 116 | defer ts.Close() 117 | 118 | es.SetURL(ts.URL). 119 | SetMethod(MethodPost). 120 | AddEventListener("user_connect", userConnectFunc, userEvent{}). 121 | AddEventListener("user_message", userMessageFunc, userEvent{}) 122 | 123 | err := es.Get() 124 | assertNil(t, err) 125 | assertEqual(t, userConnectCounter, userMessageCounter) 126 | } 127 | 128 | func TestEventSourceOverwriteFuncs(t *testing.T) { 129 | messageFunc1 := func(e any) { 130 | assertNotNil(t, e) 131 | } 132 | message2Counter := 0 133 | messageFunc2 := func(e any) { 134 | event := e.(*Event) 135 | assertEqual(t, strconv.Itoa(message2Counter), event.ID) 136 | assertEqual(t, true, strings.HasPrefix(event.Data, "The time is")) 137 | message2Counter++ 138 | } 139 | 140 | counter := 0 141 | es := createEventSource(t, "", messageFunc1, nil) 142 | ts := createSSETestServer( 143 | t, 144 | 10*time.Millisecond, 145 | func(w io.Writer) error { 146 | if counter == 50 { 147 | es.Close() 148 | return fmt.Errorf("stop sending events") 149 | } 150 | _, err := fmt.Fprintf(w, "id: %v\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate)) 151 | counter++ 152 | return err 153 | }, 154 | ) 155 | defer ts.Close() 156 | 157 | lb := new(bytes.Buffer) 158 | es.outputLogTo(lb) 159 | 160 | es.SetURL(ts.URL). 161 | OnMessage(messageFunc2, nil). 162 | OnOpen(func(url string) { 163 | t.Log("from overwrite func", url) 164 | }). 165 | OnError(func(err error) { 166 | t.Log("from overwrite func", err) 167 | }) 168 | 169 | err := es.Get() 170 | assertNil(t, err) 171 | assertEqual(t, counter, message2Counter) 172 | 173 | logLines := lb.String() 174 | assertEqual(t, true, strings.Contains(logLines, "Overwriting an existing OnEvent callback")) 175 | assertEqual(t, true, strings.Contains(logLines, "Overwriting an existing OnOpen callback")) 176 | assertEqual(t, true, strings.Contains(logLines, "Overwriting an existing OnError callback")) 177 | } 178 | 179 | func TestEventSourceRetry(t *testing.T) { 180 | messageCounter := 2 // 0 & 1 connection failure 181 | messageFunc := func(e any) { 182 | event := e.(*Event) 183 | assertEqual(t, strconv.Itoa(messageCounter), event.ID) 184 | assertEqual(t, true, strings.HasPrefix(event.Data, "The time is")) 185 | messageCounter++ 186 | } 187 | 188 | counter := 0 189 | es := createEventSource(t, "", messageFunc, nil) 190 | ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { 191 | if counter == 1 && r.URL.Query().Get("reconnect") == "1" { 192 | w.WriteHeader(http.StatusTooManyRequests) 193 | counter++ 194 | return 195 | } 196 | if counter < 2 || counter == 7 { 197 | w.WriteHeader(http.StatusTooManyRequests) 198 | counter++ 199 | return 200 | } 201 | 202 | w.Header().Set("Content-Type", "text/event-stream") 203 | w.Header().Set("Cache-Control", "no-cache") 204 | w.Header().Set("Connection", "keep-alive") 205 | 206 | // for local testing allow it 207 | w.Header().Set("Access-Control-Allow-Origin", "*") 208 | 209 | // Create a channel for client disconnection 210 | clientGone := r.Context().Done() 211 | 212 | rc := http.NewResponseController(w) 213 | tick := time.NewTicker(10 * time.Millisecond) 214 | defer tick.Stop() 215 | for { 216 | select { 217 | case <-clientGone: 218 | t.Log("Client disconnected") 219 | return 220 | case <-tick.C: 221 | if counter == 5 { 222 | fmt.Fprintf(w, "id: %v\nretry: abc\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate)) 223 | counter++ 224 | return 225 | } 226 | if counter == 15 { 227 | es.Close() 228 | return // stop sending events 229 | } 230 | fmt.Fprintf(w, "id: %v\nretry: 1\ndata: The time is %s\ndata\n\n", counter, time.Now().Format(time.UnixDate)) 231 | counter++ 232 | if err := rc.Flush(); err != nil { 233 | t.Log(err) 234 | return 235 | } 236 | } 237 | } 238 | }) 239 | defer ts.Close() 240 | 241 | // first round 242 | es.SetURL(ts.URL) 243 | err1 := es.Get() 244 | assertNotNil(t, err1) 245 | 246 | // second round 247 | counter = 0 248 | messageCounter = 2 249 | es.SetRetryCount(1). 250 | SetURL(ts.URL + "?reconnect=1") 251 | err2 := es.Get() 252 | assertNotNil(t, err2) 253 | } 254 | 255 | func TestEventSourceNoRetryRequired(t *testing.T) { 256 | es := createEventSource(t, "", func(any) {}, nil) 257 | ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { 258 | w.WriteHeader(http.StatusBadRequest) 259 | }) 260 | defer ts.Close() 261 | 262 | es.SetURL(ts.URL) 263 | err := es.Get() 264 | fmt.Println(err) 265 | assertEqual(t, true, strings.Contains(err.Error(), "400 Bad Request")) 266 | } 267 | 268 | func TestEventSourceHTTPError(t *testing.T) { 269 | es := createEventSource(t, "", func(any) {}, nil) 270 | ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { 271 | http.Redirect(w, r, "http://local host", http.StatusTemporaryRedirect) 272 | }) 273 | defer ts.Close() 274 | 275 | es.SetURL(ts.URL) 276 | err := es.Get() 277 | assertEqual(t, true, strings.Contains(err.Error(), `invalid character " " in host name`)) 278 | } 279 | 280 | func TestEventSourceParseAndReadError(t *testing.T) { 281 | type data struct{} 282 | counter := 0 283 | es := createEventSource(t, "", func(any) {}, data{}) 284 | ts := createSSETestServer( 285 | t, 286 | 5*time.Millisecond, 287 | func(w io.Writer) error { 288 | if counter == 5 { 289 | es.Close() 290 | return fmt.Errorf("stop sending events") 291 | } 292 | _, err := fmt.Fprintf(w, "id: %v\n"+ 293 | `data: The time is %s\n\n`+"\n\n", counter, time.Now().Format(time.UnixDate)) 294 | counter++ 295 | return err 296 | }, 297 | ) 298 | defer ts.Close() 299 | 300 | es.SetURL(ts.URL) 301 | err := es.Get() 302 | assertNil(t, err) 303 | 304 | // parse error 305 | parseEvent = func(_ []byte) (*rawEvent, error) { 306 | return nil, errors.New("test error") 307 | } 308 | counter = 0 309 | err = es.Get() 310 | assertNil(t, err) 311 | t.Cleanup(func() { 312 | parseEvent = parseEventFunc 313 | }) 314 | } 315 | 316 | func TestEventSourceReadError(t *testing.T) { 317 | es := createEventSource(t, "", func(any) {}, nil) 318 | ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { 319 | w.WriteHeader(http.StatusOK) 320 | }) 321 | defer ts.Close() 322 | 323 | // read error 324 | readEvent = func(_ *bufio.Scanner) ([]byte, error) { 325 | return nil, errors.New("read event test error") 326 | } 327 | t.Cleanup(func() { 328 | readEvent = readEventFunc 329 | }) 330 | 331 | es.SetURL(ts.URL) 332 | err := es.Get() 333 | assertNotNil(t, err) 334 | assertEqual(t, true, strings.Contains(err.Error(), "read event test error")) 335 | } 336 | 337 | func TestEventSourceCoverage(t *testing.T) { 338 | es := NewEventSource() 339 | err1 := es.Get() 340 | assertEqual(t, "resty:sse: event source URL is required", err1.Error()) 341 | 342 | es.SetURL("https://sse.dev/test") 343 | err2 := es.Get() 344 | assertEqual(t, "resty:sse: At least one OnMessage/AddEventListener func is required", err2.Error()) 345 | 346 | es.OnMessage(func(a any) {}, nil) 347 | es.SetURL("//res%20ty.dev") 348 | err3 := es.Get() 349 | assertEqual(t, true, strings.Contains(err3.Error(), `invalid URL escape "%20"`)) 350 | 351 | wrapResponse(nil) 352 | trimHeader(2, nil) 353 | parseEvent([]byte{}) 354 | } 355 | 356 | func createEventSource(t *testing.T, url string, fn EventMessageFunc, rt any) *EventSource { 357 | es := NewEventSource(). 358 | SetURL(url). 359 | SetMethod(MethodGet). 360 | AddHeader("X-Test-Header-1", "test header 1"). 361 | SetHeader("X-Test-Header-2", "test header 2"). 362 | SetRetryCount(2). 363 | SetRetryWaitTime(200 * time.Millisecond). 364 | SetRetryMaxWaitTime(1000 * time.Millisecond). 365 | SetMaxBufSize(1 << 14). // 16kb 366 | SetLogger(createLogger()). 367 | OnOpen(func(url string) { 368 | t.Log("I'm connected:", url) 369 | }). 370 | OnError(func(err error) { 371 | t.Log("Error occurred:", err) 372 | }) 373 | if fn != nil { 374 | es.OnMessage(fn, rt) 375 | } 376 | return es 377 | } 378 | 379 | func createSSETestServer(t *testing.T, ticker time.Duration, fn func(io.Writer) error) *httptest.Server { 380 | return createTestServer(func(w http.ResponseWriter, r *http.Request) { 381 | w.Header().Set("Content-Type", "text/event-stream") 382 | w.Header().Set("Cache-Control", "no-cache") 383 | w.Header().Set("Connection", "keep-alive") 384 | 385 | // for local testing allow it 386 | w.Header().Set("Access-Control-Allow-Origin", "*") 387 | 388 | // Create a channel for client disconnection 389 | clientGone := r.Context().Done() 390 | 391 | rc := http.NewResponseController(w) 392 | tick := time.NewTicker(ticker) 393 | defer tick.Stop() 394 | for { 395 | select { 396 | case <-clientGone: 397 | t.Log("Client disconnected") 398 | return 399 | case <-tick.C: 400 | if err := fn(w); err != nil { 401 | t.Log(err) 402 | return 403 | } 404 | if err := rc.Flush(); err != nil { 405 | t.Log(err) 406 | return 407 | } 408 | } 409 | } 410 | }) 411 | } 412 | 413 | func TestEventSourceWithDifferentMethods(t *testing.T) { 414 | testCases := []struct { 415 | name string 416 | method string 417 | body []byte 418 | }{ 419 | { 420 | name: "GET Method", 421 | method: MethodGet, 422 | body: nil, 423 | }, 424 | { 425 | name: "POST Method", 426 | method: MethodPost, 427 | body: []byte(`{"test":"post_data"}`), 428 | }, 429 | { 430 | name: "PUT Method", 431 | method: MethodPut, 432 | body: []byte(`{"test":"put_data"}`), 433 | }, 434 | { 435 | name: "DELETE Method", 436 | method: MethodDelete, 437 | body: nil, 438 | }, 439 | { 440 | name: "PATCH Method", 441 | method: MethodPatch, 442 | body: []byte(`{"test":"patch_data"}`), 443 | }, 444 | } 445 | 446 | for _, tc := range testCases { 447 | t.Run(tc.name, func(t *testing.T) { 448 | messageCounter := 0 449 | messageFunc := func(e any) { 450 | event := e.(*Event) 451 | assertEqual(t, strconv.Itoa(messageCounter), event.ID) 452 | assertEqual(t, true, strings.HasPrefix(event.Data, fmt.Sprintf("%s method test:", tc.method))) 453 | messageCounter++ 454 | } 455 | 456 | counter := 0 457 | methodVerified := false 458 | bodyVerified := false 459 | 460 | es := createEventSource(t, "", messageFunc, nil) 461 | ts := createMethodVerifyingSSETestServer( 462 | t, 463 | 10*time.Millisecond, 464 | tc.method, 465 | tc.body, 466 | &methodVerified, 467 | &bodyVerified, 468 | func(w io.Writer) error { 469 | if counter == 20 { 470 | es.Close() 471 | return fmt.Errorf("stop sending events") 472 | } 473 | _, err := fmt.Fprintf(w, "id: %v\ndata: %s method test: %s\n\n", counter, tc.method, time.Now().Format(time.RFC3339)) 474 | counter++ 475 | return err 476 | }, 477 | ) 478 | defer ts.Close() 479 | 480 | es.SetURL(ts.URL) 481 | es.SetMethod(tc.method) 482 | 483 | // set body 484 | if tc.body != nil { 485 | es.SetBody(bytes.NewBuffer(tc.body)) 486 | } 487 | 488 | err := es.Get() 489 | assertNil(t, err) 490 | 491 | // check the message count 492 | assertEqual(t, counter, messageCounter) 493 | 494 | // check if server receive correct method and body 495 | assertEqual(t, true, methodVerified) 496 | if tc.body != nil { 497 | assertEqual(t, true, bodyVerified) 498 | } 499 | }) 500 | } 501 | } 502 | 503 | // almost like create server before but add verifying method and body 504 | func createMethodVerifyingSSETestServer( 505 | t *testing.T, 506 | ticker time.Duration, 507 | expectedMethod string, 508 | expectedBody []byte, 509 | methodVerified *bool, 510 | bodyVerified *bool, 511 | fn func(io.Writer) error, 512 | ) *httptest.Server { 513 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 514 | // validate method 515 | if r.Method == expectedMethod { 516 | *methodVerified = true 517 | } else { 518 | t.Errorf("Expected method %s, got %s", expectedMethod, r.Method) 519 | } 520 | 521 | // validate body 522 | if expectedBody != nil { 523 | body, err := io.ReadAll(r.Body) 524 | if err != nil { 525 | t.Errorf("Failed to read request body: %v", err) 526 | } else if string(body) == string(expectedBody) { 527 | *bodyVerified = true 528 | } else { 529 | t.Errorf("Expected body %s, got %s", string(expectedBody), string(body)) 530 | } 531 | } 532 | 533 | // same as createSSETestServer 534 | w.Header().Set("Content-Type", "text/event-stream") 535 | w.Header().Set("Cache-Control", "no-cache") 536 | w.Header().Set("Connection", "keep-alive") 537 | w.Header().Set("Access-Control-Allow-Origin", "*") 538 | 539 | clientGone := r.Context().Done() 540 | 541 | rc := http.NewResponseController(w) 542 | tick := time.NewTicker(ticker) 543 | defer tick.Stop() 544 | 545 | for { 546 | select { 547 | case <-clientGone: 548 | t.Log("Client disconnected") 549 | return 550 | case <-tick.C: 551 | if err := fn(w); err != nil { 552 | t.Log(err) 553 | return 554 | } 555 | if err := rc.Flush(); err != nil { 556 | t.Log(err) 557 | return 558 | } 559 | } 560 | } 561 | })) 562 | } 563 | -------------------------------------------------------------------------------- /stream.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "bytes" 10 | "compress/flate" 11 | "compress/gzip" 12 | "encoding/json" 13 | "encoding/xml" 14 | "errors" 15 | "io" 16 | "sync" 17 | ) 18 | 19 | var ( 20 | ErrContentDecompresserNotFound = errors.New("resty: content decoder not found") 21 | ) 22 | 23 | type ( 24 | // ContentTypeEncoder type is for encoding the request body based on header Content-Type 25 | ContentTypeEncoder func(io.Writer, any) error 26 | 27 | // ContentTypeDecoder type is for decoding the response body based on header Content-Type 28 | ContentTypeDecoder func(io.Reader, any) error 29 | 30 | // ContentDecompresser type is for decompressing response body based on header Content-Encoding 31 | // ([RFC 9110]) 32 | // 33 | // For example, gzip, deflate, etc. 34 | // 35 | // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110 36 | ContentDecompresser func(io.ReadCloser) (io.ReadCloser, error) 37 | ) 38 | 39 | func encodeJSON(w io.Writer, v any) error { 40 | return encodeJSONEscapeHTML(w, v, true) 41 | } 42 | 43 | func encodeJSONEscapeHTML(w io.Writer, v any, esc bool) error { 44 | enc := json.NewEncoder(w) 45 | enc.SetEscapeHTML(esc) 46 | return enc.Encode(v) 47 | } 48 | 49 | func encodeJSONEscapeHTMLIndent(w io.Writer, v any, esc bool, indent string) error { 50 | enc := json.NewEncoder(w) 51 | enc.SetEscapeHTML(esc) 52 | enc.SetIndent("", indent) 53 | return enc.Encode(v) 54 | } 55 | 56 | func decodeJSON(r io.Reader, v any) error { 57 | dec := json.NewDecoder(r) 58 | for { 59 | if err := dec.Decode(v); err == io.EOF { 60 | break 61 | } else if err != nil { 62 | return err 63 | } 64 | } 65 | return nil 66 | } 67 | 68 | func encodeXML(w io.Writer, v any) error { 69 | return xml.NewEncoder(w).Encode(v) 70 | } 71 | 72 | func decodeXML(r io.Reader, v any) error { 73 | dec := xml.NewDecoder(r) 74 | for { 75 | if err := dec.Decode(v); err == io.EOF { 76 | break 77 | } else if err != nil { 78 | return err 79 | } 80 | } 81 | return nil 82 | } 83 | 84 | var gzipPool = sync.Pool{New: func() any { return new(gzip.Reader) }} 85 | 86 | func decompressGzip(r io.ReadCloser) (io.ReadCloser, error) { 87 | gr := gzipPool.Get().(*gzip.Reader) 88 | err := gr.Reset(r) 89 | return &gzipReader{s: r, r: gr}, err 90 | } 91 | 92 | type gzipReader struct { 93 | s io.ReadCloser 94 | r *gzip.Reader 95 | } 96 | 97 | func (gz *gzipReader) Read(p []byte) (n int, err error) { 98 | return gz.r.Read(p) 99 | } 100 | 101 | func (gz *gzipReader) Close() error { 102 | gz.r.Reset(nopReader{}) 103 | gzipPool.Put(gz.r) 104 | closeq(gz.s) 105 | return nil 106 | } 107 | 108 | var flatePool = sync.Pool{New: func() any { return flate.NewReader(nopReader{}) }} 109 | 110 | func decompressDeflate(r io.ReadCloser) (io.ReadCloser, error) { 111 | fr := flatePool.Get().(io.ReadCloser) 112 | err := fr.(flate.Resetter).Reset(r, nil) 113 | return &deflateReader{s: r, r: fr}, err 114 | } 115 | 116 | type deflateReader struct { 117 | s io.ReadCloser 118 | r io.ReadCloser 119 | } 120 | 121 | func (d *deflateReader) Read(p []byte) (n int, err error) { 122 | return d.r.Read(p) 123 | } 124 | 125 | func (d *deflateReader) Close() error { 126 | d.r.(flate.Resetter).Reset(nopReader{}, nil) 127 | flatePool.Put(d.r) 128 | closeq(d.s) 129 | return nil 130 | } 131 | 132 | var ErrReadExceedsThresholdLimit = errors.New("resty: read exceeds the threshold limit") 133 | 134 | var _ io.ReadCloser = (*limitReadCloser)(nil) 135 | 136 | type limitReadCloser struct { 137 | r io.Reader 138 | l int64 139 | t int64 140 | f func(s int64) 141 | } 142 | 143 | func (l *limitReadCloser) Read(p []byte) (n int, err error) { 144 | if l.l == 0 { 145 | n, err = l.r.Read(p) 146 | l.t += int64(n) 147 | l.f(l.t) 148 | return n, err 149 | } 150 | if l.t > l.l { 151 | return 0, ErrReadExceedsThresholdLimit 152 | } 153 | n, err = l.r.Read(p) 154 | l.t += int64(n) 155 | l.f(l.t) 156 | return n, err 157 | } 158 | 159 | func (l *limitReadCloser) Close() error { 160 | if c, ok := l.r.(io.Closer); ok { 161 | return c.Close() 162 | } 163 | return nil 164 | } 165 | 166 | var _ io.ReadCloser = (*copyReadCloser)(nil) 167 | 168 | type copyReadCloser struct { 169 | s io.Reader 170 | t *bytes.Buffer 171 | c bool 172 | f func(*bytes.Buffer) 173 | } 174 | 175 | func (r *copyReadCloser) Read(p []byte) (int, error) { 176 | n, err := r.s.Read(p) 177 | if n > 0 { 178 | _, _ = r.t.Write(p[:n]) 179 | } 180 | if err == io.EOF || err == ErrReadExceedsThresholdLimit { 181 | if !r.c { 182 | r.f(r.t) 183 | r.c = true 184 | } 185 | } 186 | return n, err 187 | } 188 | 189 | func (r *copyReadCloser) Close() error { 190 | if c, ok := r.s.(io.Closer); ok { 191 | return c.Close() 192 | } 193 | return nil 194 | } 195 | 196 | var _ io.ReadCloser = (*nopReadCloser)(nil) 197 | 198 | type nopReadCloser struct { 199 | r *bytes.Reader 200 | } 201 | 202 | func (r *nopReadCloser) Read(p []byte) (int, error) { 203 | n, err := r.r.Read(p) 204 | if err == io.EOF { 205 | r.r.Seek(0, io.SeekStart) 206 | } 207 | return n, err 208 | } 209 | 210 | func (r *nopReadCloser) Close() error { return nil } 211 | 212 | var _ flate.Reader = (*nopReader)(nil) 213 | 214 | type nopReader struct{} 215 | 216 | func (nopReader) Read([]byte) (int, error) { return 0, io.EOF } 217 | func (nopReader) ReadByte() (byte, error) { return 0, io.EOF } 218 | -------------------------------------------------------------------------------- /trace.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "context" 10 | "crypto/tls" 11 | "fmt" 12 | "net/http/httptrace" 13 | "time" 14 | ) 15 | 16 | // TraceInfo struct is used to provide request trace info such as DNS lookup 17 | // duration, Connection obtain duration, Server processing duration, etc. 18 | type TraceInfo struct { 19 | // DNSLookup is the duration that transport took to perform 20 | // DNS lookup. 21 | DNSLookup time.Duration `json:"dns_lookup_time"` 22 | 23 | // ConnTime is the duration it took to obtain a successful connection. 24 | ConnTime time.Duration `json:"connection_time"` 25 | 26 | // TCPConnTime is the duration it took to obtain the TCP connection. 27 | TCPConnTime time.Duration `json:"tcp_connection_time"` 28 | 29 | // TLSHandshake is the duration of the TLS handshake. 30 | TLSHandshake time.Duration `json:"tls_handshake_time"` 31 | 32 | // ServerTime is the server's duration for responding to the first byte. 33 | ServerTime time.Duration `json:"server_time"` 34 | 35 | // ResponseTime is the duration since the first response byte from the server to 36 | // request completion. 37 | ResponseTime time.Duration `json:"response_time"` 38 | 39 | // TotalTime is the duration of the total time request taken end-to-end. 40 | TotalTime time.Duration `json:"total_time"` 41 | 42 | // IsConnReused is whether this connection has been previously 43 | // used for another HTTP request. 44 | IsConnReused bool `json:"is_connection_reused"` 45 | 46 | // IsConnWasIdle is whether this connection was obtained from an 47 | // idle pool. 48 | IsConnWasIdle bool `json:"is_connection_was_idle"` 49 | 50 | // ConnIdleTime is the duration how long the connection that was previously 51 | // idle, if IsConnWasIdle is true. 52 | ConnIdleTime time.Duration `json:"connection_idle_time"` 53 | 54 | // RequestAttempt is to represent the request attempt made during a Resty 55 | // request execution flow, including retry count. 56 | RequestAttempt int `json:"request_attempt"` 57 | 58 | // RemoteAddr returns the remote network address. 59 | RemoteAddr string `json:"remote_address"` 60 | } 61 | 62 | // String method returns string representation of request trace information. 63 | func (ti TraceInfo) String() string { 64 | return fmt.Sprintf(`TRACE INFO: 65 | DNSLookupTime : %v 66 | ConnTime : %v 67 | TCPConnTime : %v 68 | TLSHandshake : %v 69 | ServerTime : %v 70 | ResponseTime : %v 71 | TotalTime : %v 72 | IsConnReused : %v 73 | IsConnWasIdle : %v 74 | ConnIdleTime : %v 75 | RequestAttempt: %v 76 | RemoteAddr : %v`, ti.DNSLookup, ti.ConnTime, ti.TCPConnTime, 77 | ti.TLSHandshake, ti.ServerTime, ti.ResponseTime, ti.TotalTime, 78 | ti.IsConnReused, ti.IsConnWasIdle, ti.ConnIdleTime, ti.RequestAttempt, 79 | ti.RemoteAddr) 80 | } 81 | 82 | // JSON method returns the JSON string of request trace information 83 | func (ti TraceInfo) JSON() string { 84 | return toJSON(ti) 85 | } 86 | 87 | // Clone method returns the clone copy of [TraceInfo] 88 | func (ti TraceInfo) Clone() *TraceInfo { 89 | ti2 := new(TraceInfo) 90 | *ti2 = ti 91 | return ti2 92 | } 93 | 94 | // clientTrace struct maps the [httptrace.ClientTrace] hooks into Fields 95 | // with the same naming for easy understanding. Plus additional insights 96 | // [Request]. 97 | type clientTrace struct { 98 | getConn time.Time 99 | dnsStart time.Time 100 | dnsDone time.Time 101 | connectDone time.Time 102 | tlsHandshakeStart time.Time 103 | tlsHandshakeDone time.Time 104 | gotConn time.Time 105 | gotFirstResponseByte time.Time 106 | endTime time.Time 107 | gotConnInfo httptrace.GotConnInfo 108 | } 109 | 110 | func (t *clientTrace) createContext(ctx context.Context) context.Context { 111 | return httptrace.WithClientTrace( 112 | ctx, 113 | &httptrace.ClientTrace{ 114 | DNSStart: func(_ httptrace.DNSStartInfo) { 115 | t.dnsStart = time.Now() 116 | }, 117 | DNSDone: func(_ httptrace.DNSDoneInfo) { 118 | t.dnsDone = time.Now() 119 | }, 120 | ConnectStart: func(_, _ string) { 121 | if t.dnsDone.IsZero() { 122 | t.dnsDone = time.Now() 123 | } 124 | if t.dnsStart.IsZero() { 125 | t.dnsStart = t.dnsDone 126 | } 127 | }, 128 | ConnectDone: func(net, addr string, err error) { 129 | t.connectDone = time.Now() 130 | }, 131 | GetConn: func(_ string) { 132 | t.getConn = time.Now() 133 | }, 134 | GotConn: func(ci httptrace.GotConnInfo) { 135 | t.gotConn = time.Now() 136 | t.gotConnInfo = ci 137 | }, 138 | GotFirstResponseByte: func() { 139 | t.gotFirstResponseByte = time.Now() 140 | }, 141 | TLSHandshakeStart: func() { 142 | t.tlsHandshakeStart = time.Now() 143 | }, 144 | TLSHandshakeDone: func(_ tls.ConnectionState, _ error) { 145 | t.tlsHandshakeDone = time.Now() 146 | }, 147 | }, 148 | ) 149 | } 150 | -------------------------------------------------------------------------------- /transport_dial.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build !(js && wasm) 6 | // +build !js !wasm 7 | 8 | package resty 9 | 10 | import ( 11 | "context" 12 | "net" 13 | ) 14 | 15 | func transportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { 16 | return dialer.DialContext 17 | } 18 | -------------------------------------------------------------------------------- /transport_dial_wasm.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build (js && wasm) || wasip1 6 | // +build js,wasm wasip1 7 | 8 | package resty 9 | 10 | import ( 11 | "context" 12 | "net" 13 | ) 14 | 15 | func transportDialContext(_ *net.Dialer) func(context.Context, string, string) (net.Conn, error) { 16 | return nil 17 | } 18 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "bytes" 10 | "crypto/md5" 11 | "crypto/rand" 12 | "encoding/binary" 13 | "encoding/hex" 14 | "errors" 15 | "fmt" 16 | "io" 17 | "log" 18 | "net/http" 19 | "net/url" 20 | "os" 21 | "reflect" 22 | "runtime" 23 | "sort" 24 | "strings" 25 | "sync/atomic" 26 | "time" 27 | ) 28 | 29 | //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ 30 | // Logger interface 31 | //_______________________________________________________________________ 32 | 33 | // Logger interface is to abstract the logging from Resty. Gives control to 34 | // the Resty users, choice of the logger. 35 | type Logger interface { 36 | Errorf(format string, v ...any) 37 | Warnf(format string, v ...any) 38 | Debugf(format string, v ...any) 39 | } 40 | 41 | func createLogger() *logger { 42 | l := &logger{l: log.New(os.Stderr, "", log.Ldate|log.Lmicroseconds)} 43 | return l 44 | } 45 | 46 | var _ Logger = (*logger)(nil) 47 | 48 | type logger struct { 49 | l *log.Logger 50 | } 51 | 52 | func (l *logger) Errorf(format string, v ...any) { 53 | l.output("ERROR RESTY "+format, v...) 54 | } 55 | 56 | func (l *logger) Warnf(format string, v ...any) { 57 | l.output("WARN RESTY "+format, v...) 58 | } 59 | 60 | func (l *logger) Debugf(format string, v ...any) { 61 | l.output("DEBUG RESTY "+format, v...) 62 | } 63 | 64 | func (l *logger) output(format string, v ...any) { 65 | if len(v) == 0 { 66 | l.l.Print(format) 67 | return 68 | } 69 | l.l.Printf(format, v...) 70 | } 71 | 72 | // credentials type is to hold an username and password information 73 | type credentials struct { 74 | Username, Password string 75 | } 76 | 77 | // Clone method returns clone of c. 78 | func (c *credentials) Clone() *credentials { 79 | cc := new(credentials) 80 | *cc = *c 81 | return cc 82 | } 83 | 84 | // String method returns masked value of username and password 85 | func (c credentials) String() string { 86 | return "Username: **********, Password: **********" 87 | } 88 | 89 | //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ 90 | // Package Helper methods 91 | //_______________________________________________________________________ 92 | 93 | // isStringEmpty method tells whether given string is empty or not 94 | func isStringEmpty(str string) bool { 95 | return len(strings.TrimSpace(str)) == 0 96 | } 97 | 98 | // detectContentType method is used to figure out `Request.Body` content type for request header 99 | func detectContentType(body any) string { 100 | contentType := plainTextType 101 | kind := inferKind(body) 102 | switch kind { 103 | case reflect.Struct, reflect.Map: 104 | contentType = jsonContentType 105 | case reflect.String: 106 | contentType = plainTextType 107 | default: 108 | if b, ok := body.([]byte); ok { 109 | contentType = http.DetectContentType(b) 110 | } else if kind == reflect.Slice { // check slice here to differentiate between any slice vs byte slice 111 | contentType = jsonContentType 112 | } 113 | } 114 | 115 | return contentType 116 | } 117 | 118 | func isJSONContentType(ct string) bool { 119 | return strings.Contains(ct, jsonKey) 120 | } 121 | 122 | func isXMLContentType(ct string) bool { 123 | return strings.Contains(ct, xmlKey) 124 | } 125 | 126 | func inferContentTypeMapKey(v string) string { 127 | if isJSONContentType(v) { 128 | return jsonKey 129 | } else if isXMLContentType(v) { 130 | return xmlKey 131 | } 132 | return "" 133 | } 134 | 135 | func firstNonEmpty(v ...string) string { 136 | for _, s := range v { 137 | if !isStringEmpty(s) { 138 | return s 139 | } 140 | } 141 | return "" 142 | } 143 | 144 | var ( 145 | mkdirAll = os.MkdirAll 146 | createFile = os.Create 147 | ioCopy = io.Copy 148 | ) 149 | 150 | func createDirectory(dir string) (err error) { 151 | if _, err = os.Stat(dir); err != nil { 152 | if os.IsNotExist(err) { 153 | if err = mkdirAll(dir, 0755); err != nil { 154 | return 155 | } 156 | } 157 | } 158 | return 159 | } 160 | 161 | func getPointer(v any) any { 162 | if v == nil { 163 | return nil 164 | } 165 | vv := reflect.ValueOf(v) 166 | if vv.Kind() == reflect.Ptr { 167 | return v 168 | } 169 | return reflect.New(vv.Type()).Interface() 170 | } 171 | 172 | func inferType(v any) reflect.Type { 173 | return reflect.Indirect(reflect.ValueOf(v)).Type() 174 | } 175 | 176 | func inferKind(v any) reflect.Kind { 177 | return inferType(v).Kind() 178 | } 179 | 180 | func newInterface(v any) any { 181 | if v == nil { 182 | return nil 183 | } 184 | return reflect.New(inferType(v)).Interface() 185 | } 186 | 187 | func functionName(i any) string { 188 | return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() 189 | } 190 | 191 | func acquireBuffer() *bytes.Buffer { 192 | buf := bufPool.Get().(*bytes.Buffer) 193 | if buf.Len() == 0 { 194 | buf.Reset() 195 | return buf 196 | } 197 | bufPool.Put(buf) 198 | return new(bytes.Buffer) 199 | } 200 | 201 | func releaseBuffer(buf *bytes.Buffer) { 202 | if buf != nil { 203 | buf.Reset() 204 | bufPool.Put(buf) 205 | } 206 | } 207 | 208 | func backToBufPool(buf *bytes.Buffer) { 209 | if buf != nil { 210 | bufPool.Put(buf) 211 | } 212 | } 213 | 214 | func closeq(v any) { 215 | if c, ok := v.(io.Closer); ok { 216 | silently(c.Close()) 217 | } 218 | } 219 | 220 | func silently(_ ...any) {} 221 | 222 | var sanitizeHeaderToken = []string{ 223 | "authorization", 224 | "auth", 225 | "token", 226 | } 227 | 228 | func isSanitizeHeader(k string) bool { 229 | kk := strings.ToLower(k) 230 | for _, v := range sanitizeHeaderToken { 231 | if strings.Contains(kk, v) { 232 | return true 233 | } 234 | } 235 | return false 236 | } 237 | 238 | func sanitizeHeaders(hdr http.Header) http.Header { 239 | for k := range hdr { 240 | if isSanitizeHeader(k) { 241 | hdr[k] = []string{"********************"} 242 | } 243 | } 244 | return hdr 245 | } 246 | 247 | func composeHeaders(hdr http.Header) string { 248 | str := make([]string, 0, len(hdr)) 249 | for _, k := range sortHeaderKeys(hdr) { 250 | str = append(str, "\t"+strings.TrimSpace(fmt.Sprintf("%25s: %s", k, strings.Join(hdr[k], ", ")))) 251 | } 252 | return strings.Join(str, "\n") 253 | } 254 | 255 | func sortHeaderKeys(hdr http.Header) []string { 256 | keys := make([]string, 0, len(hdr)) 257 | for key := range hdr { 258 | keys = append(keys, key) 259 | } 260 | sort.Strings(keys) 261 | return keys 262 | } 263 | 264 | func wrapErrors(n error, inner error) error { 265 | if n == nil && inner == nil { 266 | return nil 267 | } 268 | if inner == nil { 269 | return n 270 | } 271 | if n == nil { 272 | return inner 273 | } 274 | return &restyError{ 275 | err: n, 276 | inner: inner, 277 | } 278 | } 279 | 280 | type restyError struct { 281 | err error 282 | inner error 283 | } 284 | 285 | func (e *restyError) Error() string { 286 | return e.err.Error() 287 | } 288 | 289 | func (e *restyError) Unwrap() error { 290 | return e.inner 291 | } 292 | 293 | // cloneURLValues is a helper function to deep copy url.Values. 294 | func cloneURLValues(v url.Values) url.Values { 295 | if v == nil { 296 | return nil 297 | } 298 | return url.Values(http.Header(v).Clone()) 299 | } 300 | 301 | func cloneCookie(c *http.Cookie) *http.Cookie { 302 | return &http.Cookie{ 303 | Name: c.Name, 304 | Value: c.Value, 305 | Path: c.Path, 306 | Domain: c.Domain, 307 | Expires: c.Expires, 308 | RawExpires: c.RawExpires, 309 | MaxAge: c.MaxAge, 310 | Secure: c.Secure, 311 | HttpOnly: c.HttpOnly, 312 | SameSite: c.SameSite, 313 | Raw: c.Raw, 314 | Unparsed: c.Unparsed, 315 | } 316 | } 317 | 318 | type invalidRequestError struct { 319 | Err error 320 | } 321 | 322 | func (ire *invalidRequestError) Error() string { 323 | return ire.Err.Error() 324 | } 325 | 326 | func drainBody(res *Response) { 327 | if res != nil && res.Body != nil { 328 | defer closeq(res.Body) 329 | _, _ = io.Copy(io.Discard, res.Body) 330 | } 331 | } 332 | 333 | func toJSON(v any) string { 334 | buf := acquireBuffer() 335 | defer releaseBuffer(buf) 336 | _ = encodeJSON(buf, v) 337 | return buf.String() 338 | } 339 | 340 | //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ 341 | // GUID generation 342 | // Code inspired from mgo/bson ObjectId 343 | // Code obtained from https://github.com/go-aah/aah/blob/edge/essentials/guid.go 344 | //___________________________________ 345 | 346 | var ( 347 | // guidCounter is atomically incremented when generating a new GUID 348 | // using UniqueID() function. It's used as a counter part of an id. 349 | guidCounter = readRandomUint32() 350 | 351 | // machineID stores machine id generated once and used in subsequent calls 352 | // to UniqueId function. 353 | machineID = readMachineID() 354 | 355 | // processID is current Process Id 356 | processID = os.Getpid() 357 | ) 358 | 359 | // newGUID method returns a new Globally Unique Identifier (GUID). 360 | // 361 | // The 12-byte `UniqueId` consists of- 362 | // - 4-byte value representing the seconds since the Unix epoch, 363 | // - 3-byte machine identifier, 364 | // - 2-byte process id, and 365 | // - 3-byte counter, starting with a random value. 366 | // 367 | // Uses Mongo Object ID algorithm to generate globally unique ids - 368 | // https://docs.mongodb.com/manual/reference/method/ObjectId/ 369 | func newGUID() string { 370 | var b [12]byte 371 | // Timestamp, 4 bytes, big endian 372 | binary.BigEndian.PutUint32(b[:], uint32(time.Now().Unix())) 373 | 374 | // Machine, first 3 bytes of md5(hostname) 375 | b[4], b[5], b[6] = machineID[0], machineID[1], machineID[2] 376 | 377 | // Pid, 2 bytes, specs don't specify endianness, but we use big endian. 378 | b[7], b[8] = byte(processID>>8), byte(processID) 379 | 380 | // Increment, 3 bytes, big endian 381 | i := atomic.AddUint32(&guidCounter, 1) 382 | b[9], b[10], b[11] = byte(i>>16), byte(i>>8), byte(i) 383 | 384 | return hex.EncodeToString(b[:]) 385 | } 386 | 387 | var ioReadFull = io.ReadFull 388 | 389 | // readRandomUint32 returns a random guidCounter. 390 | func readRandomUint32() uint32 { 391 | var b [4]byte 392 | if _, err := ioReadFull(rand.Reader, b[:]); err == nil { 393 | return (uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) 394 | } 395 | 396 | // To initialize package unexported variable 'guidCounter'. 397 | // This panic would happen at program startup, so no worries at runtime panic. 398 | panic(errors.New("resty - guid: unable to generate random object id")) 399 | } 400 | 401 | var osHostname = os.Hostname 402 | 403 | // readMachineID generates and returns a machine id. 404 | // If this function fails to get the hostname it will cause a runtime error. 405 | func readMachineID() []byte { 406 | var sum [3]byte 407 | id := sum[:] 408 | 409 | if hostname, err := osHostname(); err == nil { 410 | hw := md5.New() 411 | _, _ = hw.Write([]byte(hostname)) 412 | copy(id, hw.Sum(nil)) 413 | return id 414 | } 415 | 416 | if _, err := ioReadFull(rand.Reader, id); err == nil { 417 | return id 418 | } 419 | 420 | // To initialize package unexported variable 'machineID'. 421 | // This panic would happen at program startup, so no worries at runtime panic. 422 | panic(errors.New("resty - guid: unable to get hostname and random bytes")) 423 | } 424 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. 2 | // resty source code and usage is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | // SPDX-License-Identifier: MIT 5 | 6 | package resty 7 | 8 | import ( 9 | "bytes" 10 | "errors" 11 | "io" 12 | "net/url" 13 | "os" 14 | "path/filepath" 15 | "strings" 16 | "testing" 17 | ) 18 | 19 | func TestIsJSONContentType(t *testing.T) { 20 | for _, test := range []struct { 21 | input string 22 | expect bool 23 | }{ 24 | {"application/json", true}, 25 | {"application/xml+json", true}, 26 | {"application/vnd.foo+json", true}, 27 | 28 | {"application/json; charset=utf-8", true}, 29 | {"application/vnd.foo+json; charset=utf-8", true}, 30 | 31 | {"text/json", true}, 32 | {"text/vnd.foo+json", true}, 33 | 34 | {"application/foo-json", true}, 35 | {"application/foo.json", true}, 36 | {"application/vnd.foo-json", true}, 37 | {"application/vnd.foo.json", true}, 38 | {"application/x-amz-json-1.1", true}, 39 | 40 | {"text/foo-json", true}, 41 | {"text/foo.json", true}, 42 | {"text/vnd.foo-json", true}, 43 | {"text/vnd.foo.json", true}, 44 | } { 45 | result := isJSONContentType(test.input) 46 | 47 | if result != test.expect { 48 | t.Errorf("failed on %q: want %v, got %v", test.input, test.expect, result) 49 | } 50 | } 51 | } 52 | 53 | func TestIsXMLContentType(t *testing.T) { 54 | for _, test := range []struct { 55 | input string 56 | expect bool 57 | }{ 58 | {"application/xml", true}, 59 | {"application/vnd.foo+xml", true}, 60 | 61 | {"application/xml; charset=utf-8", true}, 62 | {"application/vnd.foo+xml; charset=utf-8", true}, 63 | 64 | {"text/xml", true}, 65 | {"text/vnd.foo+xml", true}, 66 | 67 | {"application/foo-xml", true}, 68 | {"application/foo.xml", true}, 69 | {"application/vnd.foo-xml", true}, 70 | {"application/vnd.foo.xml", true}, 71 | 72 | {"text/foo-xml", true}, 73 | {"text/foo.xml", true}, 74 | {"text/vnd.foo-xml", true}, 75 | {"text/vnd.foo.xml", true}, 76 | } { 77 | result := isXMLContentType(test.input) 78 | 79 | if result != test.expect { 80 | t.Errorf("failed on %q: want %v, got %v", test.input, test.expect, result) 81 | } 82 | } 83 | } 84 | 85 | func TestCloneURLValues(t *testing.T) { 86 | v := url.Values{} 87 | v.Add("foo", "bar") 88 | v.Add("foo", "baz") 89 | v.Add("qux", "quux") 90 | 91 | c := cloneURLValues(v) 92 | nilUrl := cloneURLValues(nil) 93 | assertEqual(t, v, c) 94 | assertNil(t, nilUrl) 95 | } 96 | 97 | func TestRestyErrorFuncs(t *testing.T) { 98 | ne1 := errors.New("new error 1") 99 | nie1 := errors.New("inner error 1") 100 | 101 | assertNil(t, wrapErrors(nil, nil)) 102 | 103 | e := wrapErrors(ne1, nie1) 104 | assertEqual(t, "new error 1", e.Error()) 105 | assertEqual(t, "inner error 1", errors.Unwrap(e).Error()) 106 | 107 | e = wrapErrors(ne1, nil) 108 | assertEqual(t, "new error 1", e.Error()) 109 | 110 | e = wrapErrors(nil, nie1) 111 | assertEqual(t, "inner error 1", e.Error()) 112 | } 113 | 114 | func Test_createDirectory(t *testing.T) { 115 | errMsg := "test dir error" 116 | mkdirAll = func(path string, perm os.FileMode) error { 117 | return errors.New(errMsg) 118 | } 119 | t.Cleanup(func() { 120 | mkdirAll = os.MkdirAll 121 | }) 122 | 123 | tempDir := filepath.Join(t.TempDir(), "test-dir") 124 | err := createDirectory(tempDir) 125 | assertEqual(t, errMsg, err.Error()) 126 | } 127 | 128 | func TestUtil_readRandomUint32(t *testing.T) { 129 | defer func() { 130 | if r := recover(); r == nil { 131 | // panic: resty - guid: unable to generate random object id 132 | t.Errorf("The code did not panic") 133 | } 134 | }() 135 | errMsg := "read full error" 136 | ioReadFull = func(_ io.Reader, _ []byte) (int, error) { 137 | return 0, errors.New(errMsg) 138 | } 139 | t.Cleanup(func() { 140 | ioReadFull = io.ReadFull 141 | }) 142 | 143 | readRandomUint32() 144 | } 145 | 146 | func TestUtil_readMachineID(t *testing.T) { 147 | t.Run("hostname error", func(t *testing.T) { 148 | errHostMsg := "hostname error" 149 | osHostname = func() (string, error) { 150 | return "", errors.New(errHostMsg) 151 | } 152 | t.Cleanup(func() { 153 | osHostname = os.Hostname 154 | }) 155 | 156 | readMachineID() 157 | }) 158 | 159 | t.Run("hostname and read full error", func(t *testing.T) { 160 | defer func() { 161 | if r := recover(); r == nil { 162 | // panic: resty - guid: unable to get hostname and random bytes 163 | t.Errorf("The code did not panic") 164 | } 165 | }() 166 | errHostMsg := "hostname error" 167 | osHostname = func() (string, error) { 168 | return "", errors.New(errHostMsg) 169 | } 170 | errReadMsg := "read full error" 171 | ioReadFull = func(_ io.Reader, _ []byte) (int, error) { 172 | return 0, errors.New(errReadMsg) 173 | } 174 | t.Cleanup(func() { 175 | osHostname = os.Hostname 176 | ioReadFull = io.ReadFull 177 | }) 178 | 179 | readMachineID() 180 | }) 181 | } 182 | 183 | // This test methods exist for test coverage purpose 184 | // to validate the getter and setter 185 | func TestUtilMiscTestCoverage(t *testing.T) { 186 | l := &limitReadCloser{r: strings.NewReader("hello test close for no io.Closer")} 187 | assertNil(t, l.Close()) 188 | 189 | r := ©ReadCloser{s: strings.NewReader("hello test close for no io.Closer")} 190 | assertNil(t, r.Close()) 191 | 192 | v := struct { 193 | ID string `json:"id"` 194 | Message string `json:"message"` 195 | }{} 196 | err := decodeJSON(bytes.NewReader([]byte(`{\" \": \"some value\"}`)), &v) 197 | assertEqual(t, "invalid character '\\\\' looking for beginning of object key string", err.Error()) 198 | 199 | ireErr := &invalidRequestError{Err: errors.New("test coverage")} 200 | assertEqual(t, "test coverage", ireErr.Error()) 201 | } 202 | --------------------------------------------------------------------------------