├── .github └── workflows │ ├── build.yml │ ├── docker-release.yml │ └── release.yml ├── .gitignore ├── .goreleaser.yaml ├── Dockerfile ├── LICENSE ├── LICENSE-3rdparty.csv ├── Makefile ├── README.md ├── attache.jpg ├── cmd ├── attache │ └── main.go └── demo-runner │ └── main.go ├── demo ├── attache.sh ├── client.sh ├── config.yaml └── vault.sh ├── go.mod ├── go.sum ├── internal ├── cache │ ├── maintainer.go │ ├── maintainer_test.go │ ├── options.go │ ├── reporter.go │ ├── reporter_test.go │ ├── synchronization │ │ └── lock.go │ └── value.go ├── imds │ ├── aws.go │ ├── aws_session_cache.go │ ├── aws_session_cache_test.go │ ├── aws_test.go │ ├── azure.go │ ├── azure_subscription.go │ ├── azure_test.go │ ├── doc.go │ ├── gcp.go │ ├── gcp_test.go │ ├── metadataserver.go │ ├── metadataserver_test.go │ ├── providers.go │ ├── providers_test.go │ └── vault_test.go ├── rate │ ├── doc.go │ ├── rate.go │ └── rate_test.go ├── retry │ ├── retry.go │ └── retry_test.go ├── server │ ├── server.go │ └── server_test.go └── vault │ ├── client.go │ └── client_test.go └── scripts └── add-license-copyright.py /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Attaché 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up Go 15 | uses: actions/setup-go@v4 16 | with: 17 | go-version: '1.22' 18 | - name: Build 19 | run: go build -v ./... 20 | - name: Test 21 | run: go test -v ./... 22 | - name: Local docker build 23 | run: docker build . -t attache 24 | -------------------------------------------------------------------------------- /.github/workflows/docker-release.yml: -------------------------------------------------------------------------------- 1 | name: docker 2 | 3 | on: 4 | push: 5 | tags: 6 | - "*" 7 | 8 | env: 9 | REGISTRY: ghcr.io 10 | IMAGE_NAME: datadog/attache 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | docker-build-push: 17 | runs-on: ubuntu-latest 18 | permissions: 19 | contents: read 20 | packages: write 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v4 24 | with: 25 | fetch-depth: 0 26 | 27 | - name: Log into registry ${{ env.REGISTRY }} 28 | uses: docker/login-action@v3 29 | with: 30 | registry: ${{ env.REGISTRY }} 31 | username: ${{ github.actor }} 32 | password: ${{ secrets.GITHUB_TOKEN }} 33 | 34 | - name: Build and push Docker image 35 | uses: docker/build-push-action@v5 36 | with: 37 | context: . 38 | push: true 39 | tags: | 40 | ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.ref_name }} 41 | ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "*" 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | goreleaser: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 19 | 20 | - name: Set up Go 21 | uses: actions/setup-go@v4 22 | with: 23 | go-version: 1.22 24 | 25 | - name: Run GoReleaser 26 | timeout-minutes: 60 27 | uses: goreleaser/goreleaser-action@v6 28 | with: 29 | distribution: goreleaser 30 | args: release --clean --config .goreleaser.yaml 31 | env: 32 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | go.work.sum 23 | 24 | /attache 25 | /demo-runner 26 | /demo/vault-audit.log 27 | bin/ 28 | -------------------------------------------------------------------------------- /.goreleaser.yaml: -------------------------------------------------------------------------------- 1 | before: 2 | hooks: 3 | - go mod tidy 4 | builds: 5 | - env: 6 | - CGO_ENABLED=0 7 | goos: 8 | - linux 9 | - darwin 10 | dir: ./cmd/attache 11 | binary: attache 12 | archives: 13 | - name_template: >- 14 | {{ .ProjectName }}_ 15 | {{- title .Os }}_ 16 | {{- if eq .Arch "amd64" }}x86_64 17 | {{- else if eq .Arch "386" }}i386 18 | {{- else }}{{ .Arch }}{{ end }} 19 | checksum: 20 | name_template: 'checksums.txt' 21 | snapshot: 22 | name_template: "{{ incpatch .Version }}-next" 23 | changelog: 24 | sort: asc 25 | filters: 26 | exclude: 27 | - '^docs:' 28 | - '^test:' 29 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.22-alpine AS builder 2 | WORKDIR /build 3 | RUN apk add --update make 4 | ADD . /build 5 | RUN make attache 6 | 7 | FROM golang:1.22-alpine AS runner 8 | LABEL org.opencontainers.image.source="https://github.com/DataDog/attache/" 9 | COPY --from=builder /build/attache /attache 10 | ENTRYPOINT ["/attache"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-3rdparty.csv: -------------------------------------------------------------------------------- 1 | github.com/DataDog/appsec-internal-go,https://github.com/DataDog/appsec-internal-go/blob/v1.5.0/LICENSE,Apache-2.0,"Datadog, Inc." 2 | github.com/DataDog/attache,https://github.com/DataDog/attache/blob/HEAD/LICENSE,Apache-2.0,"Datadog, Inc." 3 | github.com/DataDog/datadog-agent/pkg/obfuscate,https://github.com/DataDog/datadog-agent/blob/pkg/obfuscate/v0.48.0/pkg/obfuscate/LICENSE,Apache-2.0,"Datadog, Inc." 4 | github.com/DataDog/datadog-agent/pkg/remoteconfig/state,https://github.com/DataDog/datadog-agent/blob/pkg/remoteconfig/state/v0.49.0-devel/pkg/remoteconfig/state/LICENSE,Apache-2.0,"Datadog, Inc." 5 | github.com/DataDog/datadog-go/v5/statsd,https://github.com/DataDog/datadog-go/blob/v5.5.0/LICENSE.txt,MIT,"Datadog, Inc." 6 | github.com/DataDog/go-libddwaf/v2,https://github.com/DataDog/go-libddwaf/blob/v2.4.2/LICENSE,Apache-2.0,"Datadog, Inc." 7 | github.com/DataDog/go-tuf,https://github.com/DataDog/go-tuf/blob/v1.0.2-0.5.2/LICENSE,BSD-3-Clause,"Datadog, Inc." 8 | github.com/DataDog/sketches-go/ddsketch,https://github.com/DataDog/sketches-go/blob/v1.4.2/LICENSE,Apache-2.0,"Datadog, Inc." 9 | github.com/aws/aws-sdk-go-v2,https://github.com/aws/aws-sdk-go-v2/blob/v1.25.2/LICENSE.txt,Apache-2.0,"Amazon.com, Inc. or its affiliates" 10 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds,https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.15.2/feature/ec2/imds/LICENSE.txt,Apache-2.0,"Amazon.com, Inc. or its affiliates" 11 | github.com/aws/aws-sdk-go-v2/internal/sync/singleflight,https://github.com/aws/aws-sdk-go-v2/blob/v1.25.2/internal/sync/singleflight/LICENSE,BSD-3-Clause,"Amazon.com, Inc. or its affiliates" 12 | github.com/aws/smithy-go,https://github.com/aws/smithy-go/blob/v1.20.1/LICENSE,Apache-2.0,"Amazon.com, Inc. or its affiliates" 13 | github.com/aws/smithy-go/internal/sync/singleflight,https://github.com/aws/smithy-go/blob/v1.20.1/internal/sync/singleflight/LICENSE,BSD-3-Clause,"Amazon.com, Inc. or its affiliates" 14 | github.com/cenkalti/backoff/v3,https://github.com/cenkalti/backoff/blob/v3.2.2/LICENSE,MIT,"Cenk Altı" 15 | github.com/cespare/xxhash/v2,https://github.com/cespare/xxhash/blob/v2.2.0/LICENSE.txt,MIT,"Caleb Spare" 16 | github.com/dustin/go-humanize,https://github.com/dustin/go-humanize/blob/v1.0.1/LICENSE,MIT,"Dustin Sallings " 17 | github.com/ebitengine/purego,https://github.com/ebitengine/purego/blob/v0.6.0-alpha.5/LICENSE,Apache-2.0,"Ebitengine" 18 | github.com/go-jose/go-jose/v4,https://github.com/go-jose/go-jose/blob/v4.0.1/LICENSE,Apache-2.0,"Square Inc. and The Go Authors" 19 | github.com/go-jose/go-jose/v4/json,https://github.com/go-jose/go-jose/blob/v4.0.1/json/LICENSE,BSD-3-Clause,"Square Inc. and The Go Authors" 20 | github.com/golang/protobuf/proto,https://github.com/golang/protobuf/blob/v1.5.4/LICENSE,BSD-3-Clause,"The Go Authors" 21 | github.com/google/uuid,https://github.com/google/uuid/blob/v1.6.0/LICENSE,BSD-3-Clause,"Google Inc." 22 | github.com/gorilla/mux,https://github.com/gorilla/mux/blob/v1.8.0/LICENSE,BSD-3-Clause,"The Gorilla Authors" 23 | github.com/hashicorp/errwrap,https://github.com/hashicorp/errwrap/blob/v1.1.0/LICENSE,MPL-2.0,"HashiCorp, Inc." 24 | github.com/hashicorp/go-cleanhttp,https://github.com/hashicorp/go-cleanhttp/blob/v0.5.2/LICENSE,MPL-2.0,"HashiCorp, Inc." 25 | github.com/hashicorp/go-immutable-radix,https://github.com/hashicorp/go-immutable-radix/blob/v1.3.1/LICENSE,MPL-2.0,"HashiCorp, Inc." 26 | github.com/hashicorp/go-metrics,https://github.com/hashicorp/go-metrics/blob/v0.5.3/LICENSE,MIT,"HashiCorp, Inc." 27 | github.com/hashicorp/go-multierror,https://github.com/hashicorp/go-multierror/blob/v1.1.1/LICENSE,MPL-2.0,"HashiCorp, Inc." 28 | github.com/hashicorp/go-retryablehttp,https://github.com/hashicorp/go-retryablehttp/blob/v0.7.6/LICENSE,MPL-2.0,"HashiCorp, Inc." 29 | github.com/hashicorp/go-rootcerts,https://github.com/hashicorp/go-rootcerts/blob/v1.0.2/LICENSE,MPL-2.0,"HashiCorp, Inc." 30 | github.com/hashicorp/go-secure-stdlib/parseutil,https://github.com/hashicorp/go-secure-stdlib/blob/parseutil/v0.1.8/parseutil/LICENSE,MPL-2.0,"HashiCorp, Inc." 31 | github.com/hashicorp/go-secure-stdlib/strutil,https://github.com/hashicorp/go-secure-stdlib/blob/strutil/v0.1.2/strutil/LICENSE,MPL-2.0,"HashiCorp, Inc." 32 | github.com/hashicorp/go-sockaddr,https://github.com/hashicorp/go-sockaddr/blob/v1.0.6/LICENSE,MPL-2.0,"HashiCorp, Inc." 33 | github.com/hashicorp/golang-lru/simplelru,https://github.com/hashicorp/golang-lru/blob/v1.0.2/LICENSE,MPL-2.0,"HashiCorp, Inc." 34 | github.com/hashicorp/hcl,https://github.com/hashicorp/hcl/blob/v1.0.1-vault-5/LICENSE,MPL-2.0,"HashiCorp, Inc." 35 | github.com/hashicorp/vault/api,https://github.com/hashicorp/vault/blob/api/v1.14.0/api/LICENSE,MPL-2.0,"HashiCorp, Inc." 36 | github.com/hashicorp/vault/sdk/helper/consts,https://github.com/hashicorp/vault/blob/sdk/v0.12.0/sdk/LICENSE,MPL-2.0,"HashiCorp, Inc." 37 | github.com/mitchellh/go-homedir,https://github.com/mitchellh/go-homedir/blob/v1.1.0/LICENSE,MIT,"Mitchell Hashimoto" 38 | github.com/mitchellh/mapstructure,https://github.com/mitchellh/mapstructure/blob/v1.5.0/LICENSE,MIT,"Mitchell Hashimoto" 39 | github.com/outcaste-io/ristretto,https://github.com/outcaste-io/ristretto/blob/v0.2.3/LICENSE,Apache-2.0,"Outcaste LLC" 40 | github.com/outcaste-io/ristretto/z,https://github.com/outcaste-io/ristretto/blob/v0.2.3/z/LICENSE,MIT,"Outcaste LLC" 41 | github.com/philhofer/fwd,https://github.com/philhofer/fwd/blob/v1.1.2/LICENSE.md,MIT,"Phil Hofer" 42 | github.com/pkg/errors,https://github.com/pkg/errors/blob/v0.9.1/LICENSE,BSD-2-Clause,"Dave Cheney " 43 | github.com/ryanuber/go-glob,https://github.com/ryanuber/go-glob/blob/v1.0.0/LICENSE,MIT,"Ryan Uber" 44 | github.com/secure-systems-lab/go-securesystemslib/cjson,https://github.com/secure-systems-lab/go-securesystemslib/blob/v0.8.0/LICENSE,MIT,"NYU Secure Systems Lab" 45 | github.com/tinylib/msgp/msgp,https://github.com/tinylib/msgp/blob/v1.1.8/LICENSE,MIT,"Philip Hofer and The Go Authors" 46 | go.uber.org/atomic,https://github.com/uber-go/atomic/blob/v1.11.0/LICENSE.txt,MIT,"Uber Technologies, Inc." 47 | go.uber.org/multierr,https://github.com/uber-go/multierr/blob/v1.11.0/LICENSE.txt,MIT,"Uber Technologies, Inc." 48 | go.uber.org/zap,https://github.com/uber-go/zap/blob/v1.27.0/LICENSE,MIT,"Uber Technologies, Inc." 49 | golang.org/x/crypto/pbkdf2,https://cs.opensource.google/go/x/crypto/+/v0.23.0:LICENSE,BSD-3-Clause,"The Go Authors" 50 | golang.org/x/net,https://cs.opensource.google/go/x/net/+/v0.25.0:LICENSE,BSD-3-Clause,"The Go Authors" 51 | golang.org/x/sys/unix,https://cs.opensource.google/go/x/sys/+/v0.20.0:LICENSE,BSD-3-Clause,"The Go Authors" 52 | golang.org/x/text,https://cs.opensource.google/go/x/text/+/v0.15.0:LICENSE,BSD-3-Clause,"The Go Authors" 53 | golang.org/x/time/rate,https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE,BSD-3-Clause,"The Go Authors" 54 | golang.org/x/xerrors,https://cs.opensource.google/go/x/xerrors/+/104605ab:LICENSE,BSD-3-Clause,"The Go Authors" 55 | google.golang.org/genproto/googleapis/rpc/status,https://github.com/googleapis/go-genproto/blob/fc5f0ca64291/googleapis/rpc/LICENSE,Apache-2.0,"Google Inc." 56 | google.golang.org/grpc,https://github.com/grpc/grpc-go/blob/v1.64.0/LICENSE,Apache-2.0,"Google Inc." 57 | google.golang.org/protobuf,https://github.com/protocolbuffers/protobuf-go/blob/v1.34.1/LICENSE,BSD-3-Clause,"Google Inc." 58 | gopkg.in/DataDog/dd-trace-go.v1,https://github.com/DataDog/dd-trace-go/blob/v1.64.0/LICENSE,Apache-2.0,"Datadog, Inc." 59 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | GOFILES:=$(shell find . -type f -iname '*.go') 2 | 3 | attache: $(GOFILES) 4 | go build -o attache ./cmd/attache/main.go 5 | 6 | demo-runner: $(GOFILES) 7 | go build -o demo-runner ./cmd/demo-runner/main.go 8 | 9 | thirdparty-licenses: 10 | @echo "Retrieving third-party licenses..." 11 | go get github.com/google/go-licenses 12 | go install github.com/google/go-licenses 13 | $(GOPATH)/bin/go-licenses csv github.com/DataDog/attache/cmd | sort > LICENSE-3rdparty.csv 14 | @echo "Third-party licenses retrieved and saved to $(ROOT_DIR)/LICENSE-3rdparty.csv" 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attaché 2 | 3 | [![made-with-Go](https://img.shields.io/badge/Made%20with-Go-1f425f.svg)](http://golang.org) 4 | 5 | Attaché provides an emulation layer for cloud provider instance metadata APIs, allowing for seamless multi-cloud IAM using Hashicorp Vault. 6 | 7 | More information can be found in the companion talk, [Freeing Identity from Infrastructure](https://www.youtube.com/watch?v=xifpJbTepCs). 8 | 9 |

10 | Attaché 11 |

12 | 13 | 14 | ## How it works 15 | 16 | 1. Attaché intercepts requests that applications perform to the cloud provider's instance metadata service (IMDS) 17 | 2. Attaché forwards these requests to a pre-configured cloud secrets backend of Hashicorp Vault to retrieve application-scoped cloud credentials 18 | 3. Finally, Attaché returns the requested credentials to the application 19 | 20 | ## Installation 21 | 22 | You can use the pre-built binaries from the [releases page](https://github.com/DataDog/attache/releases) or use the provided Docker image: 23 | 24 | ``` 25 | docker run --rm -it docker pull ghcr.io/datadog/attache 26 | ``` 27 | 28 | ## Sample usage 29 | 30 | In this example, we will use Attaché to have a local application that uses the AWS and Google Cloud SDKs to seamlessly retrieve cloud credentials. 31 | 32 | ### 1. Set up a Vault server 33 | 34 | ```bash 35 | vault server -dev -dev-root-token-id=local -log-level=DEBUG 36 | ``` 37 | ### 2. Set up your application roles in AWS and Google Cloud 38 | 39 | Create an AWS IAM role caled `application-role`, and a Google Cloud service account called `application-role` (they have to match). 40 | 41 | ### 3. Configure your Vault AWS secrets backend 42 | 43 | Let's mount an AWS secret backend. For this demo, we'll authenticate Vault to AWS using IAM user access keys, which is (to say the least) a bad practice not to follow in production: 44 | 45 | Create an AWS IAM user: 46 | 47 | ```bash 48 | accountId=$(aws sts get-caller-identity --query Account --output text) 49 | aws iam create-user --user-name vault-demo 50 | credentials=$(aws iam create-access-key --user-name vault-demo) 51 | accessKeyId=$(echo "$credentials" | jq -r '.AccessKey.AccessKeyId') 52 | secretAccessKey=$(echo "$credentials" | jq -r '.AccessKey.SecretAccessKey') 53 | ``` 54 | 55 | Allow Vault to assume the role we want to give our application: 56 | 57 | ```bash 58 | aws iam put-user-policy --user-name vault-demo --policy-name vault-demo --policy-document '{ 59 | "Version": "2012-10-17", 60 | "Statement": [ 61 | { 62 | "Effect": "Allow", 63 | "Action": "sts:AssumeRole", 64 | "Resource": "arn:aws:iam::'$accountId':role/application-role" 65 | }, 66 | { 67 | "Sid": "AllowVaultToRotateItsOwnCredentials", 68 | "Effect": "Allow", 69 | "Action": ["iam:GetUser", "iam:DeleteAccessKey", "iam:CreateAccessKey"], 70 | "Resource": "arn:aws:iam::'$accountId':user/vault-demo" 71 | } 72 | ] 73 | }' 74 | ``` 75 | 76 | Then we configure the Vault AWS credentials backend: 77 | 78 | ```bash 79 | # Point the Vault CLI to our local test server 80 | export VAULT_ADDR="http://127.0.0.1:8200" 81 | export VAULT_TOKEN="local" 82 | 83 | vault secrets enable -path cloud-iam/aws/0123456789012 aws 84 | vault write cloud-iam/aws/0123456789012/config/root access_key="$accessKeyId" secret_key="$secretAccessKey" 85 | vault write cloud-iam/aws/0123456789012/roles/application-role credential_type=assumed_role role_arns="arn:aws:iam::${accountId}:role/application-role" 86 | vault write -f cloud-iam/aws/0123456789012/config/rotate-root # rotate the IAM user access key so Vault only knows its own static credentials 87 | ``` 88 | 89 | We can confirm that Vault is able to retrieve our role credentials by using `vault read cloud-iam/aws/0123456789012/creds/application-role`. 90 | 91 | ### 4. Configure your Vault GCP secrets backend 92 | 93 | Let's mount a Google Cloud secret backend. For this demo, we'll authenticate Vault to GCP using a service account key, which is also suboptimal in production: 94 | 95 | ```bash 96 | project=$(gcloud config get-value project) 97 | vaultSa=vault-demo@$project.iam.gserviceaccount.com 98 | gcloud iam service-accounts create vault-demo 99 | gcloud iam service-accounts keys create gcp-creds.json --iam-account=$vaultSa 100 | 101 | # Allow the Vault service account to impersonate the application service account 102 | gcloud iam service-accounts add-iam-policy-binding application-role@$project.iam.gserviceaccount.com \ 103 | --role=roles/iam.serviceAccountTokenCreator \ 104 | --member=serviceAccount:$vaultSa 105 | ``` 106 | 107 | Then we configure the Vault GCP credentials backend, so it can access our prerequisite 108 | 109 | ```bash 110 | gcloud 111 | vault secrets enable -path cloud-iam/gcp/gcp-sandbox gcp 112 | vault write cloud-iam/gcp/gcp-sandbox/config credentials=@gcp-creds.json 113 | vault write cloud-iam/gcp/gcp-sandbox/impersonated-account/application-role service_account_email="application-role@gcp-sandbox.iam.gserviceaccount.com" token_scopes="https://www.googleapis.com/auth/cloud-platform" ttl="4h" 114 | ``` 115 | 116 | We can verify this worked by running `vault read cloud-iam/gcp/gcp-sandbox/impersonated-account/application-role/token` 117 | 118 | ### 5. Configure and run Attaché 119 | 120 | Let's create a configuration file for Attaché (see [Configuration reference](#configuration-reference)): 121 | 122 | ```yaml 123 | ## 124 | # Attaché global configuration 125 | ## 126 | server: 127 | bind_address: 127.0.0.1:8080 128 | graceful_timeout: 0s 129 | rate_limit: "" 130 | 131 | # We're running locally 132 | provider: "" 133 | region: "" 134 | zone: "" 135 | 136 | # AWS configuration 137 | aws_vault_mount_path: cloud-iam/aws/012345678901 138 | iam_role: application-role 139 | imds_v1_allowed: false 140 | 141 | # GCP configuration 142 | gcp_vault_mount_path: cloud-iam/gcp/gcp-sandbox 143 | gcp_project_ids: 144 | cloud-iam/gcp/gcp-sandbox: "712781682929" 145 | 146 | # Azure configuration (unused here) 147 | azure_vault_mount_path: "" 148 | ``` 149 | 150 | Then we can run Attaché: 151 | 152 | ``` 153 | $ export VAULT_ADDR="http://127.0.0.1:8200" 154 | $ export VAULT_TOKEN="local" 155 | $ attache ./config.yaml 156 | 2024-06-17T16:51:23.283+0200 DEBUG attache/main.go:35 loading configuration {"path": "./config.yaml"} 157 | 2024-06-17T16:51:23.283+0200 DEBUG attache/main.go:49 configuration loaded {"configuration": {"IamRole":"application-role","IMDSv1Allowed":false,"GcpVaultMountPath":"cloud-iam/gcp/gcp-sandbox","GcpProjectIds":{"cloud-iam/gcp/gcp-sandbox":"712781682929"},"AwsVaultMountPath":"cloud-iam/aws/012345678901","AzureVaultMountPath":"","ServerConfig":{"BindAddress":"127.0.0.1:8080","GracefulTimeout":0,"RateLimit":""},"Provider":"","Region":"","Zone":""}} 158 | 2024-06-17T16:51:23.284+0200 INFO cloud-iam-server server/server.go:110 server starting {"address": "127.0.0.1:8080"} 159 | ``` 160 | 161 | Note how we're able to manually retrieve credentials as if we were hitting the AWS IMDS, which Attaché emulates: 162 | 163 | ```bash 164 | $ IMDSV2_TOKEN=$(curl -XPUT localhost:8080/latest/api/token -H x-aws-ec2-metadata-token-ttl-seconds:21600) 165 | 166 | $ curl -H "X-aws-ec2-metadata-token: $IMDSV2_TOKEN" localhost:8080/latest/meta-data/iam/security-credentials/ 167 | application role 168 | 169 | $ curl -H "X-aws-ec2-metadata-token: $IMDSV2_TOKEN" localhost:8080/latest/meta-data/iam/security-credentials/application-role 170 | { 171 | "AccessKeyId": "ASIAZ3..", 172 | "Code": "Success", 173 | "SecretAccessKey": "liqX1...", 174 | "Token": "IQoJ...", 175 | } 176 | ``` 177 | 178 | Same as if we were hitting the GCP IMDS: 179 | 180 | ```bash 181 | $ curl -H Metadata-Flavor:Google localhost:8080/computeMetadata/v1/instance/service-accounts/ 182 | default/ 183 | application-role@gcp-sandbox.iam.gserviceaccount.com/ 184 | 185 | $ curl -H Metadata-Flavor:Google localhost:8080/computeMetadata/v1/instance/service-accounts/application-role@gcp-sandbox.iam.gserviceaccount.com/ 186 | default/ 187 | { 188 | "access_token": "ya29.c.c0AY_VpZ...", 189 | "token_type": "Bearer", 190 | "expires_in": 3597 191 | } 192 | ``` 193 | 194 | ### 6. Run your application 195 | 196 | Let's use the following application that lists AWS S3 and Google Cloud GCS buckets: 197 | 198 | ```python 199 | import boto3 200 | from google.cloud import storage 201 | 202 | def list_s3_buckets(): 203 | s3 = boto3.client('s3') 204 | 205 | response = s3.list_buckets() 206 | print(f"Found {len(response['Buckets'])} AWS S3 buckets!") 207 | 208 | def list_gcs_buckets(): 209 | client = storage.Client() 210 | 211 | buckets = client.list_buckets() 212 | print(f"Found {len(list(buckets))} GCS buckets!") 213 | 214 | list_s3_buckets() 215 | list_gcs_buckets() 216 | ``` 217 | 218 | We can set the required environment variables to point to Attaché: 219 | 220 | ```bash 221 | export AWS_EC2_METADATA_SERVICE_ENDPOINT="http://127.0.0.1:8080/" 222 | export GCE_METADATA_HOST="127.0.0.1:8080" 223 | ``` 224 | 225 | ... and then run it! 226 | 227 | ```bash 228 | pip install boto3 google-cloud-storage 229 | python app.py 230 | ``` 231 | 232 | We see: 233 | 234 | ```bash 235 | Found 154 AWS S3 buckets! 236 | Found 2 GCS buckets! 237 | ``` 238 | 239 | And in the Attaché logs: 240 | 241 | ``` 242 | 2024-06-17T17:23:15.463+0200 INFO cloud-iam-server server/server.go:170 request {"address": "127.0.0.1:8080", "path": "/latest/api/token", "method": "PUT", "userAgent": "Boto3/1.34.77 Python/3.10.13 Darwin/23.5.0 Botocore/1.34.80"} 243 | 2024-06-17T17:23:15.463+0200 INFO cloud-iam-server server/server.go:177 response {"address": "127.0.0.1:8080", "path": "/latest/api/token", "method": "PUT", "statusCode": 200, "userAgent": "Boto3/1.34.77 Python/3.10.13 Darwin/23.5.0 Botocore/1.34.80"} 244 | 2024-06-17T17:23:15.464+0200 INFO cloud-iam-server server/server.go:170 request {"address": "127.0.0.1:8080", "path": "/latest/meta-data/iam/security-credentials/", "method": "GET", "userAgent": "Boto3/1.34.77 Python/3.10.13 Darwin/23.5.0 Botocore/1.34.80"} 245 | 2024-06-17T17:23:15.465+0200 INFO cloud-iam-server server/server.go:177 response {"address": "127.0.0.1:8080", "path": "/latest/meta-data/iam/security-credentials/", "method": "GET", "statusCode": 200, "userAgent": "Boto3/1.34.77 Python/3.10.13 Darwin/23.5.0 Botocore/1.34.80"} 246 | 2024-06-17T17:23:15.466+0200 INFO cloud-iam-server server/server.go:170 request {"address": "127.0.0.1:8080", "path": "/latest/meta-data/iam/security-credentials/application-role", "method": "GET", "userAgent": "Boto3/1.34.77 Python/3.10.13 Darwin/23.5.0 Botocore/1.34.80"} 247 | 2024-06-17T17:23:15.895+0200 DEBUG token maintainer cache/maintainer.go:188 Updating cached value {"fetcher": "aws-sts-token-vault", "expiration": "2024-06-17T16:23:14.000Z"} 248 | 2024-06-17T17:23:15.895+0200 DEBUG token maintainer cache/maintainer.go:201 scheduling value refresh {"fetcher": "aws-sts-token-vault", "delay": "20m22.713030691s"} 249 | 2024-06-17T17:23:15.895+0200 INFO cloud-iam-server server/server.go:177 response {"address": "127.0.0.1:8080", "path": "/latest/meta-data/iam/security-credentials/application-role", "method": "GET", "statusCode": 200, "userAgent": "Boto3/1.34.77 Python/3.10.13 Darwin/23.5.0 Botocore/1.34.80"} 250 | ``` 251 | 252 | ## Considerations for running in production 253 | 254 | TBA 255 | 256 | ## Caching 257 | 258 | TBA 259 | 260 | ## Configuration reference 261 | 262 | ```yaml 263 | ## 264 | # Attaché global configuration 265 | ## 266 | server: 267 | bind_address: 127.0.0.1:8080 268 | graceful_timeout: 0s 269 | rate_limit: "" 270 | 271 | # If applicable, the current cloud environment where attaché is running 272 | provider: "" 273 | 274 | # If applicable, current cloud region (e.g., us-east-1a) where attaché is running 275 | region: "" 276 | 277 | # If applicable, current cloud availability zone (e.g., us-east-1a) where attaché is running 278 | zone: "" 279 | 280 | ## 281 | # AWS configuration 282 | ## 283 | 284 | # Vault path where the AWS secrets backend is mounted 285 | aws_vault_mount_path: cloud-iam/aws/012345678901 286 | 287 | # The AWS IAM role name that Attaché will assume to retrieve AWS credentials 288 | iam_role: my-role 289 | 290 | # Disable IMDSv1 291 | imds_v1_allowed: false 292 | 293 | ## 294 | # GCP configuration 295 | ## 296 | 297 | # Vault pathw here the Google Cloud secrets backend is mounted 298 | gcp_vault_mount_path: cloud-iam/gcp/my-gcp-sandbox 299 | 300 | # Mapping of Vault paths to Google Cloud project IDs 301 | gcp_project_ids: 302 | cloud-iam/gcp/datadog-sandbox: "012345678901" 303 | 304 | ## 305 | # Azure configuration 306 | ## 307 | azure_vault_mount_path: cloud-iam/azure/my-azure-role 308 | ``` 309 | -------------------------------------------------------------------------------- /attache.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataDog/attache/f9509b098ee87f6375cdd356f63035ff5dd68875/attache.jpg -------------------------------------------------------------------------------- /cmd/attache/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "os/signal" 8 | "syscall" 9 | 10 | "github.com/DataDog/attache/internal/imds" 11 | "github.com/DataDog/attache/internal/vault" 12 | "github.com/hashicorp/go-metrics" 13 | "go.uber.org/zap" 14 | "gopkg.in/yaml.v3" 15 | ) 16 | 17 | func main() { 18 | fmt.Println("starting attaché") 19 | 20 | osSignals := make(chan os.Signal, 1) 21 | signal.Notify(osSignals, syscall.SIGINT, syscall.SIGTERM) 22 | 23 | log, err := zap.NewDevelopment() 24 | if err != nil { 25 | fmt.Printf("could not initialize logger: %v\n", err) 26 | os.Exit(11) 27 | } 28 | 29 | if len(os.Args) != 2 { 30 | log.Error("usage: attache ") 31 | os.Exit(12) 32 | } 33 | 34 | filePath := os.Args[1] 35 | log.Debug("loading configuration", zap.String("path", filePath)) 36 | 37 | config := &imds.Config{} 38 | b, err := os.ReadFile(filePath) 39 | if err != nil { 40 | log.Error("unable to load configuration file", zap.String("path", filePath), zap.Error(err)) 41 | os.Exit(15) 42 | } 43 | err = yaml.Unmarshal(b, config) 44 | if err != nil { 45 | log.Error("unable to parse configuration file", zap.String("path", filePath), zap.Error(err)) 46 | os.Exit(19) 47 | } 48 | 49 | log.Debug("configuration loaded", zap.Any("configuration", config)) 50 | 51 | vConfig := vault.DefaultConfig() 52 | v, err := vault.NewClient(vConfig) 53 | 54 | server, closeFunc, err := imds.NewServer(context.Background(), &imds.MetadataServerConfig{ 55 | CloudiamConf: *config, 56 | DDVaultClient: v, 57 | MetricSink: &metrics.BlackholeSink{}, 58 | Log: log, 59 | }) 60 | if err != nil { 61 | fmt.Printf("could not initialize imds server: %v\n", err) 62 | os.Exit(19) 63 | } 64 | defer closeFunc() 65 | 66 | errs := make(chan error, 1) 67 | shutdown := server.Run(errs) 68 | defer shutdown() 69 | 70 | for { 71 | select { 72 | case err := <-errs: 73 | log.Error("attaché imds server error", zap.Error(err)) 74 | os.Exit(58) 75 | case sig := <-osSignals: 76 | log.Info("received os signal", zap.Stringer("os.Signal", sig)) 77 | return 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /cmd/demo-runner/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "os" 8 | 9 | "cloud.google.com/go/storage" 10 | ) 11 | 12 | func main() { 13 | ctx := context.Background() 14 | client, err := storage.NewClient(ctx) 15 | if err != nil { 16 | fmt.Printf("new client: %v", err) 17 | os.Exit(1) 18 | } 19 | 20 | // Read the object1 from bucket. 21 | rc, err := client.Bucket("emissary").Object("sherman.txt").NewReader(ctx) 22 | if err != nil { 23 | fmt.Printf("get object: %v\n", err) 24 | os.Exit(1) 25 | } 26 | defer rc.Close() 27 | body, err := io.ReadAll(rc) 28 | if err != nil { 29 | fmt.Printf("read object: %v\n", err) 30 | os.Exit(1) 31 | } 32 | 33 | fmt.Println(string(body)) 34 | } 35 | -------------------------------------------------------------------------------- /demo/attache.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -u 4 | set -o pipefail 5 | 6 | export VAULT_ADDR="http://127.0.0.1:8200" 7 | export VAULT_TOKEN="local" 8 | 9 | ./attache ./demo/config.yaml 10 | -------------------------------------------------------------------------------- /demo/client.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # raw commands to print tokens 4 | # curl --silent http://127.0.0.1:8080/v1/meta-data/iam/security-credentials/dd.frostbiteFalls_bullwinkle 5 | # curl -H "Metadata-Flavor: Google" 'http://127.0.0.1:8080/computeMetadata/v1/instance/service-accounts/default/token' 6 | 7 | # without pointing at our local attaché IMDS, these should both fail 8 | unset AWS_EC2_METADATA_SERVICE_ENDPOINT 9 | aws s3 cp s3://emissary/rocky.txt - 10 | 11 | # but pointing at attaché it will work 12 | export AWS_EC2_METADATA_SERVICE_ENDPOINT="http://127.0.0.1:8080/" 13 | aws s3 cp s3://emissary/rocky.txt - 14 | 15 | # same thing with a GCP golang SDK, without pointing at attache it will fail 16 | unset GCE_METADATA_HOST 17 | ./demo-runner 18 | 19 | # with GCE's metadata server env var set at attache it works: 20 | export GCE_METADATA_HOST="127.0.0.1:8080" 21 | ./demo-runner 22 | 23 | -------------------------------------------------------------------------------- /demo/config.yaml: -------------------------------------------------------------------------------- 1 | iam_role: frostbite-falls_bullwinkle 2 | imds_v1_allowed: false 3 | gcp_vault_mount_path: cloud-iam/gcp/datadog-sandbox 4 | gcp_project_ids: 5 | cloud-iam/gcp/datadog-sandbox: "958371799887" 6 | aws_vault_mount_path: cloud-iam/aws/601427279990 7 | azure_vault_mount_path: "" 8 | server: 9 | bind_address: 127.0.0.1:8080 10 | graceful_timeout: 0s 11 | rate_limit: "" 12 | provider: "" 13 | region: "" 14 | zone: "" 15 | 16 | -------------------------------------------------------------------------------- /demo/vault.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set -u 4 | set -o pipefail 5 | 6 | pkill -15 vault || true 7 | vault server -dev -dev-root-token-id=local -log-level=DEBUG & 8 | sleep 1 9 | 10 | export VAULT_ADDR="http://127.0.0.1:8200" 11 | export VAULT_TOKEN="local" 12 | 13 | if [[ -z "$V_AWS_ACCESS_KEY" ]]; then 14 | echo "AWS_ACCESS_KEY must be set" >&2 15 | exit 1 16 | fi 17 | if [[ -z "$V_AWS_SECRET_KEY" ]]; then 18 | echo "AWS_ACCESS_KEY must be set" >&2 19 | exit 1 20 | fi 21 | if [[ -z "$V_GCP_SERVICE_ACCOUNT_JSON" ]]; then 22 | echo "AWS_ACCESS_KEY must be set" >&2 23 | exit 1 24 | fi 25 | 26 | vault audit enable file file_path="./demo/vault-audit.log" 27 | 28 | vault secrets enable -path cloud-iam/aws/601427279990 aws 29 | vault write cloud-iam/aws/601427279990/config/root access_key="$V_AWS_ACCESS_KEY" secret_key="$V_AWS_SECRET_KEY" 30 | vault write cloud-iam/aws/601427279990/roles/frostbite-falls_bullwinkle credential_type=assumed_role role_arns="arn:aws:iam::601427279990:role/dd.frostbiteFalls_bullwinkle" 31 | 32 | vault secrets enable -path cloud-iam/gcp/datadog-sandbox gcp 33 | vault write cloud-iam/gcp/datadog-sandbox/config credentials="@$V_GCP_SERVICE_ACCOUNT_JSON" 34 | vault write cloud-iam/gcp/datadog-sandbox/impersonated-account/frostbite-falls_bullwinkle service_account_email="dd-frostbite-bullwinkl-3c9e72b@datadog-sandbox.iam.gserviceaccount.com" token_scopes="https://www.googleapis.com/auth/cloud-platform" ttl="4h" 35 | 36 | wait 37 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/DataDog/attache 2 | 3 | go 1.22.0 4 | 5 | require ( 6 | cloud.google.com/go/storage v1.42.0 7 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.2 8 | github.com/fatih/structs v1.1.0 9 | github.com/google/uuid v1.6.0 10 | github.com/gorilla/mux v1.8.0 11 | github.com/hashicorp/go-cleanhttp v0.5.2 12 | github.com/hashicorp/go-hclog v1.6.3 13 | github.com/hashicorp/go-metrics v0.5.3 14 | github.com/hashicorp/go-multierror v1.1.1 15 | github.com/hashicorp/go-retryablehttp v0.7.6 16 | github.com/hashicorp/vault v1.16.3 17 | github.com/hashicorp/vault/api v1.14.0 18 | github.com/hashicorp/vault/sdk v0.12.0 19 | github.com/mitchellh/mapstructure v1.5.0 20 | github.com/stretchr/testify v1.9.0 21 | go.uber.org/zap v1.27.0 22 | golang.org/x/net v0.25.0 23 | golang.org/x/time v0.5.0 24 | gopkg.in/DataDog/dd-trace-go.v1 v1.64.0 25 | gopkg.in/yaml.v3 v3.0.1 26 | ) 27 | 28 | require ( 29 | cloud.google.com/go v0.114.0 // indirect 30 | cloud.google.com/go/auth v0.5.1 // indirect 31 | cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect 32 | cloud.google.com/go/cloudsqlconn v1.4.3 // indirect 33 | cloud.google.com/go/compute/metadata v0.3.0 // indirect 34 | cloud.google.com/go/iam v1.1.8 // indirect 35 | cloud.google.com/go/kms v1.17.1 // indirect 36 | cloud.google.com/go/longrunning v0.5.7 // indirect 37 | cloud.google.com/go/monitoring v1.19.0 // indirect 38 | github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect 39 | github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.1 // indirect 40 | github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1 // indirect 41 | github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.1 // indirect 42 | github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.10.0 // indirect 43 | github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 // indirect 44 | github.com/Azure/go-autorest v14.2.1-0.20210210161804-c7f947c0610d+incompatible // indirect 45 | github.com/Azure/go-autorest/autorest v0.11.29 // indirect 46 | github.com/Azure/go-autorest/autorest/adal v0.9.23 // indirect 47 | github.com/Azure/go-autorest/autorest/azure/auth v0.5.12 // indirect 48 | github.com/Azure/go-autorest/autorest/azure/cli v0.4.6 // indirect 49 | github.com/Azure/go-autorest/autorest/date v0.3.0 // indirect 50 | github.com/Azure/go-autorest/autorest/to v0.4.0 // indirect 51 | github.com/Azure/go-autorest/autorest/validation v0.3.1 // indirect 52 | github.com/Azure/go-autorest/logger v0.2.1 // indirect 53 | github.com/Azure/go-autorest/tracing v0.6.0 // indirect 54 | github.com/AzureAD/microsoft-authentication-library-for-go v1.2.1 // indirect 55 | github.com/BurntSushi/toml v1.3.2 // indirect 56 | github.com/DataDog/appsec-internal-go v1.5.0 // indirect 57 | github.com/DataDog/datadog-agent/pkg/obfuscate v0.48.0 // indirect 58 | github.com/DataDog/datadog-agent/pkg/remoteconfig/state v0.49.0-devel // indirect 59 | github.com/DataDog/datadog-go v4.8.3+incompatible // indirect 60 | github.com/DataDog/datadog-go/v5 v5.5.0 // indirect 61 | github.com/DataDog/go-libddwaf/v2 v2.4.2 // indirect 62 | github.com/DataDog/go-tuf v1.0.2-0.5.2 // indirect 63 | github.com/DataDog/sketches-go v1.4.2 // indirect 64 | github.com/Jeffail/gabs v1.4.0 // indirect 65 | github.com/Masterminds/goutils v1.1.1 // indirect 66 | github.com/Masterminds/semver/v3 v3.2.1 // indirect 67 | github.com/Masterminds/sprig/v3 v3.2.3 // indirect 68 | github.com/Microsoft/go-winio v0.6.1 // indirect 69 | github.com/ProtonMail/go-crypto v0.0.0-20230923063757-afb1ddc0824c // indirect 70 | github.com/aliyun/alibaba-cloud-sdk-go v1.62.676 // indirect 71 | github.com/armon/go-metrics v0.4.1 // indirect 72 | github.com/armon/go-radix v1.0.0 // indirect 73 | github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect 74 | github.com/aws/aws-sdk-go v1.50.13 // indirect 75 | github.com/aws/aws-sdk-go-v2 v1.25.2 // indirect 76 | github.com/aws/aws-sdk-go-v2/credentials v1.17.6 // indirect 77 | github.com/aws/smithy-go v1.20.1 // indirect 78 | github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a // indirect 79 | github.com/benbjohnson/immutable v0.4.0 // indirect 80 | github.com/beorn7/perks v1.0.1 // indirect 81 | github.com/bgentry/speakeasy v0.1.0 // indirect 82 | github.com/boltdb/bolt v1.3.1 // indirect 83 | github.com/boombuler/barcode v1.0.1 // indirect 84 | github.com/cenkalti/backoff/v3 v3.2.2 // indirect 85 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect 86 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 87 | github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible // indirect 88 | github.com/circonus-labs/circonusllhist v0.1.3 // indirect 89 | github.com/cloudflare/circl v1.3.7 // indirect 90 | github.com/coreos/etcd v3.3.27+incompatible // indirect 91 | github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf // indirect 92 | github.com/coreos/pkg v0.0.0-20220810130054-c7d1c02cb6cf // indirect 93 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 94 | github.com/denverdino/aliyungo v0.0.0-20190125010748-a747050bb1ba // indirect 95 | github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc // indirect 96 | github.com/digitalocean/godo v1.7.5 // indirect 97 | github.com/dimchansky/utfbom v1.1.1 // indirect 98 | github.com/distribution/reference v0.6.0 // indirect 99 | github.com/docker/docker v25.0.5+incompatible // indirect 100 | github.com/docker/go-connections v0.4.0 // indirect 101 | github.com/docker/go-units v0.5.0 // indirect 102 | github.com/duosecurity/duo_api_golang v0.0.0-20190308151101-6c680f768e74 // indirect 103 | github.com/dustin/go-humanize v1.0.1 // indirect 104 | github.com/ebitengine/purego v0.6.0-alpha.5 // indirect 105 | github.com/emicklei/go-restful/v3 v3.11.0 // indirect 106 | github.com/evanphx/json-patch/v5 v5.6.0 // indirect 107 | github.com/fatih/color v1.16.0 // indirect 108 | github.com/felixge/httpsnoop v1.0.4 // indirect 109 | github.com/gammazero/deque v0.2.1 // indirect 110 | github.com/gammazero/workerpool v1.1.3 // indirect 111 | github.com/go-jose/go-jose/v3 v3.0.3 // indirect 112 | github.com/go-jose/go-jose/v4 v4.0.1 // indirect 113 | github.com/go-logr/logr v1.4.1 // indirect 114 | github.com/go-logr/stdr v1.2.2 // indirect 115 | github.com/go-ole/go-ole v1.2.6 // indirect 116 | github.com/go-openapi/analysis v0.21.4 // indirect 117 | github.com/go-openapi/errors v0.20.4 // indirect 118 | github.com/go-openapi/jsonpointer v0.20.0 // indirect 119 | github.com/go-openapi/jsonreference v0.20.2 // indirect 120 | github.com/go-openapi/loads v0.21.2 // indirect 121 | github.com/go-openapi/spec v0.20.9 // indirect 122 | github.com/go-openapi/strfmt v0.21.7 // indirect 123 | github.com/go-openapi/swag v0.22.4 // indirect 124 | github.com/go-openapi/validate v0.22.2 // indirect 125 | github.com/go-ozzo/ozzo-validation v3.6.0+incompatible // indirect 126 | github.com/go-sql-driver/mysql v1.7.1 // indirect 127 | github.com/go-test/deep v1.1.0 // indirect 128 | github.com/gogo/protobuf v1.3.2 // indirect 129 | github.com/golang-jwt/jwt/v4 v4.5.0 // indirect 130 | github.com/golang-jwt/jwt/v5 v5.2.0 // indirect 131 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 132 | github.com/golang/protobuf v1.5.4 // indirect 133 | github.com/golang/snappy v0.0.4 // indirect 134 | github.com/google/gnostic-models v0.6.8 // indirect 135 | github.com/google/go-cmp v0.6.0 // indirect 136 | github.com/google/go-metrics-stackdriver v0.2.0 // indirect 137 | github.com/google/go-querystring v1.1.0 // indirect 138 | github.com/google/gofuzz v1.2.0 // indirect 139 | github.com/google/s2a-go v0.1.7 // indirect 140 | github.com/google/tink/go v1.7.0 // indirect 141 | github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect 142 | github.com/googleapis/gax-go/v2 v2.12.4 // indirect 143 | github.com/gophercloud/gophercloud v0.1.0 // indirect 144 | github.com/hashicorp-forge/bbolt v1.3.8-hc3 // indirect 145 | github.com/hashicorp/cli v1.1.6 // indirect 146 | github.com/hashicorp/consul/sdk v0.14.1 // indirect 147 | github.com/hashicorp/errwrap v1.1.0 // indirect 148 | github.com/hashicorp/eventlogger v0.2.8 // indirect 149 | github.com/hashicorp/go-bexpr v0.1.12 // indirect 150 | github.com/hashicorp/go-discover v0.0.0-20210818145131-c573d69da192 // indirect 151 | github.com/hashicorp/go-immutable-radix v1.3.1 // indirect 152 | github.com/hashicorp/go-kms-wrapping/entropy/v2 v2.0.1 // indirect 153 | github.com/hashicorp/go-kms-wrapping/v2 v2.0.16 // indirect 154 | github.com/hashicorp/go-kms-wrapping/wrappers/aead/v2 v2.0.9 // indirect 155 | github.com/hashicorp/go-kms-wrapping/wrappers/alicloudkms/v2 v2.0.3 // indirect 156 | github.com/hashicorp/go-kms-wrapping/wrappers/awskms/v2 v2.0.9 // indirect 157 | github.com/hashicorp/go-kms-wrapping/wrappers/azurekeyvault/v2 v2.0.11 // indirect 158 | github.com/hashicorp/go-kms-wrapping/wrappers/gcpckms/v2 v2.0.12 // indirect 159 | github.com/hashicorp/go-kms-wrapping/wrappers/ocikms/v2 v2.0.7 // indirect 160 | github.com/hashicorp/go-kms-wrapping/wrappers/transit/v2 v2.0.11 // indirect 161 | github.com/hashicorp/go-memdb v1.3.4 // indirect 162 | github.com/hashicorp/go-msgpack/v2 v2.1.1 // indirect 163 | github.com/hashicorp/go-plugin v1.6.0 // indirect 164 | github.com/hashicorp/go-raftchunking v0.6.3-0.20191002164813-7e9e8525653a // indirect 165 | github.com/hashicorp/go-rootcerts v1.0.2 // indirect 166 | github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0 // indirect 167 | github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 // indirect 168 | github.com/hashicorp/go-secure-stdlib/mlock v0.1.3 // indirect 169 | github.com/hashicorp/go-secure-stdlib/parseutil v0.1.8 // indirect 170 | github.com/hashicorp/go-secure-stdlib/plugincontainer v0.3.0 // indirect 171 | github.com/hashicorp/go-secure-stdlib/reloadutil v0.1.1 // indirect 172 | github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect 173 | github.com/hashicorp/go-secure-stdlib/tlsutil v0.1.3 // indirect 174 | github.com/hashicorp/go-sockaddr v1.0.6 // indirect 175 | github.com/hashicorp/go-syslog v1.0.0 // indirect 176 | github.com/hashicorp/go-uuid v1.0.3 // indirect 177 | github.com/hashicorp/go-version v1.6.0 // indirect 178 | github.com/hashicorp/golang-lru v1.0.2 // indirect 179 | github.com/hashicorp/hcl v1.0.1-vault-5 // indirect 180 | github.com/hashicorp/hcp-sdk-go v0.75.0 // indirect 181 | github.com/hashicorp/mdns v1.0.4 // indirect 182 | github.com/hashicorp/raft v1.6.0 // indirect 183 | github.com/hashicorp/raft-autopilot v0.2.0 // indirect 184 | github.com/hashicorp/raft-boltdb/v2 v2.3.0 // indirect 185 | github.com/hashicorp/raft-snapshot v1.0.4 // indirect 186 | github.com/hashicorp/raft-wal v0.4.0 // indirect 187 | github.com/hashicorp/vault-plugin-secrets-kv v0.17.0 // indirect 188 | github.com/hashicorp/vic v1.5.1-0.20190403131502-bbfe86ec9443 // indirect 189 | github.com/hashicorp/yamux v0.1.1 // indirect 190 | github.com/huandu/xstrings v1.4.0 // indirect 191 | github.com/imdario/mergo v0.3.16 // indirect 192 | github.com/jackc/chunkreader/v2 v2.0.1 // indirect 193 | github.com/jackc/pgconn v1.14.3 // indirect 194 | github.com/jackc/pgio v1.0.0 // indirect 195 | github.com/jackc/pgpassfile v1.0.0 // indirect 196 | github.com/jackc/pgproto3/v2 v2.3.3 // indirect 197 | github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect 198 | github.com/jackc/pgtype v1.14.3 // indirect 199 | github.com/jackc/pgx/v4 v4.18.3 // indirect 200 | github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f // indirect 201 | github.com/jefferai/jsonx v1.0.0 // indirect 202 | github.com/jmespath/go-jmespath v0.4.0 // indirect 203 | github.com/josharian/intern v1.0.0 // indirect 204 | github.com/joshlf/go-acl v0.0.0-20200411065538-eae00ae38531 // indirect 205 | github.com/joyent/triton-go v1.7.1-0.20200416154420-6801d15b779f // indirect 206 | github.com/json-iterator/go v1.1.12 // indirect 207 | github.com/kelseyhightower/envconfig v1.4.0 // indirect 208 | github.com/klauspost/compress v1.17.2 // indirect 209 | github.com/kylelemons/godebug v1.1.0 // indirect 210 | github.com/linode/linodego v0.7.1 // indirect 211 | github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect 212 | github.com/mailru/easyjson v0.7.7 // indirect 213 | github.com/mattn/go-colorable v0.1.13 // indirect 214 | github.com/mattn/go-isatty v0.0.20 // indirect 215 | github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect 216 | github.com/miekg/dns v1.1.55 // indirect 217 | github.com/mitchellh/copystructure v1.2.0 // indirect 218 | github.com/mitchellh/go-homedir v1.1.0 // indirect 219 | github.com/mitchellh/go-testing-interface v1.14.1 // indirect 220 | github.com/mitchellh/pointerstructure v1.2.1 // indirect 221 | github.com/mitchellh/reflectwalk v1.0.2 // indirect 222 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 223 | github.com/modern-go/reflect2 v1.0.2 // indirect 224 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 225 | github.com/nicolai86/scaleway-sdk v1.10.2-0.20180628010248-798f60e20bb2 // indirect 226 | github.com/oklog/run v1.1.0 // indirect 227 | github.com/oklog/ulid v1.3.1 // indirect 228 | github.com/okta/okta-sdk-golang/v2 v2.12.1 // indirect 229 | github.com/opencontainers/go-digest v1.0.0 // indirect 230 | github.com/opencontainers/image-spec v1.1.0 // indirect 231 | github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect 232 | github.com/oracle/oci-go-sdk/v60 v60.0.0 // indirect 233 | github.com/outcaste-io/ristretto v0.2.3 // indirect 234 | github.com/packethost/packngo v0.1.1-0.20180711074735-b9cb5096f54c // indirect 235 | github.com/patrickmn/go-cache v2.1.0+incompatible // indirect 236 | github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect 237 | github.com/philhofer/fwd v1.1.2 // indirect 238 | github.com/pierrec/lz4 v2.6.1+incompatible // indirect 239 | github.com/pires/go-proxyproto v0.6.1 // indirect 240 | github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect 241 | github.com/pkg/errors v0.9.1 // indirect 242 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 243 | github.com/posener/complete v1.2.3 // indirect 244 | github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect 245 | github.com/pquerna/otp v1.2.1-0.20191009055518-468c2dd2b58d // indirect 246 | github.com/prometheus/client_golang v1.16.0 // indirect 247 | github.com/prometheus/client_model v0.5.0 // indirect 248 | github.com/prometheus/common v0.44.0 // indirect 249 | github.com/prometheus/procfs v0.12.0 // indirect 250 | github.com/rboyer/safeio v0.2.1 // indirect 251 | github.com/renier/xmlrpc v0.0.0-20170708154548-ce4a1a486c03 // indirect 252 | github.com/rogpeppe/go-internal v1.12.0 // indirect 253 | github.com/ryanuber/go-glob v1.0.0 // indirect 254 | github.com/sasha-s/go-deadlock v0.2.1-0.20190427202633-1595213edefa // indirect 255 | github.com/secure-systems-lab/go-securesystemslib v0.8.0 // indirect 256 | github.com/segmentio/fasthash v1.0.3 // indirect 257 | github.com/sethvargo/go-limiter v0.7.1 // indirect 258 | github.com/shirou/gopsutil/v3 v3.23.5 // indirect 259 | github.com/shoenig/go-m1cpu v0.1.6 // indirect 260 | github.com/shopspring/decimal v1.3.1 // indirect 261 | github.com/sirupsen/logrus v1.9.3 // indirect 262 | github.com/softlayer/softlayer-go v0.0.0-20180806151055-260589d94c7d // indirect 263 | github.com/sony/gobreaker v0.5.0 // indirect 264 | github.com/spf13/cast v1.6.0 // indirect 265 | github.com/spf13/pflag v1.0.5 // indirect 266 | github.com/stretchr/objx v0.5.2 // indirect 267 | github.com/tencentcloud/tencentcloud-sdk-go v3.0.171+incompatible // indirect 268 | github.com/tinylib/msgp v1.1.8 // indirect 269 | github.com/tklauser/go-sysconf v0.3.11 // indirect 270 | github.com/tklauser/numcpus v0.6.0 // indirect 271 | github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c // indirect 272 | github.com/vmware/govmomi v0.20.3 // indirect 273 | github.com/yusufpapurcu/wmi v1.2.3 // indirect 274 | go.etcd.io/bbolt v1.3.7 // indirect 275 | go.mongodb.org/mongo-driver v1.13.1 // indirect 276 | go.opencensus.io v0.24.0 // indirect 277 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect 278 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect 279 | go.opentelemetry.io/otel v1.27.0 // indirect 280 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.27.0 // indirect 281 | go.opentelemetry.io/otel/metric v1.27.0 // indirect 282 | go.opentelemetry.io/otel/sdk v1.27.0 // indirect 283 | go.opentelemetry.io/otel/trace v1.27.0 // indirect 284 | go.uber.org/atomic v1.11.0 // indirect 285 | go.uber.org/multierr v1.11.0 // indirect 286 | golang.org/x/crypto v0.23.0 // indirect 287 | golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect 288 | golang.org/x/mod v0.16.0 // indirect 289 | golang.org/x/oauth2 v0.21.0 // indirect 290 | golang.org/x/sync v0.7.0 // indirect 291 | golang.org/x/sys v0.20.0 // indirect 292 | golang.org/x/term v0.20.0 // indirect 293 | golang.org/x/text v0.15.0 // indirect 294 | golang.org/x/tools v0.19.0 // indirect 295 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect 296 | google.golang.org/api v0.183.0 // indirect 297 | google.golang.org/genproto v0.0.0-20240528184218-531527333157 // indirect 298 | google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 // indirect 299 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157 // indirect 300 | google.golang.org/grpc v1.64.0 // indirect 301 | google.golang.org/protobuf v1.34.1 // indirect 302 | gopkg.in/inf.v0 v0.9.1 // indirect 303 | gopkg.in/ini.v1 v1.67.0 // indirect 304 | gopkg.in/resty.v1 v1.12.0 // indirect 305 | gopkg.in/square/go-jose.v2 v2.6.0 // indirect 306 | gopkg.in/yaml.v2 v2.4.0 // indirect 307 | k8s.io/api v0.29.1 // indirect 308 | k8s.io/apimachinery v0.29.1 // indirect 309 | k8s.io/client-go v0.29.1 // indirect 310 | k8s.io/klog/v2 v2.110.1 // indirect 311 | k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 // indirect 312 | k8s.io/utils v0.0.0-20230726121419-3b25d923346b // indirect 313 | sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect 314 | sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect 315 | sigs.k8s.io/yaml v1.3.0 // indirect 316 | ) 317 | -------------------------------------------------------------------------------- /internal/cache/maintainer.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "math/rand" 8 | "reflect" 9 | "sync" 10 | "time" 11 | 12 | "github.com/DataDog/attache/internal/cache/synchronization" 13 | "github.com/DataDog/attache/internal/retry" 14 | "github.com/hashicorp/go-metrics" 15 | "go.uber.org/zap" 16 | ) 17 | 18 | const ( 19 | delayErrorMsg = "initial delay of %v for maintainer cannot be <= 0" 20 | ) 21 | 22 | var ( 23 | statsdCacheMaintainerStem = []string{"cache", "maintainer"} 24 | statsdCacheMaintainerExecute = append(statsdCacheMaintainerStem, "execute") 25 | statsdCacheMaintainerGet = append(statsdCacheMaintainerStem, "get") 26 | statsdCacheMaintainerUpdate = append(statsdCacheMaintainerStem, "update") 27 | statsdCacheMaintainerScheduled = append(statsdCacheMaintainerStem, "scheduled") 28 | statsdCacheMaintainerExpirationTTL = append(statsdCacheMaintainerStem, "expiration", "ttl") 29 | statsdCacheMaintainerNextRefresh = append(statsdCacheMaintainerStem, "refresh", "ttl") 30 | statsdCacheMaintainerState = append(statsdCacheMaintainerStem, "state") 31 | statsdCacheMaintainerExpired = append(statsdCacheMaintainerStem, "expired") 32 | statsdCacheMaintainerRunning = metrics.Label{Name: "state", Value: "running"} 33 | statsdCacheMaintainerStopped = metrics.Label{Name: "state", Value: "stopped"} 34 | statsdCacheMaintainerInit = metrics.Label{Name: "state", Value: "init"} 35 | cacheHit = metrics.Label{Name: "cache", Value: "hit"} 36 | cacheMiss = metrics.Label{Name: "cache", Value: "miss"} 37 | cacheError = metrics.Label{Name: "cache", Value: "error"} 38 | cacheFalse = metrics.Label{Name: "cache", Value: "false"} 39 | statusFailed = metrics.Label{Name: "status", Value: "failed"} 40 | statusSuccess = metrics.Label{Name: "status", Value: "success"} 41 | ) 42 | 43 | // Maintainer is a Refresh-Ahead Cache used to ensure that a ExpiringValue is kept valid for a given RefreshAtFunc. 44 | // Any call to execute must be done under the syncLock to ensure there is no concurrent calls to the provider 45 | type Maintainer[T any] struct { 46 | config[T] //nolint:unused,nolintlint 47 | fetcher Fetcher[T] 48 | refreshAtFunc RefreshAtFunc 49 | 50 | cachedValue cachedValue[T] //nolint:unused,nolintlint 51 | 52 | // Used to lock updates to all attributes of the maintainer except 53 | // isClosed 54 | syncLock synchronization.CancellableLock 55 | // locks updates to isClosed 56 | isClosedLock sync.Mutex 57 | refreshAt time.Time 58 | isMaintaining bool 59 | isClosed bool 60 | maintainerCtx context.Context 61 | maintainerCtxCancel context.CancelFunc 62 | wg sync.WaitGroup 63 | 64 | metricsReporter metricsReporter 65 | } 66 | 67 | func NewMaintainer[T any](fetcher Fetcher[T], refreshAtFunc RefreshAtFunc, options ...option) *Maintainer[T] { 68 | m := &Maintainer[T]{ 69 | fetcher: fetcher, 70 | refreshAtFunc: refreshAtFunc, 71 | syncLock: *synchronization.NewCancellableLock(), 72 | } 73 | for _, opt := range options { 74 | opt(&m.config) 75 | } 76 | 77 | if log := m.log; log != nil { 78 | m.log = log.With(zap.String("fetcher", fetcher.String())) 79 | m.retryOpts = append(m.retryOpts, retry.Logger(log)) 80 | } else { 81 | m.log = zap.NewNop() 82 | } 83 | if m.metricSink == nil { 84 | m.metricSink = &metrics.BlackholeSink{} 85 | } 86 | // Default retry delay after expiration to 10s 87 | if m.retryAfterExpirationDelay <= 0 { 88 | m.retryAfterExpirationDelay = 10 * time.Second 89 | } 90 | 91 | // Creates a context tied to the maintainer lifecycle. 92 | // It is used to track both metrics reporting and renewal loop 93 | m.maintainerCtx, m.maintainerCtxCancel = context.WithCancel(context.Background()) 94 | 95 | // start a background routine tied to the maintainer's context to update TTL metrics 96 | m.wg.Add(1) 97 | go func() { 98 | defer m.wg.Done() 99 | m.metricsReporter.run(m.maintainerCtx, m.metricSink, m.tags()) 100 | }() 101 | 102 | return m 103 | } 104 | 105 | func (m *Maintainer[T]) IsClosed() bool { 106 | m.isClosedLock.Lock() 107 | defer m.isClosedLock.Unlock() 108 | return m.isClosed 109 | } 110 | 111 | func (m *Maintainer[T]) Close() { 112 | m.isClosedLock.Lock() 113 | defer m.isClosedLock.Unlock() 114 | if m.isClosed { 115 | return 116 | } 117 | m.maintainerCtxCancel() 118 | m.wg.Wait() 119 | // ignore returned error: https://github.com/uber-go/zap/issues/328 120 | _ = m.log.Sync() 121 | m.isClosed = true 122 | } 123 | 124 | func isNil[T any](t T) bool { 125 | tType := reflect.TypeOf(t) 126 | if tType == nil { 127 | return true 128 | } 129 | switch tType.Kind() { 130 | case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.UnsafePointer, reflect.Interface, reflect.Slice: 131 | return reflect.ValueOf(t).IsNil() 132 | } 133 | return false 134 | } 135 | 136 | // Get returns the ExpiringValue returned by Fetcher until the ExpiringValue expires. 137 | func (m *Maintainer[T]) Get(ctx context.Context) (T, error) { 138 | if val, ok := m.cachedValue.getValue(); ok && !isNil(val) { 139 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerGet, 1.0, append(m.tags(), cacheHit)) 140 | return val, nil 141 | } 142 | 143 | // Attempts to syncLock the maintainer, returns if the context is cancelled first 144 | err := m.syncLock.LockIfNotCancelled(ctx) 145 | if err != nil { 146 | var empty T 147 | return empty, err 148 | } 149 | defer m.syncLock.Unlock() 150 | 151 | // check if the value has been cached between syncLock waiting and syncLock acquisition 152 | if val, ok := m.cachedValue.getValue(); ok && !isNil(val) { 153 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerGet, 1.0, append(m.tags(), cacheHit)) 154 | return val, nil 155 | } 156 | 157 | // execute directly when `Get` is called to defer to the client (aws sdk, vault, etc)'s 158 | // preferences for retry policies 159 | s, err := m.execute(ctx) 160 | if err != nil { 161 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerGet, 1.0, append(m.tags(), cacheError)) 162 | if m.errorHandler != nil { 163 | m.errorHandler(err) 164 | } 165 | var empty T 166 | return empty, err 167 | } 168 | 169 | // on success, start a background refresh loop 170 | if !m.isMaintaining && s.ExpiresAt != (time.Time{}) { 171 | delay := time.Until(m.refreshAt) 172 | if delay <= 0 { 173 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerGet, 1.0, append(m.tags(), cacheError)) 174 | var empty T 175 | return empty, fmt.Errorf(delayErrorMsg, delay) 176 | } 177 | 178 | // uses the main maintainer context -- the one passed in as an argument to Get is tied 179 | // to the caller's request context, we don't want to tie our background to that 180 | m.schedule(m.maintainerCtx, delay) 181 | } 182 | 183 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerGet, 1.0, append(m.tags(), cacheMiss)) 184 | return s.Value, nil 185 | } 186 | 187 | func (m *Maintainer[T]) updateCacheValue(value *ExpiringValue[T]) { 188 | m.log.Debug("Updating cached value", zap.Time("expiration", value.ExpiresAt)) 189 | m.cachedValue.updateValue(value) 190 | m.metricsReporter.setExpiresAt(value.ExpiresAt) 191 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerUpdate, 1.0, m.tags()) 192 | if m.updateHandler != nil { 193 | // Passes the value by copy to ensure times are not altered. 194 | m.updateHandler(value.Value) 195 | } 196 | } 197 | 198 | func (m *Maintainer[T]) schedule(ctx context.Context, delay time.Duration) { 199 | tags := m.tags() 200 | 201 | m.log.Debug("scheduling value refresh", zap.Duration("delay", delay)) 202 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerScheduled, 1.0, tags) 203 | 204 | m.isMaintaining = true 205 | m.metricsReporter.setMaintaining(true) 206 | m.wg.Add(1) 207 | go func() { 208 | defer m.wg.Done() 209 | defer func() { 210 | m.syncLock.Lock() 211 | m.isMaintaining = false 212 | m.syncLock.Unlock() 213 | m.metricsReporter.setMaintaining(false) 214 | }() 215 | 216 | ticker := time.NewTicker(delay) 217 | defer ticker.Stop() 218 | for { 219 | select { 220 | case <-ctx.Done(): 221 | m.log.Info("Schedule loop has been cancelled, exiting") 222 | return 223 | case <-ticker.C: 224 | err := retry.Do(ctx, func() error { 225 | err := m.syncLock.LockIfNotCancelled(ctx) 226 | if err != nil { 227 | // The context was cancelled, there is no point in running the loop 228 | return err 229 | } 230 | defer m.syncLock.Unlock() 231 | 232 | // cachedValue was already refreshed by Get(context.Context) 233 | if time.Now().Before(m.refreshAt) { 234 | return nil 235 | } 236 | 237 | _, err = m.execute(ctx) 238 | if err != nil { 239 | return err 240 | } 241 | 242 | return nil 243 | }, m.retryOpts...) 244 | 245 | if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { 246 | m.log.Warn("Failed to refresh the value", zap.Error(err)) 247 | if m.errorHandler != nil { 248 | m.errorHandler(err) 249 | } 250 | } 251 | 252 | // m.refreshAt is protected only under this syncLock 253 | lockErr := m.syncLock.LockIfNotCancelled(ctx) 254 | if lockErr != nil { 255 | m.log.Info("Schedule loop has been cancelled, exiting") 256 | return 257 | } 258 | // Get a constant vision of now to avoid edge cases here 259 | now := time.Now() 260 | 261 | expiresAt := m.cachedValue.expiresAt() 262 | if expiresAt.Before(now) { 263 | // we were unable to get a credential before our current one expired. 264 | // switch to a linear retry to avoid hammering the provider on recovery while allowing for a fast enough recovery. 265 | m.refreshAt = now.Add(m.retryAfterExpirationDelay) 266 | m.log.Warn("new value could not be retrieved before value expiration, refresh has been rescheduled", zap.Time("refreshAt", m.refreshAt)) 267 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerExpired, 1.0, append(m.tags(), cacheError)) 268 | } else if m.refreshAt.Before(now) { 269 | m.refreshAt = m.refreshAtFunc(now, expiresAt) 270 | if m.refreshAt.Before(now) { 271 | // User method is returning a refresh time in the past, switch to linear maintaining and raise an error 272 | m.refreshAt = now.Add(m.retryAfterExpirationDelay) 273 | m.log.Error("computed refreshAt time is still in the past, refresh has been rescheduled", zap.Time("refreshAt", m.refreshAt)) 274 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerUpdate, 1.0, append(m.tags(), cacheError)) 275 | } 276 | m.log.Warn("new value could not be retrieved before initial refresh deadline, refresh has been rescheduled", zap.Time("refreshAt", m.refreshAt)) 277 | } 278 | 279 | m.metricsReporter.setRefreshAt(m.refreshAt) 280 | delay := time.Until(m.refreshAt) 281 | m.log.Debug("scheduling value refresh", zap.Duration("delay", delay)) 282 | ticker.Reset(delay) 283 | m.syncLock.Unlock() 284 | 285 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerScheduled, 1.0, tags) 286 | } 287 | } 288 | }() 289 | } 290 | 291 | func (m *Maintainer[T]) execute(ctx context.Context) (*ExpiringValue[T], error) { 292 | timeNow := time.Now() 293 | value, err := m.fetcher.Fetch(ctx) 294 | if err != nil { 295 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerExecute, 1.0, append(m.tags(), statusFailed)) 296 | 297 | return nil, err 298 | } 299 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerExecute, 1.0, append(m.tags(), statusSuccess)) 300 | 301 | if value == nil { 302 | return nil, errors.New("fetcher returned a nil value and nil error") 303 | } 304 | 305 | // Do not cache value or set refreshAt if ExpiresAt is empty. 306 | // 307 | // If the Maintainer is refreshing cached credentials, let the Maintainer continue 308 | // trying to refresh what was previously cached (using retryAfterExpirationDelay). 309 | if value.ExpiresAt == (time.Time{}) { 310 | m.metricSink.IncrCounterWithLabels(statsdCacheMaintainerUpdate, 1.0, append(m.tags(), cacheFalse)) 311 | return value, nil 312 | } 313 | 314 | m.updateCacheValue(value) 315 | m.refreshAt = m.refreshAtFunc(timeNow, value.ExpiresAt) 316 | m.metricsReporter.setRefreshAt(m.refreshAt) 317 | return value, nil 318 | } 319 | 320 | // tags returns a new slice containing metrics tags related to the current maintainer's operations 321 | func (m *Maintainer[T]) tags() []metrics.Label { 322 | return []metrics.Label{{Name: "fetcher", Value: m.fetcher.String()}} 323 | } 324 | 325 | type Fetcher[T any] interface { 326 | // Stringer interface for returning a unique identifier 327 | fmt.Stringer 328 | 329 | // Fetch returns a valid ExpiringValue 330 | Fetch(ctx context.Context) (*ExpiringValue[T], error) 331 | } 332 | 333 | // RefreshAtFunc returns the time.Time to re-fetch the ExpiringValue. This returned time.Time should never be 334 | // in the past. 335 | type RefreshAtFunc func(notBefore time.Time, notAfter time.Time) time.Time 336 | 337 | // NewPercentageRemainingRefreshAt is a basic RefreshAtFunc for calculating the time.Time at which to 338 | // fetch a ExpiringValue based on the remaining percentage of the lifetime of a ExpiringValue. 339 | // 340 | // renewAfterPercentage is the percentage of time remaining at which point a ExpiringValue should be re-fetched. 341 | // jitterPercentage is the maximum percentage of (notAfter.Sub(notBefore) * renewAfterPercentage) to use for calculating a jitter value. 342 | // 343 | // i.e. if notBefore and notAfter are 60s apart, renewAfterPercentage is 0.33, and jitter percentage is 10%, 344 | // the non jittered time will be 60s * 0.33, or 20s from now, with an added jitter of 0-2 (20s * 10%) seconds, for a 345 | // final return value evenly spaced between 20 and 22 seconds from now. 346 | func NewPercentageRemainingRefreshAt(renewAfterPercentage float64, jitterPercentage float64) RefreshAtFunc { 347 | return func(notBefore time.Time, notAfter time.Time) time.Time { 348 | duration := notAfter.Sub(notBefore) 349 | 350 | renewAt := duration.Seconds() * renewAfterPercentage 351 | jitter := (duration.Seconds() - renewAt) * jitterPercentage * rand.Float64() 352 | 353 | return notBefore.Add(time.Duration((renewAt + jitter) * float64(time.Second))) 354 | } 355 | } 356 | -------------------------------------------------------------------------------- /internal/cache/options.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/DataDog/attache/internal/retry" 7 | "github.com/hashicorp/go-metrics" 8 | "go.uber.org/zap" 9 | ) 10 | 11 | // CacheUpdateHandler is called every time the cached value is updated if requested as an option 12 | // The value provided is the ExpiringValue construct which is itself containing the value within the cache 13 | // The handler must not modify the value provided 14 | type CacheUpdateHandler[T any] func(T) 15 | 16 | // ErrorHandler is called every time an error is returned when attempting to update the cached value 17 | type ErrorHandler func(error) 18 | 19 | // Go generics are currently limited when using functions. 20 | // This prevents passing Options potentially referring to the type without having all options typed. 21 | // To avoid pushing this cumbersome syntax on users (e.g. having to write WithLogger[T]), 22 | // we abstract it here, knowing that all those are internal types and the user only deals with the main calls. 23 | type optionConfig interface { 24 | setLogger(*zap.Logger) 25 | setMetricsSink(metrics.MetricSink) 26 | setRetryOptions([]retry.Option) 27 | setRetryAfterExpirarionDelay(time.Duration) 28 | setErrorHandler(ErrorHandler) 29 | } 30 | 31 | type typedOptionConfig[T any] interface { 32 | setUpdateHandler(CacheUpdateHandler[T]) 33 | } 34 | 35 | type config[T any] struct { 36 | retryOpts []retry.Option 37 | retryAfterExpirationDelay time.Duration 38 | log *zap.Logger 39 | metricSink metrics.MetricSink 40 | errorHandler ErrorHandler 41 | updateHandler CacheUpdateHandler[T] //nolint:unused,nolintlint 42 | } 43 | 44 | var _ typedOptionConfig[string] = &config[string]{} 45 | 46 | func (c *config[T]) setLogger(log *zap.Logger) { 47 | c.log = log 48 | } 49 | 50 | func (c *config[T]) setMetricsSink(metricSink metrics.MetricSink) { 51 | c.metricSink = metricSink 52 | } 53 | 54 | func (c *config[T]) setRetryOptions(retryOpts []retry.Option) { 55 | c.retryOpts = retryOpts 56 | } 57 | 58 | func (c *config[T]) setRetryAfterExpirarionDelay(delay time.Duration) { 59 | c.retryAfterExpirationDelay = delay 60 | } 61 | 62 | func (c *config[T]) setErrorHandler(errorHandler ErrorHandler) { 63 | c.errorHandler = errorHandler 64 | } 65 | 66 | //lint:ignore U1000 linters are having a hard time with generics 67 | func (c *config[T]) setUpdateHandler(updateHandler CacheUpdateHandler[T]) { 68 | c.updateHandler = updateHandler 69 | } 70 | 71 | type option func(interface{}) 72 | 73 | func WithRetryOptions(retryOpts []retry.Option) option { 74 | return func(param interface{}) { 75 | if c, ok := param.(optionConfig); ok { 76 | c.setRetryOptions(retryOpts) 77 | } 78 | } 79 | } 80 | 81 | // If a value is not renewed prior to its expiration, the maintainer uses a linear retry with a default delay of 10s. 82 | // WithRetryAfterExpirarionDelay sets this delay to another value if desired by the user. 83 | // Decreasing the delay does speedup recovery once the provider is available, but also increases the impact of the provider on recovery. 84 | func WithRetryAfterExpirarionDelay(delay time.Duration) option { 85 | return func(param interface{}) { 86 | if c, ok := param.(optionConfig); ok { 87 | c.setRetryAfterExpirarionDelay(delay) 88 | } 89 | } 90 | } 91 | 92 | func WithLogger(log *zap.Logger) option { 93 | return func(param interface{}) { 94 | if c, ok := param.(optionConfig); ok { 95 | c.setLogger(log) 96 | } 97 | } 98 | } 99 | 100 | func WithMetricsSink(sink metrics.MetricSink) option { 101 | return func(param interface{}) { 102 | if c, ok := param.(optionConfig); ok { 103 | c.setMetricsSink(sink) 104 | } 105 | } 106 | } 107 | 108 | func WithCacheUpdateHandler[T any](updateHandler CacheUpdateHandler[T]) option { 109 | return func(param interface{}) { 110 | if c, ok := param.(typedOptionConfig[T]); ok { 111 | c.setUpdateHandler(updateHandler) 112 | } 113 | } 114 | } 115 | 116 | func WithErrorHandler(errorHandler ErrorHandler) option { 117 | return func(param interface{}) { 118 | if c, ok := param.(optionConfig); ok { 119 | c.setErrorHandler(errorHandler) 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /internal/cache/reporter.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "time" 7 | 8 | "github.com/hashicorp/go-metrics" 9 | ) 10 | 11 | type metricsReporter struct { 12 | mutex sync.Mutex 13 | isMaintaining *bool 14 | refreshAt time.Time 15 | expiresAt time.Time 16 | } 17 | 18 | func (r *metricsReporter) setMaintaining(m bool) { 19 | r.mutex.Lock() 20 | defer r.mutex.Unlock() 21 | r.isMaintaining = new(bool) 22 | *r.isMaintaining = m 23 | } 24 | 25 | func (r *metricsReporter) setExpiresAt(t time.Time) { 26 | r.mutex.Lock() 27 | defer r.mutex.Unlock() 28 | r.expiresAt = t 29 | } 30 | 31 | func (r *metricsReporter) setRefreshAt(t time.Time) { 32 | r.mutex.Lock() 33 | defer r.mutex.Unlock() 34 | r.refreshAt = t 35 | } 36 | 37 | func (r *metricsReporter) run(ctx context.Context, metricSink metrics.MetricSink, tags []metrics.Label) { 38 | ticker := time.NewTicker(10 * time.Second) 39 | for { 40 | select { 41 | case <-ticker.C: 42 | r.report(metricSink, tags) 43 | case <-ctx.Done(): 44 | ticker.Stop() 45 | return 46 | } 47 | } 48 | } 49 | 50 | func (r *metricsReporter) report(metricSink metrics.MetricSink, tags []metrics.Label) { 51 | r.mutex.Lock() 52 | defer r.mutex.Unlock() 53 | 54 | // The maintainer is not yet maintaining credentials 55 | // Metrics are not reported to avoid setting 0 56 | if r.isMaintaining == nil { 57 | metricSink.SetGaugeWithLabels(statsdCacheMaintainerState, 1, append(tags, statsdCacheMaintainerInit)) 58 | return 59 | } 60 | 61 | if *r.isMaintaining { 62 | tags = append(tags, statsdCacheMaintainerRunning) 63 | } else { 64 | tags = append(tags, statsdCacheMaintainerStopped) 65 | } 66 | 67 | ttr := float64(0) 68 | if !r.refreshAt.IsZero() { 69 | ttr = time.Until(r.refreshAt).Seconds() 70 | if ttr < 0 { 71 | ttr = -1 72 | } 73 | } 74 | 75 | metricSink.SetGaugeWithLabels(statsdCacheMaintainerNextRefresh, float32(ttr), tags) 76 | 77 | ttl := float64(0) 78 | if !r.expiresAt.IsZero() { 79 | ttl = time.Until(r.expiresAt).Seconds() 80 | if ttl < 0 { 81 | ttl = -1 82 | } 83 | } 84 | metricSink.SetGaugeWithLabels(statsdCacheMaintainerExpirationTTL, float32(ttl), tags) 85 | 86 | metricSink.SetGaugeWithLabels(statsdCacheMaintainerState, 1, tags) 87 | } 88 | -------------------------------------------------------------------------------- /internal/cache/reporter_test.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | "time" 7 | 8 | "github.com/hashicorp/go-metrics" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func makeKey(parts []string) string { 14 | return strings.Join(parts, ".") 15 | } 16 | 17 | type mockValue struct { 18 | value float32 19 | labels []metrics.Label 20 | } 21 | 22 | type mockSink struct { 23 | gauges map[string]mockValue 24 | } 25 | 26 | func newMockSink() mockSink { 27 | return mockSink{ 28 | gauges: make(map[string]mockValue), 29 | } 30 | } 31 | 32 | // A Gauge should retain the last value it is set to 33 | func (m *mockSink) SetGauge(key []string, val float32) { 34 | m.gauges[makeKey(key)] = mockValue{val, nil} 35 | } 36 | func (m *mockSink) SetGaugeWithLabels(key []string, val float32, labels []metrics.Label) { 37 | m.gauges[makeKey(key)] = mockValue{val, labels} 38 | } 39 | 40 | // Should emit a Key/Value pair for each call 41 | func (m *mockSink) EmitKey(key []string, val float32) {} 42 | 43 | // Counters should accumulate values 44 | func (m *mockSink) IncrCounter(key []string, val float32) {} 45 | func (m *mockSink) IncrCounterWithLabels(key []string, val float32, labels []metrics.Label) {} 46 | 47 | // Samples are for timing information, where quantiles are used 48 | func (m *mockSink) AddSample(key []string, val float32) {} 49 | func (m *mockSink) AddSampleWithLabels(key []string, val float32, labels []metrics.Label) {} 50 | 51 | func newSetTestReporter(maintaining bool, refreshAt time.Time, expiresAt time.Time) *metricsReporter { 52 | return &metricsReporter{ 53 | isMaintaining: &maintaining, 54 | refreshAt: refreshAt, 55 | expiresAt: expiresAt, 56 | } 57 | } 58 | 59 | func TestReporter(t *testing.T) { 60 | refreshAt := time.Now().Add(5 * time.Hour) 61 | untilRefreshAt := float32(5 * time.Hour / time.Second) 62 | expiresAt := time.Now().Add(5 * time.Hour) 63 | untilExpiresAt := float32(5 * time.Hour / time.Second) 64 | 65 | for name, test := range map[string]struct { 66 | reporter *metricsReporter 67 | expectedMetrics map[string]mockValue 68 | }{ 69 | "reporter initialized but not maintaining": { 70 | reporter: &metricsReporter{}, 71 | expectedMetrics: map[string]mockValue{ 72 | makeKey(statsdCacheMaintainerState): {1, []metrics.Label{statsdCacheMaintainerInit}}, 73 | }, 74 | }, 75 | "reporter set to maintaining but other values unset": { 76 | reporter: newSetTestReporter(true, time.Time{}, time.Time{}), 77 | expectedMetrics: map[string]mockValue{ 78 | makeKey(statsdCacheMaintainerState): {1, []metrics.Label{statsdCacheMaintainerRunning}}, 79 | makeKey(statsdCacheMaintainerExpirationTTL): {0, []metrics.Label{statsdCacheMaintainerRunning}}, 80 | makeKey(statsdCacheMaintainerNextRefresh): {0, []metrics.Label{statsdCacheMaintainerRunning}}, 81 | }, 82 | }, 83 | "reporter set to maintaining and other values set": { 84 | reporter: newSetTestReporter(true, refreshAt, expiresAt), 85 | expectedMetrics: map[string]mockValue{ 86 | makeKey(statsdCacheMaintainerState): {1, []metrics.Label{statsdCacheMaintainerRunning}}, 87 | makeKey(statsdCacheMaintainerExpirationTTL): {untilExpiresAt, []metrics.Label{statsdCacheMaintainerRunning}}, 88 | makeKey(statsdCacheMaintainerNextRefresh): {untilRefreshAt, []metrics.Label{statsdCacheMaintainerRunning}}, 89 | }, 90 | }, 91 | "reporter set to not maintaining and other values set": { 92 | reporter: newSetTestReporter(false, refreshAt, expiresAt), 93 | expectedMetrics: map[string]mockValue{ 94 | makeKey(statsdCacheMaintainerState): {1, []metrics.Label{statsdCacheMaintainerStopped}}, 95 | makeKey(statsdCacheMaintainerExpirationTTL): {untilExpiresAt, []metrics.Label{statsdCacheMaintainerStopped}}, 96 | makeKey(statsdCacheMaintainerNextRefresh): {untilRefreshAt, []metrics.Label{statsdCacheMaintainerStopped}}, 97 | }, 98 | }, 99 | } { 100 | t.Run(name, func(t *testing.T) { 101 | mockSink := newMockSink() 102 | test.reporter.report(&mockSink, []metrics.Label{{Name: "mylabel", Value: "myvalue"}}) 103 | 104 | assert.Len(t, mockSink.gauges, len(test.expectedMetrics)) 105 | for metricName, metricValue := range test.expectedMetrics { 106 | if assert.Contains(t, mockSink.gauges, metricName) { 107 | setMetricValue := mockSink.gauges[metricName] 108 | // Tolerate a one second error margin 109 | assert.InDelta(t, metricValue.value, setMetricValue.value, 1) 110 | for _, label := range metricValue.labels { 111 | assert.Contains(t, setMetricValue.labels, label) 112 | } 113 | // Ensure user provided labels are properly kept 114 | assert.Contains(t, setMetricValue.labels, metrics.Label{Name: "mylabel", Value: "myvalue"}) 115 | } 116 | } 117 | }) 118 | } 119 | 120 | t.Run("maintaining set is never revoked", func(t *testing.T) { 121 | reporter := metricsReporter{} 122 | assert.Nil(t, reporter.isMaintaining) 123 | 124 | reporter.setMaintaining(true) 125 | require.NotNil(t, reporter.isMaintaining) 126 | assert.True(t, *reporter.isMaintaining) 127 | 128 | reporter.setMaintaining(false) 129 | require.NotNil(t, reporter.isMaintaining) 130 | assert.False(t, *reporter.isMaintaining) 131 | 132 | reporter.setMaintaining(true) 133 | require.NotNil(t, reporter.isMaintaining) 134 | assert.True(t, *reporter.isMaintaining) 135 | }) 136 | } 137 | -------------------------------------------------------------------------------- /internal/cache/synchronization/lock.go: -------------------------------------------------------------------------------- 1 | package synchronization 2 | 3 | import "context" 4 | 5 | type CancellableLock struct { 6 | lockChan chan struct{} 7 | } 8 | 9 | func NewCancellableLock() *CancellableLock { 10 | lock := &CancellableLock{ 11 | lockChan: make(chan struct{}, 1), 12 | } 13 | lock.Unlock() 14 | return lock 15 | } 16 | 17 | func (l *CancellableLock) Lock() { 18 | <-l.lockChan 19 | } 20 | 21 | // lockIfNotCancelled is attempting to take the lock but gives up if the context is cancelled 22 | // This allows putting a time limit on the lock, as well as locking within a cancellable call 23 | // lockIfNotCancelled returns an error if the ctx was cancelled (without taking the lock), otherwise returns nil 24 | // user must call unlock to release the lock iif the error returned is nil 25 | func (l *CancellableLock) LockIfNotCancelled(ctx context.Context) error { 26 | select { 27 | case <-l.lockChan: 28 | return nil 29 | case <-ctx.Done(): 30 | return ctx.Err() 31 | } 32 | } 33 | 34 | func (l *CancellableLock) Unlock() { 35 | l.lockChan <- struct{}{} 36 | } 37 | -------------------------------------------------------------------------------- /internal/cache/value.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | type ExpiringValue[T any] struct { 9 | Value T 10 | ExpiresAt time.Time 11 | } 12 | 13 | type cachedValue[T any] struct { 14 | mutex sync.RWMutex 15 | value *ExpiringValue[T] //nolint:unused,nolintlint 16 | } 17 | 18 | // getValue returns the value if it has not expired 19 | func (v *cachedValue[T]) getValue() (T, bool) { 20 | v.mutex.RLock() 21 | defer v.mutex.RUnlock() 22 | 23 | if v.value != nil && v.value.ExpiresAt.After(time.Now()) { 24 | return v.value.Value, true 25 | } 26 | 27 | var empty T 28 | return empty, false 29 | } 30 | 31 | func (v *cachedValue[T]) expiresAt() time.Time { 32 | v.mutex.RLock() 33 | defer v.mutex.RUnlock() 34 | if v.value != nil { 35 | return v.value.ExpiresAt 36 | } else { 37 | return time.Time{} 38 | } 39 | } 40 | 41 | func (v *cachedValue[T]) updateValue(val *ExpiringValue[T]) { 42 | v.mutex.Lock() 43 | defer v.mutex.Unlock() 44 | v.value = val 45 | } 46 | -------------------------------------------------------------------------------- /internal/imds/aws.go: -------------------------------------------------------------------------------- 1 | package imds 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "net/http" 9 | "path" 10 | "regexp" 11 | "strconv" 12 | "sync" 13 | "time" 14 | 15 | "github.com/DataDog/attache/internal/cache" 16 | "github.com/DataDog/attache/internal/retry" 17 | vaultclient "github.com/DataDog/attache/internal/vault" 18 | ec2imds "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" 19 | "github.com/gorilla/mux" 20 | "github.com/hashicorp/go-metrics" 21 | "github.com/mitchellh/mapstructure" 22 | "go.uber.org/zap" 23 | muxt "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorilla/mux" 24 | "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" 25 | ) 26 | 27 | // AwsRoleGetter returns the configured AWS IMDS role. 28 | // This meant to be temporary and will be replaced once the AWS role name generated from 29 | // the pod namespace and service account is passed directly to the AwsProvider. 30 | type AwsRoleGetter interface { 31 | lookupRole(ctx context.Context) (*role, error) 32 | } 33 | 34 | // AwsProvider implements the AWS Metadata Service backed by Vault. 35 | type AwsProvider struct { 36 | maintainer *cache.Maintainer[*AwsCredentials] 37 | roleGetter AwsRoleGetter 38 | sessions *imdsSessionCache 39 | metricSink metrics.MetricSink 40 | identifier InstanceIdentifier 41 | v1Allowed bool 42 | } 43 | 44 | const ( 45 | defaultAwsCredTTL = time.Hour 46 | ) 47 | 48 | var ( 49 | awsRoleRegex = regexp.MustCompile(`.+:role\/(.+)$`) 50 | 51 | labelAwsProvider = metrics.Label{Name: "provider", Value: "aws"} 52 | ) 53 | 54 | // timeNow is used for testing. 55 | var timeNow = time.Now 56 | 57 | // Aws returns a new AwsProvider. 58 | func Aws( 59 | ctx context.Context, 60 | log *zap.Logger, 61 | v1Allowed bool, 62 | metricSink metrics.MetricSink, 63 | tokenFetcher cache.Fetcher[*AwsCredentials], 64 | roleGetter AwsRoleGetter, 65 | identifier InstanceIdentifier, 66 | refreshFunc cache.RefreshAtFunc, retryOpts ...retry.Option, 67 | ) (*AwsProvider, error) { 68 | if ctx == nil { 69 | return nil, errors.New("context cannote be nil") 70 | } 71 | 72 | if log == nil { 73 | return nil, errors.New("log cannot be nil") 74 | } 75 | 76 | if metricSink == nil { 77 | return nil, errors.New("metricSink cannot be nil") 78 | } 79 | 80 | if tokenFetcher == nil { 81 | return nil, errors.New("tokenFetcher cannot be nil") 82 | } 83 | 84 | if roleGetter == nil { 85 | return nil, errors.New("roleGetter cannot be nil") 86 | } 87 | 88 | if identifier == nil { 89 | return nil, errors.New("instance identity provider cannot be nil") 90 | } 91 | 92 | if refreshFunc == nil { 93 | return nil, errors.New("refresh func cannot be nil") 94 | } 95 | 96 | maintainer := cache.NewMaintainer[*AwsCredentials]( 97 | tokenFetcher, 98 | refreshFunc, 99 | cache.WithLogger(log.Named("token maintainer")), 100 | cache.WithMetricsSink(metricSink), 101 | cache.WithRetryOptions(retryOpts), 102 | ) 103 | 104 | p := &AwsProvider{ 105 | maintainer: maintainer, 106 | roleGetter: roleGetter, 107 | sessions: newIMDSSessionCache(ctx, metricSink, maxAwsEC2MetadataTokens, time.Minute), 108 | metricSink: metricSink, 109 | identifier: identifier, 110 | v1Allowed: v1Allowed, 111 | } 112 | 113 | return p, nil 114 | } 115 | 116 | // Name returns the provider's logical name. 117 | func (p *AwsProvider) Name() string { 118 | return "aws" 119 | } 120 | 121 | // RegisterHandlers registers all HTTP handlers for the AWS provider. 122 | func (p *AwsProvider) RegisterHandlers(router *muxt.Router, handlerFactory *HandlerFactory) error { 123 | router.Handle( 124 | "/{version}/meta-data/iam/security-credentials", 125 | imdsv2Verifier(p.v1Allowed, p.sessions, handlerFactory.CreateHTTPHandler(p.Name(), p.handleSecurityCredentials)), 126 | ) 127 | 128 | router.Handle( 129 | "/{version}/meta-data/iam/security-credentials/", 130 | imdsv2Verifier(p.v1Allowed, p.sessions, handlerFactory.CreateHTTPHandler(p.Name(), p.handleSecurityCredentials)), 131 | ) 132 | 133 | router.Handle( 134 | "/{version}/meta-data/iam/security-credentials/{role:.+}", 135 | imdsv2Verifier(p.v1Allowed, p.sessions, handlerFactory.CreateHTTPHandler(p.Name(), p.handleSecurityCredentialsRole)), 136 | ) 137 | 138 | router.Handle( 139 | "/{version}/api/token", 140 | imdsVersionTag(handlerFactory.CreateHTTPHandler(p.Name(), p.handleIMDSV2Token)), 141 | ).Methods(http.MethodPut) 142 | 143 | router.Handle( 144 | "/{version}/dynamic/instance-identity/document", 145 | imdsVersionTag(handlerFactory.CreateHTTPHandler(p.Name(), p.handleIdentityDocument)), 146 | ) 147 | 148 | return nil 149 | } 150 | 151 | func (p *AwsProvider) handleSecurityCredentialsRole(_ *zap.Logger, writer http.ResponseWriter, request *http.Request) error { 152 | awsRole, err := p.roleGetter.lookupRole(request.Context()) 153 | if err != nil { 154 | return fmt.Errorf("error looking up IAM role arns: %w", err) 155 | } 156 | 157 | params := mux.Vars(request) 158 | requestedRole := params["role"] 159 | if awsRole.name != requestedRole { 160 | return errors.New("requested role not allowed") 161 | } 162 | 163 | // Pass background context to ignore cancellation signal and cache 164 | // credentials in case of low timeout on imds client. Retried requests 165 | // by client should eventually succeed once cached credentials are 166 | // populated. 167 | // 168 | // The parent span is copied to include upstream calls in any trace. 169 | reqSpan, _ := tracer.SpanFromContext(request.Context()) 170 | reqCtx := tracer.ContextWithSpan(context.Background(), reqSpan) 171 | 172 | creds, err := p.maintainer.Get(reqCtx) 173 | if err != nil { 174 | return err 175 | } 176 | 177 | if request.Context().Err() != nil { 178 | return request.Context().Err() 179 | } 180 | 181 | return json.NewEncoder(writer).Encode(creds) 182 | } 183 | 184 | func (p *AwsProvider) handleSecurityCredentials(_ *zap.Logger, writer http.ResponseWriter, request *http.Request) error { 185 | role, err := p.roleGetter.lookupRole(request.Context()) 186 | if err != nil { 187 | return fmt.Errorf("error looking up IAM role: %w", err) 188 | } 189 | 190 | _, err = writer.Write([]byte(role.name)) 191 | 192 | return err 193 | } 194 | 195 | func (p *AwsProvider) handleIdentityDocument(_ *zap.Logger, writer http.ResponseWriter, request *http.Request) error { 196 | doc, err := p.identifier.GetInstanceIdentity(request.Context()) 197 | if err != nil { 198 | return fmt.Errorf("getting instance identity document: %w", err) 199 | } 200 | 201 | return json.NewEncoder(writer).Encode(doc) 202 | } 203 | 204 | const ( 205 | awsEC2MetadataTokenTTLSeconds = "X-aws-ec2-metadata-token-ttl-seconds" 206 | awsEC2MetadataToken = "X-aws-ec2-metadata-token" 207 | ) 208 | 209 | // handleIMDSV2Token generates session tokens for IMDSv2 http clients. 210 | func (p *AwsProvider) handleIMDSV2Token(_ *zap.Logger, writer http.ResponseWriter, request *http.Request) error { 211 | writer.Header().Set("Server", "EC2ws") 212 | 213 | hdr := request.Header.Get(awsEC2MetadataTokenTTLSeconds) 214 | ttl, err := strconv.Atoi(hdr) 215 | 216 | switch { 217 | case err != nil || ttl < 0 || ttl > 21600: 218 | return HTTPError{ 219 | code: http.StatusBadRequest, 220 | error: errors.New("x-aws-ec2-metadata-token-ttl-seconds must be an integer between (0, 21600)"), 221 | } 222 | case request.Header.Get("X-Forwarded-For") != "": 223 | return HTTPError{ 224 | code: http.StatusForbidden, 225 | error: errors.New("X-Forwarded-For cannot be used with EC2 IMDS GetToken"), 226 | } 227 | } 228 | 229 | session, _, err := p.sessions.NewSession(time.Duration(ttl) * time.Second) 230 | if err != nil { 231 | return err 232 | } 233 | 234 | writer.Header().Set(awsEC2MetadataTokenTTLSeconds, strconv.Itoa(ttl)) 235 | writer.Header().Set("Content-Type", "text/plain") 236 | 237 | if _, err := writer.Write([]byte(session.ID)); err != nil { 238 | return fmt.Errorf("writing token response: %w", err) 239 | } 240 | 241 | return nil 242 | } 243 | 244 | // imdsv2Verifier enforces validity of EC2 IMDSv2 session tokens provided via 245 | // the X-aws-ec2-metadata-token header. If the header is not present, the 246 | // request is as IMDSv1 format and allowed. 247 | // 248 | // For more details on EC2 IMDSv2, see the following documentation: 249 | // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html#instance-metadata-v2-how-it-works 250 | func imdsv2Verifier(v1Allowed bool, sessions *imdsSessionCache, next http.Handler) http.Handler { 251 | return imdsVersionTag(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 252 | now := timeNow() 253 | 254 | w.Header().Set("Server", "EC2ws") 255 | 256 | token := r.Header.Get(awsEC2MetadataToken) 257 | 258 | session, found := sessions.GetSession(token) 259 | 260 | switch { 261 | case token == "": 262 | // IMDSv1: 263 | if !v1Allowed { 264 | //IMDSv1 is disabled 265 | w.WriteHeader(http.StatusForbidden) 266 | return 267 | } 268 | case !found: 269 | // invalid or expired session token 270 | w.WriteHeader(http.StatusUnauthorized) 271 | return 272 | case r.Method != http.MethodGet && r.Method != http.MethodHead: 273 | // only Get & HEAD methods are allowed 274 | w.WriteHeader(http.StatusForbidden) 275 | return 276 | default: 277 | // annotate remaining token TTL in response 278 | ttl := fmt.Sprintf("%.0f", session.Expiry.Sub(now).Seconds()) 279 | w.Header().Set(awsEC2MetadataTokenTTLSeconds, ttl) 280 | } 281 | 282 | next.ServeHTTP(w, r) 283 | })) 284 | } 285 | 286 | // imdsVersionTag annotates requests with the IMDS api version for telemetry. 287 | func imdsVersionTag(next http.Handler) http.Handler { 288 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 289 | imds := "IMDSv1" 290 | 291 | if r.Header.Get(awsEC2MetadataToken) != "" || 292 | r.Header.Get(awsEC2MetadataTokenTTLSeconds) != "" { 293 | imds = "IMDSv2" 294 | } 295 | 296 | // include imds version for instrumentation 297 | next.ServeHTTP(w, wrapRequestTag(r, "imds-version", imds)) 298 | }) 299 | } 300 | 301 | type role struct { 302 | name string 303 | arn string 304 | } 305 | 306 | type AwsCredentials struct { 307 | AccessKeyID string `json:"AccessKeyId"` 308 | Code string 309 | Expiration time.Time 310 | LastUpdated time.Time 311 | SecretAccessKey string 312 | Token string 313 | Type string 314 | } 315 | 316 | type AwsVaultStsTokenFetcher struct { 317 | vault *vaultclient.Client 318 | vaultStsEndpoint string 319 | vaultRoleEndpoint string 320 | 321 | awsRole *role 322 | lookupRoleMutex sync.Mutex 323 | 324 | log *zap.Logger 325 | metricSink metrics.MetricSink 326 | } 327 | 328 | func NewVaultAwsStsTokenFetcher(vault *vaultclient.Client, 329 | iamRole, vaultMountPath string, 330 | log *zap.Logger, 331 | metricSink metrics.MetricSink, 332 | ) (*AwsVaultStsTokenFetcher, error) { 333 | if vault == nil { 334 | return nil, errors.New("vault client cannot be nil") 335 | } 336 | 337 | if iamRole == "" { 338 | return nil, errors.New("iam role cannot be empty") 339 | } 340 | 341 | if vaultMountPath == "" { 342 | return nil, errors.New("vault mount path cannot be empty") 343 | } 344 | 345 | if log == nil { 346 | return nil, errors.New("log cannot be nil") 347 | } 348 | 349 | if metricSink == nil { 350 | return nil, errors.New("metric sink cannot be nil") 351 | } 352 | 353 | fetcher := &AwsVaultStsTokenFetcher{ 354 | vault: vault, 355 | 356 | vaultStsEndpoint: path.Join(vaultMountPath, "sts", iamRole), 357 | vaultRoleEndpoint: path.Join(vaultMountPath, "roles", iamRole), 358 | 359 | metricSink: metricSink, 360 | } 361 | fetcher.log = log.With(zap.String("fetcher", fetcher.String())) 362 | 363 | return fetcher, nil 364 | } 365 | 366 | func (a *AwsVaultStsTokenFetcher) String() string { 367 | return "aws-sts-token-vault" 368 | } 369 | 370 | func (a *AwsVaultStsTokenFetcher) Fetch(ctx context.Context) (creds *cache.ExpiringValue[*AwsCredentials], err error) { 371 | fetchSpan, ctx := tracer.StartSpanFromContext(ctx, "AwsVaultStsTokenFetcher.Fetch") 372 | 373 | defer func() { 374 | fetchSpan.Finish(tracer.WithError(err)) 375 | 376 | statusLabel := labelSuccess 377 | if err != nil { 378 | statusLabel = labelFail 379 | } 380 | 381 | labels := []metrics.Label{labelAwsProvider, labelVaultMethod, statusLabel} 382 | a.metricSink.IncrCounterWithLabels(statsdCloudCredRequest, 1, labels) 383 | }() 384 | 385 | role, err := a.lookupRole(ctx) 386 | if err != nil { 387 | return nil, err 388 | } 389 | 390 | // The deprecated AWS boto library only supports a format of "%Y-%m-%dT%H:%M:%SZ" 391 | // while the newer boto3 library appears to be more robust and support this and a 392 | // datetime with milliseconds. Therefore, fallback to using the older format. 393 | // See https://github.com/boto/boto/issues/3771 394 | lastUpdated := timeNow().Truncate(time.Second).UTC() 395 | secret, err := a.vault.ReadWithData(ctx, a.vaultStsEndpoint, map[string][]string{ 396 | "role_arn": {role.arn}, 397 | }) 398 | if err != nil { 399 | return nil, fmt.Errorf("failed to read secret from Vault: %w", err) 400 | } 401 | 402 | if secret == nil || secret.Data == nil { 403 | return nil, errors.New("unable to correctly read secret from Vault") 404 | } 405 | 406 | var decodedSecret struct { 407 | AccessKey string `mapstructure:"access_key"` 408 | SecretKey string `mapstructure:"secret_key"` 409 | SecurityToken string `mapstructure:"security_token"` 410 | } 411 | 412 | err = mapstructure.Decode(secret.Data, &decodedSecret) 413 | if err != nil { 414 | return nil, fmt.Errorf("error decoding data from Vault: %w", err) 415 | } 416 | 417 | ttl, err := secret.TokenTTL() 418 | if err != nil { 419 | return nil, fmt.Errorf("failed to get credential TTL: %w", err) 420 | } 421 | 422 | if ttl == 0 { 423 | if secret.LeaseDuration > 0 { 424 | ttl = time.Second * time.Duration(secret.LeaseDuration) 425 | } else { 426 | a.log.Warn("credential TTL is zero, using default", zap.Duration("default", defaultAwsCredTTL)) 427 | ttl = defaultAwsCredTTL 428 | } 429 | } 430 | 431 | result := &AwsCredentials{ 432 | AccessKeyID: decodedSecret.AccessKey, 433 | Code: "Success", 434 | Expiration: lastUpdated.Add(ttl), 435 | LastUpdated: lastUpdated, 436 | SecretAccessKey: decodedSecret.SecretKey, 437 | Token: decodedSecret.SecurityToken, 438 | Type: "AWS-HMAC", 439 | } 440 | 441 | return &cache.ExpiringValue[*AwsCredentials]{ 442 | Value: result, 443 | ExpiresAt: result.Expiration, 444 | }, nil 445 | } 446 | 447 | // lookupRole will look up the Vault AWS role by the Vault name. Only a single role ARN can be configured with 448 | // Attaché. Therefore, a role configuration without exactly 1 AWS role ARN will result in an error. 449 | // The return value of this is memoized as it should never change during the Attaché lifecycle. 450 | func (a *AwsVaultStsTokenFetcher) lookupRole(ctx context.Context) (*role, error) { 451 | if a.awsRole != nil { 452 | return a.awsRole, nil 453 | } 454 | 455 | a.lookupRoleMutex.Lock() 456 | defer a.lookupRoleMutex.Unlock() 457 | 458 | if a.awsRole != nil { 459 | return a.awsRole, nil 460 | } 461 | 462 | secret, err := a.vault.Read(ctx, a.vaultRoleEndpoint) 463 | if err != nil { 464 | return nil, fmt.Errorf("unable to read path %q: %w", a.vaultRoleEndpoint, err) 465 | } 466 | 467 | if secret == nil || secret.Data == nil { 468 | return nil, newRoleDoesNotExistError(a.vaultRoleEndpoint) 469 | } 470 | 471 | var response struct { 472 | RoleArns []string `mapstructure:"role_arns"` 473 | } 474 | 475 | err = mapstructure.Decode(secret.Data, &response) 476 | if err != nil { 477 | return nil, fmt.Errorf("unable to decode Vault response: %w", err) 478 | } 479 | 480 | if response.RoleArns == nil || len(response.RoleArns) == 0 { 481 | return nil, errors.New("vault role must have at least one role_arn defined") 482 | } 483 | 484 | if len(response.RoleArns) > 1 { 485 | return nil, errors.New("cannot have multiple role_arns defined for a Vault role") 486 | } 487 | 488 | roleArn := response.RoleArns[0] 489 | 490 | matches := awsRoleRegex.FindSubmatch([]byte(roleArn)) 491 | if matches == nil || len(matches) < 1 { 492 | return nil, fmt.Errorf("unable to extract role from role ARN: %s", roleArn) 493 | } 494 | 495 | roleName := string(matches[1]) 496 | a.awsRole = &role{name: roleName, arn: roleArn} 497 | 498 | return a.awsRole, nil 499 | } 500 | 501 | // InstanceIdentifier provides an instance identity document 502 | type InstanceIdentifier interface { 503 | GetInstanceIdentity(context.Context) (interface{}, error) 504 | } 505 | 506 | // NewAwsInstanceIdentifier builds an instance identity document provider based 507 | // on a given current cloud provider context. 508 | func NewAwsInstanceIdentifier(provider, region, zone string) InstanceIdentifier { 509 | switch provider { 510 | case "aws": 511 | // use sparsely populated document if current provider is aws 512 | return (*staticAwsIdentifier)(&ec2imds.InstanceIdentityDocument{ 513 | Region: region, 514 | AvailabilityZone: zone, 515 | }) 516 | default: 517 | // use zero-valued document if current provider is not aws 518 | return (*staticAwsIdentifier)(&ec2imds.InstanceIdentityDocument{}) 519 | } 520 | } 521 | 522 | type staticAwsIdentifier ec2imds.InstanceIdentityDocument 523 | 524 | func (i *staticAwsIdentifier) GetInstanceIdentity(_ context.Context) (interface{}, error) { 525 | return (*ec2imds.InstanceIdentityDocument)(i), nil 526 | } 527 | -------------------------------------------------------------------------------- /internal/imds/aws_session_cache.go: -------------------------------------------------------------------------------- 1 | package imds 2 | 3 | import ( 4 | "container/list" 5 | "context" 6 | "crypto/rand" 7 | "encoding/base64" 8 | "sync" 9 | "time" 10 | 11 | "github.com/hashicorp/go-metrics" 12 | ) 13 | 14 | var ( 15 | statsdIMDSSessionActive = []string{"imds", "active_sessions"} 16 | statsdIMDSSessionCreate = []string{"imds", "sessions_create"} 17 | statsdIMDSSessionRevoke = []string{"imds", "sessions_revoke"} 18 | imdsSessionExpired = []metrics.Label{{Name: "reason", Value: "expired"}} 19 | imdsSessionOverflow = []metrics.Label{{Name: "reason", Value: "overflow"}} 20 | ) 21 | 22 | const ( 23 | maxAwsEC2MetadataTokens = 5000 24 | ) 25 | 26 | type imdsSessionCache struct { 27 | metricSink metrics.MetricSink 28 | sess map[string]*imdsSession 29 | list *list.List 30 | max int 31 | mu sync.Mutex 32 | } 33 | 34 | type imdsSession struct { 35 | ID string 36 | Expiry time.Time 37 | element *list.Element 38 | } 39 | 40 | func newIMDSSessionCache(ctx context.Context, metricSink metrics.MetricSink, max int, cleanup time.Duration) *imdsSessionCache { 41 | sc := &imdsSessionCache{ 42 | metricSink: metricSink, 43 | sess: make(map[string]*imdsSession), 44 | list: list.New(), 45 | max: max, 46 | } 47 | 48 | go sc.evictBackground(ctx, cleanup) 49 | 50 | go sc.reportMetrics(ctx) 51 | 52 | return sc 53 | } 54 | 55 | func (sc *imdsSessionCache) NewSession(ttl time.Duration) (s *imdsSession, evicted bool, err error) { 56 | // match encoded length of actual EC2 IMDSv2 session tokens 57 | byt := make([]byte, 40) 58 | if _, err = rand.Read(byt); err != nil { 59 | return 60 | } 61 | 62 | sc.metricSink.IncrCounter(statsdIMDSSessionCreate, 1.0) 63 | 64 | s = &imdsSession{ 65 | ID: base64.URLEncoding.EncodeToString(byt), 66 | Expiry: timeNow().Add(ttl), 67 | } 68 | 69 | sc.mu.Lock() 70 | defer sc.mu.Unlock() 71 | 72 | sc.sess[s.ID] = s 73 | s.element = sc.list.PushFront(s) 74 | 75 | if sc.list.Len() > sc.max { 76 | sc.evictOverflow() 77 | evicted = true 78 | } 79 | 80 | return 81 | } 82 | 83 | func (sc *imdsSessionCache) GetSession(id string) (*imdsSession, bool) { 84 | sc.mu.Lock() 85 | defer sc.mu.Unlock() 86 | 87 | s, ok := sc.sess[id] 88 | if !ok || s.Expiry.Before(timeNow()) { 89 | return nil, false 90 | } 91 | 92 | sc.list.MoveToFront(s.element) 93 | 94 | return s, true 95 | } 96 | 97 | func (sc *imdsSessionCache) evictBackground(ctx context.Context, interval time.Duration) { 98 | ticker := time.NewTicker(interval) 99 | defer ticker.Stop() 100 | 101 | for { 102 | select { 103 | case <-ticker.C: 104 | sc.mu.Lock() 105 | sc.evictExpired() 106 | sc.mu.Unlock() 107 | case <-ctx.Done(): 108 | return 109 | } 110 | } 111 | } 112 | 113 | func (sc *imdsSessionCache) evictOverflow() { 114 | for sc.list.Len() > sc.max { 115 | // evict least recently used 116 | v := sc.list.Remove(sc.list.Back()) 117 | 118 | if s, ok := v.(*imdsSession); ok { 119 | delete(sc.sess, s.ID) 120 | } 121 | 122 | sc.metricSink.IncrCounterWithLabels(statsdIMDSSessionRevoke, 1.0, imdsSessionOverflow) 123 | } 124 | } 125 | 126 | func (sc *imdsSessionCache) evictExpired() { 127 | now := timeNow() 128 | 129 | for _, s := range sc.sess { 130 | if s.Expiry.Before(now) { 131 | sc.list.Remove(s.element) 132 | 133 | delete(sc.sess, s.ID) 134 | 135 | sc.metricSink.IncrCounterWithLabels(statsdIMDSSessionRevoke, 1.0, imdsSessionExpired) 136 | } 137 | } 138 | } 139 | 140 | func (sc *imdsSessionCache) reportMetrics(ctx context.Context) { 141 | ticker := time.NewTicker(10 * time.Second) 142 | defer ticker.Stop() 143 | 144 | for { 145 | select { 146 | case <-ticker.C: 147 | sc.mu.Lock() 148 | sc.metricSink.SetGauge(statsdIMDSSessionActive, float32(sc.list.Len())) 149 | sc.mu.Unlock() 150 | case <-ctx.Done(): 151 | return 152 | } 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /internal/imds/aws_session_cache_test.go: -------------------------------------------------------------------------------- 1 | package imds 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/hashicorp/go-metrics" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func Test_imdsSessionCache(t *testing.T) { 13 | ctx, cancel := context.WithCancel(context.TODO()) 14 | defer cancel() 15 | 16 | t.Run("generate session", func(t *testing.T) { 17 | c := newIMDSSessionCache(ctx, &metrics.BlackholeSink{}, 3, time.Hour) 18 | 19 | want, _, err := c.NewSession(time.Hour) 20 | assert.NoError(t, err) 21 | 22 | have, ok := c.GetSession(want.ID) 23 | assert.True(t, ok) 24 | assert.Equal(t, want, have) 25 | }) 26 | 27 | t.Run("expire session", func(t *testing.T) { 28 | c := newIMDSSessionCache(ctx, &metrics.BlackholeSink{}, 3, time.Hour) 29 | 30 | before, after := time.Now(), time.Now().Add(1*time.Hour) 31 | timeNow = func() time.Time { return before } 32 | defer func() { timeNow = time.Now }() 33 | 34 | s, _, err := c.NewSession(time.Hour) 35 | assert.NoError(t, err) 36 | 37 | have, ok := c.GetSession(s.ID) 38 | assert.True(t, ok) 39 | assert.Equal(t, s, have) 40 | assert.Equal(t, 1, c.list.Len()) 41 | 42 | // advance time past expiry 43 | timeNow = func() time.Time { return after } 44 | 45 | // not gettable, not yet evicted 46 | have, ok = c.GetSession(s.ID) 47 | assert.False(t, ok) 48 | assert.Nil(t, have) 49 | assert.Equal(t, 1, c.list.Len()) 50 | 51 | c.evictExpired() 52 | 53 | // evicted 54 | have, ok = c.GetSession(s.ID) 55 | assert.False(t, ok) 56 | assert.Nil(t, have) 57 | assert.Equal(t, 0, c.list.Len()) 58 | }) 59 | 60 | t.Run("lru overflow eviction", func(t *testing.T) { 61 | c := newIMDSSessionCache(ctx, &metrics.BlackholeSink{}, 2, time.Hour) 62 | 63 | s1, evicted, err := c.NewSession(time.Hour) 64 | assert.NoError(t, err) 65 | assert.False(t, evicted) 66 | 67 | s2, evicted, err := c.NewSession(time.Hour) 68 | assert.NoError(t, err) 69 | assert.False(t, evicted) 70 | 71 | // add s3, evict s1 72 | s3, evicted, err := c.NewSession(time.Hour) 73 | assert.NoError(t, err) 74 | assert.True(t, evicted) 75 | 76 | _, ok := c.GetSession(s1.ID) 77 | assert.False(t, ok) 78 | _, ok = c.GetSession(s3.ID) 79 | assert.True(t, ok) 80 | _, ok = c.GetSession(s2.ID) 81 | assert.True(t, ok) 82 | 83 | // add s4, evict s3 84 | s4, evicted, err := c.NewSession(time.Hour) 85 | assert.NoError(t, err) 86 | assert.True(t, evicted) 87 | 88 | _, ok = c.GetSession(s3.ID) 89 | assert.False(t, ok) 90 | _, ok = c.GetSession(s2.ID) 91 | assert.True(t, ok) 92 | _, ok = c.GetSession(s4.ID) 93 | assert.True(t, ok) 94 | }) 95 | } 96 | -------------------------------------------------------------------------------- /internal/imds/azure.go: -------------------------------------------------------------------------------- 1 | package imds 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "path" 11 | "strconv" 12 | "sync" 13 | "time" 14 | 15 | "github.com/DataDog/attache/internal/cache" 16 | "github.com/DataDog/attache/internal/retry" 17 | "github.com/DataDog/attache/internal/vault" 18 | "github.com/hashicorp/go-metrics" 19 | "github.com/mitchellh/mapstructure" 20 | "go.uber.org/zap" 21 | muxt "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorilla/mux" 22 | "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" 23 | ) 24 | 25 | var ( 26 | labelAzureProvider = metrics.Label{Name: "provider", Value: "azure"} 27 | ) 28 | 29 | // AzureProvider implements identity portion of the Azure Instance Metadata service. 30 | type AzureProvider struct { 31 | maintainers map[string]*cache.Maintainer[*AzureCredentials] 32 | resourceTokenMutex map[string]*sync.Mutex 33 | mutexMapMutex sync.RWMutex 34 | retryOpts []retry.Option 35 | refreshFunc cache.RefreshAtFunc 36 | ctx context.Context 37 | tokenFetcherFactory AzureTokenFetcherFactory 38 | subscriptionIDGetter AzureSubscriptionIDGetter 39 | 40 | log *zap.Logger 41 | metricSink metrics.MetricSink 42 | } 43 | 44 | type AzureTokenFetcherFactory = func(resource string) (cache.Fetcher[*AzureCredentials], error) 45 | 46 | // Azure returns a new AzureProvider. 47 | func Azure( 48 | ctx context.Context, 49 | log *zap.Logger, 50 | metricSink metrics.MetricSink, 51 | refreshFunc cache.RefreshAtFunc, 52 | tokenFetcherFactory AzureTokenFetcherFactory, 53 | subscriptionIDGetter AzureSubscriptionIDGetter, 54 | retryOpts ...retry.Option, 55 | ) (*AzureProvider, error) { 56 | if ctx == nil { 57 | return nil, errors.New("ctx cannot be nil") 58 | } 59 | 60 | if log == nil { 61 | return nil, errors.New("log cannot be nil") 62 | } 63 | 64 | if metricSink == nil { 65 | return nil, errors.New("metricSink cannot be nil") 66 | } 67 | 68 | if refreshFunc == nil { 69 | return nil, errors.New("refresh func cannot be nil") 70 | } 71 | 72 | return &AzureProvider{ 73 | log: log, 74 | metricSink: metricSink, 75 | 76 | resourceTokenMutex: map[string]*sync.Mutex{}, 77 | maintainers: map[string]*cache.Maintainer[*AzureCredentials]{}, 78 | ctx: ctx, 79 | tokenFetcherFactory: tokenFetcherFactory, 80 | subscriptionIDGetter: subscriptionIDGetter, 81 | 82 | refreshFunc: refreshFunc, 83 | retryOpts: retryOpts, 84 | }, nil 85 | } 86 | 87 | // Name returns the provider's logical name. 88 | func (p *AzureProvider) Name() string { 89 | return "azure" 90 | } 91 | 92 | // RegisterHandlers registers all HTTP handlers for the Azure provider. 93 | func (p *AzureProvider) RegisterHandlers(router *muxt.Router, handlerFactory *HandlerFactory) error { 94 | router.Handle( 95 | "/metadata/identity/oauth2/token", 96 | metadataHeaderVerifier(azureResourceVerifier(azureAPIVersionVerifier(handlerFactory.CreateHTTPHandler(p.Name(), p.handleGetToken)))), 97 | ) 98 | 99 | router.Handle( 100 | "/metadata/instance/compute/subscriptionId", 101 | metadataHeaderVerifier(azureAPIVersionVerifier(handlerFactory.CreateHTTPHandler(p.Name(), p.handleGetSubscriptionID))), 102 | ) 103 | 104 | return nil 105 | } 106 | 107 | func (p *AzureProvider) handleGetSubscriptionID(logger *zap.Logger, w http.ResponseWriter, r *http.Request) error { 108 | subscriptionID, err := p.subscriptionIDGetter.getSubscriptionID(r.Context()) 109 | if err != nil { 110 | return fmt.Errorf("unable to get azure vault config: %w", err) 111 | } 112 | 113 | logger.Debug("fetched azure config") 114 | 115 | w.Header().Set("Content-Type", "text/plain; charset=utf-8") 116 | 117 | if _, err := io.WriteString(w, subscriptionID); err != nil { 118 | return err 119 | } 120 | 121 | return nil 122 | } 123 | 124 | func (p *AzureProvider) handleGetToken(logger *zap.Logger, w http.ResponseWriter, r *http.Request) error { 125 | resource := r.URL.Query().Get("resource") 126 | if resource == "" { 127 | resource = "https://management.azure.com/" 128 | } 129 | 130 | // Pass background context to ignore cancellation signal and cache 131 | // credentials in case of low timeout on imds client. Retried requests 132 | // by client should eventually succeed once cached credentials are 133 | // populated. 134 | // 135 | // The parent span is copied to include upstream calls in any trace. 136 | reqSpan, _ := tracer.SpanFromContext(r.Context()) 137 | reqCtx := tracer.ContextWithSpan(context.Background(), reqSpan) 138 | 139 | token, err := p.getToken(reqCtx, logger, resource) 140 | if err != nil { 141 | return fmt.Errorf("unable to get azure access token: %w", err) 142 | } 143 | 144 | if r.Context().Err() != nil { 145 | return r.Context().Err() 146 | } 147 | 148 | logger.Debug("fetched azure access token") 149 | 150 | expiresOn, err := parseExpiresOn(token.ExpiresOn) 151 | if err != nil { 152 | return fmt.Errorf("unable to parse expires_on (%q): %w", token.ExpiresOn, err) 153 | } 154 | 155 | // recalculate expires_in since token is cached 156 | seconds := int(time.Until(*expiresOn).Seconds()) 157 | token.ExpiresIn = strconv.Itoa(seconds) 158 | 159 | w.Header().Set("Content-Type", "application/json") 160 | 161 | return json.NewEncoder(w).Encode(token) 162 | } 163 | 164 | func (p *AzureProvider) getMutexForResource(resource string) *sync.Mutex { 165 | p.mutexMapMutex.RLock() 166 | mutex, ok := p.resourceTokenMutex[resource] 167 | p.mutexMapMutex.RUnlock() 168 | if ok { 169 | return mutex 170 | } 171 | 172 | p.mutexMapMutex.Lock() 173 | defer p.mutexMapMutex.Unlock() 174 | 175 | mutex, ok = p.resourceTokenMutex[resource] 176 | if ok { 177 | return mutex 178 | } 179 | 180 | mutex = &sync.Mutex{} 181 | p.resourceTokenMutex[resource] = mutex 182 | 183 | return mutex 184 | } 185 | 186 | func (p *AzureProvider) getToken(ctx context.Context, log *zap.Logger, resource string) (*AzureCredentials, error) { 187 | mutex := p.getMutexForResource(resource) 188 | mutex.Lock() 189 | defer mutex.Unlock() 190 | 191 | if _, ok := p.maintainers[resource]; !ok { 192 | azureFetcher, err := p.tokenFetcherFactory(resource) 193 | if err != nil { 194 | return nil, fmt.Errorf("failed to create azure token fetcher for resource %s: %w", resource, err) 195 | } 196 | 197 | p.maintainers[resource] = cache.NewMaintainer[*AzureCredentials]( 198 | azureFetcher, 199 | p.refreshFunc, 200 | cache.WithLogger(log.Named("token maintainer")), 201 | cache.WithMetricsSink(p.metricSink), 202 | cache.WithRetryOptions(p.retryOpts), 203 | ) 204 | } 205 | 206 | result, err := p.maintainers[resource].Get(ctx) 207 | if err != nil { 208 | return nil, err 209 | } 210 | 211 | return result, nil 212 | } 213 | 214 | const invalidRequest = "invalid_request" 215 | 216 | type azureResponseError struct { 217 | error string 218 | description string 219 | } 220 | 221 | func (e *azureResponseError) Error() string { 222 | return fmt.Sprintf("{\"error\":\"%v\",\"error_description\":\"%v\"}", e.error, e.description) 223 | } 224 | 225 | // metadataHeaderVerifier ensures that the HTTP request contains the "Metadata: true" header 226 | // and ensures that the same header is set for responses. 227 | func metadataHeaderVerifier(next http.Handler) http.Handler { 228 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 229 | w.Header().Set("Metadata", "true") 230 | if r.Header.Get("Metadata") != "true" { 231 | e := &azureResponseError{error: invalidRequest, description: "Required metadata header not specified"} 232 | http.Error(w, e.Error(), http.StatusBadRequest) 233 | 234 | return 235 | } 236 | 237 | next.ServeHTTP(w, r) 238 | }) 239 | } 240 | 241 | func azureAPIVersionVerifier(next http.Handler) http.Handler { 242 | return azureQueryParamVerifier(next, "api-version") 243 | } 244 | 245 | func azureResourceVerifier(next http.Handler) http.Handler { 246 | return azureQueryParamVerifier(next, "resource") 247 | } 248 | 249 | func azureQueryParamVerifier(next http.Handler, paramName string) http.Handler { 250 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 251 | apiVersion := r.URL.Query().Get(paramName) 252 | if apiVersion == "" { 253 | e := &azureResponseError{ 254 | error: invalidRequest, 255 | description: fmt.Sprintf("Required query variable '%v' is missing", paramName), 256 | } 257 | 258 | http.Error(w, e.Error(), http.StatusBadRequest) 259 | 260 | return 261 | } 262 | 263 | next.ServeHTTP(w, r) 264 | }) 265 | } 266 | 267 | // AzureCredentials fields are documented here: 268 | // https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http 269 | type AzureCredentials struct { 270 | AccessToken string `json:"access_token"` 271 | RefreshToken string `json:"refresh_token"` 272 | 273 | ExpiresIn string `json:"expires_in"` 274 | ExpiresOn string `json:"expires_on"` 275 | NotBefore string `json:"not_before"` 276 | 277 | Resource string `json:"resource"` 278 | Type string `json:"token_type"` 279 | } 280 | 281 | const ( 282 | // For Azure expires_on formats see: 283 | // https://github.com/Azure/go-autorest/blob/10e0b31633f168ce1a329dcbdd0ab9842e533fb5/autorest/adal/token.go#L85-L89 284 | 285 | // the format for expires_on in UTC with AM/PM. 286 | expiresOnDateFormatPM = "1/2/2006 15:04:05 PM +00:00" 287 | 288 | // the format for expires_on in UTC without AM/PM. 289 | expiresOnDateFormat = "1/2/2006 15:04:05 +00:00" 290 | ) 291 | 292 | // parseExpiresOn converts expires_on to time.Time. 293 | func parseExpiresOn(s string) (*time.Time, error) { 294 | if seconds, err := strconv.ParseInt(s, 10, 64); err == nil { 295 | eo := time.Unix(seconds, 0) 296 | 297 | return &eo, nil 298 | } else if eo, err := time.Parse(expiresOnDateFormatPM, s); err == nil { 299 | t := eo.UTC() 300 | 301 | return &t, nil 302 | } else if eo, err := time.Parse(expiresOnDateFormat, s); err == nil { 303 | t := eo.UTC() 304 | 305 | return &t, nil 306 | } else { 307 | return nil, err 308 | } 309 | } 310 | 311 | type azureVaultTokenFetcher struct { 312 | vault *vault.Client 313 | vaultTokenEndpoint string 314 | 315 | resource string 316 | metricSink metrics.MetricSink 317 | } 318 | 319 | func NewAzureVaultTokenFetcher( 320 | vault *vault.Client, 321 | vaultMountPath, iamRole, resource string, 322 | metricSink metrics.MetricSink, 323 | ) (cache.Fetcher[*AzureCredentials], error) { 324 | if vault == nil { 325 | return nil, errors.New("vault client cannot be nil") 326 | } 327 | 328 | if vaultMountPath == "" { 329 | return nil, errors.New("vaultMountPath cannot be empty") 330 | } 331 | 332 | if iamRole == "" { 333 | return nil, errors.New("iamRole cannot be empty") 334 | } 335 | 336 | if resource == "" { 337 | return nil, errors.New("resource cannot be empty") 338 | } 339 | 340 | if metricSink == nil { 341 | return nil, errors.New("metric sink cannot be nil") 342 | } 343 | 344 | return &azureVaultTokenFetcher{ 345 | vault: vault, 346 | vaultTokenEndpoint: path.Join(vaultMountPath, "token", iamRole), 347 | resource: resource, 348 | metricSink: metricSink, 349 | }, nil 350 | } 351 | 352 | func (a *azureVaultTokenFetcher) String() string { 353 | return "azure-token-vault" 354 | } 355 | 356 | func (a *azureVaultTokenFetcher) Fetch(ctx context.Context) (creds *cache.ExpiringValue[*AzureCredentials], err error) { 357 | fetchSpan, ctx := tracer.StartSpanFromContext(ctx, "AzureVaultTokenFetcher.Fetch") 358 | defer func() { 359 | fetchSpan.Finish(tracer.WithError(err)) 360 | 361 | statusLabel := labelSuccess 362 | if err != nil { 363 | statusLabel = labelFail 364 | } 365 | 366 | labels := []metrics.Label{labelAzureProvider, labelVaultMethod, statusLabel} 367 | a.metricSink.IncrCounterWithLabels(statsdCloudCredRequest, 1, labels) 368 | }() 369 | 370 | now := timeNow() 371 | secret, err := a.vault.ReadWithData(ctx, a.vaultTokenEndpoint, map[string][]string{"resource": {a.resource}}) 372 | if err != nil { 373 | return nil, err 374 | } 375 | 376 | if secret == nil || secret.Data == nil { 377 | return nil, newRoleDoesNotExistError(a.vaultTokenEndpoint) 378 | } 379 | 380 | result := &AzureCredentials{} 381 | decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ 382 | Metadata: nil, 383 | Result: result, 384 | TagName: "json", 385 | }) 386 | if err != nil { 387 | return nil, err 388 | } 389 | 390 | err = decoder.Decode(secret.Data) 391 | if err != nil { 392 | return nil, err 393 | } 394 | 395 | expiresIn, err := strconv.ParseInt(result.ExpiresIn, 10, 0) 396 | if err != nil { 397 | return nil, fmt.Errorf("unable to convert expires_in (%q): %w", result.ExpiresIn, err) 398 | } 399 | expiresOn := now.Add(time.Duration(expiresIn) * time.Second) 400 | result.ExpiresOn = strconv.FormatInt(expiresOn.Unix(), 10) 401 | 402 | return &cache.ExpiringValue[*AzureCredentials]{ 403 | Value: result, 404 | ExpiresAt: expiresOn, 405 | }, nil 406 | } 407 | -------------------------------------------------------------------------------- /internal/imds/azure_subscription.go: -------------------------------------------------------------------------------- 1 | package imds 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "path" 7 | "sync" 8 | "time" 9 | 10 | "github.com/DataDog/attache/internal/vault" 11 | "github.com/mitchellh/mapstructure" 12 | ) 13 | 14 | // AzureSubscriptionIDGetter returns the Azure subscription ID for IMDS 15 | type AzureSubscriptionIDGetter interface { 16 | getSubscriptionID(ctx context.Context) (string, error) 17 | } 18 | 19 | // azureStaticSubscriptionIDGetter is an AzureSubscriptionIDGetter that always returns a static value 20 | type azureStaticSubscriptionIDGetter struct { 21 | subscriptionID string 22 | } 23 | 24 | func NewAzureStaticSubscriptionIDGetter(subscriptionID string) AzureSubscriptionIDGetter { 25 | return &azureStaticSubscriptionIDGetter{ 26 | subscriptionID: subscriptionID, 27 | } 28 | } 29 | 30 | func (g *azureStaticSubscriptionIDGetter) getSubscriptionID(ctx context.Context) (string, error) { 31 | return g.subscriptionID, nil 32 | } 33 | 34 | // azureVaultSubscriptionIDGetter is an AzureSubscriptionIDGetter that fetches the subscription ID from Vault 35 | type azureVaultSubscriptionIDGetter struct { 36 | vault *vault.Client 37 | vaultConfigEndpoint string 38 | 39 | subscription string 40 | subscriptionExpiration time.Time 41 | subscriptionMutex sync.RWMutex 42 | } 43 | 44 | func NewAzureVaultSubscriptionIDGetter(vault *vault.Client, vaultMountPath string) AzureSubscriptionIDGetter { 45 | return &azureVaultSubscriptionIDGetter{ 46 | vault: vault, 47 | vaultConfigEndpoint: path.Join(vaultMountPath, "config"), 48 | } 49 | } 50 | 51 | type vaultAzureConfig struct { 52 | TenantID string `mapstructure:"tenant_id"` 53 | SubscriptionID string `mapstructure:"subscription_id"` 54 | ClientID string `mapstructure:"client_id"` 55 | Environment string `mapstructure:"environment"` 56 | } 57 | 58 | // getSubscriptionID fetches the Azure subscription ID from Vault. 59 | func (g *azureVaultSubscriptionIDGetter) getSubscriptionID(ctx context.Context) (string, error) { 60 | g.subscriptionMutex.RLock() 61 | curSubscription := g.subscription 62 | curSubscriptionExpiry := g.subscriptionExpiration 63 | g.subscriptionMutex.RUnlock() 64 | if curSubscription != "" && curSubscriptionExpiry.After(timeNow()) { 65 | return curSubscription, nil 66 | } 67 | 68 | g.subscriptionMutex.Lock() 69 | defer g.subscriptionMutex.Unlock() 70 | if g.subscription != "" && g.subscriptionExpiration.After(timeNow()) { 71 | return g.subscription, nil 72 | } 73 | 74 | secret, err := g.vault.Read(ctx, g.vaultConfigEndpoint) 75 | if err != nil { 76 | return "", err 77 | } 78 | 79 | if secret == nil || secret.Data == nil { 80 | return "", errors.New("vault azure config is empty") 81 | } 82 | 83 | var c vaultAzureConfig 84 | if err = mapstructure.Decode(secret.Data, &c); err != nil { 85 | return "", err 86 | } 87 | 88 | g.subscription = c.SubscriptionID 89 | g.subscriptionExpiration = timeNow().Add(5 * time.Minute) 90 | 91 | return g.subscription, nil 92 | } 93 | -------------------------------------------------------------------------------- /internal/imds/azure_test.go: -------------------------------------------------------------------------------- 1 | package imds 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/http/httptest" 10 | "path" 11 | "strconv" 12 | "testing" 13 | "time" 14 | 15 | "github.com/DataDog/attache/internal/cache" 16 | "github.com/DataDog/attache/internal/vault" 17 | "github.com/fatih/structs" 18 | "github.com/hashicorp/go-metrics" 19 | "github.com/stretchr/testify/assert" 20 | "github.com/stretchr/testify/require" 21 | "go.uber.org/zap/zaptest" 22 | muxt "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorilla/mux" 23 | ) 24 | 25 | func structsMap(s interface{}) map[string]interface{} { 26 | t := structs.New(s) 27 | t.TagName = "json" 28 | 29 | return t.Map() 30 | } 31 | 32 | type providerParams struct { 33 | fetcherFactory func(resource string) (cache.Fetcher[*AzureCredentials], error) 34 | subscriptionIDGetter AzureSubscriptionIDGetter 35 | } 36 | 37 | func createAzureRouter(t *testing.T, params *providerParams) (*muxt.Router, error) { 38 | t.Helper() 39 | 40 | log := zaptest.NewLogger(t) 41 | refreshFunc := cache.NewPercentageRemainingRefreshAt(1, 0) 42 | 43 | p, err := Azure(context.Background(), log, &metrics.BlackholeSink{}, refreshFunc, params.fetcherFactory, params.subscriptionIDGetter) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | r := muxt.NewRouter() 49 | 50 | factory := NewHandlerFactory(&metrics.BlackholeSink{}, log) 51 | err = p.RegisterHandlers(r, factory) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return r, nil 57 | } 58 | 59 | func createAzureProviderVaultParams(v *vault.Client) *providerParams { 60 | return &providerParams{ 61 | fetcherFactory: func(resource string) (cache.Fetcher[*AzureCredentials], error) { 62 | return NewAzureVaultTokenFetcher(v, mountPath, "fake-iam-role", resource, &metrics.BlackholeSink{}) 63 | }, 64 | subscriptionIDGetter: NewAzureVaultSubscriptionIDGetter(v, mountPath), 65 | } 66 | } 67 | 68 | func Test_azureProvider_handleGetTokenVault(t *testing.T) { 69 | v := newVaultCluster(t) 70 | 71 | configureVaultCluster(t, v) 72 | 73 | tenant, resource := "test-tenant", "https://resource.endpoint/" 74 | _, err := v.Write(context.Background(), path.Join(mountPath, "config"), map[string]interface{}{ 75 | "environment": "AzurePublicCloud", 76 | "tenant_id": tenant, 77 | "subscription_id": "test-subscription", 78 | "client_id": "test-vault-backend-client-id", 79 | }) 80 | require.NoError(t, err) 81 | 82 | expiresOn := time.Now().UTC().Add(1 * time.Hour) 83 | expiresOnStr := strconv.FormatInt(expiresOn.Unix(), 10) 84 | _, err = v.Write(context.Background(), path.Join(mountPath, "token", iamRole), structsMap(AzureCredentials{ 85 | AccessToken: "test-access-token", 86 | ExpiresOn: expiresOnStr, 87 | ExpiresIn: "3600", 88 | NotBefore: expiresOnStr, 89 | Resource: resource, 90 | Type: "Bearer", 91 | })) 92 | require.NoError(t, err) 93 | 94 | r, err := createAzureRouter(t, createAzureProviderVaultParams(v)) 95 | assert.NoError(t, err) 96 | 97 | req, _ := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/metadata/identity/oauth2/token?resource=https%3A%2F%2Fresource.endpoint%2F&api-version=2020-02-02", nil) 98 | req.Header.Add("Metadata", "true") 99 | recorder := httptest.NewRecorder() 100 | 101 | r.ServeHTTP(recorder, req) 102 | resp := recorder.Result() 103 | 104 | assert.Equal(t, http.StatusOK, resp.StatusCode) 105 | 106 | body, err := io.ReadAll(resp.Body) 107 | require.NoError(t, err) 108 | require.NoError(t, resp.Body.Close()) 109 | 110 | var received AzureCredentials 111 | err = json.Unmarshal(body, &received) 112 | require.NoError(t, err) 113 | 114 | receivedExpiresIn, err := strconv.ParseInt(received.ExpiresIn, 10, 0) 115 | require.NoError(t, err) 116 | 117 | assert.Equal(t, "test-access-token", received.AccessToken) 118 | assert.Equal(t, expiresOnStr, received.ExpiresOn) 119 | assert.Less(t, receivedExpiresIn, int64(3600)) 120 | assert.Equal(t, "https://resource.endpoint/", received.Resource) 121 | assert.Equal(t, "Bearer", received.Type) 122 | assert.NotEmpty(t, received.ExpiresIn) 123 | 124 | // canceled request 125 | canceledCtx, cancelCtxFn := context.WithCancel(context.Background()) 126 | cancelCtxFn() 127 | 128 | req, _ = http.NewRequestWithContext(canceledCtx, "GET", "http://localhost/metadata/identity/oauth2/token?resource=https%3A%2F%2Fresource.endpoint%2F&api-version=2020-02-02", nil) 129 | req.Header.Add("Metadata", "true") 130 | recorder = httptest.NewRecorder() 131 | 132 | r.ServeHTTP(recorder, req) 133 | resp = recorder.Result() 134 | 135 | body, err = io.ReadAll(resp.Body) 136 | require.NoError(t, err) 137 | require.NoError(t, resp.Body.Close()) 138 | assert.Equal(t, statusClientClosedRequest, resp.StatusCode) 139 | assert.Contains(t, string(body), statusClientClosedRequestText) 140 | } 141 | 142 | func Test_azureProvider_handleGetVaultSubscriptionID(t *testing.T) { 143 | v := newVaultCluster(t) 144 | configureVaultCluster(t, v) 145 | 146 | _, err := v.Write(context.Background(), path.Join(mountPath, "config"), map[string]interface{}{ 147 | "environment": "AzurePublicCloud", 148 | "tenant_id": "test-tenant", 149 | "subscription_id": "test-subscription", 150 | "client_id": "test-vault-backend-client-id", 151 | }) 152 | require.NoError(t, err) 153 | 154 | r, err := createAzureRouter(t, createAzureProviderVaultParams(v)) 155 | assert.NoError(t, err) 156 | 157 | req, _ := http.NewRequestWithContext(context.TODO(), "GET", "http://localhost/metadata/instance/compute/subscriptionId?api-version=2017-08-01", nil) 158 | req.Header.Add("Metadata", "true") 159 | recorder := httptest.NewRecorder() 160 | 161 | r.ServeHTTP(recorder, req) 162 | resp := recorder.Result() 163 | 164 | assert.Equal(t, http.StatusOK, resp.StatusCode) 165 | 166 | body, err := io.ReadAll(resp.Body) 167 | assert.NoError(t, err) 168 | require.NoError(t, resp.Body.Close()) 169 | 170 | assert.Equal(t, "test-subscription", string(body)) 171 | } 172 | 173 | func Test_verifyMetadataHeader(t *testing.T) { 174 | v := newVaultCluster(t) 175 | r, err := createAzureRouter(t, createAzureProviderVaultParams(v)) 176 | assert.NoError(t, err) 177 | 178 | tests := map[string]struct { 179 | path string 180 | }{ 181 | "handleSubscriptionID": { 182 | path: "/metadata/instance/compute/subscriptionId", 183 | }, 184 | "handleToken": { 185 | path: "/metadata/identity/oauth2/token", 186 | }, 187 | } 188 | for tn, tt := range tests { 189 | t.Run(tn, func(t *testing.T) { 190 | req, _ := http.NewRequestWithContext(context.Background(), "GET", "http://localhost"+tt.path, nil) 191 | recorder := httptest.NewRecorder() 192 | 193 | r.ServeHTTP(recorder, req) 194 | resp := recorder.Result() 195 | 196 | assert.Equal(t, http.StatusBadRequest, resp.StatusCode) 197 | 198 | body, err := io.ReadAll(resp.Body) 199 | assert.NoError(t, err) 200 | require.NoError(t, resp.Body.Close()) 201 | assert.Equal(t, "{\"error\":\"invalid_request\",\"error_description\":\"Required metadata header not specified\"}\n", string(body)) 202 | }) 203 | } 204 | } 205 | 206 | func Test_verifyQueryParameterExists(t *testing.T) { 207 | v := newVaultCluster(t) 208 | 209 | r, err := createAzureRouter(t, createAzureProviderVaultParams(v)) 210 | assert.NoError(t, err) 211 | 212 | configureVaultCluster(t, v) 213 | 214 | tenant, resource := "test-tenant", "https://resource.endpoint/" 215 | _, err = v.Write(context.Background(), path.Join(mountPath, "config"), map[string]interface{}{ 216 | "environment": "AzurePublicCloud", 217 | "tenant_id": tenant, 218 | "subscription_id": "test-subscription", 219 | "client_id": "test-vault-backend-client-id", 220 | }) 221 | require.NoError(t, err) 222 | 223 | nowTime := time.Now() 224 | _, err = v.Write(context.Background(), path.Join(mountPath, "token", iamRole), structsMap(AzureCredentials{ 225 | AccessToken: "test-access-token", 226 | ExpiresIn: "3600", 227 | ExpiresOn: strconv.FormatInt(nowTime.Add(1*time.Hour).Unix(), 10), 228 | NotBefore: strconv.FormatInt(nowTime.Unix(), 10), 229 | Resource: resource, 230 | Type: "Bearer", 231 | })) 232 | require.NoError(t, err) 233 | 234 | tests := map[string]struct { 235 | path string 236 | missing string 237 | httpStatus int 238 | }{ 239 | "handleSubscriptionID missing api-version": { 240 | path: "/metadata/instance/compute/subscriptionId?resource=blah", 241 | missing: "api-version", 242 | httpStatus: http.StatusBadRequest, 243 | }, 244 | "handleSubscriptionID valid request": { 245 | path: "/metadata/instance/compute/subscriptionId?api-version=blah", 246 | missing: "", 247 | httpStatus: http.StatusOK, 248 | }, 249 | "handleToken missing resource": { 250 | path: "/metadata/identity/oauth2/token?api-version=blah", 251 | missing: "resource", 252 | httpStatus: http.StatusBadRequest, 253 | }, 254 | "handleToken missing api-version": { 255 | path: "/metadata/identity/oauth2/token?resource=blah", 256 | missing: "api-version", 257 | httpStatus: http.StatusBadRequest, 258 | }, 259 | "handleToken valid request": { 260 | path: "/metadata/identity/oauth2/token?resource=blah&api-version=blah", 261 | missing: "", 262 | httpStatus: http.StatusOK, 263 | }, 264 | } 265 | for tn, tt := range tests { 266 | t.Run(tn, func(t *testing.T) { 267 | req, _ := http.NewRequestWithContext(context.Background(), "GET", "http://localhost"+tt.path, nil) 268 | req.Header.Add("Metadata", "true") 269 | recorder := httptest.NewRecorder() 270 | 271 | r.ServeHTTP(recorder, req) 272 | resp := recorder.Result() 273 | 274 | assert.Equal(t, tt.httpStatus, resp.StatusCode) 275 | 276 | if tt.httpStatus == http.StatusBadRequest { 277 | body, err := io.ReadAll(resp.Body) 278 | assert.NoError(t, err) 279 | require.NoError(t, resp.Body.Close()) 280 | assert.Equal(t, fmt.Sprintf("{\"error\":\"invalid_request\",\"error_description\":\"Required query variable '%v' is missing\"}\n", tt.missing), string(body)) 281 | } 282 | }) 283 | } 284 | } 285 | 286 | func Test_parseExpiresOn(t *testing.T) { 287 | now := time.Now().UTC().Round(time.Second) 288 | 289 | tests := map[string]struct { 290 | input string 291 | want *time.Time 292 | delta time.Duration 293 | wantErr bool 294 | }{ 295 | "successful integer parse": { 296 | input: strconv.FormatInt(now.Unix(), 10), 297 | want: timePtr(t, now), 298 | }, 299 | "successful datetime format parse": { 300 | input: now.Format(expiresOnDateFormat), 301 | want: timePtr(t, now), 302 | }, 303 | "successful datetime PM format parse": { 304 | input: now.Format(expiresOnDateFormatPM), 305 | want: timePtr(t, now), 306 | }, 307 | "invalid number": { 308 | input: "123.123.123", 309 | wantErr: true, 310 | }, 311 | "invalid datetime format": { 312 | input: now.Format(time.RFC1123Z), 313 | wantErr: true, 314 | }, 315 | } 316 | for tn, tt := range tests { 317 | t.Run(tn, func(t *testing.T) { 318 | got, err := parseExpiresOn(tt.input) 319 | if (err != nil) != tt.wantErr { 320 | t.Errorf("parseExpiresOn() error = %v, wantErr %v", err, tt.wantErr) 321 | 322 | return 323 | } 324 | if tt.want == nil { 325 | require.Nil(t, got) 326 | } else { 327 | require.NotNil(t, got) 328 | require.WithinDuration(t, *tt.want, *got, tt.delta) 329 | } 330 | }) 331 | } 332 | } 333 | 334 | func timePtr(t *testing.T, b time.Time) *time.Time { 335 | t.Helper() 336 | 337 | return &b 338 | } 339 | -------------------------------------------------------------------------------- /internal/imds/doc.go: -------------------------------------------------------------------------------- 1 | // Package imds contains emulations of cloud provider IMDS APIs 2 | package imds 3 | -------------------------------------------------------------------------------- /internal/imds/metadataserver.go: -------------------------------------------------------------------------------- 1 | package imds 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "strings" 9 | "time" 10 | 11 | "github.com/DataDog/attache/internal/cache" 12 | "github.com/DataDog/attache/internal/rate" 13 | "github.com/DataDog/attache/internal/retry" 14 | "github.com/DataDog/attache/internal/server" 15 | "github.com/DataDog/attache/internal/vault" 16 | "github.com/hashicorp/go-metrics" 17 | "github.com/hashicorp/go-multierror" 18 | "go.uber.org/zap" 19 | muxt "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorilla/mux" 20 | ) 21 | 22 | // Config for metadata server. 23 | type Config struct { 24 | // the cloud provider IAM role (e.g. for GCP, the service account) 25 | IamRole string `yaml:"iam_role"` 26 | 27 | // By default AWS IMDSv1 is disabled 28 | IMDSv1Allowed bool `yaml:"imds_v1_allowed"` 29 | 30 | // the Vault mount path for GCP 31 | GcpVaultMountPath string `yaml:"gcp_vault_mount_path"` 32 | 33 | // mapping from gcp project name to the gcp project id 34 | GcpProjectIds map[string]string `yaml:"gcp_project_ids"` 35 | 36 | // the Vault mount path for AWS 37 | AwsVaultMountPath string `yaml:"aws_vault_mount_path"` 38 | 39 | // the Vault mount path for Azure 40 | AzureVaultMountPath string `yaml:"azure_vault_mount_path"` 41 | 42 | // Server configuration 43 | ServerConfig server.Config `yaml:"server"` 44 | 45 | // the cloud provider (e.g., "aws") 46 | Provider string `yaml:"provider"` 47 | 48 | // the cloud provider region (e.g., "us-east-1") 49 | Region string `yaml:"region"` 50 | 51 | // the cloud provider zone (e.g., "us-east-1a") 52 | Zone string `yaml:"zone"` 53 | } 54 | 55 | // Validate a config. 56 | func (c *Config) Validate() error { 57 | var result error 58 | 59 | if strings.TrimSpace(c.IamRole) == "" { 60 | result = multierror.Append(result, errors.New("IAM role cannot be empty")) 61 | } 62 | 63 | if err := c.ServerConfig.Validate(); err != nil { 64 | result = multierror.Append(result, fmt.Errorf("server configuration not valid: %w", err)) 65 | } 66 | 67 | return result 68 | } 69 | 70 | type MetadataServerConfig struct { 71 | CloudiamConf Config 72 | DDVaultClient *vault.Client 73 | 74 | MetricSink metrics.MetricSink 75 | Log *zap.Logger 76 | } 77 | 78 | // NewServer creates a new metadata server. 79 | func NewServer(ctx context.Context, conf *MetadataServerConfig) (*server.Server, func(), error) { 80 | log := decorateLog(conf.Log) 81 | router, closeFunc, err := newRouter(ctx, conf) 82 | if err != nil { 83 | return nil, func() {}, err 84 | } 85 | 86 | return server.NewServer(log, conf.CloudiamConf.ServerConfig, router, server.WithMetricSink(conf.MetricSink)), closeFunc, nil 87 | } 88 | 89 | func newRouter(ctx context.Context, conf *MetadataServerConfig) (*muxt.Router, func(), error) { 90 | // closeFunc is returned to the caller and handled cleaningup 91 | // any resources created in this method. (ex. cloudiam client managers) 92 | var cleanupFuncs []func() 93 | closeFunc := func() { 94 | for _, f := range cleanupFuncs { 95 | f() 96 | } 97 | } 98 | 99 | router := muxt.NewRouter(muxt.WithServiceName("attache.imds"), muxt.WithIgnoreRequest(func(req *http.Request) bool { 100 | // Skip tracing AWS API token requests because they are very frequent 101 | // and make no remote calls internally so aren't very interesting. 102 | // This saves on memory allocations. 103 | if req.RequestURI == "/latest/api/token" { 104 | return true 105 | } 106 | return false 107 | })) 108 | 109 | // NOTE: it would be better if the provider wasn't responsible for registering. 110 | // Then we could sanity check what was being registered to ensure that there were 111 | // no duplicates. However, that is not trivial because a handler can be registered 112 | // not only to a path but to a host, method, header, etc... therefore it is 113 | // imperative to have sufficient testing to ensure that routes are not being subsumed. 114 | factory := NewHandlerFactory(conf.MetricSink, conf.Log) 115 | p := []Provider{} 116 | 117 | // Refresh after 20m + jitter of [0,24s] 118 | // 119 | // The AWS SDKs for Go, Java, and Python have a common denominator of a 120 | // renewal window starting when there is 15min TTL remaining on the token. 121 | // 122 | // - aws-sdk-go 5m - https://github.com/aws/aws-sdk-go/blob/main/aws/defaults/defaults.go#L205 123 | // - botocore 15m - https://github.com/boto/botocore/blob/master/botocore/credentials.py#L377-L382 124 | // - aws-sdk-java 15m - https://github.com/aws/aws-sdk-java/blob/master/aws-java-sdk-core/src/main/java/com/amazonaws/auth/BaseCredentialsFetcher.java#L42-L46 125 | // 126 | // We do not want to increase this too close to the actual expiration of the 127 | // tokens (for all cloud providers this is currently 1hr) for the following 128 | // reasons: 129 | // 130 | // 1. If Attaché IMDS returns a token too near the expiration and a client 131 | // uses it once expired without realizing that the client will receive 132 | // an error due to the expired token. 133 | // 2. Cloud provider SDKs commonly try to retrieve a new token prior to the 134 | // current active one expiring at a specific renew threshold. For AWS 135 | // this is >10min before the current token expires. This means a call to 136 | // AWS ends up first trying to retrieve a new token from Attaché. So 137 | // for every request a client makes to the cloud provider between the 138 | // renewal threshold and the token expiration, the client will try to call 139 | // Attaché IMDS for a new token. This can result in performance issues 140 | // in the application because of the extra call to linklocal Attaché 141 | // IMDS and can also potentially cause Attaché to block if its request 142 | // queue is backed up or even OOM. 143 | // 144 | // With a token refresh of 20m + jitter our Mean Time To Failure (MTTF) will 145 | // be 20m and a Least Time To Failure (LTTF) of 10m. 146 | refreshFunc := cache.NewPercentageRemainingRefreshAt(0.33333333, 0.10) 147 | retryOpts := []retry.Option{ 148 | retry.MaxAttempts(4), 149 | retry.MaxJitter(2 * time.Minute), 150 | retry.InitialDelay(10 * time.Second), 151 | } 152 | 153 | cloudiamConf := conf.CloudiamConf 154 | 155 | if strings.TrimSpace(cloudiamConf.GcpVaultMountPath) != "" { 156 | var gcpServiceAccountInfoGetter GcpServiceAccountInfoGetter 157 | var gcpTokenGetter cache.Fetcher[*GcpCredentials] 158 | 159 | vaultFetcher, err := NewGcpVaultTokenFetcher(conf.DDVaultClient, cloudiamConf.IamRole, cloudiamConf.GcpVaultMountPath, cloudiamConf.GcpProjectIds, conf.Log, conf.MetricSink) 160 | if err != nil { 161 | return nil, closeFunc, fmt.Errorf("failed to create vault GCP token fetcher: %w", err) 162 | } 163 | 164 | gcpServiceAccountInfoGetter = vaultFetcher 165 | gcpTokenGetter = vaultFetcher 166 | 167 | gcpProvider, err := Gcp(ctx, conf.Log, conf.MetricSink, gcpTokenGetter, gcpServiceAccountInfoGetter, refreshFunc, retryOpts...) 168 | if err != nil { 169 | return nil, closeFunc, fmt.Errorf("unable to create GCP provider: %w", err) 170 | } 171 | p = append(p, gcpProvider) 172 | } 173 | 174 | if strings.TrimSpace(cloudiamConf.AwsVaultMountPath) != "" { 175 | identifier := NewAwsInstanceIdentifier(cloudiamConf.Provider, cloudiamConf.Region, cloudiamConf.Zone) 176 | 177 | var awsRoleGetter AwsRoleGetter 178 | var awsTokenFetcher cache.Fetcher[*AwsCredentials] 179 | 180 | vaultFetcher, err := NewVaultAwsStsTokenFetcher(conf.DDVaultClient, cloudiamConf.IamRole, cloudiamConf.AwsVaultMountPath, conf.Log, conf.MetricSink) 181 | if err != nil { 182 | return nil, closeFunc, fmt.Errorf("failed to create vault AWS token fetcher: %w", err) 183 | } 184 | 185 | awsRoleGetter = vaultFetcher 186 | awsTokenFetcher = vaultFetcher 187 | 188 | awsProvider, err := Aws(ctx, conf.Log, conf.CloudiamConf.IMDSv1Allowed, conf.MetricSink, awsTokenFetcher, awsRoleGetter, identifier, refreshFunc, retryOpts...) 189 | if err != nil { 190 | return nil, closeFunc, fmt.Errorf("unable to create AWS provider: %w", err) 191 | } 192 | p = append(p, awsProvider) 193 | } 194 | 195 | if strings.TrimSpace(cloudiamConf.AzureVaultMountPath) != "" { 196 | azureSubscriptionIDGetter := NewAzureVaultSubscriptionIDGetter(conf.DDVaultClient, cloudiamConf.AzureVaultMountPath) 197 | 198 | tokenFetcherFactory := func(resource string) (cache.Fetcher[*AzureCredentials], error) { 199 | vaultFetcher, err := NewAzureVaultTokenFetcher(conf.DDVaultClient, cloudiamConf.AzureVaultMountPath, cloudiamConf.IamRole, resource, conf.MetricSink) 200 | if err != nil { 201 | return nil, err 202 | } 203 | return vaultFetcher, nil 204 | } 205 | 206 | azureProvider, err := Azure(ctx, conf.Log, conf.MetricSink, refreshFunc, tokenFetcherFactory, azureSubscriptionIDGetter, retryOpts...) 207 | if err != nil { 208 | return nil, closeFunc, fmt.Errorf("unable to create Azure provider: %w", err) 209 | } 210 | p = append(p, azureProvider) 211 | } 212 | 213 | if len(p) == 0 { 214 | return nil, closeFunc, errors.New("no metadataserver providers registered") 215 | } 216 | 217 | for _, provider := range p { 218 | if err := provider.RegisterHandlers(router, factory); err != nil { 219 | return nil, closeFunc, fmt.Errorf("unable to register provider %v: %w", provider.Name(), err) 220 | } 221 | } 222 | 223 | // Limiter rate equal to len(p) for the steady state credential fetch from Vault, 224 | // effectively 1 request/s per cloud provider. Burst/bucket capacity of 225 | // (2 x len(p) + overhead) allowing for the initial credentials fetch which 226 | // requires an additional request for metadata plus a little overhead. 227 | totalProviders := len(p) 228 | conf.DDVaultClient.SetLimiter(rate.NewLimiter(rate.Limit(totalProviders), (2*totalProviders)+4)) 229 | 230 | return router, closeFunc, nil 231 | } 232 | 233 | func decorateLog(log *zap.Logger) *zap.Logger { 234 | return log.Named("cloud-iam-server") 235 | } 236 | -------------------------------------------------------------------------------- /internal/imds/metadataserver_test.go: -------------------------------------------------------------------------------- 1 | package imds 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | vaultclient "github.com/DataDog/attache/internal/vault" 8 | "github.com/hashicorp/go-metrics" 9 | "github.com/stretchr/testify/require" 10 | "go.uber.org/zap" 11 | "go.uber.org/zap/zaptest" 12 | ) 13 | 14 | func fromVaultClient(t *testing.T) *vaultclient.Client { 15 | t.Helper() 16 | 17 | config := vaultclient.DefaultConfig() 18 | config.Insecure = true 19 | 20 | ddClient, err := vaultclient.NewClient(config) 21 | require.NoError(t, err) 22 | 23 | return ddClient 24 | } 25 | 26 | func TestNewServer(t *testing.T) { 27 | _ = newVaultCluster(t) 28 | 29 | logger := zaptest.NewLogger(t, zaptest.Level(zap.DebugLevel)) 30 | config := Config{ 31 | IamRole: "blah", 32 | AwsVaultMountPath: "aws", 33 | GcpVaultMountPath: "gcp", 34 | AzureVaultMountPath: "azure", 35 | } 36 | 37 | server, closeFunc, err := NewServer(context.Background(), &MetadataServerConfig{ 38 | CloudiamConf: config, 39 | DDVaultClient: fromVaultClient(t), 40 | MetricSink: &metrics.BlackholeSink{}, 41 | Log: logger, 42 | }) 43 | require.NoError(t, err) 44 | require.NotNil(t, closeFunc) 45 | defer closeFunc() 46 | 47 | eChan := make(chan error, 1) 48 | shutdown := server.Run(eChan) 49 | defer shutdown() 50 | } 51 | -------------------------------------------------------------------------------- /internal/imds/providers.go: -------------------------------------------------------------------------------- 1 | package imds 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "reflect" 9 | "runtime" 10 | "strconv" 11 | "strings" 12 | "time" 13 | 14 | "github.com/hashicorp/go-metrics" 15 | "go.uber.org/zap" 16 | muxt "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorilla/mux" 17 | ) 18 | 19 | var ( 20 | statsdCloudCredRequest = []string{"cloudcred", "request"} 21 | 22 | labelVaultMethod = metrics.Label{Name: "method", Value: "vault"} 23 | labelSuccess = metrics.Label{Name: "status", Value: "success"} 24 | labelFail = metrics.Label{Name: "status", Value: "fail"} 25 | ) 26 | 27 | // Provider encapsulates all the parameters necessary for implementing the cloud 28 | // provider's Metadata Service backed by Vault. 29 | type Provider interface { 30 | 31 | // Name returns the name of the Provider. 32 | Name() string 33 | 34 | // RegisterHandlers registers HTTP handlers with the server. 35 | // 36 | // `mux.Router` parameter is the server and `HandlerFactory` is a factory 37 | // for creating HTTP Handlers for handling requests. 38 | RegisterHandlers(router *muxt.Router, factory *HandlerFactory) error 39 | } 40 | 41 | var _ handlerError = &roleDoesNotExistError{} 42 | 43 | type roleDoesNotExistError struct { 44 | roleName string 45 | } 46 | 47 | func (err *roleDoesNotExistError) Error() string { 48 | return fmt.Sprintf("role %q does not exist", err.roleName) 49 | } 50 | 51 | func (err *roleDoesNotExistError) Status() int { 52 | return 404 53 | } 54 | 55 | func newRoleDoesNotExistError(roleName string) error { 56 | return &roleDoesNotExistError{ 57 | roleName: roleName, 58 | } 59 | } 60 | 61 | // HandlerFactory struct for creating Handlers. 62 | type HandlerFactory struct { 63 | logger *zap.Logger 64 | metricSink metrics.MetricSink 65 | } 66 | 67 | // NewHandlerFactory creates a HandlerFactory. 68 | func NewHandlerFactory(metricSink metrics.MetricSink, log *zap.Logger) *HandlerFactory { 69 | return &HandlerFactory{ 70 | logger: log, 71 | metricSink: metricSink, 72 | } 73 | } 74 | 75 | // CreateHTTPHandler for an HTTP server. 76 | func (f *HandlerFactory) CreateHTTPHandler(provider string, handlerFunc handlerFunc) http.Handler { 77 | return &handler{ 78 | function: handlerFunc, 79 | logger: f.logger, 80 | name: functionName(handlerFunc), 81 | provider: provider, 82 | metricSink: f.metricSink, 83 | } 84 | } 85 | 86 | func functionName(handlerFunc handlerFunc) string { 87 | dotName := runtime.FuncForPC(reflect.ValueOf(handlerFunc).Pointer()).Name() 88 | n := strings.Split(dotName, ".") 89 | name := strings.TrimSuffix(n[len(n)-1], "-fm") 90 | 91 | return name 92 | } 93 | 94 | type handlerError interface { 95 | error 96 | Status() int 97 | } 98 | 99 | // HTTPError should be returned for all HTTP handlers that need to return an error with a custom 100 | // HTTP status code or http response body. Otherwise, HTTP handlers should return `error`. 101 | type HTTPError struct { 102 | code int 103 | error error 104 | } 105 | 106 | func (he HTTPError) Error() string { 107 | return he.error.Error() 108 | } 109 | 110 | // Status code of the HTTP response. 111 | func (he HTTPError) Status() int { 112 | return he.code 113 | } 114 | 115 | type handlerFunc func(*zap.Logger, http.ResponseWriter, *http.Request) error 116 | 117 | type handler struct { 118 | function handlerFunc 119 | logger *zap.Logger 120 | name string 121 | metricSink metrics.MetricSink 122 | provider string 123 | } 124 | 125 | var requestTagsKey = &struct{}{} 126 | 127 | func wrapRequestTag(r *http.Request, k, v string) *http.Request { 128 | tags, ok := r.Context().Value(requestTagsKey).(map[string]string) 129 | if ok && tags != nil { 130 | tags[k] = v 131 | return r 132 | } 133 | 134 | tags = map[string]string{k: v} 135 | 136 | return r.WithContext(context.WithValue(r.Context(), requestTagsKey, tags)) 137 | } 138 | 139 | func (h *handler) requestTags(w *responseWriter, r *http.Request) []metrics.Label { 140 | extra, _ := r.Context().Value(requestTagsKey).(map[string]string) 141 | 142 | tags := make([]metrics.Label, 4, 4+len(extra)) 143 | 144 | tags = append(tags, 145 | metrics.Label{Name: "method:%v", Value: r.Method}, 146 | metrics.Label{Name: "name:%v", Value: h.name}, 147 | metrics.Label{Name: "provider:%v", Value: h.provider}, 148 | metrics.Label{Name: "status_code:%v", Value: strconv.Itoa(w.statusCode)}, 149 | ) 150 | 151 | for k, v := range extra { 152 | tags = append(tags, metrics.Label{Name: k, Value: v}) 153 | } 154 | 155 | return tags 156 | } 157 | 158 | const ( 159 | statusClientClosedRequest = 499 160 | statusClientClosedRequestText = "Client Closed Request" 161 | ) 162 | 163 | // ServeHTTP serves HTTP requests. 164 | func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 165 | rw := newResponseWriter(w) 166 | start := timeNow() 167 | 168 | err := h.function(h.logger, rw, r) 169 | if err != nil { 170 | var handlerError handlerError 171 | if errors.As(err, &handlerError) { 172 | // always log the full original error, but take the status from the underlying wrapped 173 | // httpError, and return the error message and status from the underlying http error only. 174 | h.logger.Error(http.StatusText(handlerError.Status()), zap.Error(err)) 175 | http.Error(rw, handlerError.Error(), handlerError.Status()) 176 | } else if errors.Is(err, context.Canceled) { 177 | // use non-standard 499 status code for instrumentation. This should be unseen by the client. 178 | h.logger.Warn(statusClientClosedRequestText, zap.Error(err)) 179 | http.Error(rw, statusClientClosedRequestText, statusClientClosedRequest) 180 | } else { 181 | h.logger.Error(http.StatusText(http.StatusInternalServerError), zap.Error(err)) 182 | http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 183 | } 184 | } 185 | 186 | tags := h.requestTags(rw, r) 187 | 188 | h.metricSink.AddSampleWithLabels([]string{"request_duration_seconds"}, float32(time.Since(start).Seconds()), tags) 189 | } 190 | 191 | type responseWriter struct { 192 | http.ResponseWriter 193 | statusCode int 194 | } 195 | 196 | // WriteHeader for HTTP responses. 197 | func (rw *responseWriter) WriteHeader(statusCode int) { 198 | rw.statusCode = statusCode 199 | rw.ResponseWriter.WriteHeader(statusCode) 200 | } 201 | 202 | func newResponseWriter(w http.ResponseWriter) *responseWriter { 203 | return &responseWriter{w, http.StatusOK} 204 | } 205 | -------------------------------------------------------------------------------- /internal/imds/providers_test.go: -------------------------------------------------------------------------------- 1 | package imds 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/DataDog/attache/internal/cache" 8 | "github.com/DataDog/attache/internal/vault" 9 | "github.com/hashicorp/go-metrics" 10 | "github.com/stretchr/testify/require" 11 | "go.uber.org/zap/zaptest" 12 | ) 13 | 14 | func Test_functionName(t *testing.T) { 15 | v, err := vault.NewClient(vault.DefaultConfig()) 16 | require.NoError(t, err) 17 | 18 | fetcher, err := NewVaultAwsStsTokenFetcher(v, "role", "mount", zaptest.NewLogger(t), &metrics.BlackholeSink{}) 19 | require.NoError(t, err) 20 | 21 | p, err := Aws(context.Background(), zaptest.NewLogger(t), false, &metrics.BlackholeSink{}, fetcher, fetcher, &staticAwsIdentifier{}, cache.NewPercentageRemainingRefreshAt(1, 0)) 22 | require.NoError(t, err) 23 | 24 | tests := map[string]struct { 25 | handlerFunc handlerFunc 26 | want string 27 | }{ 28 | "provider function": { 29 | handlerFunc: p.handleSecurityCredentials, 30 | want: "handleSecurityCredentials", 31 | }, 32 | "nil": { 33 | handlerFunc: nil, 34 | want: "", 35 | }, 36 | } 37 | for tn, tt := range tests { 38 | t.Run(tn, func(t *testing.T) { 39 | if got := functionName(tt.handlerFunc); got != tt.want { 40 | t.Errorf("functionName() = %v, want %v", got, tt.want) 41 | } 42 | }) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /internal/imds/vault_test.go: -------------------------------------------------------------------------------- 1 | package imds 2 | 3 | import ( 4 | "os" 5 | "sync" 6 | "testing" 7 | 8 | ddvault "github.com/DataDog/attache/internal/vault" 9 | "github.com/hashicorp/go-hclog" 10 | vaulthttp "github.com/hashicorp/vault/http" 11 | "github.com/hashicorp/vault/sdk/helper/logging" 12 | "github.com/hashicorp/vault/vault" 13 | "github.com/hashicorp/vault/vault/seal" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | var wg sync.WaitGroup 18 | 19 | func newVaultCluster(t *testing.T) *ddvault.Client { 20 | t.Helper() 21 | log := logging.NewVaultLogger(hclog.Warn) 22 | coreConfig := &vault.CoreConfig{} 23 | cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ 24 | HandlerFunc: vaulthttp.Handler, 25 | Logger: log, 26 | NumCores: 1, 27 | // SealFunc doesn't need to be set for all of this to work, but in order to avoid 28 | // 'no seal config found, can't determine if legacy or new-style shamir' being logged 29 | // out multiple times per test, we explicitly configure the test seal here 30 | SealFunc: func() vault.Seal { 31 | return vault.NewTestSeal(t, &seal.TestSealOpts{ 32 | StoredKeys: seal.StoredKeysSupportedShamirRoot, 33 | }) 34 | }, 35 | }) 36 | cluster.Start() 37 | vault.TestWaitActive(t, cluster.Cores[0].Core) 38 | core := cluster.Cores[0] 39 | client := core.Client 40 | 41 | require.NoError(t, os.Setenv("VAULT_ADDR", client.Address())) 42 | require.NoError(t, os.Setenv("VAULT_TOKEN", client.Token())) 43 | 44 | config := ddvault.DefaultConfig() 45 | 46 | // we have two options when testing: directly use the core client from the 47 | // test vault cluster, which will have all the right CA Certs configured 48 | // _or_ disable TLS verification if we need to configure out own client, 49 | // which for these tests of our client configuring paths, we need to do 50 | config.Insecure = true 51 | 52 | c, err := ddvault.NewClient(config) 53 | require.NoError(t, err) 54 | 55 | t.Cleanup(func() { 56 | // this call currently includes a needless 1 second time.Sleep call, 57 | // which may be an issue as we keep adding test cases, so we do the cleanup 58 | // in its own goroutine and register with a package level wait group. 59 | wg.Add(1) 60 | go func() { 61 | cluster.Cleanup() 62 | wg.Done() 63 | }() 64 | }) 65 | 66 | return c 67 | } 68 | -------------------------------------------------------------------------------- /internal/rate/doc.go: -------------------------------------------------------------------------------- 1 | package rate 2 | 3 | // Token bucket rate limiter from the x/time/rate golang library. 4 | // Alterations include: 5 | // - Limiter.WaitNWithCallback allowing callers to obtain the 6 | // time.Duration for which a request is rate limited. 7 | -------------------------------------------------------------------------------- /internal/rate/rate.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package rate provides a rate limiter. 6 | package rate 7 | 8 | import ( 9 | "context" 10 | "fmt" 11 | "math" 12 | "sync" 13 | "time" 14 | ) 15 | 16 | // Limit defines the maximum frequency of some events. 17 | // Limit is represented as number of events per second. 18 | // A zero Limit allows no events. 19 | type Limit float64 20 | 21 | // Inf is the infinite rate limit; it allows all events (even if burst is zero). 22 | const Inf = Limit(math.MaxFloat64) 23 | 24 | // Every converts a minimum time interval between events to a Limit. 25 | func Every(interval time.Duration) Limit { 26 | if interval <= 0 { 27 | return Inf 28 | } 29 | return 1 / Limit(interval.Seconds()) 30 | } 31 | 32 | // A Limiter controls how frequently events are allowed to happen. 33 | // It implements a "token bucket" of size b, initially full and refilled 34 | // at rate r tokens per second. 35 | // Informally, in any large enough time interval, the Limiter limits the 36 | // rate to r tokens per second, with a maximum burst size of b events. 37 | // As a special case, if r == Inf (the infinite rate), b is ignored. 38 | // See https://en.wikipedia.org/wiki/Token_bucket for more about token buckets. 39 | // 40 | // The zero value is a valid Limiter, but it will reject all events. 41 | // Use NewLimiter to create non-zero Limiters. 42 | // 43 | // Limiter has three main methods, Allow, Reserve, and Wait. 44 | // Most callers should use Wait. 45 | // 46 | // Each of the three methods consumes a single token. 47 | // They differ in their behavior when no token is available. 48 | // If no token is available, Allow returns false. 49 | // If no token is available, Reserve returns a reservation for a future token 50 | // and the amount of time the caller must wait before using it. 51 | // If no token is available, Wait blocks until one can be obtained 52 | // or its associated context.Context is canceled. 53 | // 54 | // The methods AllowN, ReserveN, and WaitN consume n tokens. 55 | type Limiter struct { 56 | mu sync.Mutex 57 | limit Limit 58 | burst int 59 | tokens float64 60 | // last is the last time the limiter's tokens field was updated 61 | last time.Time 62 | // lastEvent is the latest time of a rate-limited event (past or future) 63 | lastEvent time.Time 64 | } 65 | 66 | // Limit returns the maximum overall event rate. 67 | func (lim *Limiter) Limit() Limit { 68 | lim.mu.Lock() 69 | defer lim.mu.Unlock() 70 | return lim.limit 71 | } 72 | 73 | // Burst returns the maximum burst size. Burst is the maximum number of tokens 74 | // that can be consumed in a single call to Allow, Reserve, or Wait, so higher 75 | // Burst values allow more events to happen at once. 76 | // A zero Burst allows no events, unless limit == Inf. 77 | func (lim *Limiter) Burst() int { 78 | lim.mu.Lock() 79 | defer lim.mu.Unlock() 80 | return lim.burst 81 | } 82 | 83 | // TokensAt returns the number of tokens available at time t. 84 | func (lim *Limiter) TokensAt(t time.Time) float64 { 85 | lim.mu.Lock() 86 | _, tokens := lim.advance(t) // does not mutate lim 87 | lim.mu.Unlock() 88 | return tokens 89 | } 90 | 91 | // Tokens returns the number of tokens available now. 92 | func (lim *Limiter) Tokens() float64 { 93 | return lim.TokensAt(time.Now()) 94 | } 95 | 96 | // NewLimiter returns a new Limiter that allows events up to rate r and permits 97 | // bursts of at most b tokens. 98 | func NewLimiter(r Limit, b int) *Limiter { 99 | return &Limiter{ 100 | limit: r, 101 | burst: b, 102 | } 103 | } 104 | 105 | // Allow reports whether an event may happen now. 106 | func (lim *Limiter) Allow() bool { 107 | return lim.AllowN(time.Now(), 1) 108 | } 109 | 110 | // AllowN reports whether n events may happen at time t. 111 | // Use this method if you intend to drop / skip events that exceed the rate limit. 112 | // Otherwise use Reserve or Wait. 113 | func (lim *Limiter) AllowN(t time.Time, n int) bool { 114 | return lim.reserveN(t, n, 0).ok 115 | } 116 | 117 | // A Reservation holds information about events that are permitted by a Limiter to happen after a delay. 118 | // A Reservation may be canceled, which may enable the Limiter to permit additional events. 119 | type Reservation struct { 120 | ok bool 121 | lim *Limiter 122 | tokens int 123 | timeToAct time.Time 124 | // This is the Limit at reservation time, it can change later. 125 | limit Limit 126 | } 127 | 128 | // OK returns whether the limiter can provide the requested number of tokens 129 | // within the maximum wait time. If OK is false, Delay returns InfDuration, and 130 | // Cancel does nothing. 131 | func (r *Reservation) OK() bool { 132 | return r.ok 133 | } 134 | 135 | // Delay is shorthand for DelayFrom(time.Now()). 136 | func (r *Reservation) Delay() time.Duration { 137 | return r.DelayFrom(time.Now()) 138 | } 139 | 140 | // InfDuration is the duration returned by Delay when a Reservation is not OK. 141 | const InfDuration = time.Duration(math.MaxInt64) 142 | 143 | // DelayFrom returns the duration for which the reservation holder must wait 144 | // before taking the reserved action. Zero duration means act immediately. 145 | // InfDuration means the limiter cannot grant the tokens requested in this 146 | // Reservation within the maximum wait time. 147 | func (r *Reservation) DelayFrom(t time.Time) time.Duration { 148 | if !r.ok { 149 | return InfDuration 150 | } 151 | delay := r.timeToAct.Sub(t) 152 | if delay < 0 { 153 | return 0 154 | } 155 | return delay 156 | } 157 | 158 | // Cancel is shorthand for CancelAt(time.Now()). 159 | func (r *Reservation) Cancel() { 160 | r.CancelAt(time.Now()) 161 | } 162 | 163 | // CancelAt indicates that the reservation holder will not perform the reserved action 164 | // and reverses the effects of this Reservation on the rate limit as much as possible, 165 | // considering that other reservations may have already been made. 166 | func (r *Reservation) CancelAt(t time.Time) { 167 | if !r.ok { 168 | return 169 | } 170 | 171 | r.lim.mu.Lock() 172 | defer r.lim.mu.Unlock() 173 | 174 | if r.lim.limit == Inf || r.tokens == 0 || r.timeToAct.Before(t) { 175 | return 176 | } 177 | 178 | // calculate tokens to restore 179 | // The duration between lim.lastEvent and r.timeToAct tells us how many tokens were reserved 180 | // after r was obtained. These tokens should not be restored. 181 | restoreTokens := float64(r.tokens) - r.limit.tokensFromDuration(r.lim.lastEvent.Sub(r.timeToAct)) 182 | if restoreTokens <= 0 { 183 | return 184 | } 185 | // advance time to now 186 | t, tokens := r.lim.advance(t) 187 | // calculate new number of tokens 188 | tokens += restoreTokens 189 | if burst := float64(r.lim.burst); tokens > burst { 190 | tokens = burst 191 | } 192 | // update state 193 | r.lim.last = t 194 | r.lim.tokens = tokens 195 | if r.timeToAct == r.lim.lastEvent { 196 | prevEvent := r.timeToAct.Add(r.limit.durationFromTokens(float64(-r.tokens))) 197 | if !prevEvent.Before(t) { 198 | r.lim.lastEvent = prevEvent 199 | } 200 | } 201 | } 202 | 203 | // Reserve is shorthand for ReserveN(time.Now(), 1). 204 | func (lim *Limiter) Reserve() *Reservation { 205 | return lim.ReserveN(time.Now(), 1) 206 | } 207 | 208 | // ReserveN returns a Reservation that indicates how long the caller must wait before n events happen. 209 | // The Limiter takes this Reservation into account when allowing future events. 210 | // The returned Reservation’s OK() method returns false if n exceeds the Limiter's burst size. 211 | // Usage example: 212 | // 213 | // r := lim.ReserveN(time.Now(), 1) 214 | // if !r.OK() { 215 | // // Not allowed to act! Did you remember to set lim.burst to be > 0 ? 216 | // return 217 | // } 218 | // time.Sleep(r.Delay()) 219 | // Act() 220 | // 221 | // Use this method if you wish to wait and slow down in accordance with the rate limit without dropping events. 222 | // If you need to respect a deadline or cancel the delay, use Wait instead. 223 | // To drop or skip events exceeding rate limit, use Allow instead. 224 | func (lim *Limiter) ReserveN(t time.Time, n int) *Reservation { 225 | r := lim.reserveN(t, n, InfDuration) 226 | return &r 227 | } 228 | 229 | // Wait is shorthand for WaitN(ctx, 1). 230 | func (lim *Limiter) Wait(ctx context.Context) (err error) { 231 | return lim.WaitN(ctx, 1) 232 | } 233 | 234 | // WaitN blocks until lim permits n events to happen. 235 | // It returns an error if n exceeds the Limiter's burst size, the Context is 236 | // canceled, or the expected wait time exceeds the Context's Deadline. 237 | // The burst limit is ignored if the rate limit is Inf. 238 | func (lim *Limiter) WaitN(ctx context.Context, n int) (err error) { 239 | return lim.WaitNWithCallback(ctx, n, func(delay time.Duration) {}) 240 | } 241 | 242 | func (lim *Limiter) WaitNWithCallback(ctx context.Context, n int, fn func(delay time.Duration)) (err error) { 243 | // The test code calls lim.wait with a fake timer generator. 244 | // This is the real timer generator. 245 | newTimer := func(d time.Duration) (<-chan time.Time, func() bool, func()) { 246 | timer := time.NewTimer(d) 247 | return timer.C, timer.Stop, func() {} 248 | } 249 | 250 | return lim.wait(ctx, n, time.Now(), newTimer, fn) 251 | } 252 | 253 | // wait is the internal implementation of WaitN. 254 | func (lim *Limiter) wait(ctx context.Context, n int, t time.Time, newTimer func(d time.Duration) (<-chan time.Time, func() bool, func()), fn func(delay time.Duration)) error { 255 | lim.mu.Lock() 256 | burst := lim.burst 257 | limit := lim.limit 258 | lim.mu.Unlock() 259 | 260 | if n > burst && limit != Inf { 261 | return fmt.Errorf("rate: Wait(n=%d) exceeds limiter's burst %d", n, burst) 262 | } 263 | // Check if ctx is already cancelled 264 | select { 265 | case <-ctx.Done(): 266 | return ctx.Err() 267 | default: 268 | } 269 | // Determine wait limit 270 | waitLimit := InfDuration 271 | if deadline, ok := ctx.Deadline(); ok { 272 | waitLimit = deadline.Sub(t) 273 | } 274 | // Reserve 275 | r := lim.reserveN(t, n, waitLimit) 276 | if !r.ok { 277 | return fmt.Errorf("rate: Wait(n=%d) would exceed context deadline", n) 278 | } 279 | // Wait if necessary 280 | delay := r.DelayFrom(t) 281 | if fn != nil { 282 | fn(delay) 283 | } 284 | if delay == 0 { 285 | return nil 286 | } 287 | ch, stop, advance := newTimer(delay) 288 | defer stop() 289 | advance() // only has an effect when testing 290 | select { 291 | case <-ch: 292 | // We can proceed. 293 | return nil 294 | case <-ctx.Done(): 295 | // Context was canceled before we could proceed. Cancel the 296 | // reservation, which may permit other events to proceed sooner. 297 | r.Cancel() 298 | return ctx.Err() 299 | } 300 | } 301 | 302 | // SetLimit is shorthand for SetLimitAt(time.Now(), newLimit). 303 | func (lim *Limiter) SetLimit(newLimit Limit) { 304 | lim.SetLimitAt(time.Now(), newLimit) 305 | } 306 | 307 | // SetLimitAt sets a new Limit for the limiter. The new Limit, and Burst, may be violated 308 | // or underutilized by those which reserved (using Reserve or Wait) but did not yet act 309 | // before SetLimitAt was called. 310 | func (lim *Limiter) SetLimitAt(t time.Time, newLimit Limit) { 311 | lim.mu.Lock() 312 | defer lim.mu.Unlock() 313 | 314 | t, tokens := lim.advance(t) 315 | 316 | lim.last = t 317 | lim.tokens = tokens 318 | lim.limit = newLimit 319 | } 320 | 321 | // SetBurst is shorthand for SetBurstAt(time.Now(), newBurst). 322 | func (lim *Limiter) SetBurst(newBurst int) { 323 | lim.SetBurstAt(time.Now(), newBurst) 324 | } 325 | 326 | // SetBurstAt sets a new burst size for the limiter. 327 | func (lim *Limiter) SetBurstAt(t time.Time, newBurst int) { 328 | lim.mu.Lock() 329 | defer lim.mu.Unlock() 330 | 331 | t, tokens := lim.advance(t) 332 | 333 | lim.last = t 334 | lim.tokens = tokens 335 | lim.burst = newBurst 336 | } 337 | 338 | // reserveN is a helper method for AllowN, ReserveN, and WaitN. 339 | // maxFutureReserve specifies the maximum reservation wait duration allowed. 340 | // reserveN returns Reservation, not *Reservation, to avoid allocation in AllowN and WaitN. 341 | func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration) Reservation { 342 | lim.mu.Lock() 343 | defer lim.mu.Unlock() 344 | 345 | if lim.limit == Inf { 346 | return Reservation{ 347 | ok: true, 348 | lim: lim, 349 | tokens: n, 350 | timeToAct: t, 351 | } 352 | } else if lim.limit == 0 { 353 | var ok bool 354 | if lim.burst >= n { 355 | ok = true 356 | lim.burst -= n 357 | } 358 | return Reservation{ 359 | ok: ok, 360 | lim: lim, 361 | tokens: lim.burst, 362 | timeToAct: t, 363 | } 364 | } 365 | 366 | t, tokens := lim.advance(t) 367 | 368 | // Calculate the remaining number of tokens resulting from the request. 369 | tokens -= float64(n) 370 | 371 | // Calculate the wait duration 372 | var waitDuration time.Duration 373 | if tokens < 0 { 374 | waitDuration = lim.limit.durationFromTokens(-tokens) 375 | } 376 | 377 | // Decide result 378 | ok := n <= lim.burst && waitDuration <= maxFutureReserve 379 | 380 | // Prepare reservation 381 | r := Reservation{ 382 | ok: ok, 383 | lim: lim, 384 | limit: lim.limit, 385 | } 386 | if ok { 387 | r.tokens = n 388 | r.timeToAct = t.Add(waitDuration) 389 | 390 | // Update state 391 | lim.last = t 392 | lim.tokens = tokens 393 | lim.lastEvent = r.timeToAct 394 | } 395 | 396 | return r 397 | } 398 | 399 | // advance calculates and returns an updated state for lim resulting from the passage of time. 400 | // lim is not changed. 401 | // advance requires that lim.mu is held. 402 | func (lim *Limiter) advance(t time.Time) (newT time.Time, newTokens float64) { 403 | last := lim.last 404 | if t.Before(last) { 405 | last = t 406 | } 407 | 408 | // Calculate the new number of tokens, due to time that passed. 409 | elapsed := t.Sub(last) 410 | delta := lim.limit.tokensFromDuration(elapsed) 411 | tokens := lim.tokens + delta 412 | if burst := float64(lim.burst); tokens > burst { 413 | tokens = burst 414 | } 415 | return t, tokens 416 | } 417 | 418 | // durationFromTokens is a unit conversion function from the number of tokens to the duration 419 | // of time it takes to accumulate them at a rate of limit tokens per second. 420 | func (limit Limit) durationFromTokens(tokens float64) time.Duration { 421 | if limit <= 0 { 422 | return InfDuration 423 | } 424 | seconds := tokens / float64(limit) 425 | return time.Duration(float64(time.Second) * seconds) 426 | } 427 | 428 | // tokensFromDuration is a unit conversion function from a time duration to the number of tokens 429 | // which could be accumulated during that duration at a rate of limit tokens per second. 430 | func (limit Limit) tokensFromDuration(d time.Duration) float64 { 431 | if limit <= 0 { 432 | return 0 433 | } 434 | return d.Seconds() * float64(limit) 435 | } 436 | -------------------------------------------------------------------------------- /internal/retry/retry.go: -------------------------------------------------------------------------------- 1 | package retry 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "math/rand" 7 | "time" 8 | 9 | "go.uber.org/zap" 10 | ) 11 | 12 | type RetryableFunc func() error 13 | 14 | type Config struct { 15 | maxAttempts int 16 | initialDelay time.Duration 17 | logger *zap.Logger 18 | maxJitter time.Duration 19 | } 20 | 21 | type Option func(*Config) 22 | 23 | // MaxAttempts is the total number of attempts including the initial attempt. 24 | func MaxAttempts(maxAttempts int) Option { 25 | return func(c *Config) { 26 | c.maxAttempts = maxAttempts 27 | } 28 | } 29 | 30 | // MaxJitter the maximum amount of time between [0, maxJitter] to add to a delay. 31 | func MaxJitter(maxJitter time.Duration) Option { 32 | return func(c *Config) { 33 | c.maxJitter = maxJitter 34 | } 35 | } 36 | 37 | func InitialDelay(initialDelay time.Duration) Option { 38 | return func(c *Config) { 39 | c.initialDelay = initialDelay 40 | } 41 | } 42 | 43 | func Logger(logger *zap.Logger) Option { 44 | return func(c *Config) { 45 | c.logger = logger 46 | } 47 | } 48 | 49 | // Do will perform N attempts to execute RetryableFunc. 50 | func Do(ctx context.Context, retryable RetryableFunc, opts ...Option) error { 51 | config := &Config{ 52 | maxAttempts: 3, 53 | initialDelay: 200 * time.Millisecond, 54 | maxJitter: 100 * time.Millisecond, 55 | logger: zap.NewNop(), 56 | } 57 | 58 | for _, opt := range opts { 59 | opt(config) 60 | } 61 | 62 | var attempt int 63 | 64 | delay := calcDelayForNextRetry(attempt, config.initialDelay, config.maxJitter) 65 | ticker := time.NewTicker(delay) 66 | 67 | var err error 68 | for attempt < config.maxAttempts { 69 | attempt++ 70 | err = retryable() 71 | if err != nil { 72 | retryDelay := calcDelayForNextRetry(attempt, config.initialDelay, config.maxJitter) 73 | config.logger.Debug("retry failed", 74 | zap.Error(err), 75 | zap.Int("attempt", attempt), 76 | zap.Duration("retry_delay", retryDelay)) 77 | 78 | // on the last attempt we return the error right away rather than waiting to return 79 | if attempt < config.maxAttempts { 80 | ticker.Reset(retryDelay) 81 | select { 82 | case <-ctx.Done(): 83 | e := err.Error() 84 | return fmt.Errorf("%v: %w", e, ctx.Err()) 85 | case <-ticker.C: 86 | } 87 | } 88 | } else { 89 | return nil 90 | } 91 | } 92 | 93 | return err 94 | } 95 | 96 | // calcDelayForNextRetry calculates the delay to wait for a given retry attempt _after_ the current attempt. 97 | func calcDelayForNextRetry(currentAttempt int, initialDelay time.Duration, maxJitter time.Duration) time.Duration { 98 | currentAttempt++ 99 | delay := time.Duration(currentAttempt) * initialDelay 100 | 101 | var jitter time.Duration 102 | if maxJitter > 0 { 103 | jitter = time.Duration(rand.Int63n(int64(maxJitter))) 104 | } 105 | 106 | return delay + jitter 107 | } 108 | -------------------------------------------------------------------------------- /internal/retry/retry_test.go: -------------------------------------------------------------------------------- 1 | package retry 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/mock" 11 | "go.uber.org/zap" 12 | ) 13 | 14 | func Test_calcDelayForNextRetry(t *testing.T) { 15 | type args struct { 16 | attempt int 17 | initialDelay time.Duration 18 | maxJitter time.Duration 19 | } 20 | tests := map[string]struct { 21 | args args 22 | want time.Duration 23 | }{ 24 | "attempt 0 (initial attempt)": { 25 | args: args{ 26 | attempt: 0, 27 | initialDelay: 3 * time.Second, 28 | maxJitter: 0, 29 | }, 30 | want: 3 * time.Second, 31 | }, 32 | "attempt 1": { 33 | args: args{ 34 | attempt: 1, 35 | initialDelay: 3 * time.Second, 36 | maxJitter: 0, 37 | }, 38 | want: 6 * time.Second, 39 | }, 40 | "attempt 2": { 41 | args: args{ 42 | attempt: 2, 43 | initialDelay: 3 * time.Second, 44 | maxJitter: 0, 45 | }, 46 | want: 9 * time.Second, 47 | }, 48 | "attempt 3": { 49 | args: args{ 50 | attempt: 3, 51 | initialDelay: 3 * time.Second, 52 | maxJitter: 0, 53 | }, 54 | want: 12 * time.Second, 55 | }, 56 | "attempt 4": { 57 | args: args{ 58 | attempt: 4, 59 | initialDelay: 3 * time.Second, 60 | maxJitter: 0, 61 | }, 62 | want: 15 * time.Second, 63 | }, 64 | } 65 | for tn, tt := range tests { 66 | t.Run(tn, func(t *testing.T) { 67 | got := calcDelayForNextRetry(tt.args.attempt, tt.args.initialDelay, tt.args.maxJitter) 68 | if got != tt.want { 69 | t.Errorf("calcRetryDelay() got1 = %v, want %v", got, tt.want) 70 | } 71 | }) 72 | } 73 | } 74 | 75 | func Test_calcRetryDelayWithJitter(t *testing.T) { 76 | got := calcDelayForNextRetry(0, 1*time.Second, 1*time.Second) 77 | assert.GreaterOrEqual(t, got, 1*time.Second) 78 | assert.LessOrEqual(t, got, 2*time.Second) 79 | } 80 | 81 | func TestMaxAttempts(t *testing.T) { 82 | tests := map[string]struct { 83 | maxAttempts int 84 | want *Config 85 | }{ 86 | "success": { 87 | maxAttempts: 2, 88 | want: &Config{ 89 | maxAttempts: 2, 90 | }, 91 | }, 92 | } 93 | for tn, tt := range tests { 94 | t.Run(tn, func(t *testing.T) { 95 | f := MaxAttempts(tt.maxAttempts) 96 | c := &Config{} 97 | f(c) 98 | assert.Equal(t, tt.want, c) 99 | }) 100 | } 101 | } 102 | 103 | func TestInitialDelay(t *testing.T) { 104 | tests := map[string]struct { 105 | initialDelay time.Duration 106 | want *Config 107 | }{ 108 | "success": { 109 | initialDelay: 2 * time.Minute, 110 | want: &Config{ 111 | initialDelay: 2 * time.Minute, 112 | }, 113 | }, 114 | } 115 | for tn, tt := range tests { 116 | t.Run(tn, func(t *testing.T) { 117 | f := InitialDelay(tt.initialDelay) 118 | c := &Config{} 119 | f(c) 120 | assert.Equal(t, tt.want, c) 121 | }) 122 | } 123 | } 124 | 125 | func TestLogger(t *testing.T) { 126 | tests := map[string]struct { 127 | logger *zap.Logger 128 | want *Config 129 | }{ 130 | "success": { 131 | logger: zap.NewNop(), 132 | want: &Config{ 133 | logger: zap.NewNop(), 134 | }, 135 | }, 136 | } 137 | for tn, tt := range tests { 138 | t.Run(tn, func(t *testing.T) { 139 | f := Logger(tt.logger) 140 | c := &Config{} 141 | f(c) 142 | assert.Equal(t, tt.want, c) 143 | }) 144 | } 145 | } 146 | 147 | func TestMaxJitter(t *testing.T) { 148 | tests := map[string]struct { 149 | maxJitter time.Duration 150 | want *Config 151 | }{ 152 | "success": { 153 | maxJitter: 1 * time.Second, 154 | want: &Config{ 155 | maxJitter: 1 * time.Second, 156 | }, 157 | }, 158 | } 159 | for tn, tt := range tests { 160 | t.Run(tn, func(t *testing.T) { 161 | f := MaxJitter(tt.maxJitter) 162 | c := &Config{} 163 | f(c) 164 | assert.Equal(t, tt.want, c) 165 | }) 166 | } 167 | } 168 | 169 | func TestDo(t *testing.T) { 170 | tests := map[string]struct { 171 | retryable RetryableFunc 172 | opts []Option 173 | wantErr error 174 | }{ 175 | "success": { 176 | retryable: func() RetryableFunc { 177 | m := &mockRetryable{} 178 | m.On("execute").Return(nil) 179 | 180 | return m.execute 181 | }(), 182 | opts: []Option{MaxAttempts(3), MaxJitter(0)}, 183 | wantErr: nil, 184 | }, 185 | "error on all retry": { 186 | retryable: func() RetryableFunc { 187 | m := &mockRetryable{} 188 | m.On("execute").Return(errors.New("failing")).Times(3) 189 | 190 | return m.execute 191 | }(), 192 | opts: []Option{MaxAttempts(3), MaxJitter(0)}, 193 | wantErr: errors.New("failing"), 194 | }, 195 | "success after error": { 196 | retryable: func() RetryableFunc { 197 | m := &mockRetryable{} 198 | m.On("execute").Return(errors.New("failing")).Times(2) 199 | m.On("execute").Return(nil) 200 | 201 | return m.execute 202 | }(), 203 | opts: []Option{MaxAttempts(3), MaxJitter(0)}, 204 | wantErr: nil, 205 | }, 206 | } 207 | for tn, tt := range tests { 208 | t.Run(tn, func(t *testing.T) { 209 | err := Do(context.TODO(), tt.retryable, tt.opts...) 210 | assert.Equal(t, tt.wantErr, err) 211 | }) 212 | } 213 | } 214 | 215 | type mockRetryable struct { 216 | mock.Mock 217 | } 218 | 219 | func (m *mockRetryable) execute() error { 220 | args := m.Called() 221 | 222 | return args.Error(0) 223 | } 224 | -------------------------------------------------------------------------------- /internal/server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "strconv" 8 | "strings" 9 | "time" 10 | 11 | "github.com/hashicorp/go-metrics" 12 | "github.com/hashicorp/go-multierror" 13 | "go.uber.org/zap" 14 | "golang.org/x/time/rate" 15 | muxt "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorilla/mux" 16 | ) 17 | 18 | const ( 19 | path = "path" 20 | method = "method" 21 | statusCode = "statusCode" 22 | userAgent = "userAgent" 23 | ) 24 | 25 | var ( 26 | statsdHTTPRequestRateLimited = []string{"http", "rate_limited"} 27 | ) 28 | 29 | // Server struct for an HTTP server. 30 | type Server struct { 31 | log *zap.Logger 32 | gracefulTimeout time.Duration 33 | srv *http.Server 34 | rateLimiter *rate.Limiter 35 | } 36 | 37 | // Config struct for configuring a Server. 38 | type Config struct { 39 | // BindAddress that the server binds to 40 | BindAddress string `yaml:"bind_address"` 41 | 42 | // GracefulTimeout duration for which the server gracefully wait for 43 | // existing connections to finish before exiting. 44 | GracefulTimeout time.Duration `yaml:"graceful_timeout"` 45 | 46 | // RateLimit specifies requests per second and burst with the format 47 | // ':'. Empty string means no limit. 48 | RateLimit string `yaml:"rate_limit"` 49 | } 50 | 51 | // Validate a Config. 52 | func (c *Config) Validate() error { 53 | var result error 54 | if strings.TrimSpace(c.BindAddress) == "" { 55 | result = multierror.Append(result, errors.New("bind address cannot be empty")) 56 | } 57 | 58 | if c.GracefulTimeout == 0 { 59 | result = multierror.Append(result, errors.New("graceful timeout must be greater than 0")) 60 | } 61 | 62 | if c.RateLimit != "" { 63 | _, _, err := parseRate(c.RateLimit) 64 | if err != nil { 65 | result = multierror.Append(result, err) 66 | } 67 | } 68 | 69 | return result 70 | } 71 | 72 | // NewServer creates a new Server. 73 | func NewServer(log *zap.Logger, config Config, router *muxt.Router, opts ...Option) *Server { 74 | srvConfig := newServerConfig() 75 | for _, fn := range opts { 76 | fn(srvConfig) 77 | } 78 | 79 | log = log.With(zap.String("address", config.BindAddress)) 80 | router.Use(recovery(log)) 81 | router.Use(logMiddleware(log, srvConfig)) 82 | 83 | if config.RateLimit != "" { 84 | reqs, burst, _ := parseRate(config.RateLimit) 85 | limiter := rate.NewLimiter(rate.Limit(reqs), burst) 86 | 87 | labels := []metrics.Label{{Name: "addr", Value: config.BindAddress}} 88 | router.Use(rateLimiterMiddleware(limiter, srvConfig.metricSink, labels)) 89 | } 90 | 91 | srv := &http.Server{ 92 | Addr: config.BindAddress, 93 | WriteTimeout: time.Second * 15, 94 | ReadTimeout: time.Second * 15, 95 | IdleTimeout: time.Second * 60, 96 | Handler: router, 97 | } 98 | 99 | return &Server{ 100 | log: log, 101 | gracefulTimeout: config.GracefulTimeout, 102 | srv: srv, 103 | } 104 | } 105 | 106 | // Run starts the server in a goroutine, and any errors returned by the server 107 | // at shutdown time will be passed to the provided chan parameter. Run 108 | // returns a shutdown callback that will safely stop the server when called. 109 | func (s *Server) Run(errs chan error) func() { 110 | s.log.Info("server starting") 111 | go func() { 112 | if err := s.srv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { 113 | s.log.Error("error running http server", zap.Error(err)) 114 | errs <- err 115 | } 116 | }() 117 | 118 | return s.shutdown 119 | } 120 | 121 | func (s *Server) shutdown() { 122 | ctx, cancel := context.WithTimeout(context.Background(), s.gracefulTimeout) 123 | defer cancel() 124 | 125 | s.log.Info("server stopping") 126 | if err := s.srv.Shutdown(ctx); err != nil { 127 | s.log.Error("error while shutting down", zap.Error(err)) 128 | 129 | return 130 | } 131 | 132 | s.log.Info("server stopped") 133 | } 134 | 135 | type serverConfig struct { 136 | ignoreRequest func(*http.Request) bool 137 | metricSink metrics.MetricSink 138 | } 139 | 140 | func newServerConfig() *serverConfig { 141 | return &serverConfig{ 142 | ignoreRequest: func(request *http.Request) bool { return false }, 143 | metricSink: &metrics.BlackholeSink{}, 144 | } 145 | } 146 | 147 | type Option func(*serverConfig) 148 | 149 | func WithIgnoreLoggingRequest(f func(*http.Request) bool) Option { 150 | return func(config *serverConfig) { 151 | config.ignoreRequest = f 152 | } 153 | } 154 | 155 | func WithMetricSink(metricSink metrics.MetricSink) Option { 156 | return func(config *serverConfig) { 157 | config.metricSink = metricSink 158 | } 159 | } 160 | 161 | func logMiddleware(log *zap.Logger, cfg *serverConfig) func(next http.Handler) http.Handler { 162 | return func(next http.Handler) http.Handler { 163 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 164 | //Sanitize out whitespace before logging 165 | //Failure to do this could allow users to generate false log lines through their user agent 166 | requestUserAgent := r.UserAgent() 167 | requestUserAgent = strings.ReplaceAll(requestUserAgent, "\n", "") 168 | requestUserAgent = strings.ReplaceAll(requestUserAgent, "\r", "") 169 | if !cfg.ignoreRequest(r) { 170 | log.Info("request", zap.String(path, r.RequestURI), zap.String(method, r.Method), zap.String(userAgent, requestUserAgent)) 171 | } 172 | 173 | wrappedW := &logMiddlewareHTTPResponseWriter{ResponseWriter: w} 174 | next.ServeHTTP(wrappedW, r) 175 | 176 | if !cfg.ignoreRequest(r) { 177 | log.Info("response", zap.String(path, r.RequestURI), zap.String(method, r.Method), zap.Int(statusCode, wrappedW.statusCode), zap.String(userAgent, requestUserAgent)) 178 | } 179 | }) 180 | } 181 | } 182 | 183 | type logMiddlewareHTTPResponseWriter struct { 184 | http.ResponseWriter 185 | statusCode int 186 | } 187 | 188 | func (w *logMiddlewareHTTPResponseWriter) WriteHeader(statusCode int) { 189 | w.statusCode = statusCode 190 | w.ResponseWriter.WriteHeader(statusCode) 191 | } 192 | 193 | func (w *logMiddlewareHTTPResponseWriter) Write(byt []byte) (int, error) { 194 | // write (& record) status line & headers if not previous done 195 | if w.statusCode == 0 { 196 | w.WriteHeader(http.StatusOK) 197 | } 198 | 199 | return w.ResponseWriter.Write(byt) 200 | } 201 | 202 | func recovery(log *zap.Logger) func(next http.Handler) http.Handler { 203 | return func(next http.Handler) http.Handler { 204 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 205 | defer func() { 206 | if rc := recover(); rc != nil { 207 | log.Error("handler panic", zap.Any("panic", rc)) 208 | 209 | w.WriteHeader(http.StatusInternalServerError) 210 | } 211 | }() 212 | 213 | next.ServeHTTP(w, r) 214 | }) 215 | } 216 | } 217 | 218 | func rateLimiterMiddleware(rateLimiter *rate.Limiter, metricSink metrics.MetricSink, labels []metrics.Label) func(next http.Handler) http.Handler { 219 | return func(next http.Handler) http.Handler { 220 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 221 | if !rateLimiter.Allow() { 222 | w.WriteHeader(http.StatusTooManyRequests) 223 | metricSink.IncrCounterWithLabels(statsdHTTPRequestRateLimited, 1.0, append([]metrics.Label{{Name: "path", Value: r.RequestURI}}, labels...)) 224 | return 225 | } 226 | 227 | next.ServeHTTP(w, r) 228 | }) 229 | } 230 | } 231 | 232 | func parseRate(rateStr string) (float64, int, error) { 233 | reqsStr, burstStr, found := strings.Cut(rateStr, ":") 234 | if !found { 235 | return 0, 0, errors.New("rate limit must be in the format ':'") 236 | } 237 | 238 | req, reqErr := strconv.ParseFloat(reqsStr, 64) 239 | burst, burstErr := strconv.Atoi(burstStr) 240 | if reqErr != nil || burstErr != nil { 241 | return 0, 0, errors.New("rate limit must be in the format ':'") 242 | } 243 | 244 | return req, burst, nil 245 | } 246 | -------------------------------------------------------------------------------- /internal/server/server_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | "go.uber.org/zap" 11 | "go.uber.org/zap/zapcore" 12 | "go.uber.org/zap/zaptest/observer" 13 | ) 14 | 15 | func Test_middleware_recovery(t *testing.T) { 16 | fn := func(w http.ResponseWriter, r *http.Request) { panic("test panic") } 17 | 18 | req := httptest.NewRequest("GET", "http://localhost/panic", nil) 19 | 20 | w := httptest.NewRecorder() 21 | 22 | assert.Panics(t, func() { 23 | http.HandlerFunc(fn).ServeHTTP(w, req) 24 | }, "panic handler panics") 25 | 26 | assert.NotPanics(t, func() { 27 | recovery(zap.NewNop())(http.HandlerFunc(fn)).ServeHTTP(w, req) 28 | }, "recovery handler does not panic") 29 | 30 | resp := w.Result() 31 | require.NoError(t, resp.Body.Close()) 32 | 33 | assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) 34 | } 35 | 36 | func Test_logMiddleware(t *testing.T) { 37 | tests := map[string]struct { 38 | cfg *serverConfig 39 | hdlr http.HandlerFunc 40 | wantStatus int 41 | assertLogs func(*testing.T, *observer.ObservedLogs) 42 | }{ 43 | "log request": { 44 | cfg: &serverConfig{ 45 | ignoreRequest: func(request *http.Request) bool { return false }, 46 | }, 47 | hdlr: func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) }, 48 | wantStatus: http.StatusNoContent, 49 | assertLogs: func(t *testing.T, logs *observer.ObservedLogs) { 50 | t.Helper() 51 | assert.Equal(t, 2, logs.Len()) 52 | assert.Equal(t, 1, logs.FilterField(zap.Int(statusCode, http.StatusNoContent)).Len()) 53 | }, 54 | }, 55 | "log request with implicit writeheader": { 56 | cfg: &serverConfig{ 57 | ignoreRequest: func(request *http.Request) bool { return false }, 58 | }, 59 | hdlr: func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte{}) }, 60 | wantStatus: http.StatusOK, 61 | assertLogs: func(t *testing.T, logs *observer.ObservedLogs) { 62 | t.Helper() 63 | assert.Equal(t, 2, logs.Len()) 64 | assert.Equal(t, 1, logs.FilterField(zap.Int(statusCode, http.StatusOK)).Len()) 65 | }, 66 | }, 67 | "skip log request": { 68 | cfg: &serverConfig{ 69 | ignoreRequest: func(request *http.Request) bool { return true }, 70 | }, 71 | hdlr: func(w http.ResponseWriter, r *http.Request) {}, 72 | wantStatus: http.StatusOK, 73 | assertLogs: func(t *testing.T, logs *observer.ObservedLogs) { 74 | t.Helper() 75 | assert.Equal(t, 0, logs.Len()) 76 | }, 77 | }, 78 | } 79 | for tn, tt := range tests { 80 | t.Run(tn, func(t *testing.T) { 81 | req := httptest.NewRequest("GET", "http://localhost/test", nil) 82 | w := httptest.NewRecorder() 83 | 84 | core, recorded := observer.New(zapcore.InfoLevel) 85 | logger := zap.New(core) 86 | 87 | logMiddleware(logger, tt.cfg)(tt.hdlr).ServeHTTP(w, req) 88 | resp := w.Result() 89 | require.NoError(t, resp.Body.Close()) 90 | assert.Equal(t, tt.wantStatus, resp.StatusCode) 91 | tt.assertLogs(t, recorded) 92 | }) 93 | } 94 | } 95 | 96 | func Test_parseRate(t *testing.T) { 97 | tests := map[string]struct { 98 | rateLimit string 99 | assert func(t *testing.T, rate float64, burst int, err error) 100 | }{ 101 | "valid rate": { 102 | rateLimit: "20:100", 103 | assert: func(t *testing.T, rate float64, burst int, err error) { 104 | t.Helper() 105 | assert.NoError(t, err) 106 | assert.Equal(t, float64(20), rate) 107 | assert.Equal(t, 100, burst) 108 | }, 109 | }, 110 | "bad burst": { 111 | rateLimit: "20:bad", 112 | assert: func(t *testing.T, rate float64, burst int, err error) { 113 | t.Helper() 114 | assert.Error(t, err) 115 | }, 116 | }, 117 | "bad rate": { 118 | rateLimit: "bad:100", 119 | assert: func(t *testing.T, rate float64, burst int, err error) { 120 | t.Helper() 121 | assert.Error(t, err) 122 | }, 123 | }, 124 | "no colon": { 125 | rateLimit: "12345", 126 | assert: func(t *testing.T, rate float64, burst int, err error) { 127 | t.Helper() 128 | assert.Error(t, err) 129 | }, 130 | }, 131 | } 132 | 133 | for tn, tt := range tests { 134 | t.Run(tn, func(t *testing.T) { 135 | rate, burst, err := parseRate(tt.rateLimit) 136 | tt.assert(t, rate, burst, err) 137 | }) 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /internal/vault/client.go: -------------------------------------------------------------------------------- 1 | package vault 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "crypto/tls" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "net/url" 12 | "os" 13 | "sync" 14 | "time" 15 | 16 | "github.com/DataDog/attache/internal/rate" 17 | "github.com/hashicorp/go-cleanhttp" 18 | "github.com/hashicorp/go-metrics" 19 | "github.com/hashicorp/go-retryablehttp" 20 | vaultapi "github.com/hashicorp/vault/api" 21 | "go.uber.org/zap" 22 | "golang.org/x/net/http2" 23 | vaulttrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/hashicorp/vault" 24 | "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" 25 | ) 26 | 27 | var ( 28 | statsdVaultRequest = []string{"vault", "request"} 29 | statsdVaultRequestHTTPRetry = []string{"vault", "request", "http", "retry"} 30 | rateLimitDelayed = metrics.Label{Name: "rate_limit_action", Value: "delayed"} 31 | rateLimitUntouched = metrics.Label{Name: "rate_limit_action", Value: "untouched"} 32 | ) 33 | 34 | const ( 35 | httpMethod = "http_method" 36 | httpPath = "http_path" 37 | httpStatusCode = "http_status_code" 38 | xVaultToken = "X-Vault-Token" 39 | xVaultRequest = "X-Vault-Request" 40 | ) 41 | 42 | type Client struct { 43 | address string 44 | retryableHTTPClient *retryablehttp.Client 45 | 46 | limiter *rate.Limiter 47 | 48 | // Timeout is for setting custom timeout parameter in the HttpClient 49 | timeout time.Duration 50 | 51 | token string 52 | 53 | modifyMutex sync.RWMutex 54 | 55 | metricSink metrics.MetricSink 56 | log *zap.Logger 57 | } 58 | 59 | type Config struct { 60 | Address string 61 | Token string 62 | Insecure bool 63 | MetricSink metrics.MetricSink 64 | Log *zap.Logger 65 | } 66 | 67 | func DefaultConfig() *Config { 68 | var vaultAddr string 69 | if envVaultAddr, ok := os.LookupEnv("VAULT_ADDR"); ok { 70 | vaultAddr = envVaultAddr 71 | } else { 72 | vaultAddr = "https://127.0.0.1:8500" 73 | } 74 | 75 | var vaultToken string 76 | if envVaultToken, ok := os.LookupEnv("VAULT_TOKEN"); ok { 77 | vaultToken = envVaultToken 78 | } 79 | 80 | return &Config{ 81 | Address: vaultAddr, 82 | Token: vaultToken, 83 | MetricSink: &metrics.BlackholeSink{}, 84 | Log: zap.NewNop(), 85 | } 86 | } 87 | 88 | func NewClient(config *Config) (*Client, error) { 89 | if config == nil { 90 | config = DefaultConfig() 91 | } 92 | 93 | httpClient := cleanhttp.DefaultPooledClient() 94 | transport, ok := httpClient.Transport.(*http.Transport) 95 | if !ok { 96 | return nil, fmt.Errorf("http transport %T not of type %T", httpClient.Transport, &http.Transport{}) 97 | } 98 | 99 | transport.TLSHandshakeTimeout = 10 * time.Second 100 | transport.TLSClientConfig = &tls.Config{ 101 | MinVersion: tls.VersionTLS12, 102 | } 103 | if err := http2.ConfigureTransport(transport); err != nil { 104 | return nil, err 105 | } 106 | 107 | if config.Insecure { 108 | clientTLSConfig := transport.TLSClientConfig 109 | clientTLSConfig.InsecureSkipVerify = true 110 | } 111 | 112 | httpClient = vaulttrace.WrapHTTPClient(httpClient, vaulttrace.WithAnalytics(true)) 113 | 114 | return &Client{ 115 | address: config.Address, 116 | token: config.Token, 117 | retryableHTTPClient: &retryablehttp.Client{ 118 | HTTPClient: httpClient, 119 | RetryWaitMin: time.Millisecond * 1000, 120 | RetryWaitMax: time.Millisecond * 1500, 121 | RetryMax: 2, 122 | Backoff: retryablehttp.LinearJitterBackoff, 123 | CheckRetry: ObservedRetryPolicy(config.Log, config.MetricSink), //nolint:bodyclose 124 | ErrorHandler: retryablehttp.PassthroughErrorHandler, 125 | }, 126 | timeout: time.Second * 60, 127 | modifyMutex: sync.RWMutex{}, 128 | metricSink: config.MetricSink, 129 | log: config.Log, 130 | }, nil 131 | } 132 | 133 | func (d *Client) SetToken(token string) { 134 | d.modifyMutex.Lock() 135 | defer d.modifyMutex.Unlock() 136 | 137 | d.token = token 138 | } 139 | 140 | func (d *Client) Token() string { 141 | d.modifyMutex.RLock() 142 | defer d.modifyMutex.RUnlock() 143 | 144 | return d.token 145 | } 146 | 147 | func (d *Client) SetLimiter(limiter *rate.Limiter) { 148 | d.modifyMutex.Lock() 149 | defer d.modifyMutex.Unlock() 150 | 151 | d.limiter = limiter 152 | } 153 | 154 | func (d *Client) Read(ctx context.Context, p string) (*vaultapi.Secret, error) { 155 | return d.ReadWithData(ctx, p, nil) 156 | } 157 | 158 | func (d *Client) ReadWithData(ctx context.Context, p string, data map[string][]string) (*vaultapi.Secret, error) { 159 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.address+"/v1/"+p, http.NoBody) 160 | if err != nil { 161 | return nil, err 162 | } 163 | 164 | var values url.Values 165 | for k, v := range data { 166 | if values == nil { 167 | values = make(url.Values) 168 | } 169 | for _, val := range v { 170 | values.Add(k, val) 171 | } 172 | } 173 | req.URL.RawQuery = values.Encode() 174 | 175 | resp, err := d.do(req) 176 | if resp != nil { 177 | defer resp.Body.Close() 178 | } 179 | if resp != nil && resp.StatusCode == 404 { 180 | secret, parseErr := vaultapi.ParseSecret(resp.Body) 181 | switch parseErr { //nolint:errorlint 182 | case nil: 183 | case io.EOF: 184 | return nil, nil 185 | default: 186 | return nil, parseErr 187 | } 188 | if secret != nil && (len(secret.Warnings) > 0 || len(secret.Data) > 0) { 189 | return secret, nil 190 | } 191 | return nil, nil 192 | } 193 | if err != nil { 194 | return nil, err 195 | } 196 | 197 | return vaultapi.ParseSecret(resp.Body) 198 | } 199 | 200 | func (d *Client) Write(ctx context.Context, p string, data map[string]interface{}) (*vaultapi.Secret, error) { 201 | marshalled, err := json.Marshal(data) 202 | if err != nil { 203 | return nil, err 204 | } 205 | 206 | req, err := http.NewRequestWithContext(ctx, http.MethodPut, d.address+"/v1/"+p, bytes.NewReader(marshalled)) 207 | if err != nil { 208 | return nil, err 209 | } 210 | 211 | resp, err := d.do(req) 212 | if resp != nil { 213 | defer resp.Body.Close() 214 | } 215 | if resp != nil && resp.StatusCode == 404 { 216 | secret, parseErr := vaultapi.ParseSecret(resp.Body) 217 | switch parseErr { //nolint:errorlint 218 | case nil: 219 | case io.EOF: 220 | return nil, nil 221 | default: 222 | return nil, parseErr 223 | } 224 | if secret != nil && (len(secret.Warnings) > 0 || len(secret.Data) > 0) { 225 | return secret, err 226 | } 227 | } 228 | if err != nil { 229 | return nil, err 230 | } 231 | 232 | return vaultapi.ParseSecret(resp.Body) 233 | } 234 | 235 | func (d *Client) Delete(ctx context.Context, p string) (*vaultapi.Secret, error) { 236 | return d.DeleteWithData(ctx, p, nil) 237 | } 238 | 239 | func (d *Client) DeleteWithData(ctx context.Context, p string, data map[string][]string) (*vaultapi.Secret, error) { 240 | req, err := http.NewRequestWithContext(ctx, http.MethodDelete, d.address+"/v1/"+p, http.NoBody) 241 | if err != nil { 242 | return nil, err 243 | } 244 | 245 | var values url.Values 246 | for k, v := range data { 247 | if values == nil { 248 | values = make(url.Values) 249 | } 250 | for _, val := range v { 251 | values.Add(k, val) 252 | } 253 | } 254 | req.URL.RawQuery = values.Encode() 255 | 256 | resp, err := d.do(req) 257 | if resp != nil { 258 | defer resp.Body.Close() 259 | } 260 | if resp != nil && resp.StatusCode == 404 { 261 | secret, parseErr := vaultapi.ParseSecret(resp.Body) 262 | switch parseErr { //nolint:errorlint 263 | case nil: 264 | case io.EOF: 265 | return nil, nil 266 | default: 267 | return nil, parseErr 268 | } 269 | if secret != nil && (len(secret.Warnings) > 0 || len(secret.Data) > 0) { 270 | return secret, nil 271 | } 272 | return nil, nil 273 | } 274 | if err != nil { 275 | return nil, err 276 | } 277 | 278 | return vaultapi.ParseSecret(resp.Body) 279 | } 280 | 281 | func (d *Client) List(ctx context.Context, p string) (*vaultapi.Secret, error) { 282 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.address+"/v1/"+p, http.NoBody) 283 | if err != nil { 284 | return nil, err 285 | } 286 | q := req.URL.Query() 287 | q.Add("list", "true") 288 | req.URL.RawQuery = q.Encode() 289 | 290 | resp, err := d.do(req) 291 | if resp != nil { 292 | defer resp.Body.Close() 293 | } 294 | if resp != nil && resp.StatusCode == 404 { 295 | secret, parseErr := vaultapi.ParseSecret(resp.Body) 296 | switch parseErr { //nolint:errorlint 297 | case nil: 298 | case io.EOF: 299 | return nil, nil 300 | default: 301 | return nil, parseErr 302 | } 303 | if secret != nil && (len(secret.Warnings) > 0 || len(secret.Data) > 0) { 304 | return secret, nil 305 | } 306 | return nil, nil 307 | } 308 | if err != nil { 309 | return nil, err 310 | } 311 | 312 | return vaultapi.ParseSecret(resp.Body) 313 | } 314 | 315 | func (d *Client) Raw(ctx context.Context, method string, path string, body any) (*vaultapi.Secret, error) { 316 | marshalled, err := json.Marshal(body) 317 | if err != nil { 318 | return nil, err 319 | } 320 | 321 | req, err := http.NewRequestWithContext(ctx, method, d.address+"/v1/"+path, bytes.NewReader(marshalled)) 322 | if err != nil { 323 | return nil, err 324 | } 325 | 326 | resp, err := d.do(req) 327 | if err != nil { 328 | return nil, err 329 | } 330 | 331 | if resp != nil { 332 | defer resp.Body.Close() 333 | } 334 | 335 | return vaultapi.ParseSecret(resp.Body) 336 | } 337 | 338 | func (d *Client) do(req *http.Request) (*http.Response, error) { 339 | d.modifyMutex.RLock() 340 | limiter := d.limiter 341 | token := d.token 342 | timeout := d.timeout 343 | d.modifyMutex.RUnlock() 344 | 345 | reqCtx := req.Context() 346 | if limiter != nil { 347 | limiterSpan, limiterCtx := tracer.StartSpanFromContext(reqCtx, "rate.Limiter") 348 | err := limiter.WaitNWithCallback(limiterCtx, 1, func(delay time.Duration) { 349 | escapedPath := req.URL.EscapedPath() 350 | metricTags := []metrics.Label{ 351 | {Name: httpPath, Value: escapedPath}, 352 | {Name: httpMethod, Value: req.Method}, 353 | } 354 | if delay > 0 { 355 | d.log.Warn("Vault request delayed due to rate limiting", 356 | zap.Int64("delay_milliseconds", delay.Milliseconds()), 357 | zap.String(httpMethod, req.Method), 358 | zap.String(httpPath, escapedPath), 359 | ) 360 | 361 | metricTags = append(metricTags, rateLimitDelayed) 362 | } else { 363 | metricTags = append(metricTags, rateLimitUntouched) 364 | } 365 | 366 | d.metricSink.IncrCounterWithLabels(statsdVaultRequest, 1.0, metricTags) 367 | }) 368 | limiterSpan.Finish() 369 | if err != nil { 370 | return nil, err 371 | } 372 | } 373 | 374 | req.Header.Add(xVaultRequest, "true") 375 | req.Header.Add(xVaultToken, token) 376 | 377 | retryableReq, err := retryablehttp.FromRequest(req) 378 | if err != nil { 379 | return nil, err 380 | } 381 | 382 | if timeout != 0 { 383 | // Note: we purposefully do not call cancel manually. The reason is 384 | // when canceled, the request.Body will EOF when reading due to the way 385 | // it streams data in. Cancel will still be run when the timeout is 386 | // hit, so this doesn't really harm anything. 387 | ctx, _ := context.WithTimeout(reqCtx, timeout) //nolint:govet 388 | _ = req.WithContext(ctx) 389 | } 390 | 391 | resp, err := d.retryableHTTPClient.Do(retryableReq) 392 | if err != nil { 393 | return nil, err 394 | } 395 | 396 | return resp, nil 397 | } 398 | 399 | func ObservedRetryPolicy(log *zap.Logger, metricSink metrics.MetricSink) func(ctx context.Context, resp *http.Response, err error) (bool, error) { 400 | return func(ctx context.Context, resp *http.Response, err error) (bool, error) { 401 | retry, err := vaultapi.DefaultRetryPolicy(ctx, resp, err) 402 | if retry { 403 | logFields, metricLabels := make([]zap.Field, 0, 3), make([]metrics.Label, 0, 2) 404 | 405 | if resp != nil && resp.Request != nil { 406 | escapedPath := resp.Request.URL.EscapedPath() 407 | 408 | logFields = append(logFields, 409 | zap.String(httpMethod, resp.Request.Method), 410 | zap.String(httpPath, escapedPath), 411 | zap.Int(httpStatusCode, resp.StatusCode), 412 | ) 413 | 414 | metricLabels = append(metricLabels, 415 | metrics.Label{Name: httpPath, Value: escapedPath}, 416 | metrics.Label{Name: httpMethod, Value: resp.Request.Method}, 417 | ) 418 | } 419 | 420 | log.Warn("retrying Vault request due to failure", logFields...) 421 | 422 | metricSink.IncrCounterWithLabels(statsdVaultRequestHTTPRetry, 1.0, metricLabels) 423 | } 424 | 425 | return retry, err 426 | } 427 | } 428 | -------------------------------------------------------------------------------- /internal/vault/client_test.go: -------------------------------------------------------------------------------- 1 | package vault 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "sync" 10 | "testing" 11 | 12 | vaultapi "github.com/hashicorp/vault/api" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func TestRead(t *testing.T) { 18 | t.Run("reads secret", func(t *testing.T) { 19 | handler := func(w http.ResponseWriter, r *http.Request) { 20 | assert.Equal(t, http.MethodGet, r.Method) 21 | assert.Equal(t, "/v1/test/path/read", r.URL.Path) 22 | assert.Empty(t, r.URL.RawQuery) 23 | _, err := r.Body.Read(nil) 24 | assert.Equal(t, io.EOF, err) // no request body 25 | 26 | err = json.NewEncoder(w).Encode(&vaultapi.Secret{ 27 | Data: map[string]interface{}{ 28 | "secretkey": "secretvalue", 29 | }, 30 | }) 31 | require.NoError(t, err) 32 | } 33 | 34 | client, done := testClientServer(t, DefaultConfig(), handler) 35 | defer done() 36 | 37 | want := map[string]interface{}{"secretkey": "secretvalue"} 38 | 39 | secret, err := client.ReadWithData(context.Background(), "test/path/read", nil) 40 | assert.NoError(t, err) 41 | assert.NotNil(t, secret) 42 | assert.Equal(t, want, secret.Data) 43 | }) 44 | 45 | t.Run("encodes data as query string", func(t *testing.T) { 46 | handler := func(w http.ResponseWriter, r *http.Request) { 47 | assert.Equal(t, "testvalue", r.URL.Query().Get("testkey")) 48 | _, err := r.Body.Read(nil) 49 | assert.Equal(t, io.EOF, err) // no request body 50 | } 51 | 52 | client, done := testClientServer(t, DefaultConfig(), handler) 53 | defer done() 54 | 55 | data := map[string][]string{"testkey": {"testvalue"}} 56 | _, err := client.ReadWithData(context.Background(), "test/path/read", data) 57 | assert.NoError(t, err) 58 | }) 59 | 60 | t.Run("canceled context", func(t *testing.T) { 61 | var waitForHandler sync.WaitGroup 62 | waitForHandler.Add(1) 63 | 64 | handler := func(w http.ResponseWriter, r *http.Request) { 65 | waitForHandler.Done() 66 | <-r.Context().Done() 67 | } 68 | 69 | client, done := testClientServer(t, DefaultConfig(), handler) 70 | defer done() 71 | 72 | ctx, cancel := context.WithCancel(context.Background()) 73 | 74 | go func() { 75 | // cancel after handler entry 76 | waitForHandler.Wait() 77 | cancel() 78 | }() 79 | 80 | _, err := client.ReadWithData(ctx, "testpath", nil) 81 | assert.Equal(t, context.Canceled, err) 82 | }) 83 | } 84 | 85 | func TestWrite(t *testing.T) { 86 | t.Run("writes secret", func(t *testing.T) { 87 | handler := func(w http.ResponseWriter, r *http.Request) { 88 | assert.Equal(t, http.MethodPut, r.Method) 89 | assert.Equal(t, "/v1/test/path/write", r.URL.Path) 90 | 91 | body, err := io.ReadAll(r.Body) 92 | assert.NoError(t, err) 93 | assert.JSONEq(t, `{"testkey": "testvalue"}`, string(body)) 94 | 95 | err = json.NewEncoder(w).Encode(&vaultapi.Secret{ 96 | Data: map[string]interface{}{ 97 | "secretkey": "secretvalue", 98 | }, 99 | }) 100 | require.NoError(t, err) 101 | } 102 | 103 | client, done := testClientServer(t, DefaultConfig(), handler) 104 | defer done() 105 | 106 | want := map[string]interface{}{"secretkey": "secretvalue"} 107 | 108 | secret, err := client.Write(context.Background(), "test/path/write", map[string]interface{}{"testkey": "testvalue"}) 109 | assert.NoError(t, err) 110 | 111 | assert.NoError(t, err) 112 | assert.NotNil(t, secret) 113 | assert.Equal(t, want, secret.Data) 114 | }) 115 | } 116 | 117 | func TestDelete(t *testing.T) { 118 | handler := func(w http.ResponseWriter, r *http.Request) { 119 | assert.Equal(t, http.MethodDelete, r.Method) 120 | assert.Equal(t, "/v1/test/path/delete", r.URL.Path) 121 | 122 | body, err := io.ReadAll(r.Body) 123 | assert.NoError(t, err) 124 | assert.Empty(t, body) 125 | 126 | respData := map[string]interface{}{} 127 | for k, v := range r.URL.Query() { 128 | respData[k] = v 129 | } 130 | 131 | err = json.NewEncoder(w).Encode(&vaultapi.Secret{ 132 | Data: respData, 133 | }) 134 | require.NoError(t, err) 135 | } 136 | 137 | t.Run("delete secret", func(t *testing.T) { 138 | client, done := testClientServer(t, DefaultConfig(), handler) 139 | defer done() 140 | 141 | secret, err := client.Delete(context.Background(), "test/path/delete") 142 | assert.NoError(t, err) 143 | 144 | assert.NoError(t, err) 145 | assert.NotNil(t, secret) 146 | assert.Equal(t, map[string]interface{}{}, secret.Data) 147 | }) 148 | 149 | t.Run("delete secret with data", func(t *testing.T) { 150 | client, done := testClientServer(t, DefaultConfig(), handler) 151 | defer done() 152 | 153 | data := map[string][]string{"testkey": {"testvalue"}} 154 | secret, err := client.DeleteWithData(context.Background(), "test/path/delete", data) 155 | assert.NoError(t, err) 156 | 157 | expected := map[string]interface{}{ 158 | "testkey": []interface{}{"testvalue"}, 159 | } 160 | assert.NoError(t, err) 161 | assert.NotNil(t, secret) 162 | assert.Equal(t, expected, secret.Data) 163 | }) 164 | } 165 | 166 | func TestList(t *testing.T) { 167 | expected := map[string]interface{}{"keys": []interface{}{"the", "list", "of", "strings"}} 168 | handler := func(w http.ResponseWriter, r *http.Request) { 169 | assert.Equal(t, http.MethodGet, r.Method) 170 | assert.Equal(t, "true", r.URL.Query().Get("list")) 171 | assert.Equal(t, "/v1/test/path/list", r.URL.Path) 172 | 173 | body, err := io.ReadAll(r.Body) 174 | assert.NoError(t, err) 175 | assert.Empty(t, body) 176 | 177 | err = json.NewEncoder(w).Encode(&vaultapi.Secret{ 178 | Data: expected, 179 | }) 180 | require.NoError(t, err) 181 | } 182 | 183 | t.Run("list", func(t *testing.T) { 184 | client, done := testClientServer(t, DefaultConfig(), handler) 185 | defer done() 186 | 187 | secret, err := client.List(context.Background(), "test/path/list") 188 | assert.NoError(t, err) 189 | 190 | assert.NoError(t, err) 191 | assert.NotNil(t, secret) 192 | assert.Equal(t, expected, secret.Data) 193 | }) 194 | } 195 | 196 | func testClientServer(t *testing.T, cfg *Config, handler http.HandlerFunc) (*Client, func()) { 197 | t.Helper() 198 | 199 | ts := httptest.NewServer(handler) 200 | cfg.Address = ts.URL 201 | 202 | client, err := NewClient(cfg) 203 | require.NoError(t, err) 204 | require.NotNil(t, client) 205 | 206 | return client, ts.Close 207 | } 208 | -------------------------------------------------------------------------------- /scripts/add-license-copyright.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | # This script adds the 'Copyright' column to the auto-generated LICENSE-3rdparty file 4 | with open('LICENSE-3rdparty.csv', 'r') as file: 5 | thirdparty_licenses = file.read().rstrip('\n').split('\n') 6 | 7 | 8 | 9 | # Input format is URL, link to license, License type 10 | repos = { 11 | 'golang.org/x/': 'The Go Authors', 12 | 'github.com/hashicorp/': 'HashiCorp, Inc.', 13 | 'github.com/DataDog/': 'Datadog, Inc.', 14 | 'github.com/uber-go/': 'Uber Technologies, Inc.', 15 | 'go.uber.org/': 'Uber Technologies, Inc.', 16 | 'github.com/aws/': 'Amazon.com, Inc. or its affiliates', 17 | 'github.com/cenkalti/backoff/': 'Cenk Altı', 18 | 'github.com/cespare/xxhash/': 'Caleb Spare', 19 | 'github.com/dustin/go-humanize': 'Dustin Sallings ', 20 | 'github.com/ebitengine/purego': 'Ebitengine', 21 | 'github.com/go-jose/go-jose': 'Square Inc. and The Go Authors', 22 | 'github.com/golang/protobuf': 'The Go Authors', 23 | 'github.com/google/uuid': 'Google Inc.', 24 | 'google.golang.org/': 'Google Inc.', 25 | 'github.com/gorilla/mux': 'The Gorilla Authors', 26 | 'github.com/mitchellh/': 'Mitchell Hashimoto', 27 | 'github.com/outcaste-io/': 'Outcaste LLC', 28 | 'github.com/philhofer/fwd': 'Phil Hofer', 29 | 'github.com/pkg/errors': 'Dave Cheney ', 30 | 'github.com/ryanuber/go-glob': 'Ryan Uber', 31 | 'github.com/secure-systems-lab/go-securesystemslib': 'NYU Secure Systems Lab', 32 | 'github.com/tinylib/msgp': 'Philip Hofer and The Go Authors', 33 | 'gopkg.in/DataDog/dd-trace-go': 'Datadog, Inc.' 34 | } 35 | for dependency in thirdparty_licenses: 36 | package, license_url, license_type = dependency.strip().split(',') 37 | author = None 38 | for repo_pattern, candidate_author in repos.items(): 39 | if package.startswith(repo_pattern): 40 | author = candidate_author 41 | break 42 | 43 | if author is None: 44 | raise ValueError(f'No author found for {package}') 45 | 46 | print(f'{package},{license_url},{license_type},"{author}"') 47 | --------------------------------------------------------------------------------