├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── Bug_Report.md │ ├── Feature_Request.md │ └── config.yml ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ └── pull_request.yml ├── .golangci.yml ├── .markdownlint.yml ├── .semgrep └── imports.yml ├── .semgrepignore ├── CHANGELOG.md ├── GNUmakefile ├── LICENSE ├── META.d └── _summary.yaml ├── README.md ├── aws_config.go ├── aws_config_test.go ├── awsauth.go ├── awsauth_test.go ├── clients.go ├── config.go ├── configtesting ├── config.go └── file_parsing.go ├── credentials.go ├── credentials_test.go ├── diag ├── diagnostic.go ├── diagnostics.go ├── diagnostics_test.go ├── error_diagnostic.go ├── error_diagnostic_test.go ├── native_error_diagnostic.go ├── severity.go ├── warning_diagnostic.go └── warning_diagnostic_test.go ├── endpoints.go ├── endpoints ├── endpoints.go ├── endpoints_gen.go ├── generate.go ├── partition.go ├── partition_test.go ├── region.go └── service.go ├── errors.go ├── errors_test.go ├── go.mod ├── go.sum ├── go.work ├── http_client.go ├── http_client_test.go ├── internal ├── awsconfig │ └── resolvers.go ├── config │ ├── apn_info.go │ ├── config.go │ ├── config_test.go │ ├── errors.go │ └── user_agent.go ├── constants │ └── constants.go ├── errs │ └── errs.go ├── expand │ ├── filepath.go │ └── filepath_test.go ├── generate │ ├── common │ │ └── generator.go │ └── endpoints │ │ ├── main.go │ │ └── output.go.gtpl ├── slices │ └── slices.go └── test │ ├── context.go │ ├── http_client.go │ ├── user_agent.go │ └── validator.go ├── logger.go ├── logging ├── attributes.go ├── aws.go ├── aws_test.go ├── context.go ├── hc_logger.go ├── hc_logger_test.go ├── http.go ├── logger.go ├── logger_test.go ├── mask.go ├── null_logger.go ├── tf_logger.go └── tf_logger_test.go ├── mockdata └── mocks.go ├── s3attributes_test.go ├── servicemocks ├── mock.go ├── pem_file.go └── setup.go ├── tfawserr ├── awserr.go └── awserr_test.go ├── tools ├── go.mod ├── go.sum └── main.go ├── user_agent.go ├── user_agent_test.go ├── useragent ├── context.go └── context_test.go ├── v2 └── awsv1shim │ ├── credentials.go │ ├── credentials_test.go │ ├── go.mod │ ├── go.sum │ ├── http_client.go │ ├── http_client_test.go │ ├── logger.go │ ├── mockdata │ └── mocks.go │ ├── resolvers.go │ ├── session.go │ ├── session_test.go │ ├── tfawserr │ ├── awserr.go │ └── awserr_test.go │ └── user_agent.go └── validation ├── json.go ├── json_test.go ├── region.go └── region_test.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Default owner for all pull requests 2 | * @hashicorp/terraform-aws 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Bug_Report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug Report 3 | about: If something isn't working as expected 🤔. 4 | 5 | --- 6 | 7 | 8 | 9 | ### Community Note 10 | 11 | * Please vote on this issue by adding a 👍 [reaction](https://blog.github.com/2016-03-10-add-reactions-to-pull-requests-issues-and-comments/) to the original issue to help the community and maintainers prioritize this request 12 | * Please do not leave "+1" or other comments that do not add relevant new information or questions, they generate extra noise for issue followers and do not help prioritize the request 13 | * If you are interested in working on this issue or have submitted a pull request, please leave a comment 14 | 15 | 16 | 17 | ### Environment and Versions 18 | 19 | * Terraform Executor: Local workstation / EC2 Instance / ECS / Terraform Cloud or Enterprise / Other (please specify) 20 | * If outside Terraform Cloud/Enterprise, which operating system and version: Linux/macOS/Windows 21 | * Terraform CLI version: `terraform -v` 22 | * Terraform AWS Provider version: `terraform providers -v` 23 | * Terraform Backend/Provider Configuration: 24 | 25 | ```hcl 26 | # Copy-paste your Terraform S3 Backend or AWS Provider configurations here 27 | ``` 28 | 29 | * AWS environment variables (if any): 30 | 31 | ```sh 32 | AWS_XXX=example 33 | ``` 34 | 35 | * AWS configuration files (if any): 36 | 37 | ```txt 38 | # Copy-paste your AWS shared configuration file contents here 39 | ``` 40 | 41 | ### Expected Behavior 42 | 43 | What should have happened? 44 | 45 | ### Actual Behavior 46 | 47 | What actually happened? 48 | 49 | ### Steps to Reproduce 50 | 51 | 52 | 53 | 1. `terraform init` 54 | 1. `terraform apply` 55 | 56 | ### Debug Output 57 | 58 | Please provide a link to a GitHub Gist containing the complete debug output. Please do NOT paste the debug output in the issue; just paste a link to the Gist. 59 | 60 | To obtain the debug output, see the [Terraform documentation on debugging](https://www.terraform.io/docs/internals/debugging.html). 61 | 62 | ### References 63 | 64 | Are there any other GitHub issues (open or closed) or pull requests that should be linked here? Terraform or AWS documentation? 65 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Feature_Request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature Request 3 | about: I have a suggestion (and might want to implement myself 🙂)! 4 | labels: enhancement 5 | --- 6 | 7 | 8 | 9 | ### Community Note 10 | 11 | * Please vote on this issue by adding a 👍 [reaction](https://blog.github.com/2016-03-10-add-reactions-to-pull-requests-issues-and-comments/) to the original issue to help the community and maintainers prioritize this request 12 | * Please do not leave "+1" or other comments that do not add relevant new information or questions, they generate extra noise for issue followers and do not help prioritize the request 13 | * If you are interested in working on this issue or have submitted a pull request, please leave a comment 14 | 15 | 16 | 17 | ### Description 18 | 19 | Please leave a helpful description of the feature request here. 20 | 21 | ### Potential Library Implementation 22 | 23 | ```go 24 | // Please provide an example Go implementation showing 25 | // how this functionality may look within this codebase. 26 | ``` 27 | 28 | ### Potential Terraform Backend/Provider Configuration 29 | 30 | ```hcl 31 | # Please provide an example Terraform configuration showing 32 | # how this functionality will be implemented, if relevant. 33 | ``` 34 | 35 | ### References 36 | 37 | Are there any other GitHub issues (open or closed) or pull requests that should be linked here? Terraform or AWS documentation? 38 | 39 | Please ensure that Terraform S3 Backend and Terraform AWS Provider GitHub feature requests are created and linked here. 40 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) HashiCorp, Inc. 2 | # SPDX-License-Identifier: MPL-2.0 3 | 4 | blank_issues_enabled: false 5 | contact_links: 6 | - name: Terraform AWS Provider Bug Reports and Feature Requests 7 | url: https://github.com/terraform-providers/terraform-provider-aws/issues/new/choose 8 | about: Terraform AWS Provider specific issues should be raised in the Terraform AWS Provider codebase. 9 | - name: Terraform AWS Provider Questions 10 | url: https://discuss.hashicorp.com/c/terraform-providers/tf-aws 11 | about: GitHub issues in this repository are only intended for bug reports and feature requests. Other issues will be closed. Please ask and answer questions through the Terraform AWS Provider Community Forum. 12 | - name: Terraform S3 Backend Bug Reports and Feature Requests 13 | url: https://github.com/hashicorp/terraform/issues/new/choose 14 | about: Terraform S3 Backend specific issues should be raised in the Terraform Core codebase. 15 | - name: Terraform S3 Backend, Configuration Language, or Workflow Questions 16 | url: https://discuss.hashicorp.com/c/terraform-core 17 | about: Please ask and answer S3 Backend, language, or workflow related questions through the Terraform Core Community Forum. 18 | - name: Security Vulnerability 19 | url: https://www.hashicorp.com/security.html 20 | about: Please report security vulnerabilities responsibly. 21 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### Community Note 4 | 5 | * Please vote on this pull request by adding a 👍 [reaction](https://blog.github.com/2016-03-10-add-reactions-to-pull-requests-issues-and-comments/) to the original pull request comment to help the community and maintainers prioritize this request 6 | * Please do not leave "+1" or other comments that do not add relevant new information or questions, they generate extra noise for pull request followers and do not help prioritize the request 7 | 8 | 9 | 10 | Reference: #0000 11 | 12 | 13 | 14 | ## Rollback Plan 15 | 16 | If a change needs to be reverted, we will publish an updated version of the library. 17 | 18 | ## Changes to Security Controls 19 | 20 | Are there any changes to security controls (access controls, encryption, logging) in this pull request? If so, explain. 21 | 22 | ## Description 23 | 24 | Please briefly describe the changes proposed in this pull request. -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) HashiCorp, Inc. 2 | # SPDX-License-Identifier: MPL-2.0 3 | 4 | version: 2 5 | updates: 6 | - package-ecosystem: "github-actions" 7 | directory: "/" 8 | schedule: 9 | interval: "daily" 10 | 11 | - package-ecosystem: "gomod" 12 | directories: 13 | - "/" 14 | - "/tools" 15 | - "/v2/awsv1shim" 16 | groups: 17 | aws-sdk-go-v1: 18 | patterns: 19 | - "github.com/aws/aws-sdk-go" 20 | aws-sdk-go-v2: 21 | patterns: 22 | - "github.com/aws/aws-sdk-go-v2" 23 | - "github.com/aws/aws-sdk-go-v2/*" 24 | opentelemetry: 25 | patterns: 26 | - "go.opentelemetry.io/otel" 27 | - "go.opentelemetry.io/contrib/*" 28 | ignore: 29 | # aws/smithy-go should only be updated via aws/aws-sdk-go-v2 30 | - dependency-name: "github.com/aws/smithy-go" 31 | - dependency-name: "golang.org/x/tools" 32 | schedule: 33 | interval: "daily" 34 | -------------------------------------------------------------------------------- /.github/workflows/pull_request.yml: -------------------------------------------------------------------------------- 1 | name: Pull Request Checks 2 | on: [pull_request] 3 | 4 | jobs: 5 | go_mod: 6 | name: go mod 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | path: [".", "tools", "v2/awsv1shim"] 11 | steps: 12 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 13 | 14 | - uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 15 | with: 16 | go-version-file: ./go.mod 17 | 18 | - name: go mod 19 | working-directory: ${{ matrix.path }} 20 | run: | 21 | go mod tidy 22 | git diff --exit-code --quiet -- go.mod go.sum || \ 23 | (echo; echo "Unexpected difference in ${{ matrix.path }}/go.mod or ${{ matrix.path }}/go.sum files. Run 'go mod tidy' command or revert any go.mod/go.sum changes and commit."; exit 1) 24 | 25 | go_work_sync: 26 | name: go work sync 27 | runs-on: ubuntu-latest 28 | strategy: 29 | matrix: 30 | path: [".", "tools", "v2/awsv1shim"] 31 | steps: 32 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 33 | 34 | - uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 35 | with: 36 | go-version-file: ./go.mod 37 | 38 | - name: go work sync 39 | run: | 40 | go work sync 41 | git diff --exit-code --quiet -- ${{ matrix.path }}/go.mod ${{ matrix.path }}/go.sum || \ 42 | (echo; echo "Modules out of sync in ${{ matrix.path }}/. Run 'go mod sync' and 'cd ${{ matrix.path }} && go mod tidy' to bring them in sync."; exit 1) 43 | 44 | go_test: 45 | name: go test 46 | runs-on: ubuntu-latest 47 | steps: 48 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 49 | 50 | - uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 51 | with: 52 | go-version-file: ./go.mod 53 | 54 | - run: | 55 | go test ./... 56 | cd v2/awsv1shim && go test ./... 57 | 58 | golangci-lint: 59 | runs-on: ubuntu-latest 60 | steps: 61 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 62 | 63 | - uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 64 | with: 65 | go-version-file: ./go.mod 66 | 67 | - run: cd tools && go install github.com/golangci/golangci-lint/cmd/golangci-lint 68 | 69 | - run: | 70 | golangci-lint run ./... 71 | cd v2/awsv1shim && golangci-lint run ./... 72 | 73 | import-lint: 74 | runs-on: ubuntu-latest 75 | steps: 76 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 77 | 78 | - uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 79 | with: 80 | go-version-file: ./go.mod 81 | 82 | - run: cd tools && go install github.com/pavius/impi/cmd/impi 83 | 84 | # impi runs against the whole directory tree, ignoring modules 85 | - run: impi --local . --scheme stdThirdPartyLocal ./... 86 | 87 | semgrep: 88 | runs-on: ubuntu-latest 89 | container: 90 | image: returntocorp/semgrep 91 | steps: 92 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 93 | 94 | - run: semgrep --error --quiet --config .semgrep 95 | env: 96 | REWRITE_RULE_IDS: 0 97 | 98 | markdown-lint: 99 | runs-on: ubuntu-latest 100 | steps: 101 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 102 | 103 | - uses: avto-dev/markdown-lint@04d43ee9191307b50935a753da3b775ab695eceb # v1.5.0 104 | with: 105 | config: ".markdownlint.yml" 106 | args: "./README.md" 107 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) HashiCorp, Inc. 2 | # SPDX-License-Identifier: MPL-2.0 3 | 4 | issues: 5 | max-issues-per-linter: 0 6 | max-same-issues: 0 7 | exclude-rules: 8 | - path: configtesting 9 | linters: 10 | - goconst 11 | - linters: 12 | - staticcheck 13 | text: "SA1019: aws.Endpoint is deprecated" 14 | - linters: 15 | - staticcheck 16 | text: "SA1019: aws.EndpointResolverWithOptions is deprecated" 17 | - linters: 18 | - staticcheck 19 | text: "SA1019: aws.EndpointResolverWithOptionsFunc is deprecated" 20 | - linters: 21 | - staticcheck 22 | text: "SA1019: config.WithEndpointResolverWithOptions is deprecated" 23 | 24 | linters: 25 | disable-all: true 26 | enable: 27 | - copyloopvar 28 | - dogsled 29 | - errcheck 30 | - errname 31 | - goconst 32 | - gofmt 33 | - gosimple 34 | - govet 35 | - ineffassign 36 | - misspell 37 | - mnd 38 | - staticcheck 39 | - typecheck 40 | - unconvert 41 | - unparam 42 | - unused 43 | - usetesting 44 | - whitespace 45 | 46 | linters-settings: 47 | copyloopvar: 48 | check-alias: true 49 | goconst: 50 | ignore-tests: true 51 | mnd: 52 | ignored-functions: 53 | - strings.SplitN 54 | - strings.SplitAfterN 55 | - os.MkdirAll 56 | usetesting: 57 | os-create-temp: false 58 | os-mkdir-temp: false 59 | os-setenv: false 60 | -------------------------------------------------------------------------------- /.markdownlint.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) HashiCorp, Inc. 2 | # SPDX-License-Identifier: MPL-2.0 3 | 4 | # Configuration for markdownlint 5 | # https://github.com/DavidAnson/markdownlint#configuration 6 | 7 | default: true 8 | 9 | # MD007 10 | ul-indent: 11 | indent: 4 12 | 13 | #MD010 14 | no-hard-tabs: 15 | code_blocks: false 16 | 17 | # MD046 18 | code-block-style: 19 | style: fenced 20 | 21 | # MD048 22 | code-fence-style: 23 | style: backtick 24 | 25 | # Disabled rules 26 | 27 | # MD013 28 | line-length: false 29 | # MD014 30 | commands-show-output: false 31 | -------------------------------------------------------------------------------- /.semgrep/imports.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) HashiCorp, Inc. 2 | # SPDX-License-Identifier: MPL-2.0 3 | 4 | rules: 5 | - id: no-sdkv1-imports 6 | languages: [go] 7 | message: The package `awsbase` should not include any references to the AWS SDK for Go v1 8 | paths: 9 | exclude: 10 | - awsv1shim 11 | - awsmocks 12 | patterns: 13 | - pattern: | 14 | import ("$X") 15 | - focus-metavariable: $X 16 | - metavariable-regex: 17 | metavariable: "$X" 18 | regex: 'github.com/aws/aws-sdk-go/.+' 19 | severity: ERROR 20 | 21 | - id: no-sdkv2-imports-in-awsv1shim 22 | languages: [go] 23 | message: The package `awsv1shim` should not include references to the AWS SDK for Go v2 24 | paths: 25 | include: 26 | - awsv1shim 27 | patterns: 28 | - pattern: | 29 | import ("$X") 30 | - focus-metavariable: $X 31 | - metavariable-regex: 32 | metavariable: "$X" 33 | regex: 'github.com/aws/aws-sdk-go-v2/.+' 34 | - pattern-not: | 35 | import ("github.com/aws/aws-sdk-go-v2/aws/transport/http") 36 | - pattern-not: | 37 | import ("github.com/aws/aws-sdk-go-v2/config") 38 | - pattern-not: | 39 | import ("github.com/aws/aws-sdk-go-v2/aws/retry") 40 | severity: ERROR 41 | -------------------------------------------------------------------------------- /.semgrepignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hashicorp/aws-sdk-go-base/3891601e36d64e62a6a7eb97c2f84559fcc2dd58/.semgrepignore -------------------------------------------------------------------------------- /GNUmakefile: -------------------------------------------------------------------------------- 1 | TIMEOUT ?= 30s 2 | 3 | default: test lint 4 | 5 | cleantidy: ## Tidy go modules 6 | @echo "make: tidying Go mods..." 7 | @cd tools && go mod tidy && cd .. 8 | @cd v2/awsv1shim && go mod tidy && cd ../.. 9 | @go mod tidy 10 | @echo "make: Go mods tidied" 11 | 12 | fmt: ## Run gofmt 13 | gofmt -s -w ./ 14 | 15 | gen: ## Run generators 16 | @echo "make: Running Go generators..." 17 | @go generate ./... 18 | 19 | golangci-lint: ## Run golangci-lint 20 | @golangci-lint run ./... 21 | @cd v2/awsv1shim && golangci-lint run ./... 22 | 23 | help: ## Display this help 24 | @grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-13s\033[0m %s\n", $$1, $$2}' 25 | 26 | importlint: ## Lint imports 27 | @impi --local . --scheme stdThirdPartyLocal ./... 28 | 29 | lint: golangci-lint importlint ## Run all linters 30 | 31 | semgrep: ## Run semgrep checks 32 | @docker run --rm --volume "${PWD}:/src" returntocorp/semgrep semgrep --config .semgrep --no-rewrite-rule-ids 33 | 34 | test: ## Run unit tests 35 | go test -timeout=$(TIMEOUT) -parallel=4 ./... 36 | cd v2/awsv1shim && go test -timeout=$(TIMEOUT) -parallel=4 ./... 37 | 38 | tools: ## Install tools 39 | cd tools && go install github.com/golangci/golangci-lint/cmd/golangci-lint 40 | cd tools && go install github.com/pavius/impi/cmd/impi 41 | 42 | # Please keep targets in alphabetical order 43 | .PHONY: \ 44 | cleantidy \ 45 | fmt \ 46 | gen \ 47 | golangci-lint \ 48 | help \ 49 | importlint \ 50 | lint \ 51 | test \ 52 | test \ 53 | tools \ 54 | -------------------------------------------------------------------------------- /META.d/_summary.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) HashiCorp, Inc. 2 | # SPDX-License-Identifier: MPL-2.0 3 | 4 | --- 5 | schema: 1.1 6 | 7 | partition: tf-ecosystem 8 | 9 | summary: 10 | owner: team-tf-aws 11 | description: | 12 | An opinionated AWS SDK for Go v2 library for consistent authentication configuration between projects plus additional 13 | helper functions. Relied on by the Terraform S3 Backend, terraform-provider-aws and terraform-provider-awscc. 14 | 15 | visibility: external 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # aws-sdk-go-base 2 | 3 | An opinionated [AWS SDK for Go v2](https://github.com/aws/aws-sdk-go-v2) library for consistent authentication configuration between projects plus additional helper functions. This library was originally started in [HashiCorp Terraform](https://github.com/hashicorp/terraform), migrated with the [Terraform AWS Provider](https://github.com/terraform-providers/terraform-provider-aws) during the Terraform 0.10 Core and Provider split, and now is offered as a separate library to allow easier dependency management in the Terraform ecosystem. 4 | 5 | **NOTE:** This library is not currently designed or intended for usage outside 6 | the [Terraform S3 Backend](https://www.terraform.io/docs/backends/types/s3.html), 7 | the [Terraform AWS Provider](https://www.terraform.io/docs/providers/aws), 8 | and the [Terraform AWS Cloud Control Provider](https://registry.terraform.io/providers/hashicorp/awscc). 9 | 10 | This project publishes two Go modules, `aws-sdk-go-base/v2` and `aws-sdk-go-base/v2/awsv1shim/v2`. 11 | The module `aws-sdk-go-base/v2` returns configuration compatible with the [AWS SDK for Go v2](https://github.com/aws/aws-sdk-go-v2). 12 | In order to assist with migrating large code bases using the AWS SDK for Go v1, the module `aws-sdk-go-base/v2/awsv1shim/v2` takes the AWS SDK for Go v2 configuration and returns configuration for the AWS SDK for Go v1. 13 | 14 | ## Requirements 15 | 16 | * [Go](https://golang.org/doc/install) 1.23 or higher 17 | 18 | ## Development 19 | 20 | Running `make test` will test both `aws-sdk-go-base/v2` and `aws-sdk-go-base/v2/awsv1shim/v2`. 21 | To test individual cases, `go test` works as well, but be aware that it only works in the current module. 22 | To test both modules, run: 23 | 24 | ```sh 25 | $ go test -v ./... 26 | $ cd v2/awsv1shim && go test -v ./... 27 | ``` 28 | 29 | Code is validated using 30 | [`golangci-lint`](https://github.com/golangci/golangci-lint) for general code quality, 31 | [`impi`](https://github.com/pavius/impi) for import block formatting, and 32 | [Semgrep](https://semgrep.dev) to validate additional rules. 33 | 34 | `golangci-lint` and `impi` are Go tools, and can be installed using either `make tools` or installing the Go packages. 35 | Installing the packages from the `tools` directory will ensure that you are using the expected version. 36 | `Semgrep` can be installed as described in [the documentation](https://semgrep.dev/docs/getting-started/) or using a Docker container. 37 | 38 | Validation can be run using `make lint` to run `golangci-lint` and `impi`. 39 | `make semgrep` will run Semgrep using a Docker container. 40 | 41 | If running linters directly, be aware that `golangci-lint` will only run for the current module. 42 | To validate both modules, run: 43 | 44 | ```sh 45 | $ golangci-lint run ./... 46 | $ cd v2/awsv1shim && golangci-lint run ./... 47 | ``` 48 | 49 | ## Release Process 50 | 51 | The two modules can be released separately. 52 | If changes are only made to `awsv1shim`, `aws-sdk-go-base` should **not** be released. 53 | However, if changes are made to `aws-sdk-go-base`, both modules should be released. 54 | 55 | 1. If creating a new release of `aws-sdk-go-base` 56 | 1. Update the reference in the `awsv1shim` `go.mod` file 57 | 1. Run `go mod tidy` 58 | 1. Update the CHANGELOG.md file 59 | 1. Push the updated files to GitHub 60 | 1. Push new version tags to GitHub. For more details on Go module versioning, see . (Commands `git tag -a -m "" `, `git push --tags`) 61 | * For `aws-sdk-go-base`, use the form `vX.Y.Z` 62 | * For `awsv1shim`, use the form `v2/awsv1shim/vX.Y.Z` 63 | 1. Close the associated GitHub milestone 64 | 1. Create the releases on GitHub 65 | 66 | ## AWS SDK Upgrade Policy 67 | 68 | `aws-sdk-go-base` will only upgrade AWS SDKs as needed to bring in bug fixes or required enhancements. 69 | This leaves software making use of this module free to manage their own SDK versions. 70 | -------------------------------------------------------------------------------- /awsauth.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsbase 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "fmt" 10 | 11 | "github.com/aws/aws-sdk-go-v2/aws" 12 | "github.com/aws/aws-sdk-go-v2/aws/arn" 13 | "github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds" 14 | "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" 15 | "github.com/aws/aws-sdk-go-v2/service/iam" 16 | "github.com/aws/aws-sdk-go-v2/service/sts" 17 | "github.com/aws/smithy-go" 18 | "github.com/hashicorp/aws-sdk-go-base/v2/logging" 19 | multierror "github.com/hashicorp/go-multierror" 20 | ) 21 | 22 | // getAccountIDAndPartition gets the account ID and associated partition. 23 | func getAccountIDAndPartition(ctx context.Context, iamClient *iam.Client, stsClient *sts.Client, authProviderName string) (string, string, error) { 24 | var accountID, partition string 25 | var err, errors error 26 | 27 | if authProviderName == ec2rolecreds.ProviderName { 28 | accountID, partition, err = getAccountIDAndPartitionFromEC2Metadata(ctx) 29 | } else { 30 | accountID, partition, err = getAccountIDAndPartitionFromIAMGetUser(ctx, iamClient) 31 | } 32 | if accountID != "" { 33 | return accountID, partition, nil 34 | } 35 | errors = multierror.Append(errors, err) 36 | 37 | accountID, partition, err = getAccountIDAndPartitionFromSTSGetCallerIdentity(ctx, stsClient) 38 | if accountID != "" { 39 | return accountID, partition, nil 40 | } 41 | errors = multierror.Append(errors, err) 42 | 43 | accountID, partition, err = getAccountIDAndPartitionFromIAMListRoles(ctx, iamClient) 44 | if accountID != "" { 45 | return accountID, partition, nil 46 | } 47 | errors = multierror.Append(errors, err) 48 | 49 | return accountID, partition, errors 50 | } 51 | 52 | // getAccountIDAndPartitionFromEC2Metadata gets the account ID and associated 53 | // partition from EC2 metadata. 54 | func getAccountIDAndPartitionFromEC2Metadata(ctx context.Context) (accountID string, partition string, err error) { 55 | logger := logging.RetrieveLogger(ctx) 56 | 57 | logger.Debug(ctx, "Retrieving account information from EC2 Metadata") 58 | 59 | cfg := aws.Config{} 60 | 61 | metadataClient := imds.NewFromConfig(cfg) 62 | info, err := metadataClient.GetIAMInfo(ctx, &imds.GetIAMInfoInput{}) 63 | if err != nil { 64 | // We can end up here if there's an issue with the instance metadata service 65 | // or if we're getting credentials from AdRoll's Hologram (in which case IAMInfo will 66 | // error out). 67 | logger.Debug(ctx, "Unable to retrieve account information from EC2 Metadata", map[string]any{ 68 | "error": err, 69 | }) 70 | return "", "", fmt.Errorf("retrieving account information via EC2 Metadata IAM information: %w", err) 71 | } 72 | 73 | accountID, partition, err = parseAccountIDAndPartitionFromARN(info.InstanceProfileArn) 74 | if err != nil { 75 | logger.Debug(ctx, "Unable to retrieve account information from EC2 Metadata", map[string]any{ 76 | "error": err, 77 | }) 78 | return "", "", fmt.Errorf("retrieving account information from EC2 Metadata: %w", err) 79 | } else { 80 | logger.Info(ctx, "Retrieved account information from EC2 Metadata") 81 | } 82 | return 83 | } 84 | 85 | // getAccountIDAndPartitionFromIAMGetUser gets the account ID and associated 86 | // partition from IAM. 87 | func getAccountIDAndPartitionFromIAMGetUser(ctx context.Context, iamClient iam.GetUserAPIClient) (accountID string, partition string, err error) { 88 | logger := logging.RetrieveLogger(ctx) 89 | 90 | logger.Debug(ctx, "Retrieving account information via iam:GetUser") 91 | 92 | output, err := iamClient.GetUser(ctx, &iam.GetUserInput{}) 93 | if err != nil { 94 | // AccessDenied and ValidationError can be raised 95 | // if credentials belong to federated profile, so we ignore these 96 | var apiErr smithy.APIError 97 | if errors.As(err, &apiErr) { 98 | switch apiErr.ErrorCode() { 99 | case "AccessDenied", "InvalidClientTokenId", "ValidationError": 100 | logger.Debug(ctx, "Retrieving account information via iam:GetUser: ignoring error", map[string]any{ 101 | "error": err, 102 | }) 103 | return "", "", nil 104 | } 105 | } 106 | logger.Debug(ctx, "Unable to retrieve account information via iam:GetUser", map[string]any{ 107 | "error": err, 108 | }) 109 | return "", "", fmt.Errorf("retrieving account information via iam:GetUser: %w", err) 110 | } 111 | 112 | if output == nil || output.User == nil { 113 | logger.Debug(ctx, "Unable to retrieve account information via iam:GetUser", map[string]any{ 114 | "error": "empty response", 115 | }) 116 | return "", "", errors.New("retrieving account information via iam:GetUser: empty response") 117 | } 118 | 119 | accountID, partition, err = parseAccountIDAndPartitionFromARN(aws.ToString(output.User.Arn)) 120 | if err != nil { 121 | logger.Debug(ctx, "Unable to retrieve account information via iam:GetUser", map[string]any{ 122 | "error": err, 123 | }) 124 | return "", "", fmt.Errorf("retrieving account information via iam:GetUser: %w", err) 125 | } else { 126 | logger.Info(ctx, "Retrieved account information via iam:GetUser") 127 | } 128 | return 129 | } 130 | 131 | // getAccountIDAndPartitionFromIAMListRoles gets the account ID and associated 132 | // partition from listing IAM roles. 133 | func getAccountIDAndPartitionFromIAMListRoles(ctx context.Context, iamClient iam.ListRolesAPIClient) (accountID string, partition string, err error) { 134 | logger := logging.RetrieveLogger(ctx) 135 | 136 | logger.Debug(ctx, "Retrieving account information via iam:ListRoles") 137 | 138 | output, err := iamClient.ListRoles(ctx, &iam.ListRolesInput{ 139 | MaxItems: aws.Int32(1), 140 | }) 141 | if err != nil { 142 | logger.Debug(ctx, "Unable to retrieve account information via iam:ListRoles", map[string]any{ 143 | "error": err, 144 | }) 145 | return "", "", fmt.Errorf("retrieving account information via iam:ListRoles: %w", err) 146 | } 147 | 148 | if output == nil || len(output.Roles) < 1 { 149 | logger.Debug(ctx, "Unable to retrieve account information via iam:ListRoles", map[string]any{ 150 | "error": "empty response", 151 | }) 152 | return "", "", errors.New("retrieving account information via iam:ListRoles: empty response") 153 | } 154 | 155 | accountID, partition, err = parseAccountIDAndPartitionFromARN(aws.ToString(output.Roles[0].Arn)) 156 | if err != nil { 157 | logger.Debug(ctx, "Unable to retrieve account information via iam:ListRoles", map[string]any{ 158 | "error": err, 159 | }) 160 | return "", "", fmt.Errorf("retrieving account information via iam:ListRoles: %w", err) 161 | } else { 162 | logger.Info(ctx, "Retrieved account information via iam:ListRoles") 163 | } 164 | return 165 | } 166 | 167 | // getAccountIDAndPartitionFromSTSGetCallerIdentity gets the account ID and associated 168 | // partition from STS caller identity. 169 | func getAccountIDAndPartitionFromSTSGetCallerIdentity(ctx context.Context, stsClient *sts.Client) (accountID string, partition string, err error) { 170 | logger := logging.RetrieveLogger(ctx) 171 | 172 | logger.Debug(ctx, "Retrieving caller identity from STS") 173 | 174 | output, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) 175 | if err != nil { 176 | logger.Debug(ctx, "Unable to retrieve caller identity from STS", map[string]any{ 177 | "error": err, 178 | }) 179 | return "", "", fmt.Errorf("retrieving caller identity from STS: %w", err) 180 | } 181 | 182 | if output == nil || output.Arn == nil { 183 | logger.Debug(ctx, "Unable to retrieve caller identity from STS", map[string]any{ 184 | "error": "empty response", 185 | }) 186 | return "", "", errors.New("retrieving caller identity from STS: empty response") 187 | } 188 | 189 | accountID, partition, err = parseAccountIDAndPartitionFromARN(aws.ToString(output.Arn)) 190 | if err != nil { 191 | logger.Debug(ctx, "Unable to retrieve caller identity from STS", map[string]any{ 192 | "error": err, 193 | }) 194 | return "", "", fmt.Errorf("retrieving caller identity from STS: %w", err) 195 | } else { 196 | logger.Info(ctx, "Retrieved caller identity from STS") 197 | } 198 | return 199 | } 200 | 201 | func parseAccountIDAndPartitionFromARN(inputARN string) (string, string, error) { 202 | arn, err := arn.Parse(inputARN) 203 | if err != nil { 204 | return "", "", fmt.Errorf("parsing ARN (%s): %s", inputARN, err) 205 | } 206 | return arn.AccountID, arn.Partition, nil 207 | } 208 | -------------------------------------------------------------------------------- /clients.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsbase 5 | 6 | import ( 7 | "context" 8 | 9 | "github.com/aws/aws-sdk-go-v2/aws" 10 | "github.com/aws/aws-sdk-go-v2/service/iam" 11 | "github.com/aws/aws-sdk-go-v2/service/sts" 12 | "github.com/hashicorp/aws-sdk-go-base/v2/logging" 13 | ) 14 | 15 | func iamClient(ctx context.Context, awsConfig aws.Config, c *Config) *iam.Client { 16 | logger := logging.RetrieveLogger(ctx) 17 | 18 | return iam.NewFromConfig(awsConfig, func(opts *iam.Options) { 19 | if c.IamEndpoint != "" { 20 | logger.Info(ctx, "IAM client: setting custom endpoint", map[string]any{ 21 | "tf_aws.iam_client.endpoint": c.IamEndpoint, 22 | }) 23 | opts.EndpointResolver = iam.EndpointResolverFromURL(c.IamEndpoint) //nolint:staticcheck // The replacement is not documented yet (2023/07/31) 24 | } 25 | }) 26 | } 27 | 28 | func stsClient(ctx context.Context, awsConfig aws.Config, c *Config) *sts.Client { 29 | logger := logging.RetrieveLogger(ctx) 30 | 31 | return sts.NewFromConfig(awsConfig, func(opts *sts.Options) { 32 | if c.StsRegion != "" { 33 | logger.Info(ctx, "STS client: setting region", map[string]any{ 34 | "tf_aws.sts_client.region": c.StsRegion, 35 | }) 36 | opts.Region = c.StsRegion 37 | } 38 | if c.StsEndpoint != "" { 39 | logger.Info(ctx, "STS client: setting custom endpoint", map[string]any{ 40 | "tf_aws.sts_client.endpoint": c.StsEndpoint, 41 | }) 42 | opts.EndpointResolver = sts.EndpointResolverFromURL(c.StsEndpoint) //nolint:staticcheck // The replacement is not documented yet (2023/07/31) 43 | } 44 | }) 45 | } 46 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsbase 5 | 6 | import ( 7 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" 8 | ) 9 | 10 | // Config, APNInfo, APNProduct, and AssumeRole are aliased to an internal package to break a dependency cycle 11 | // in internal/httpclient. 12 | 13 | type Config = config.Config 14 | 15 | type APNInfo = config.APNInfo 16 | 17 | type AssumeRole = config.AssumeRole 18 | 19 | type AssumeRoleWithWebIdentity = config.AssumeRoleWithWebIdentity 20 | 21 | type UserAgentProducts = config.UserAgentProducts 22 | 23 | type UserAgentProduct = config.UserAgentProduct 24 | 25 | const ( 26 | EC2MetadataEndpointModeIPv4 = "IPv4" 27 | EC2MetadataEndpointModeIPv6 = "IPv6" 28 | ) 29 | 30 | func EC2MetadataEndpointMode_Values() []string { 31 | return []string{ 32 | EC2MetadataEndpointModeIPv4, 33 | EC2MetadataEndpointModeIPv6, 34 | } 35 | } 36 | 37 | const ( 38 | HTTPProxyModeLegacy = config.HTTPProxyModeLegacy 39 | HTTPProxyModeSeparate = config.HTTPProxyModeSeparate 40 | ) 41 | -------------------------------------------------------------------------------- /configtesting/config.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package configtesting 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "os" 10 | "testing" 11 | 12 | "github.com/aws/aws-sdk-go-v2/aws" 13 | "github.com/google/go-cmp/cmp" 14 | "github.com/google/go-cmp/cmp/cmpopts" 15 | "github.com/hashicorp/aws-sdk-go-base/v2/mockdata" 16 | "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" 17 | ) 18 | 19 | type TestMode int 20 | 21 | const ( 22 | TestModeInvalid TestMode = 0 23 | TestModeLocal TestMode = 1 24 | TestModeAcceptanceTest TestMode = 2 25 | ) 26 | 27 | type TestDriver interface { 28 | Init(mode TestMode) 29 | TestCase() TestCaseDriver 30 | } 31 | 32 | type TestCaseDriver interface { 33 | Configuration(f []ConfigFunc) Configurer 34 | Setup(t *testing.T) 35 | Apply(ctx context.Context, t *testing.T) (context.Context, Thing) 36 | } 37 | 38 | type Configurer interface { 39 | SetAccessKey(s string) 40 | SetSecretKey(s string) 41 | SetProfile(s string) 42 | SetUseFIPSEndpoint(b bool) 43 | AddEndpoint(k, v string) 44 | AddSharedConfigFile(f string) 45 | } 46 | 47 | type Thing interface { 48 | GetCredentials() aws.CredentialsProvider 49 | GetRegion() string 50 | } 51 | 52 | type AwsConfigThing interface { 53 | GetAwsConfig() aws.Config 54 | } 55 | 56 | type ConfigFunc func(c Configurer) 57 | 58 | func WithProfile(s string) ConfigFunc { 59 | return func(c Configurer) { 60 | c.SetProfile(s) 61 | } 62 | } 63 | 64 | func WithUseFIPSEndpoint(b bool) ConfigFunc { 65 | return func(c Configurer) { 66 | c.SetUseFIPSEndpoint(b) 67 | } 68 | } 69 | 70 | func SSO(t *testing.T, driver TestDriver) { 71 | t.Helper() 72 | 73 | driver.Init(TestModeLocal) 74 | 75 | const ssoSessionName = "test-sso-session" 76 | 77 | testCases := map[string]struct { 78 | Configuration []ConfigFunc 79 | SharedConfigurationFile string 80 | ExpectedCredentialsValue aws.Credentials 81 | }{ 82 | "shared configuration file": { 83 | SharedConfigurationFile: fmt.Sprintf(` 84 | [default] 85 | sso_session = %s 86 | sso_account_id = 123456789012 87 | sso_role_name = testRole 88 | region = us-east-1 89 | 90 | [sso-session test-sso-session] 91 | sso_region = us-east-1 92 | sso_start_url = https://d-123456789a.awsapps.com/start 93 | sso_registration_scopes = sso:account:access 94 | `, ssoSessionName), 95 | ExpectedCredentialsValue: mockdata.MockSsoCredentials, 96 | }, 97 | 98 | "use FIPS": { 99 | Configuration: []ConfigFunc{ 100 | WithUseFIPSEndpoint(true), 101 | }, 102 | SharedConfigurationFile: fmt.Sprintf(` 103 | [default] 104 | sso_session = %s 105 | sso_account_id = 123456789012 106 | sso_role_name = testRole 107 | region = us-east-1 108 | 109 | [sso-session test-sso-session] 110 | sso_region = us-east-1 111 | sso_start_url = https://d-123456789a.awsapps.com/start 112 | sso_registration_scopes = sso:account:access 113 | `, ssoSessionName), 114 | ExpectedCredentialsValue: mockdata.MockSsoCredentials, 115 | }, 116 | } 117 | 118 | for name, tc := range testCases { 119 | t.Run(name, func(t *testing.T) { 120 | caseDriver := driver.TestCase() 121 | 122 | servicemocks.InitSessionTestEnv(t) 123 | 124 | ctx := context.TODO() 125 | 126 | err := servicemocks.SsoTestSetup(t, ssoSessionName) 127 | if err != nil { 128 | t.Fatalf("setup: %s", err) 129 | } 130 | 131 | config := caseDriver.Configuration(tc.Configuration) 132 | 133 | closeSso, ssoEndpoint := servicemocks.SsoCredentialsApiMock() 134 | defer closeSso() 135 | config.AddEndpoint("sso", ssoEndpoint) 136 | 137 | tempdir, err := os.MkdirTemp("", "temp") 138 | if err != nil { 139 | t.Fatalf("error creating temp dir: %s", err) 140 | } 141 | defer os.Remove(tempdir) 142 | t.Setenv("TMPDIR", tempdir) 143 | 144 | if tc.SharedConfigurationFile != "" { 145 | file, err := os.CreateTemp("", "aws-sdk-go-base-shared-configuration-file") 146 | 147 | if err != nil { 148 | t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) 149 | } 150 | 151 | defer os.Remove(file.Name()) 152 | 153 | err = os.WriteFile(file.Name(), []byte(tc.SharedConfigurationFile), 0600) //nolint:mnd 154 | 155 | if err != nil { 156 | t.Fatalf("unexpected error writing shared configuration file: %s", err) 157 | } 158 | 159 | config.AddSharedConfigFile(file.Name()) 160 | } 161 | 162 | caseDriver.Setup(t) 163 | 164 | ctx, thing := caseDriver.Apply(ctx, t) 165 | 166 | credentials := thing.GetCredentials() 167 | if credentials == nil { 168 | t.Fatal("credentials are nil") 169 | } 170 | credentialsValue, err := credentials.Retrieve(ctx) 171 | 172 | if err != nil { 173 | t.Fatalf("retrieving credentials: %s", err) 174 | } 175 | 176 | if diff := cmp.Diff(credentialsValue, tc.ExpectedCredentialsValue, cmpopts.IgnoreFields(aws.Credentials{}, "Expires")); diff != "" { 177 | t.Fatalf("unexpected credentials: (- got, + expected)\n%s", diff) 178 | } 179 | }) 180 | } 181 | } 182 | 183 | func LegacySSO(t *testing.T, driver TestDriver) { 184 | t.Helper() 185 | 186 | driver.Init(TestModeLocal) 187 | 188 | const ssoStartUrl = "https://d-123456789a.awsapps.com/start" 189 | 190 | testCases := map[string]struct { 191 | Configuration []ConfigFunc 192 | SharedConfigurationFile string 193 | ExpectedCredentialsValue aws.Credentials 194 | }{ 195 | "shared configuration file": { 196 | SharedConfigurationFile: fmt.Sprintf(` 197 | [default] 198 | sso_start_url = %s 199 | sso_region = us-east-1 200 | sso_account_id = 123456789012 201 | sso_role_name = testRole 202 | region = us-east-1 203 | `, ssoStartUrl), 204 | ExpectedCredentialsValue: mockdata.MockSsoCredentials, 205 | }, 206 | 207 | "use FIPS": { 208 | Configuration: []ConfigFunc{ 209 | WithUseFIPSEndpoint(true), 210 | }, 211 | SharedConfigurationFile: fmt.Sprintf(` 212 | [default] 213 | sso_start_url = %s 214 | sso_region = us-east-1 215 | sso_account_id = 123456789012 216 | sso_role_name = testRole 217 | region = us-east-1 218 | `, ssoStartUrl), 219 | ExpectedCredentialsValue: mockdata.MockSsoCredentials, 220 | }, 221 | } 222 | 223 | for name, tc := range testCases { 224 | t.Run(name, func(t *testing.T) { 225 | caseDriver := driver.TestCase() 226 | 227 | servicemocks.InitSessionTestEnv(t) 228 | 229 | ctx := context.TODO() 230 | 231 | err := servicemocks.SsoTestSetup(t, ssoStartUrl) 232 | if err != nil { 233 | t.Fatalf("setup: %s", err) 234 | } 235 | 236 | config := caseDriver.Configuration(tc.Configuration) 237 | 238 | closeSso, ssoEndpoint := servicemocks.SsoCredentialsApiMock() 239 | defer closeSso() 240 | config.AddEndpoint("sso", ssoEndpoint) 241 | 242 | tempdir, err := os.MkdirTemp("", "temp") 243 | if err != nil { 244 | t.Fatalf("error creating temp dir: %s", err) 245 | } 246 | defer os.Remove(tempdir) 247 | t.Setenv("TMPDIR", tempdir) 248 | 249 | if tc.SharedConfigurationFile != "" { 250 | file, err := os.CreateTemp("", "aws-sdk-go-base-shared-configuration-file") 251 | 252 | if err != nil { 253 | t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) 254 | } 255 | 256 | defer os.Remove(file.Name()) 257 | 258 | err = os.WriteFile(file.Name(), []byte(tc.SharedConfigurationFile), 0600) //nolint:mnd 259 | 260 | if err != nil { 261 | t.Fatalf("unexpected error writing shared configuration file: %s", err) 262 | } 263 | 264 | config.AddSharedConfigFile(file.Name()) 265 | } 266 | 267 | caseDriver.Setup(t) 268 | 269 | ctx, thing := caseDriver.Apply(ctx, t) 270 | 271 | credentials := thing.GetCredentials() 272 | if credentials == nil { 273 | t.Fatal("credentials are nil") 274 | } 275 | credentialsValue, err := credentials.Retrieve(ctx) 276 | 277 | if err != nil { 278 | t.Fatalf("retrieving credentials: %s", err) 279 | } 280 | 281 | if diff := cmp.Diff(credentialsValue, tc.ExpectedCredentialsValue, cmpopts.IgnoreFields(aws.Credentials{}, "Expires")); diff != "" { 282 | t.Fatalf("unexpected credentials: (- got, + expected)\n%s", diff) 283 | } 284 | }) 285 | } 286 | } 287 | -------------------------------------------------------------------------------- /configtesting/file_parsing.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package configtesting 5 | 6 | import ( 7 | "context" 8 | "os" 9 | "testing" 10 | 11 | "github.com/aws/aws-sdk-go-v2/config" 12 | "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" 13 | ) 14 | 15 | // haredConfigFileParsing prevents regression in shared config file parsing 16 | // * https://github.com/aws/aws-sdk-go-v2/issues/2349: indented keys 17 | func SharedConfigFileParsing(t *testing.T, driver TestDriver) { 18 | t.Helper() 19 | 20 | driver.Init(TestModeLocal) 21 | 22 | testcases := map[string]struct { 23 | Configuration []ConfigFunc 24 | SharedConfigurationFile string 25 | Check func(t *testing.T, thing Thing) 26 | }{ 27 | "leading newline": { 28 | SharedConfigurationFile: ` 29 | [default] 30 | region = us-west-2 31 | `, 32 | Check: func(t *testing.T, thing Thing) { 33 | region := thing.GetRegion() 34 | if a, e := region, "us-west-2"; a != e { 35 | t.Errorf("expected region %q, got %q", e, a) 36 | } 37 | }, 38 | }, 39 | 40 | "leading whitespace": { 41 | // Do not "fix" indentation! 42 | SharedConfigurationFile: ` [default] 43 | region = us-west-2 44 | `, 45 | Check: func(t *testing.T, thing Thing) { 46 | region := thing.GetRegion() 47 | if a, e := region, "us-west-2"; a != e { 48 | t.Errorf("expected region %q, got %q", e, a) 49 | } 50 | }, 51 | }, 52 | 53 | "leading newline and whitespace": { 54 | // Do not "fix" indentation! 55 | SharedConfigurationFile: ` 56 | [default] 57 | region = us-west-2 58 | `, 59 | Check: func(t *testing.T, thing Thing) { 60 | region := thing.GetRegion() 61 | if a, e := region, "us-west-2"; a != e { 62 | t.Errorf("expected region %q, got %q", e, a) 63 | } 64 | }, 65 | }, 66 | 67 | "named profile after leading newline and whitespace": { 68 | Configuration: []ConfigFunc{ 69 | WithProfile("test"), 70 | }, 71 | // Do not "fix" indentation! 72 | SharedConfigurationFile: ` 73 | [default] 74 | region = us-west-2 75 | 76 | [profile test] 77 | region = us-east-1 78 | `, 79 | Check: func(t *testing.T, thing Thing) { 80 | region := thing.GetRegion() 81 | if a, e := region, "us-east-1"; a != e { 82 | t.Errorf("expected region %q, got %q", e, a) 83 | } 84 | }, 85 | }, 86 | 87 | "named profile": { 88 | Configuration: []ConfigFunc{ 89 | WithProfile("test"), 90 | }, 91 | SharedConfigurationFile: ` 92 | [default] 93 | region = us-west-2 94 | 95 | [profile test] 96 | region = us-east-1 97 | `, 98 | Check: func(t *testing.T, thing Thing) { 99 | region := thing.GetRegion() 100 | if a, e := region, "us-east-1"; a != e { 101 | t.Errorf("expected region %q, got %q", e, a) 102 | } 103 | }, 104 | }, 105 | 106 | "trailing hash": { 107 | SharedConfigurationFile: ` 108 | [default] 109 | sso_start_url = https://d-123456789a.awsapps.com/start# 110 | `, 111 | Check: func(t *testing.T, thing Thing) { 112 | ct, ok := thing.(AwsConfigThing) 113 | if !ok { 114 | t.Skipf("Not an AwsConfigThing") 115 | } 116 | 117 | awsConfig := ct.GetAwsConfig() 118 | var ssoStartUrl string 119 | for _, source := range awsConfig.ConfigSources { 120 | if shared, ok := source.(config.SharedConfig); ok { 121 | ssoStartUrl = shared.SSOStartURL 122 | } 123 | } 124 | if a, e := ssoStartUrl, "https://d-123456789a.awsapps.com/start#"; a != e { 125 | t.Errorf("expected sso_start_url %q, got %q", e, a) 126 | } 127 | }, 128 | }, 129 | } 130 | 131 | for name, tc := range testcases { 132 | t.Run(name, func(t *testing.T) { 133 | ctx := context.TODO() 134 | 135 | caseDriver := driver.TestCase() 136 | 137 | servicemocks.InitSessionTestEnv(t) 138 | 139 | config := caseDriver.Configuration(tc.Configuration) 140 | 141 | config.SetAccessKey(servicemocks.MockStaticAccessKey) 142 | config.SetSecretKey(servicemocks.MockStaticSecretKey) 143 | 144 | if tc.SharedConfigurationFile != "" { 145 | file, err := os.CreateTemp("", "aws-sdk-go-base-shared-configuration-file") 146 | 147 | if err != nil { 148 | t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) 149 | } 150 | 151 | defer os.Remove(file.Name()) 152 | 153 | err = os.WriteFile(file.Name(), []byte(tc.SharedConfigurationFile), 0600) //nolint:mnd 154 | 155 | if err != nil { 156 | t.Fatalf("unexpected error writing shared configuration file: %s", err) 157 | } 158 | 159 | config.AddSharedConfigFile(file.Name()) 160 | } 161 | 162 | _, thing := caseDriver.Apply(ctx, t) 163 | 164 | tc.Check(t, thing) 165 | }) 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /credentials_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsbase 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "os" 10 | "testing" 11 | 12 | "github.com/aws/aws-sdk-go-v2/aws" 13 | "github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds" 14 | "github.com/aws/aws-sdk-go-v2/credentials/stscreds" 15 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/test" 16 | "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" 17 | ) 18 | 19 | // TestAWSGetCredentials_ec2Imds is designed to test the scenario of running Terraform 20 | // from an EC2 instance, without environment variables or manually supplied 21 | // credentials. 22 | func TestAWSGetCredentials_ec2Imds(t *testing.T) { 23 | // clear AWS_* environment variables 24 | resetEnv := servicemocks.UnsetEnv(t) 25 | defer resetEnv() 26 | 27 | ctx := test.Context(t) 28 | 29 | // capture the test server's close method, to call after the test returns 30 | ts := servicemocks.AwsMetadataApiMock(append( 31 | servicemocks.Ec2metadata_securityCredentialsEndpoints, 32 | servicemocks.Ec2metadata_instanceIdEndpoint, 33 | servicemocks.Ec2metadata_iamInfoEndpoint, 34 | )) 35 | defer ts() 36 | 37 | // An empty config, no key supplied 38 | cfg := Config{} 39 | 40 | creds, source, err := getCredentialsProvider(ctx, &cfg) 41 | if err != nil { 42 | t.Fatalf("unexpected '%[1]T' error getting credentials provider: %[1]s", err) 43 | } 44 | 45 | if a, e := source, ec2rolecreds.ProviderName; a != e { 46 | t.Errorf("Expected initial source to be %q, %q given", e, a) 47 | } 48 | 49 | validateCredentialsProvider(ctx, creds, "Ec2MetadataAccessKey", "Ec2MetadataSecretKey", "Ec2MetadataSessionToken", ec2rolecreds.ProviderName, t) 50 | testCredentialsProviderWrappedWithCache(creds, t) 51 | } 52 | 53 | func TestAWSGetCredentials_shouldErrorWithInvalidEc2ImdsEndpoint(t *testing.T) { 54 | ctx := test.Context(t) 55 | 56 | resetEnv := servicemocks.UnsetEnv(t) 57 | defer resetEnv() 58 | // capture the test server's close method, to call after the test returns 59 | ts := servicemocks.InvalidEC2MetadataEndpoint(t) 60 | defer ts() 61 | 62 | // An empty config, no key supplied 63 | cfg := Config{} 64 | 65 | _, _, diags := getCredentialsProvider(ctx, &cfg) 66 | if diags == nil { 67 | t.Fatal("expected error returned when getting creds w/ invalid EC2 IMDS endpoint") 68 | } 69 | if !ContainsNoValidCredentialSourcesError(diags) { 70 | t.Fatalf("expected NoValidCredentialSourcesError, got '%[1]T': %[1]s", diags) 71 | } 72 | } 73 | 74 | func TestAWSGetCredentials_sharedCredentialsFile(t *testing.T) { 75 | ctx := test.Context(t) 76 | 77 | resetEnv := servicemocks.UnsetEnv(t) 78 | defer resetEnv() 79 | 80 | t.Setenv("AWS_PROFILE", "myprofile") 81 | 82 | fileEnvName := writeCredentialsFile(credentialsFileContentsEnv, t) 83 | defer os.Remove(fileEnvName) 84 | 85 | fileParamName := writeCredentialsFile(credentialsFileContentsParam, t) 86 | defer os.Remove(fileParamName) 87 | 88 | t.Setenv("AWS_SHARED_CREDENTIALS_FILE", fileEnvName) 89 | 90 | // Confirm AWS_SHARED_CREDENTIALS_FILE is working 91 | credsEnv, source, err := getCredentialsProvider(ctx, &Config{ 92 | Profile: "myprofile", 93 | }) 94 | if err != nil { 95 | t.Fatalf("unexpected '%[1]T' error getting credentials provider from environment: %[1]s", err) 96 | } 97 | if a, e := source, sharedConfigCredentialsSource(fileEnvName); a != e { 98 | t.Errorf("Expected initial source to be %q, %q given", e, a) 99 | } 100 | validateCredentialsProvider(ctx, credsEnv, "accesskey1", "secretkey1", "", sharedConfigCredentialsSource(fileEnvName), t) 101 | 102 | // Confirm CredsFilename overrides AWS_SHARED_CREDENTIALS_FILE 103 | credsParam, source, err := getCredentialsProvider(ctx, &Config{ 104 | Profile: "myprofile", 105 | SharedCredentialsFiles: []string{fileParamName}, 106 | }) 107 | if err != nil { 108 | t.Fatalf("unexpected '%[1]T' error getting credentials provider from configuration: %[1]s", err) 109 | } 110 | if a, e := source, sharedConfigCredentialsSource(fileParamName); a != e { 111 | t.Errorf("Expected initial source to be %q, %q given", e, a) 112 | } 113 | validateCredentialsProvider(ctx, credsParam, "accesskey2", "secretkey2", "", sharedConfigCredentialsSource(fileParamName), t) 114 | } 115 | 116 | func TestAWSGetCredentials_webIdentityToken(t *testing.T) { 117 | ctx := test.Context(t) 118 | 119 | cfg := Config{ 120 | AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{ 121 | RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn, 122 | SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName, 123 | WebIdentityToken: servicemocks.MockWebIdentityToken, 124 | }, 125 | } 126 | 127 | ts := servicemocks.MockAwsApiServer("STS", []*servicemocks.MockEndpoint{ 128 | servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint, 129 | servicemocks.MockStsGetCallerIdentityValidAssumedRoleEndpoint, 130 | }) 131 | defer ts.Close() 132 | cfg.StsEndpoint = ts.URL 133 | 134 | creds, source, err := getCredentialsProvider(ctx, &cfg) 135 | if err != nil { 136 | t.Fatalf("unexpected '%[1]T' error getting credentials provider: %[1]s", err) 137 | } 138 | 139 | if a, e := source, stscreds.WebIdentityProviderName; a != e { 140 | t.Errorf("Expected initial source to be %q, %q given", e, a) 141 | } 142 | 143 | validateCredentialsProvider(ctx, creds, 144 | servicemocks.MockStsAssumeRoleWithWebIdentityAccessKey, 145 | servicemocks.MockStsAssumeRoleWithWebIdentitySecretKey, 146 | servicemocks.MockStsAssumeRoleWithWebIdentitySessionToken, 147 | stscreds.WebIdentityProviderName, t) 148 | testCredentialsProviderWrappedWithCache(creds, t) 149 | } 150 | 151 | var credentialsFileContentsEnv = `[myprofile] 152 | aws_access_key_id = accesskey1 153 | aws_secret_access_key = secretkey1 154 | ` 155 | 156 | var credentialsFileContentsParam = `[myprofile] 157 | aws_access_key_id = accesskey2 158 | aws_secret_access_key = secretkey2 159 | ` 160 | 161 | func writeCredentialsFile(credentialsFileContents string, t *testing.T) string { 162 | file, err := os.CreateTemp(os.TempDir(), "terraform_aws_cred") 163 | if err != nil { 164 | t.Fatalf("Error writing temporary credentials file: %s", err) 165 | } 166 | _, err = file.WriteString(credentialsFileContents) 167 | if err != nil { 168 | t.Fatalf("Error writing temporary credentials to file: %s", err) 169 | } 170 | err = file.Close() 171 | if err != nil { 172 | t.Fatalf("Error closing temporary credentials file: %s", err) 173 | } 174 | return file.Name() 175 | } 176 | 177 | func validateCredentialsProvider(ctx context.Context, creds aws.CredentialsProvider, accesskey, secretkey, token, source string, t *testing.T) { 178 | v, err := creds.Retrieve(ctx) 179 | if err != nil { 180 | t.Fatalf("Error retrieving credentials: %s", err) 181 | } 182 | 183 | if v.AccessKeyID != accesskey { 184 | t.Errorf("AccessKeyID mismatch, expected: %q, got %q", accesskey, v.AccessKeyID) 185 | } 186 | if v.SecretAccessKey != secretkey { 187 | t.Errorf("SecretAccessKey mismatch, expected: %q, got %q", secretkey, v.SecretAccessKey) 188 | } 189 | if v.SessionToken != token { 190 | t.Errorf("SessionToken mismatch, expected: %q, got %q", token, v.SessionToken) 191 | } 192 | if v.Source != source { 193 | t.Errorf("Expected provider name to be %q, %q given", source, v.Source) 194 | } 195 | } 196 | 197 | func testCredentialsProviderWrappedWithCache(creds aws.CredentialsProvider, t *testing.T) { 198 | switch creds.(type) { 199 | case *aws.CredentialsCache: 200 | break 201 | default: 202 | t.Error("expected credentials provider to be wrapped with aws.CredentialsCache") 203 | } 204 | } 205 | 206 | func sharedConfigCredentialsSource(filename string) string { 207 | return fmt.Sprintf(sharedConfigCredentialsProvider+": %s", filename) 208 | } 209 | -------------------------------------------------------------------------------- /diag/diagnostic.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package diag 5 | 6 | type Diagnostic interface { 7 | // Severity returns the desired level of feedback for the diagnostic. 8 | Severity() Severity 9 | 10 | // Summary is a short description for the diagnostic. 11 | // 12 | // Typically this is implemented as a title, such as "Invalid Resource Name", 13 | // or single line sentence. 14 | Summary() string 15 | 16 | // Detail is a long description for the diagnostic. 17 | // 18 | // This should contain all relevant information about why the diagnostic 19 | // was generated and if applicable, ways to prevent the diagnostic. It 20 | // should generally be written and formatted for human consumption by 21 | // practitioners or provider developers. 22 | Detail() string 23 | 24 | // Equal returns true if the other diagnostic is wholly equivalent. 25 | Equal(Diagnostic) bool 26 | } 27 | 28 | type DiagnosticWithErr interface { 29 | Diagnostic 30 | 31 | Err() error 32 | } 33 | -------------------------------------------------------------------------------- /diag/diagnostics.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package diag 5 | 6 | // Diagnostics represents a collection of diagnostics. 7 | // 8 | // While this collection is ordered, the order is not guaranteed as reliable 9 | // or consistent. 10 | type Diagnostics []Diagnostic 11 | 12 | // AddError adds a generic error diagnostic to the collection. 13 | func (diags Diagnostics) AddError(summary string, detail string) Diagnostics { 14 | return diags.Append(NewErrorDiagnostic(summary, detail)) 15 | } 16 | 17 | func (diags Diagnostics) AddSimpleError(err error) Diagnostics { 18 | return diags.Append(NewNativeErrorDiagnostic(err)) 19 | } 20 | 21 | // AddWarning adds a generic warning diagnostic to the collection. 22 | func (diags Diagnostics) AddWarning(summary string, detail string) Diagnostics { 23 | return diags.Append(NewWarningDiagnostic(summary, detail)) 24 | } 25 | 26 | // Append adds non-empty and non-duplicate diagnostics to the collection. 27 | func (diags Diagnostics) Append(in ...Diagnostic) Diagnostics { 28 | for _, diag := range in { 29 | if diag == nil { 30 | continue 31 | } 32 | 33 | if diags.Contains(diag) { 34 | continue 35 | } 36 | 37 | if diags == nil { 38 | diags = Diagnostics{diag} 39 | } else { 40 | diags = append(diags, diag) 41 | } 42 | } 43 | 44 | return diags 45 | } 46 | 47 | // Contains returns true if the collection contains an equal Diagnostic. 48 | func (diags Diagnostics) Contains(in Diagnostic) bool { 49 | for _, diag := range diags { 50 | if diag.Equal(in) { 51 | return true 52 | } 53 | } 54 | 55 | return false 56 | } 57 | 58 | // Equal returns true if all given diagnostics are equivalent in order and 59 | // content, based on the underlying (Diagnostic).Equal() method of each. 60 | func (diags Diagnostics) Equal(other Diagnostics) bool { 61 | if len(diags) != len(other) { 62 | return false 63 | } 64 | 65 | for diagIndex, diag := range diags { 66 | if !diag.Equal(other[diagIndex]) { 67 | return false 68 | } 69 | } 70 | 71 | return true 72 | } 73 | 74 | // HasError returns true if the collection has an error severity Diagnostic. 75 | func (diags Diagnostics) HasError() bool { 76 | for _, diag := range diags { 77 | if diag.Severity() == SeverityError { 78 | return true 79 | } 80 | } 81 | 82 | return false 83 | } 84 | 85 | func (diags Diagnostics) Count() int { 86 | return len(diags) 87 | } 88 | 89 | // ErrorsCount returns the number of Diagnostic in Diagnostics that are SeverityError. 90 | func (diags Diagnostics) ErrorsCount() int { 91 | return len(diags.Errors()) 92 | } 93 | 94 | // WarningsCount returns the number of Diagnostic in Diagnostics that are SeverityWarning. 95 | func (diags Diagnostics) WarningsCount() int { 96 | return len(diags.Warnings()) 97 | } 98 | 99 | // Errors returns all the Diagnostic in Diagnostics that are SeverityError. 100 | func (diags Diagnostics) Errors() Diagnostics { 101 | dd := Diagnostics{} 102 | 103 | for _, d := range diags { 104 | if SeverityError == d.Severity() { 105 | dd = append(dd, d) 106 | } 107 | } 108 | 109 | return dd 110 | } 111 | 112 | // Warnings returns all the Diagnostic in Diagnostics that are SeverityWarning. 113 | func (diags Diagnostics) Warnings() Diagnostics { 114 | dd := Diagnostics{} 115 | 116 | for _, d := range diags { 117 | if SeverityWarning == d.Severity() { 118 | dd = append(dd, d) 119 | } 120 | } 121 | 122 | return dd 123 | } 124 | -------------------------------------------------------------------------------- /diag/error_diagnostic.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package diag 5 | 6 | var _ Diagnostic = ErrorDiagnostic{} 7 | 8 | // ErrorDiagnostic is a generic diagnostic with error severity. 9 | type ErrorDiagnostic struct { 10 | detail string 11 | summary string 12 | } 13 | 14 | // NewErrorDiagnostic returns a new error severity diagnostic with the given summary and detail. 15 | func NewErrorDiagnostic(summary string, detail string) ErrorDiagnostic { 16 | return ErrorDiagnostic{ 17 | detail: detail, 18 | summary: summary, 19 | } 20 | } 21 | 22 | // Severity returns the diagnostic severity. 23 | func (d ErrorDiagnostic) Severity() Severity { 24 | return SeverityError 25 | } 26 | 27 | // Summary returns the diagnostic summary. 28 | func (d ErrorDiagnostic) Summary() string { 29 | return d.summary 30 | } 31 | 32 | // Detail returns the diagnostic detail. 33 | func (d ErrorDiagnostic) Detail() string { 34 | return d.detail 35 | } 36 | 37 | // Equal returns true if the other diagnostic is wholly equivalent. 38 | func (d ErrorDiagnostic) Equal(other Diagnostic) bool { 39 | ed, ok := other.(ErrorDiagnostic) 40 | 41 | if !ok { 42 | return false 43 | } 44 | 45 | return ed.Summary() == d.Summary() && ed.Detail() == d.Detail() 46 | } 47 | -------------------------------------------------------------------------------- /diag/error_diagnostic_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package diag_test 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/hashicorp/aws-sdk-go-base/v2/diag" 10 | ) 11 | 12 | func TestErrorDiagnosticEqual(t *testing.T) { 13 | t.Parallel() 14 | 15 | testCases := map[string]struct { 16 | diag diag.ErrorDiagnostic 17 | other diag.Diagnostic 18 | expected bool 19 | }{ 20 | "matching": { 21 | diag: diag.NewErrorDiagnostic("test summary", "test detail"), 22 | other: diag.NewErrorDiagnostic("test summary", "test detail"), 23 | expected: true, 24 | }, 25 | "nil": { 26 | diag: diag.NewErrorDiagnostic("test summary", "test detail"), 27 | other: nil, 28 | expected: false, 29 | }, 30 | "different-detail": { 31 | diag: diag.NewErrorDiagnostic("test summary", "test detail"), 32 | other: diag.NewErrorDiagnostic("test summary", "different detail"), 33 | expected: false, 34 | }, 35 | "different-summary": { 36 | diag: diag.NewErrorDiagnostic("test summary", "test detail"), 37 | other: diag.NewErrorDiagnostic("different summary", "test detail"), 38 | expected: false, 39 | }, 40 | "different-type": { 41 | diag: diag.NewErrorDiagnostic("test summary", "test detail"), 42 | other: diag.NewWarningDiagnostic("test summary", "test detail"), 43 | expected: false, 44 | }, 45 | } 46 | 47 | for name, tc := range testCases { 48 | t.Run(name, func(t *testing.T) { 49 | t.Parallel() 50 | 51 | got := tc.diag.Equal(tc.other) 52 | 53 | if got != tc.expected { 54 | t.Errorf("Unexpected response: got: %t, wanted: %t", got, tc.expected) 55 | } 56 | }) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /diag/native_error_diagnostic.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package diag 5 | 6 | import "fmt" 7 | 8 | var _ DiagnosticWithErr = NativeErrorDiagnostic{} 9 | 10 | // NativeErrorDiagnostic is a diagnostic with error severity which wraps a Go error. 11 | type NativeErrorDiagnostic struct { 12 | // detail string 13 | // summary string 14 | err error 15 | } 16 | 17 | // NewNativeErrorDiagnostic returns a new error severity diagnostic with the given error. 18 | func NewNativeErrorDiagnostic(err error) NativeErrorDiagnostic { 19 | return NativeErrorDiagnostic{ 20 | err: err, 21 | } 22 | } 23 | 24 | // Severity returns the diagnostic severity. 25 | func (d NativeErrorDiagnostic) Severity() Severity { 26 | return SeverityError 27 | } 28 | 29 | // Summary returns the diagnostic summary. 30 | func (d NativeErrorDiagnostic) Summary() string { 31 | return d.err.Error() 32 | } 33 | 34 | // Detail returns the diagnostic detail. 35 | func (d NativeErrorDiagnostic) Detail() string { 36 | return "" 37 | } 38 | 39 | func (d NativeErrorDiagnostic) Err() error { 40 | return d.err 41 | } 42 | 43 | // Equal returns true if the other diagnostic is wholly equivalent. 44 | func (d NativeErrorDiagnostic) Equal(other Diagnostic) bool { 45 | ed, ok := other.(NativeErrorDiagnostic) 46 | 47 | if !ok { 48 | return false 49 | } 50 | 51 | return ed.Summary() == d.Summary() && ed.Detail() == d.Detail() 52 | } 53 | 54 | func (d NativeErrorDiagnostic) GoString() string { 55 | return fmt.Sprintf("NativeErrorDiagnostic: err: %s", d.err.Error()) 56 | } 57 | -------------------------------------------------------------------------------- /diag/severity.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package diag 5 | 6 | // Severity represents the level of feedback for a diagnostic. 7 | // 8 | // Each severity implies behavior changes for the feedback and potentially the 9 | // further execution of logic. 10 | type Severity int 11 | 12 | const ( 13 | // SeverityInvalid represents an undefined severity. 14 | // 15 | // It should not be used directly in implementations. 16 | SeverityInvalid Severity = 0 17 | 18 | // SeverityError represents a terminating condition. 19 | // 20 | // This can cause a failing status code for command line programs. 21 | // 22 | // Most implementations should return early when encountering an error. 23 | SeverityError Severity = 1 24 | 25 | // SeverityWarning represents a condition with explicit feedback. 26 | // 27 | // Most implementations should continue when encountering a warning. 28 | SeverityWarning Severity = 2 29 | ) 30 | 31 | // String returns a textual representation of the severity. 32 | func (s Severity) String() string { 33 | switch s { 34 | case SeverityError: 35 | return "Error" 36 | case SeverityWarning: 37 | return "Warning" 38 | default: 39 | return "Invalid" 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /diag/warning_diagnostic.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package diag 5 | 6 | var _ Diagnostic = WarningDiagnostic{} 7 | 8 | // WarningDiagnostic is a generic diagnostic with warning severity. 9 | type WarningDiagnostic struct { 10 | detail string 11 | summary string 12 | } 13 | 14 | // NewWarningDiagnostic returns a new warning severity diagnostic with the given summary and detail. 15 | func NewWarningDiagnostic(summary string, detail string) WarningDiagnostic { 16 | return WarningDiagnostic{ 17 | detail: detail, 18 | summary: summary, 19 | } 20 | } 21 | 22 | // Severity returns the diagnostic severity. 23 | func (d WarningDiagnostic) Severity() Severity { 24 | return SeverityWarning 25 | } 26 | 27 | // Summary returns the diagnostic summary. 28 | func (d WarningDiagnostic) Summary() string { 29 | return d.summary 30 | } 31 | 32 | // Detail returns the diagnostic detail. 33 | func (d WarningDiagnostic) Detail() string { 34 | return d.detail 35 | } 36 | 37 | // Equal returns true if the other diagnostic is wholly equivalent. 38 | func (d WarningDiagnostic) Equal(other Diagnostic) bool { 39 | wd, ok := other.(WarningDiagnostic) 40 | 41 | if !ok { 42 | return false 43 | } 44 | 45 | return wd.Summary() == d.Summary() && wd.Detail() == d.Detail() 46 | } 47 | -------------------------------------------------------------------------------- /diag/warning_diagnostic_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package diag_test 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/hashicorp/aws-sdk-go-base/v2/diag" 10 | ) 11 | 12 | func TestWarningDiagnosticEqual(t *testing.T) { 13 | t.Parallel() 14 | 15 | testCases := map[string]struct { 16 | diag diag.WarningDiagnostic 17 | other diag.Diagnostic 18 | expected bool 19 | }{ 20 | "matching": { 21 | diag: diag.NewWarningDiagnostic("test summary", "test detail"), 22 | other: diag.NewWarningDiagnostic("test summary", "test detail"), 23 | expected: true, 24 | }, 25 | "nil": { 26 | diag: diag.NewWarningDiagnostic("test summary", "test detail"), 27 | other: nil, 28 | expected: false, 29 | }, 30 | "different-detail": { 31 | diag: diag.NewWarningDiagnostic("test summary", "test detail"), 32 | other: diag.NewWarningDiagnostic("test summary", "different detail"), 33 | expected: false, 34 | }, 35 | "different-summary": { 36 | diag: diag.NewWarningDiagnostic("test summary", "test detail"), 37 | other: diag.NewWarningDiagnostic("different summary", "test detail"), 38 | expected: false, 39 | }, 40 | "different-type": { 41 | diag: diag.NewWarningDiagnostic("test summary", "test detail"), 42 | other: diag.NewErrorDiagnostic("test summary", "test detail"), 43 | expected: false, 44 | }, 45 | } 46 | 47 | for name, tc := range testCases { 48 | t.Run(name, func(t *testing.T) { 49 | t.Parallel() 50 | 51 | got := tc.diag.Equal(tc.other) 52 | 53 | if got != tc.expected { 54 | t.Errorf("Unexpected response: got: %t, wanted: %t", got, tc.expected) 55 | } 56 | }) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /endpoints.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsbase 5 | 6 | import ( 7 | "context" 8 | 9 | "github.com/aws/aws-sdk-go-v2/aws" 10 | "github.com/aws/aws-sdk-go-v2/service/iam" 11 | "github.com/aws/aws-sdk-go-v2/service/sso" 12 | "github.com/aws/aws-sdk-go-v2/service/sts" 13 | "github.com/hashicorp/aws-sdk-go-base/v2/logging" 14 | ) 15 | 16 | // This endpoint resolver is needed when authenticating because the AWS SDK makes internal 17 | // calls to STS. The resolver should not be attached to the aws.Config returned to the 18 | // client, since it should configure its own overrides 19 | func credentialsEndpointResolver(ctx context.Context, c *Config) aws.EndpointResolverWithOptions { 20 | logger := logging.RetrieveLogger(ctx) 21 | 22 | resolver := func(service, region string, options ...any) (aws.Endpoint, error) { 23 | switch service { 24 | case iam.ServiceID: 25 | if endpoint := c.IamEndpoint; endpoint != "" { 26 | logger.Info(ctx, "Credentials resolution: setting custom IAM endpoint", map[string]any{ 27 | "tf_aws.iam_client.endpoint": endpoint, 28 | }) 29 | return aws.Endpoint{ 30 | URL: endpoint, 31 | Source: aws.EndpointSourceCustom, 32 | SigningRegion: region, 33 | }, nil 34 | } 35 | case sso.ServiceID: 36 | if endpoint := c.SsoEndpoint; endpoint != "" { 37 | logger.Info(ctx, "Credentials resolution: setting custom SSO endpoint", map[string]any{ 38 | "tf_aws.sso_client.endpoint": endpoint, 39 | }) 40 | return aws.Endpoint{ 41 | URL: endpoint, 42 | Source: aws.EndpointSourceCustom, 43 | SigningRegion: region, 44 | }, nil 45 | } 46 | case sts.ServiceID: 47 | if endpoint := c.StsEndpoint; endpoint != "" { 48 | fields := map[string]any{ 49 | "tf_aws.sts_client.endpoint": endpoint, 50 | } 51 | if c.StsRegion != "" { 52 | fields["tf_aws.sts_client.signing_region"] = c.StsRegion 53 | region = c.StsRegion 54 | } 55 | logger.Info(ctx, "Credentials resolution: setting custom STS endpoint", fields) 56 | return aws.Endpoint{ 57 | URL: endpoint, 58 | Source: aws.EndpointSourceCustom, 59 | SigningRegion: region, 60 | }, nil 61 | } 62 | } 63 | 64 | return aws.Endpoint{}, &aws.EndpointNotFoundError{} 65 | } 66 | 67 | return aws.EndpointResolverWithOptionsFunc(resolver) 68 | } 69 | -------------------------------------------------------------------------------- /endpoints/endpoints.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package endpoints 5 | 6 | const ( 7 | AwsGlobalRegionID = "aws-global" // AWS Standard global region. 8 | ) 9 | -------------------------------------------------------------------------------- /endpoints/generate.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | //go:generate go run ../internal/generate/endpoints/main.go -- https://raw.githubusercontent.com/aws/aws-sdk-go-v2/main/codegen/smithy-aws-go-codegen/src/main/resources/software/amazon/smithy/aws/go/codegen/endpoints.json 5 | 6 | package endpoints 7 | -------------------------------------------------------------------------------- /endpoints/partition.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package endpoints 5 | 6 | import ( 7 | "maps" 8 | "regexp" 9 | ) 10 | 11 | // Partition represents an AWS partition. 12 | // See https://docs.aws.amazon.com/whitepapers/latest/aws-fault-isolation-boundaries/partitions.html. 13 | type Partition struct { 14 | id string 15 | name string 16 | dnsSuffix string 17 | regionRegex *regexp.Regexp 18 | regions map[string]Region 19 | services map[string]Service 20 | } 21 | 22 | // ID returns the identifier of the partition. 23 | func (p Partition) ID() string { 24 | return p.id 25 | } 26 | 27 | // Name returns the name of the partition. 28 | func (p Partition) Name() string { 29 | return p.name 30 | } 31 | 32 | // DNSSuffix returns the base domain name of the partition. 33 | func (p Partition) DNSSuffix() string { 34 | return p.dnsSuffix 35 | } 36 | 37 | // RegionRegex return the regular expression that matches Region IDs for the partition. 38 | func (p Partition) RegionRegex() *regexp.Regexp { 39 | return p.regionRegex 40 | } 41 | 42 | // Regions returns a map of Regions for the partition, indexed by their ID. 43 | func (p Partition) Regions() map[string]Region { 44 | return maps.Clone(p.regions) 45 | } 46 | 47 | // Services returns a map of service endpoints for the partition, indexed by their ID. 48 | func (p Partition) Services() map[string]Service { 49 | return maps.Clone(p.services) 50 | } 51 | 52 | // DefaultPartitions returns a list of the partitions. 53 | func DefaultPartitions() []Partition { 54 | ps := make([]Partition, 0, len(partitions)) 55 | 56 | for _, p := range partitions { 57 | ps = append(ps, p) 58 | } 59 | 60 | return ps 61 | } 62 | 63 | // PartitionForRegion returns the first partition which includes the specific Region. 64 | func PartitionForRegion(ps []Partition, regionID string) (Partition, bool) { 65 | for _, p := range ps { 66 | if _, ok := p.regions[regionID]; ok || p.regionRegex.MatchString(regionID) { 67 | return p, true 68 | } 69 | } 70 | 71 | return Partition{}, false 72 | } 73 | -------------------------------------------------------------------------------- /endpoints/partition_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package endpoints_test 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/hashicorp/aws-sdk-go-base/v2/endpoints" 10 | ) 11 | 12 | func TestDefaultPartitions(t *testing.T) { 13 | t.Parallel() 14 | 15 | got := endpoints.DefaultPartitions() 16 | if len(got) == 0 { 17 | t.Fatalf("expected partitions, got none") 18 | } 19 | } 20 | 21 | func TestPartitionForRegion(t *testing.T) { 22 | t.Parallel() 23 | 24 | testcases := map[string]struct { 25 | expectedFound bool 26 | expectedID string 27 | }{ 28 | "us-east-1": { 29 | expectedFound: true, 30 | expectedID: "aws", 31 | }, 32 | "us-gov-west-1": { 33 | expectedFound: true, 34 | expectedID: "aws-us-gov", 35 | }, 36 | "not-found": { 37 | expectedFound: false, 38 | }, 39 | "us-east-17": { 40 | expectedFound: true, 41 | expectedID: "aws", 42 | }, 43 | } 44 | 45 | ps := endpoints.DefaultPartitions() 46 | for region, testcase := range testcases { 47 | gotID, gotFound := endpoints.PartitionForRegion(ps, region) 48 | 49 | if gotFound != testcase.expectedFound { 50 | t.Errorf("expected PartitionFound %t for Region %q, got %t", testcase.expectedFound, region, gotFound) 51 | } 52 | if gotID.ID() != testcase.expectedID { 53 | t.Errorf("expected PartitionID %q for Region %q, got %q", testcase.expectedID, region, gotID.ID()) 54 | } 55 | } 56 | } 57 | 58 | func TestPartitionRegions(t *testing.T) { 59 | t.Parallel() 60 | 61 | testcases := map[string]struct { 62 | expectedRegions bool 63 | }{ 64 | "us-east-1": { 65 | expectedRegions: true, 66 | }, 67 | "us-gov-west-1": { 68 | expectedRegions: true, 69 | }, 70 | "not-found": { 71 | expectedRegions: false, 72 | }, 73 | } 74 | 75 | ps := endpoints.DefaultPartitions() 76 | for region, testcase := range testcases { 77 | gotID, _ := endpoints.PartitionForRegion(ps, region) 78 | 79 | if got, want := len(gotID.Regions()) > 0, testcase.expectedRegions; got != want { 80 | t.Errorf("expected Regions %t for Region %q, got %t", want, region, got) 81 | } 82 | } 83 | } 84 | 85 | func TestPartitionServices(t *testing.T) { 86 | t.Parallel() 87 | 88 | testcases := map[string]struct { 89 | expectedServices bool 90 | }{ 91 | "us-east-1": { 92 | expectedServices: true, 93 | }, 94 | "us-gov-west-1": { 95 | expectedServices: true, 96 | }, 97 | "not-found": { 98 | expectedServices: false, 99 | }, 100 | } 101 | 102 | ps := endpoints.DefaultPartitions() 103 | for region, testcase := range testcases { 104 | gotID, _ := endpoints.PartitionForRegion(ps, region) 105 | 106 | if got, want := len(gotID.Services()) > 0, testcase.expectedServices; got != want { 107 | t.Errorf("expected services %t for Region %q, got %t", want, region, got) 108 | } 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /endpoints/region.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package endpoints 5 | 6 | // Region represents an AWS Region. 7 | // See https://docs.aws.amazon.com/whitepapers/latest/aws-fault-isolation-boundaries/regions.html. 8 | type Region struct { 9 | id string 10 | description string 11 | } 12 | 13 | // ID returns the Region's identifier. 14 | func (r Region) ID() string { 15 | return r.id 16 | } 17 | 18 | // Description returns the Region's description. 19 | func (r Region) Description() string { 20 | return r.description 21 | } 22 | -------------------------------------------------------------------------------- /endpoints/service.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package endpoints 5 | 6 | // Service represents an AWS service endpoint. 7 | type Service struct { 8 | id string 9 | } 10 | 11 | // ID returns the service endpoint's identifier. 12 | func (s Service) ID() string { 13 | return s.id 14 | } 15 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsbase 5 | 6 | import ( 7 | "fmt" 8 | "slices" 9 | 10 | "github.com/hashicorp/aws-sdk-go-base/v2/diag" 11 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" 12 | ) 13 | 14 | // cannotAssumeRoleError occurs when AssumeRole cannot complete. 15 | type cannotAssumeRoleError struct { 16 | ar config.AssumeRole 17 | err error 18 | } 19 | 20 | func (e cannotAssumeRoleError) Severity() diag.Severity { 21 | return diag.SeverityError 22 | } 23 | 24 | func (e cannotAssumeRoleError) Summary() string { 25 | return "Cannot assume IAM Role" 26 | } 27 | 28 | func (e cannotAssumeRoleError) Detail() string { 29 | return fmt.Sprintf(`IAM Role (%s) cannot be assumed. 30 | 31 | There are a number of possible causes of this - the most common are: 32 | * The credentials used in order to assume the role are invalid 33 | * The credentials do not have appropriate permission to assume the role 34 | * The role ARN is not valid 35 | 36 | Error: %s 37 | `, e.ar.RoleARN, e.err) 38 | } 39 | 40 | func (e cannotAssumeRoleError) Equal(other diag.Diagnostic) bool { 41 | ed, ok := other.(cannotAssumeRoleError) 42 | if !ok { 43 | return false 44 | } 45 | 46 | return ed.Summary() == e.Summary() && ed.Detail() == e.Detail() 47 | } 48 | 49 | func (e cannotAssumeRoleError) Err() error { 50 | return e.err 51 | } 52 | 53 | func newCannotAssumeRoleError(ar AssumeRole, err error) cannotAssumeRoleError { 54 | return cannotAssumeRoleError{ 55 | ar: ar, 56 | err: err, 57 | } 58 | } 59 | 60 | var _ diag.DiagnosticWithErr = cannotAssumeRoleError{} 61 | 62 | // IsCannotAssumeRoleError returns true if the error contains the CannotAssumeRoleError type. 63 | func IsCannotAssumeRoleError(diag diag.Diagnostic) bool { 64 | _, ok := diag.(cannotAssumeRoleError) 65 | return ok 66 | } 67 | 68 | // NoValidCredentialSourcesError occurs when all credential lookup methods have been exhausted without results. 69 | type NoValidCredentialSourcesError = config.NoValidCredentialSourcesError 70 | 71 | // IsNoValidCredentialSourcesError returns true if the diagnostic is a NoValidCredentialSourcesError. 72 | func IsNoValidCredentialSourcesError(diag diag.Diagnostic) bool { 73 | _, ok := diag.(NoValidCredentialSourcesError) 74 | return ok 75 | } 76 | 77 | // ContainsNoValidCredentialSourcesError returns true if the diagnostics contains a NoValidCredentialSourcesError type. 78 | func ContainsNoValidCredentialSourcesError(diags diag.Diagnostics) bool { 79 | return slices.ContainsFunc(diags, IsNoValidCredentialSourcesError) 80 | } 81 | -------------------------------------------------------------------------------- /errors_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsbase 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/hashicorp/aws-sdk-go-base/v2/diag" 10 | ) 11 | 12 | func TestIsCannotAssumeRoleError(t *testing.T) { 13 | testCases := []struct { 14 | Name string 15 | Diag diag.Diagnostic 16 | Expected bool 17 | }{ 18 | { 19 | Name: "nil error", 20 | }, 21 | { 22 | Name: "Top-level NoValidCredentialSourcesError", 23 | Diag: NoValidCredentialSourcesError{}, 24 | }, 25 | { 26 | Name: "Top-level CannotAssumeRoleError", 27 | Diag: cannotAssumeRoleError{}, 28 | Expected: true, 29 | }, 30 | } 31 | 32 | for _, testCase := range testCases { 33 | t.Run(testCase.Name, func(t *testing.T) { 34 | got := IsCannotAssumeRoleError(testCase.Diag) 35 | 36 | if got != testCase.Expected { 37 | t.Errorf("got %t, expected %t", got, testCase.Expected) 38 | } 39 | }) 40 | } 41 | } 42 | 43 | func TestIsNoValidCredentialSourcesError(t *testing.T) { 44 | testCases := []struct { 45 | Name string 46 | Diag diag.Diagnostic 47 | Expected bool 48 | }{ 49 | { 50 | Name: "nil error", 51 | }, 52 | { 53 | Name: "Top-level CannotAssumeRoleError", 54 | Diag: cannotAssumeRoleError{}, 55 | }, 56 | { 57 | Name: "Top-level NoValidCredentialSourcesError", 58 | Diag: NoValidCredentialSourcesError{}, 59 | Expected: true, 60 | }, 61 | } 62 | 63 | for _, testCase := range testCases { 64 | t.Run(testCase.Name, func(t *testing.T) { 65 | got := IsNoValidCredentialSourcesError(testCase.Diag) 66 | 67 | if got != testCase.Expected { 68 | t.Errorf("got %t, expected %t", got, testCase.Expected) 69 | } 70 | }) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/aws-sdk-go-base/v2 2 | 3 | go 1.23.6 4 | 5 | require ( 6 | github.com/aws/aws-sdk-go-v2 v1.36.3 7 | github.com/aws/aws-sdk-go-v2/config v1.29.14 8 | github.com/aws/aws-sdk-go-v2/credentials v1.17.67 9 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 10 | github.com/aws/aws-sdk-go-v2/service/dynamodb v1.43.1 11 | github.com/aws/aws-sdk-go-v2/service/iam v1.42.0 12 | github.com/aws/aws-sdk-go-v2/service/s3 v1.79.4 13 | github.com/aws/aws-sdk-go-v2/service/sqs v1.38.5 14 | github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 15 | github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 16 | github.com/aws/smithy-go v1.22.3 17 | github.com/google/go-cmp v0.7.0 18 | github.com/hashicorp/go-hclog v1.6.3 19 | github.com/hashicorp/go-multierror v1.1.1 20 | github.com/hashicorp/terraform-plugin-log v0.9.0 21 | github.com/mitchellh/go-homedir v1.1.0 22 | go.opentelemetry.io/contrib/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws v0.61.0 23 | go.opentelemetry.io/otel v1.36.0 24 | golang.org/x/net v0.40.0 25 | golang.org/x/text v0.25.0 26 | ) 27 | 28 | require ( 29 | github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect 30 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect 31 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect 32 | github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect 33 | github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 // indirect 34 | github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect 35 | github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.2 // indirect 36 | github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.15 // indirect 37 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect 38 | github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 // indirect 39 | github.com/aws/aws-sdk-go-v2/service/sns v1.34.4 // indirect 40 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect 41 | github.com/fatih/color v1.18.0 // indirect 42 | github.com/go-logr/logr v1.4.2 // indirect 43 | github.com/go-logr/stdr v1.2.2 // indirect 44 | github.com/hashicorp/errwrap v1.1.0 // indirect 45 | github.com/mattn/go-colorable v0.1.14 // indirect 46 | github.com/mattn/go-isatty v0.0.20 // indirect 47 | github.com/mitchellh/go-testing-interface v1.14.1 // indirect 48 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect 49 | go.opentelemetry.io/otel/metric v1.36.0 // indirect 50 | go.opentelemetry.io/otel/trace v1.36.0 // indirect 51 | golang.org/x/sys v0.33.0 // indirect 52 | ) 53 | -------------------------------------------------------------------------------- /go.work: -------------------------------------------------------------------------------- 1 | go 1.23.6 2 | 3 | use ( 4 | . 5 | ./tools 6 | ./v2/awsv1shim 7 | ) 8 | -------------------------------------------------------------------------------- /http_client.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsbase 5 | 6 | import ( 7 | awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" 8 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" 9 | ) 10 | 11 | func defaultHttpClient(c *config.Config) (*awshttp.BuildableClient, error) { 12 | opts, err := c.HTTPTransportOptions() 13 | if err != nil { 14 | return nil, err 15 | } 16 | 17 | httpClient := awshttp.NewBuildableClient().WithTransportOptions(opts) 18 | 19 | return httpClient, err 20 | } 21 | -------------------------------------------------------------------------------- /http_client_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsbase 5 | 6 | import ( 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" 11 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/test" 12 | ) 13 | 14 | func TestHTTPClientConfiguration_basic(t *testing.T) { 15 | client, err := defaultHttpClient(&config.Config{}) 16 | if err != nil { 17 | t.Fatalf("unexpected error: %s", err) 18 | } 19 | 20 | transport := client.GetTransport() 21 | 22 | test.HTTPClientConfigurationTest_basic(t, transport) 23 | } 24 | 25 | func TestHTTPClientConfiguration_insecureHTTPS(t *testing.T) { 26 | client, err := defaultHttpClient(&config.Config{ 27 | Insecure: true, 28 | }) 29 | if err != nil { 30 | t.Fatalf("unexpected error: %s", err) 31 | } 32 | 33 | transport := client.GetTransport() 34 | 35 | test.HTTPClientConfigurationTest_insecureHTTPS(t, transport) 36 | } 37 | 38 | func TestHTTPClientConfiguration_proxy(t *testing.T) { 39 | test.HTTPClientConfigurationTest_proxy(t, transport) 40 | } 41 | 42 | func transport(t *testing.T, config *config.Config) *http.Transport { 43 | t.Helper() 44 | 45 | client, err := defaultHttpClient(config) 46 | if err != nil { 47 | t.Fatalf("creating client: %s", err) 48 | } 49 | 50 | return client.GetTransport() 51 | } 52 | -------------------------------------------------------------------------------- /internal/awsconfig/resolvers.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsconfig 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | 10 | "github.com/aws/aws-sdk-go-v2/aws" 11 | "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" 12 | ) 13 | 14 | // Copied from https://github.com/aws/aws-sdk-go-v2/blob/main/internal/configsources/config.go 15 | type UseFIPSEndpointProvider interface { 16 | GetUseFIPSEndpoint(context.Context) (value aws.FIPSEndpointState, found bool, err error) 17 | } 18 | 19 | // Copied from https://github.com/aws/aws-sdk-go-v2/blob/main/internal/configsources/config.go 20 | func ResolveUseFIPSEndpoint(ctx context.Context, configSources []any) (value aws.FIPSEndpointState, found bool, err error) { 21 | for _, cfg := range configSources { 22 | if p, ok := cfg.(UseFIPSEndpointProvider); ok { 23 | value, found, err = p.GetUseFIPSEndpoint(ctx) 24 | if err != nil || found { 25 | break 26 | } 27 | } 28 | } 29 | return 30 | } 31 | 32 | func FIPSEndpointStateString(state aws.FIPSEndpointState) string { 33 | switch state { 34 | case aws.FIPSEndpointStateUnset: 35 | return "FIPSEndpointStateUnset" 36 | case aws.FIPSEndpointStateEnabled: 37 | return "FIPSEndpointStateEnabled" 38 | case aws.FIPSEndpointStateDisabled: 39 | return "FIPSEndpointStateDisabled" 40 | } 41 | return fmt.Sprintf("unknown aws.FIPSEndpointState (%d)", state) 42 | } 43 | 44 | // Copied from https://github.com/aws/aws-sdk-go-v2/blob/main/internal/configsources/config.go 45 | type UseDualStackEndpointProvider interface { 46 | GetUseDualStackEndpoint(context.Context) (value aws.DualStackEndpointState, found bool, err error) 47 | } 48 | 49 | // Copied from https://github.com/aws/aws-sdk-go-v2/blob/main/internal/configsources/config.go 50 | func ResolveUseDualStackEndpoint(ctx context.Context, configSources []any) (value aws.DualStackEndpointState, found bool, err error) { 51 | for _, cfg := range configSources { 52 | if p, ok := cfg.(UseDualStackEndpointProvider); ok { 53 | value, found, err = p.GetUseDualStackEndpoint(ctx) 54 | if err != nil || found { 55 | break 56 | } 57 | } 58 | } 59 | return 60 | } 61 | 62 | func DualStackEndpointStateString(state aws.DualStackEndpointState) string { 63 | switch state { 64 | case aws.DualStackEndpointStateUnset: 65 | return "DualStackEndpointStateUnset" 66 | case aws.DualStackEndpointStateEnabled: 67 | return "DualStackEndpointStateEnabled" 68 | case aws.DualStackEndpointStateDisabled: 69 | return "DualStackEndpointStateDisabled" 70 | } 71 | return fmt.Sprintf("unknown aws.FIPSEndpointStateUnset (%d)", state) 72 | } 73 | 74 | // Copied and renamed from https://github.com/aws/aws-sdk-go-v2/blob/main/feature/ec2/imds/internal/config/resolvers.go 75 | type EC2IMDSClientEnableStateResolver interface { 76 | GetEC2IMDSClientEnableState() (imds.ClientEnableState, bool, error) 77 | } 78 | 79 | // Copied and renamed from https://github.com/aws/aws-sdk-go-v2/blob/main/feature/ec2/imds/internal/config/resolvers.go 80 | func ResolveEC2IMDSClientEnableState(sources []any) (value imds.ClientEnableState, found bool, err error) { 81 | for _, source := range sources { 82 | if resolver, ok := source.(EC2IMDSClientEnableStateResolver); ok { 83 | value, found, err = resolver.GetEC2IMDSClientEnableState() 84 | if err != nil || found { 85 | return value, found, err 86 | } 87 | } 88 | } 89 | return value, found, err 90 | } 91 | 92 | func EC2IMDSClientEnableStateString(state imds.ClientEnableState) string { 93 | switch state { 94 | case imds.ClientDefaultEnableState: 95 | return "ClientDefaultEnableState" 96 | case imds.ClientDisabled: 97 | return "ClientDisabled" 98 | case imds.ClientEnabled: 99 | return "ClientEnabled" 100 | } 101 | return fmt.Sprintf("unknown imds.ClientEnableState (%d)", state) 102 | } 103 | 104 | // Copied and renamed from https://github.com/aws/aws-sdk-go-v2/blob/main/feature/ec2/imds/internal/config/resolvers.go 105 | type EC2IMDSEndpointResolver interface { 106 | GetEC2IMDSEndpoint() (value string, found bool, err error) 107 | } 108 | 109 | // Copied and renamed from https://github.com/aws/aws-sdk-go-v2/blob/main/feature/ec2/imds/internal/config/resolvers.go 110 | func ResolveEC2IMDSEndpointConfig(configSources []any) (value string, found bool, err error) { 111 | for _, cfg := range configSources { 112 | if p, ok := cfg.(EC2IMDSEndpointResolver); ok { 113 | value, found, err = p.GetEC2IMDSEndpoint() 114 | if err != nil || found { 115 | break 116 | } 117 | } 118 | } 119 | return 120 | } 121 | 122 | // Copied and renamed from https://github.com/aws/aws-sdk-go-v2/blob/main/feature/ec2/imds/internal/config/resolvers.go 123 | type EC2IMDSEndpointModeResolver interface { 124 | GetEC2IMDSEndpointMode() (imds.EndpointModeState, bool, error) 125 | } 126 | 127 | // Copied and renamed from https://github.com/aws/aws-sdk-go-v2/blob/main/feature/ec2/imds/internal/config/resolvers.go 128 | func ResolveEC2IMDSEndpointModeConfig(sources []any) (value imds.EndpointModeState, found bool, err error) { 129 | for _, source := range sources { 130 | if resolver, ok := source.(EC2IMDSEndpointModeResolver); ok { 131 | value, found, err = resolver.GetEC2IMDSEndpointMode() 132 | if err != nil || found { 133 | return value, found, err 134 | } 135 | } 136 | } 137 | return value, found, err 138 | } 139 | 140 | func EC2IMDSEndpointModeString(state imds.EndpointModeState) string { 141 | switch state { 142 | case imds.EndpointModeStateUnset: 143 | return "EndpointModeStateUnset" 144 | case imds.EndpointModeStateIPv4: 145 | return "EndpointModeStateIPv4" 146 | case imds.EndpointModeStateIPv6: 147 | return "EndpointModeStateIPv6" 148 | } 149 | return fmt.Sprintf("unknown imds.EndpointModeState (%d)", state) 150 | } 151 | 152 | // Copied and renamed from https://github.com/aws/aws-sdk-go-v2/blob/main/config/provider.go 153 | type RetryMaxAttemptsProvider interface { 154 | GetRetryMaxAttempts(context.Context) (int, bool, error) 155 | } 156 | 157 | // Copied and renamed from https://github.com/aws/aws-sdk-go-v2/blob/main/config/provider.go 158 | func GetRetryMaxAttempts(ctx context.Context, sources []any) (v int, found bool, err error) { 159 | for _, c := range sources { 160 | if p, ok := c.(RetryMaxAttemptsProvider); ok { 161 | v, found, err = p.GetRetryMaxAttempts(ctx) 162 | if err != nil || found { 163 | break 164 | } 165 | } 166 | } 167 | return v, found, err 168 | } 169 | -------------------------------------------------------------------------------- /internal/config/apn_info.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package config 5 | 6 | import ( 7 | smithyhttp "github.com/aws/smithy-go/transport/http" 8 | ) 9 | 10 | type APNInfo struct { 11 | PartnerName string 12 | Products []UserAgentProduct 13 | } 14 | 15 | // Builds the user-agent string for APN 16 | func (apn APNInfo) BuildUserAgentString() string { 17 | builder := smithyhttp.NewUserAgentBuilder() 18 | builder.AddKeyValue("APN", "1.0") 19 | builder.AddKeyValue(apn.PartnerName, "1.0") 20 | for _, p := range apn.Products { 21 | p.buildUserAgentPart(builder) 22 | } 23 | return builder.Build() 24 | } 25 | -------------------------------------------------------------------------------- /internal/config/config_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package config 5 | 6 | import ( 7 | "fmt" 8 | "net/url" 9 | "testing" 10 | 11 | "github.com/aws/aws-sdk-go-v2/aws" 12 | "github.com/google/go-cmp/cmp" 13 | "github.com/hashicorp/aws-sdk-go-base/v2/diag" 14 | "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" 15 | ) 16 | 17 | func TestConfig_VerifyAccountIDAllowed(t *testing.T) { 18 | tests := []struct { 19 | name string 20 | config Config 21 | accountID string 22 | wantErr bool 23 | }{ 24 | { 25 | "empty", 26 | Config{}, 27 | "1234", 28 | false, 29 | }, 30 | { 31 | "allowed", 32 | Config{ 33 | AllowedAccountIds: []string{"1234"}, 34 | }, 35 | "1234", 36 | false, 37 | }, 38 | { 39 | "not allowed", 40 | Config{ 41 | AllowedAccountIds: []string{"5678"}, 42 | }, 43 | "1234", 44 | true, 45 | }, 46 | { 47 | "forbidden", 48 | Config{ 49 | ForbiddenAccountIds: []string{"1234"}, 50 | }, 51 | "1234", 52 | true, 53 | }, 54 | { 55 | "not forbidden", 56 | Config{ 57 | ForbiddenAccountIds: []string{"5678"}, 58 | }, 59 | "1234", 60 | false, 61 | }, 62 | { 63 | // In practice the upstream interfaces (AWS Provider, S3 Backend, etc.) should make 64 | // these conflict, but documenting the behavior for completeness. 65 | "allowed and forbidden", 66 | Config{ 67 | AllowedAccountIds: []string{"1234"}, 68 | ForbiddenAccountIds: []string{"1234"}, 69 | }, 70 | "1234", 71 | true, 72 | }, 73 | } 74 | for _, tt := range tests { 75 | t.Run(tt.name, func(t *testing.T) { 76 | if err := tt.config.VerifyAccountIDAllowed(tt.accountID); (err != nil) != tt.wantErr { 77 | t.Errorf("Config.VerifyAccountIDAllowed() error = %v, wantErr %v", err, tt.wantErr) 78 | } 79 | }) 80 | } 81 | } 82 | 83 | func foo(_ *url.URL, err error) error { 84 | return err 85 | } 86 | 87 | func TestValidateProxyConfig(t *testing.T) { 88 | testcases := map[string]struct { 89 | config Config 90 | environmentVariables map[string]string 91 | expectedDiags diag.Diagnostics 92 | }{ 93 | "no config": {}, 94 | 95 | "invalid HTTP proxy": { 96 | config: Config{ 97 | HTTPProxy: aws.String(" http://invalid.test"), // explicit URL parse failure 98 | HTTPSProxy: aws.String("http://valid.test"), 99 | }, 100 | expectedDiags: diag.Diagnostics{ 101 | diag.NewErrorDiagnostic( 102 | "Invalid HTTP Proxy", 103 | fmt.Sprintf("Unable to parse URL: %s", foo(url.Parse(" http://invalid.test"))), //nolint:staticcheck 104 | ), 105 | }, 106 | }, 107 | 108 | "invalid HTTPS proxy": { 109 | config: Config{ 110 | HTTPProxy: aws.String("http://valid.test"), 111 | HTTPSProxy: aws.String(" http://invalid.test"), // explicit URL parse failure 112 | }, 113 | expectedDiags: diag.Diagnostics{ 114 | diag.NewErrorDiagnostic( 115 | "Invalid HTTPS Proxy", 116 | fmt.Sprintf("Unable to parse URL: %s", foo(url.Parse(" http://invalid.test"))), //nolint:staticcheck 117 | ), 118 | }, 119 | }, 120 | 121 | "invalid both proxies": { 122 | config: Config{ 123 | HTTPProxy: aws.String(" http://invalid.test"), // explicit URL parse failure 124 | HTTPSProxy: aws.String(" http://invalid.test"), // explicit URL parse failure 125 | }, 126 | expectedDiags: diag.Diagnostics{ 127 | diag.NewErrorDiagnostic( 128 | "Invalid HTTP Proxy", 129 | fmt.Sprintf("Unable to parse URL: %s", foo(url.Parse(" http://invalid.test"))), //nolint:staticcheck 130 | ), 131 | diag.NewErrorDiagnostic( 132 | "Invalid HTTPS Proxy", 133 | fmt.Sprintf("Unable to parse URL: %s", foo(url.Parse(" http://invalid.test"))), //nolint:staticcheck 134 | ), 135 | }, 136 | }, 137 | 138 | "HTTP proxy without HTTPS proxy Legacy": { 139 | config: Config{ 140 | HTTPProxy: aws.String("http://valid.test"), 141 | HTTPProxyMode: HTTPProxyModeLegacy, 142 | }, 143 | expectedDiags: diag.Diagnostics{ 144 | diag.NewWarningDiagnostic( 145 | missingHttpsProxyWarningSummary, 146 | fmt.Sprintf( 147 | "An HTTP proxy was set but no HTTPS proxy was. Using HTTP proxy %q for HTTPS requests. This behavior may change in future versions.\n\n"+ 148 | "To specify no proxy for HTTPS, set the HTTPS to an empty string.", 149 | "http://valid.test"), 150 | ), 151 | }, 152 | }, 153 | 154 | "HTTP proxy empty string": { 155 | config: Config{ 156 | HTTPProxy: aws.String(""), 157 | }, 158 | expectedDiags: diag.Diagnostics{}, 159 | }, 160 | 161 | "HTTP proxy with HTTPS proxy empty string Legacy": { 162 | config: Config{ 163 | HTTPProxy: aws.String("http://valid.test"), 164 | HTTPSProxy: aws.String(""), 165 | HTTPProxyMode: HTTPProxyModeLegacy, 166 | }, 167 | expectedDiags: diag.Diagnostics{}, 168 | }, 169 | 170 | "HTTP proxy config with HTTPS_PROXY envvar": { 171 | config: Config{ 172 | HTTPProxy: aws.String("http://valid.test"), 173 | }, 174 | environmentVariables: map[string]string{ 175 | "HTTPS_PROXY": "http://envvar-proxy.test:1234", 176 | }, 177 | expectedDiags: diag.Diagnostics{}, 178 | }, 179 | 180 | "HTTP proxy config with https_proxy envvar": { 181 | config: Config{ 182 | HTTPProxy: aws.String("http://valid.test"), 183 | }, 184 | environmentVariables: map[string]string{ 185 | "https_proxy": "http://envvar-proxy.test:1234", 186 | }, 187 | expectedDiags: diag.Diagnostics{}, 188 | }, 189 | 190 | "HTTP proxy without HTTPS proxy Separate": { 191 | config: Config{ 192 | HTTPProxy: aws.String("http://valid.test"), 193 | HTTPProxyMode: HTTPProxyModeSeparate, 194 | }, 195 | expectedDiags: diag.Diagnostics{ 196 | diag.NewWarningDiagnostic( 197 | missingHttpsProxyWarningSummary, 198 | "An HTTP proxy was set but no HTTPS proxy was.\n\n"+ 199 | "To specify no proxy for HTTPS, set the HTTPS to an empty string.", 200 | ), 201 | }, 202 | }, 203 | } 204 | 205 | for name, testcase := range testcases { 206 | t.Run(name, func(t *testing.T) { 207 | servicemocks.InitSessionTestEnv(t) 208 | 209 | for k, v := range testcase.environmentVariables { 210 | t.Setenv(k, v) 211 | } 212 | 213 | var diags diag.Diagnostics 214 | 215 | testcase.config.ValidateProxySettings(&diags) 216 | 217 | if diff := cmp.Diff(diags, testcase.expectedDiags); diff != "" { 218 | t.Errorf("Unexpected response (+wanted, -got): %s", diff) 219 | } 220 | }) 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /internal/config/errors.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package config 5 | 6 | import ( 7 | "fmt" 8 | 9 | "github.com/hashicorp/aws-sdk-go-base/v2/diag" 10 | ) 11 | 12 | // CannotAssumeRoleWithWebIdentityError occurs when AssumeRoleWithWebIdentity cannot complete. 13 | type CannotAssumeRoleWithWebIdentityError struct { 14 | Config *Config 15 | err error 16 | } 17 | 18 | func (e CannotAssumeRoleWithWebIdentityError) Severity() diag.Severity { 19 | return diag.SeverityError 20 | } 21 | 22 | func (e CannotAssumeRoleWithWebIdentityError) Summary() string { 23 | return "Cannot assume IAM Role with web identity" 24 | } 25 | 26 | func (e CannotAssumeRoleWithWebIdentityError) Detail() string { 27 | if e.Config == nil || e.Config.AssumeRoleWithWebIdentity == nil { 28 | return fmt.Sprintf("cannot assume role with web identity: %s", e.err) 29 | } 30 | 31 | return fmt.Sprintf(`IAM Role (%s) cannot be assumed with web identity token. 32 | 33 | There are a number of possible causes of this - the most common are: 34 | * The web identity token used in order to assume the role is invalid 35 | * The web identity token does not have appropriate permission to assume the role 36 | * The role ARN is not valid 37 | 38 | Error: %s 39 | `, e.Config.AssumeRoleWithWebIdentity.RoleARN, e.err) 40 | } 41 | 42 | func (e CannotAssumeRoleWithWebIdentityError) Equal(other diag.Diagnostic) bool { 43 | ed, ok := other.(CannotAssumeRoleWithWebIdentityError) 44 | if !ok { 45 | return false 46 | } 47 | 48 | return ed.Summary() == e.Summary() && ed.Detail() == e.Detail() 49 | } 50 | 51 | func (e CannotAssumeRoleWithWebIdentityError) Err() error { 52 | return e.err 53 | } 54 | 55 | func (c *Config) NewCannotAssumeRoleWithWebIdentityError(err error) CannotAssumeRoleWithWebIdentityError { 56 | return CannotAssumeRoleWithWebIdentityError{ 57 | Config: c, 58 | err: err, 59 | } 60 | } 61 | 62 | var _ diag.DiagnosticWithErr = CannotAssumeRoleWithWebIdentityError{} 63 | 64 | // NoValidCredentialSourcesError occurs when all credential lookup methods have been exhausted without results. 65 | type NoValidCredentialSourcesError struct { 66 | Config *Config 67 | err error 68 | } 69 | 70 | func (e NoValidCredentialSourcesError) Severity() diag.Severity { 71 | return diag.SeverityError 72 | } 73 | 74 | func (e NoValidCredentialSourcesError) Summary() string { 75 | return "No valid credential sources found" 76 | } 77 | 78 | func (e NoValidCredentialSourcesError) Detail() string { 79 | if e.Config == nil { 80 | return e.err.Error() 81 | } 82 | 83 | return fmt.Sprintf(`Please see %[1]s 84 | for more information about providing credentials. 85 | 86 | Error: %[2]s 87 | `, e.Config.CallerDocumentationURL, e.err) 88 | } 89 | 90 | func (e NoValidCredentialSourcesError) Equal(other diag.Diagnostic) bool { 91 | ed, ok := other.(NoValidCredentialSourcesError) 92 | if !ok { 93 | return false 94 | } 95 | 96 | return ed.Summary() == e.Summary() && ed.Detail() == e.Detail() 97 | } 98 | 99 | func (e NoValidCredentialSourcesError) Err() error { 100 | return e.err 101 | } 102 | 103 | func (c *Config) NewNoValidCredentialSourcesError(err error) NoValidCredentialSourcesError { 104 | return NoValidCredentialSourcesError{ 105 | Config: c, 106 | err: err, 107 | } 108 | } 109 | 110 | var _ diag.DiagnosticWithErr = NoValidCredentialSourcesError{} 111 | -------------------------------------------------------------------------------- /internal/config/user_agent.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package config 5 | 6 | import ( 7 | smithyhttp "github.com/aws/smithy-go/transport/http" 8 | ) 9 | 10 | type UserAgentProduct struct { 11 | Name string 12 | Version string 13 | Comment string 14 | } 15 | 16 | type UserAgentProducts []UserAgentProduct 17 | 18 | func (ua UserAgentProducts) BuildUserAgentString() string { 19 | builder := smithyhttp.NewUserAgentBuilder() 20 | for _, p := range ua { 21 | p.buildUserAgentPart(builder) 22 | } 23 | return builder.Build() 24 | } 25 | 26 | func (p UserAgentProduct) buildUserAgentPart(b *smithyhttp.UserAgentBuilder) { 27 | if p.Name != "" { 28 | if p.Version != "" { 29 | b.AddKeyValue(p.Name, p.Version) 30 | } else { 31 | b.AddKey(p.Name) 32 | } 33 | } 34 | if p.Comment != "" { 35 | b.AddKey("(" + p.Comment + ")") 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /internal/constants/constants.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package constants 5 | 6 | const ( 7 | // AppendUserAgentEnvVar is a conventionally used environment variable 8 | // containing additional HTTP User-Agent information. 9 | // If present and its value is non-empty, it is directly appended to the 10 | // User-Agent header for HTTP requests. 11 | AppendUserAgentEnvVar = "TF_APPEND_USER_AGENT" 12 | 13 | // Maximum network retries. 14 | // We depend on the AWS Go SDK DefaultRetryer exponential backoff. 15 | // Ensure that if the AWS Config MaxRetries is set high (which it is by 16 | // default), that we only retry for a few seconds with typically 17 | // unrecoverable network errors, such as DNS lookup failures. 18 | MaxNetworkRetryCount = 9 19 | ) 20 | -------------------------------------------------------------------------------- /internal/errs/errs.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package errs 5 | 6 | import ( 7 | "errors" 8 | ) 9 | 10 | // IsA indicates whether an error matches an error type. 11 | func IsA[T error](err error) bool { 12 | _, ok := As[T](err) 13 | return ok 14 | } 15 | 16 | // As is equivalent to errors.As(), but returns the value in-line. 17 | func As[T error](err error) (T, bool) { 18 | var as T 19 | ok := errors.As(err, &as) 20 | return as, ok 21 | } 22 | -------------------------------------------------------------------------------- /internal/expand/filepath.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package expand 5 | 6 | import ( 7 | "os" 8 | 9 | "github.com/hashicorp/go-multierror" 10 | "github.com/mitchellh/go-homedir" 11 | ) 12 | 13 | func FilePaths(in []string) ([]string, error) { 14 | var errs *multierror.Error 15 | result := make([]string, 0, len(in)) 16 | for _, v := range in { 17 | p, err := FilePath(v) 18 | if err != nil { 19 | errs = multierror.Append(errs, err) 20 | continue 21 | } 22 | result = append(result, p) 23 | } 24 | return result, errs.ErrorOrNil() 25 | } 26 | 27 | func FilePath(in string) (s string, err error) { 28 | e := os.ExpandEnv(in) 29 | s, err = homedir.Expand(e) 30 | return 31 | } 32 | -------------------------------------------------------------------------------- /internal/expand/filepath_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package expand_test 5 | 6 | import ( 7 | "os" 8 | "testing" 9 | 10 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/expand" 11 | "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" 12 | ) 13 | 14 | func TestExpandFilePath(t *testing.T) { 15 | testcases := map[string]struct { 16 | path string 17 | expected string 18 | envvars map[string]string 19 | }{ 20 | "filename": { 21 | path: "file", 22 | expected: "file", 23 | }, 24 | "file in current dir": { 25 | path: "./file", 26 | expected: "./file", 27 | }, 28 | "file with tilde": { 29 | path: "~/file", 30 | expected: "/my/home/dir/file", 31 | envvars: map[string]string{ 32 | "HOME": "/my/home/dir", 33 | }, 34 | }, 35 | "file with envvar": { 36 | path: "$HOME/file", 37 | expected: "/home/dir/file", 38 | envvars: map[string]string{ 39 | "HOME": "/home/dir", 40 | }, 41 | }, 42 | "full file in envvar": { 43 | path: "$CONF_FILE", 44 | expected: "/path/to/conf/file", 45 | envvars: map[string]string{ 46 | "CONF_FILE": "/path/to/conf/file", 47 | }, 48 | }, 49 | } 50 | 51 | for name, testcase := range testcases { 52 | t.Run(name, func(t *testing.T) { 53 | servicemocks.StashEnv(t) 54 | 55 | for k, v := range testcase.envvars { 56 | os.Setenv(k, v) 57 | } 58 | 59 | a, err := expand.FilePath(testcase.path) 60 | if err != nil { 61 | t.Fatalf("unexpected error: %s", err) 62 | } 63 | 64 | if a != testcase.expected { 65 | t.Errorf("expected expansion to %q, got %q", testcase.expected, a) 66 | } 67 | }) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /internal/generate/common/generator.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package common 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "go/format" 10 | "io" 11 | "maps" 12 | "os" 13 | "path" 14 | "strings" 15 | "text/template" 16 | "unicode" 17 | "unicode/utf8" 18 | 19 | "golang.org/x/text/cases" 20 | "golang.org/x/text/language" 21 | ) 22 | 23 | type Generator struct{} 24 | 25 | func NewGenerator() *Generator { 26 | return &Generator{} 27 | } 28 | 29 | func (g *Generator) Infof(format string, a ...any) { 30 | g.output(os.Stdout, format, a...) 31 | } 32 | 33 | func (g *Generator) Warnf(format string, a ...any) { 34 | g.Errorf(format, a...) 35 | } 36 | 37 | func (g *Generator) Errorf(format string, a ...any) { 38 | g.output(os.Stderr, format, a...) 39 | } 40 | 41 | func (g *Generator) Fatalf(format string, a ...any) { 42 | g.Errorf(format, a...) 43 | os.Exit(1) 44 | } 45 | 46 | func (g *Generator) output(w io.Writer, format string, a ...any) { 47 | fmt.Fprintf(w, format, a...) 48 | fmt.Fprint(w, "\n") 49 | } 50 | 51 | type Destination interface { 52 | CreateDirectories() error 53 | Write() error 54 | WriteBytes(body []byte) error 55 | WriteTemplate(templateName, templateBody string, templateData any, funcMaps ...template.FuncMap) error 56 | WriteTemplateSet(templates *template.Template, templateData any) error 57 | } 58 | 59 | func (g *Generator) NewGoFileDestination(filename string) Destination { 60 | return &fileDestination{ 61 | baseDestination: baseDestination{formatter: format.Source}, 62 | filename: filename, 63 | } 64 | } 65 | 66 | func (g *Generator) NewUnformattedFileDestination(filename string) Destination { 67 | return &fileDestination{ 68 | filename: filename, 69 | } 70 | } 71 | 72 | type fileDestination struct { 73 | baseDestination 74 | append bool 75 | filename string 76 | } 77 | 78 | func (d *fileDestination) CreateDirectories() error { 79 | const ( 80 | perm os.FileMode = 0755 81 | ) 82 | dirname := path.Dir(d.filename) 83 | err := os.MkdirAll(dirname, perm) 84 | 85 | if err != nil { 86 | return fmt.Errorf("creating target directory %s: %w", dirname, err) 87 | } 88 | 89 | return nil 90 | } 91 | 92 | func (d *fileDestination) Write() error { 93 | var flags int 94 | if d.append { 95 | flags = os.O_APPEND | os.O_CREATE | os.O_WRONLY 96 | } else { 97 | flags = os.O_TRUNC | os.O_CREATE | os.O_WRONLY 98 | } 99 | f, err := os.OpenFile(d.filename, flags, 0644) //nolint:mnd // good protection for new files 100 | 101 | if err != nil { 102 | return fmt.Errorf("opening file (%s): %w", d.filename, err) 103 | } 104 | 105 | defer f.Close() 106 | 107 | _, err = f.WriteString(d.buffer.String()) 108 | 109 | if err != nil { 110 | return fmt.Errorf("writing to file (%s): %w", d.filename, err) 111 | } 112 | 113 | return nil 114 | } 115 | 116 | type baseDestination struct { 117 | formatter func([]byte) ([]byte, error) 118 | buffer strings.Builder 119 | } 120 | 121 | func (d *baseDestination) WriteBytes(body []byte) error { 122 | _, err := d.buffer.Write(body) 123 | return err 124 | } 125 | 126 | func (d *baseDestination) WriteTemplate(templateName, templateBody string, templateData any, funcMaps ...template.FuncMap) error { 127 | body, err := parseTemplate(templateName, templateBody, templateData, funcMaps...) 128 | 129 | if err != nil { 130 | return err 131 | } 132 | 133 | body, err = d.format(body) 134 | if err != nil { 135 | return err 136 | } 137 | 138 | return d.WriteBytes(body) 139 | } 140 | 141 | func parseTemplate(templateName, templateBody string, templateData any, funcMaps ...template.FuncMap) ([]byte, error) { 142 | funcMap := template.FuncMap{ 143 | "FirstUpper": FirstUpper, 144 | // Title returns a string with the first character of each word as upper case. 145 | "Title": cases.Title(language.Und, cases.NoLower).String, 146 | } 147 | for _, v := range funcMaps { 148 | maps.Copy(funcMap, v) // Extras overwrite defaults. 149 | } 150 | tmpl, err := template.New(templateName).Funcs(funcMap).Parse(templateBody) 151 | 152 | if err != nil { 153 | return nil, fmt.Errorf("parsing function template: %w", err) 154 | } 155 | 156 | return executeTemplate(tmpl, templateData) 157 | } 158 | 159 | func executeTemplate(tmpl *template.Template, templateData any) ([]byte, error) { 160 | var buffer bytes.Buffer 161 | err := tmpl.Execute(&buffer, templateData) 162 | 163 | if err != nil { 164 | return nil, fmt.Errorf("executing template: %w", err) 165 | } 166 | 167 | return buffer.Bytes(), nil 168 | } 169 | 170 | func (d *baseDestination) WriteTemplateSet(templates *template.Template, templateData any) error { 171 | body, err := executeTemplate(templates, templateData) 172 | if err != nil { 173 | return err 174 | } 175 | 176 | body, err = d.format(body) 177 | if err != nil { 178 | return err 179 | } 180 | 181 | return d.WriteBytes(body) 182 | } 183 | 184 | func (d *baseDestination) format(body []byte) ([]byte, error) { 185 | if d.formatter == nil { 186 | return body, nil 187 | } 188 | 189 | unformattedBody := body 190 | body, err := d.formatter(unformattedBody) 191 | if err != nil { 192 | return nil, fmt.Errorf("formatting parsed template:\n%s\n%w", unformattedBody, err) 193 | } 194 | 195 | return body, nil 196 | } 197 | 198 | // FirstUpper returns a string with the first character as upper case. 199 | func FirstUpper(s string) string { 200 | if s == "" { 201 | return "" 202 | } 203 | r, n := utf8.DecodeRuneInString(s) 204 | return string(unicode.ToUpper(r)) + s[n:] 205 | } 206 | -------------------------------------------------------------------------------- /internal/generate/endpoints/main.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | //go:build generate 5 | // +build generate 6 | 7 | package main 8 | 9 | import ( 10 | _ "embed" 11 | "encoding/json" 12 | "flag" 13 | "fmt" 14 | "html/template" 15 | "io" 16 | "net/http" 17 | "os" 18 | "sort" 19 | "strings" 20 | 21 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/generate/common" 22 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/slices" 23 | ) 24 | 25 | type PartitionDatum struct { 26 | ID string 27 | Name string 28 | DNSSuffix string 29 | RegionRegex string 30 | Regions []RegionDatum 31 | Services []ServiceDatum 32 | } 33 | 34 | type RegionDatum struct { 35 | ID string 36 | Description string 37 | } 38 | 39 | type ServiceDatum struct { 40 | ID string 41 | } 42 | 43 | type TemplateData struct { 44 | Partitions []PartitionDatum 45 | } 46 | 47 | func usage() { 48 | fmt.Fprintf(os.Stderr, "Usage:\n") 49 | fmt.Fprintf(os.Stderr, "\tmain.go \n\n") 50 | } 51 | 52 | func main() { 53 | flag.Usage = usage 54 | flag.Parse() 55 | 56 | args := flag.Args() 57 | 58 | if len(args) < 1 { 59 | flag.Usage() 60 | os.Exit(2) 61 | } 62 | 63 | inputURL := args[0] 64 | filename := `endpoints_gen.go` 65 | target := map[string]any{} 66 | 67 | g := common.NewGenerator() 68 | g.Infof("Generating endpoints/%s", filename) 69 | 70 | if err := readHTTPJSON(inputURL, &target); err != nil { 71 | g.Fatalf("error reading JSON from %s: %s", inputURL, err) 72 | } 73 | 74 | td := TemplateData{} 75 | templateFuncMap := template.FuncMap{ 76 | // IDToTitle splits a '-' or '.' separated string and returns a string with each part title cased. 77 | "IDToTitle": func(s string) (string, error) { 78 | parts := strings.Split(s, "-") 79 | if len(parts) == 1 { 80 | parts = strings.Split(s, ".") 81 | } 82 | return strings.Join(slices.ApplyToAll(parts, func(s string) string { 83 | return common.FirstUpper(s) 84 | }), ""), nil 85 | }, 86 | } 87 | 88 | if version, ok := target["version"].(float64); ok { 89 | if version != 3.0 { 90 | g.Fatalf("unsupported endpoints document version: %d", int(version)) 91 | } 92 | } else { 93 | g.Fatalf("can't parse endpoints document version") 94 | } 95 | 96 | /* 97 | See https://github.com/aws/aws-sdk-go/blob/main/aws/endpoints/v3model.go. 98 | e.g. 99 | { 100 | "partitions": [{ 101 | "partition": "aws", 102 | "partitionName": "AWS Standard", 103 | "regions" : { 104 | "af-south-1" : { 105 | "description" : "Africa (Cape Town)" 106 | }, 107 | ... 108 | }, 109 | "services" : { 110 | "access-analyzer" : { 111 | "endpoints" : { 112 | "af-south-1" : { }, 113 | ... 114 | }, 115 | }, 116 | ... 117 | }, 118 | ... 119 | }, ...] 120 | } 121 | */ 122 | if partitions, ok := target["partitions"].([]any); ok { 123 | for _, partition := range partitions { 124 | if partition, ok := partition.(map[string]any); ok { 125 | partitionDatum := PartitionDatum{} 126 | 127 | if id, ok := partition["partition"].(string); ok { 128 | partitionDatum.ID = id 129 | } 130 | if name, ok := partition["partitionName"].(string); ok { 131 | partitionDatum.Name = name 132 | } 133 | if dnsSuffix, ok := partition["dnsSuffix"].(string); ok { 134 | partitionDatum.DNSSuffix = dnsSuffix 135 | } 136 | if regionRegex, ok := partition["regionRegex"].(string); ok { 137 | partitionDatum.RegionRegex = regionRegex 138 | } 139 | if regions, ok := partition["regions"].(map[string]any); ok { 140 | for id, region := range regions { 141 | regionDatum := RegionDatum{ 142 | ID: id, 143 | } 144 | 145 | if region, ok := region.(map[string]any); ok { 146 | if description, ok := region["description"].(string); ok { 147 | regionDatum.Description = description 148 | } 149 | } 150 | 151 | partitionDatum.Regions = append(partitionDatum.Regions, regionDatum) 152 | } 153 | } 154 | if services, ok := partition["services"].(map[string]any); ok { 155 | for id := range services { 156 | serviceDatum := ServiceDatum{ 157 | ID: id, 158 | } 159 | 160 | partitionDatum.Services = append(partitionDatum.Services, serviceDatum) 161 | } 162 | } 163 | 164 | td.Partitions = append(td.Partitions, partitionDatum) 165 | } 166 | } 167 | } 168 | 169 | sort.SliceStable(td.Partitions, func(i, j int) bool { 170 | return td.Partitions[i].ID < td.Partitions[j].ID 171 | }) 172 | 173 | for i := 0; i < len(td.Partitions); i++ { 174 | sort.SliceStable(td.Partitions[i].Regions, func(j, k int) bool { 175 | return td.Partitions[i].Regions[j].ID < td.Partitions[i].Regions[k].ID 176 | }) 177 | } 178 | 179 | for i := 0; i < len(td.Partitions); i++ { 180 | sort.SliceStable(td.Partitions[i].Services, func(j, k int) bool { 181 | return td.Partitions[i].Services[j].ID < td.Partitions[i].Services[k].ID 182 | }) 183 | } 184 | 185 | d := g.NewGoFileDestination(filename) 186 | 187 | if err := d.WriteTemplate("endpoints", tmpl, td, templateFuncMap); err != nil { 188 | g.Fatalf("error generating endpoint resolver: %s", err) 189 | } 190 | 191 | if err := d.Write(); err != nil { 192 | g.Fatalf("generating file (%s): %s", filename, err) 193 | } 194 | } 195 | 196 | func readHTTPJSON(url string, to any) error { 197 | r, err := http.Get(url) 198 | if err != nil { 199 | return err 200 | } 201 | defer r.Body.Close() 202 | 203 | return decodeFromReader(r.Body, to) 204 | } 205 | 206 | func decodeFromReader(r io.Reader, to any) error { 207 | dec := json.NewDecoder(r) 208 | 209 | for { 210 | if err := dec.Decode(to); err == io.EOF { 211 | break 212 | } else if err != nil { 213 | return err 214 | } 215 | } 216 | 217 | return nil 218 | } 219 | 220 | //go:embed output.go.gtpl 221 | var tmpl string 222 | -------------------------------------------------------------------------------- /internal/generate/endpoints/output.go.gtpl: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | // Code generated by internal/generate/endpoints/main.go; DO NOT EDIT. 5 | 6 | package endpoints 7 | 8 | import ( 9 | "regexp" 10 | ) 11 | 12 | // All known partition IDs. 13 | const ( 14 | {{- range .Partitions }} 15 | {{ .ID | IDToTitle}}PartitionID = "{{ .ID }}" // {{ .Name }} 16 | {{- end }} 17 | ) 18 | 19 | // All known Region IDs. 20 | const ( 21 | {{- range .Partitions }} 22 | // {{ .Name }} partition's Regions. 23 | {{- range .Regions }} 24 | {{ .ID | IDToTitle}}RegionID = "{{ .ID }}" // {{ .Description }} 25 | {{- end }} 26 | {{- end }} 27 | ) 28 | 29 | var ( 30 | partitions = map[string]Partition{ 31 | {{- range .Partitions }} 32 | {{ .ID | IDToTitle}}PartitionID: { 33 | id: {{ .ID | IDToTitle}}PartitionID, 34 | name: "{{ .Name }}", 35 | dnsSuffix: "{{ .DNSSuffix }}", 36 | regionRegex: regexp.MustCompile(`{{ .RegionRegex }}`), 37 | regions: map[string]Region{ 38 | {{- range .Regions }} 39 | {{ .ID | IDToTitle}}RegionID: { 40 | id: {{ .ID | IDToTitle}}RegionID, 41 | description: "{{ .Description }}", 42 | }, 43 | {{- end }} 44 | }, 45 | services: map[string]Service{ 46 | {{- range .Services }} 47 | "{{ .ID }}": { 48 | id: "{{ .ID }}", 49 | }, 50 | {{- end }} 51 | }, 52 | }, 53 | {{- end }} 54 | } 55 | ) -------------------------------------------------------------------------------- /internal/slices/slices.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package slices 5 | 6 | import "slices" 7 | 8 | // ApplyToAll returns a new slice containing the results of applying the function `f` to each element of the original slice `s`. 9 | func ApplyToAll[T, U any](s []T, f func(T) U) []U { 10 | v := make([]U, len(s)) 11 | 12 | for i, e := range s { 13 | v[i] = f(e) 14 | } 15 | 16 | return v 17 | } 18 | 19 | type FilterFunc[T any] func(T) bool 20 | 21 | // Filter returns a new slice containing all values that return `true` for the filter function `f` 22 | func Filter[T any](s []T, f FilterFunc[T]) []T { 23 | v := make([]T, 0, len(s)) 24 | 25 | for _, e := range s { 26 | if f(e) { 27 | v = append(v, e) 28 | } 29 | } 30 | 31 | return slices.Clip(v) 32 | } 33 | -------------------------------------------------------------------------------- /internal/test/context.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package test 5 | 6 | import ( 7 | "context" 8 | "testing" 9 | 10 | "github.com/hashicorp/aws-sdk-go-base/v2/logging" 11 | ) 12 | 13 | func Context(t *testing.T) context.Context { 14 | return logging.RegisterLogger(context.Background(), logging.TfLogger(t.Name())) 15 | } 16 | -------------------------------------------------------------------------------- /internal/test/user_agent.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package test 5 | 6 | import ( 7 | "os" 8 | "testing" 9 | 10 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" 11 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/constants" 12 | "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" 13 | ) 14 | 15 | type UserAgentTestCase struct { 16 | Config *config.Config 17 | Context config.UserAgentProducts 18 | EnvironmentVariables map[string]string 19 | ExpectedUserAgent string 20 | } 21 | 22 | func TestUserAgentProducts(t *testing.T, awsSdkGoUserAgent func() string, testUserAgentProducts func(t *testing.T, testCase UserAgentTestCase)) { 23 | t.Helper() 24 | 25 | testCases := map[string]UserAgentTestCase{ 26 | "standard User-Agent": { 27 | Config: &config.Config{ 28 | AccessKey: servicemocks.MockStaticAccessKey, 29 | Region: "us-east-1", 30 | SecretKey: servicemocks.MockStaticSecretKey, 31 | }, 32 | ExpectedUserAgent: awsSdkGoUserAgent(), 33 | }, 34 | 35 | "customized User-Agent TF_APPEND_USER_AGENT product": { 36 | Config: &config.Config{ 37 | AccessKey: servicemocks.MockStaticAccessKey, 38 | Region: "us-east-1", 39 | SecretKey: servicemocks.MockStaticSecretKey, 40 | }, 41 | EnvironmentVariables: map[string]string{ 42 | constants.AppendUserAgentEnvVar: "Env", 43 | }, 44 | ExpectedUserAgent: awsSdkGoUserAgent() + " Env", 45 | }, 46 | 47 | "customized User-Agent TF_APPEND_USER_AGENT product version": { 48 | Config: &config.Config{ 49 | AccessKey: servicemocks.MockStaticAccessKey, 50 | Region: "us-east-1", 51 | SecretKey: servicemocks.MockStaticSecretKey, 52 | }, 53 | EnvironmentVariables: map[string]string{ 54 | constants.AppendUserAgentEnvVar: "Env/1.2", 55 | }, 56 | ExpectedUserAgent: awsSdkGoUserAgent() + " Env/1.2", 57 | }, 58 | 59 | "customized User-Agent TF_APPEND_USER_AGENT multi product": { 60 | Config: &config.Config{ 61 | AccessKey: servicemocks.MockStaticAccessKey, 62 | Region: "us-east-1", 63 | SecretKey: servicemocks.MockStaticSecretKey, 64 | }, 65 | EnvironmentVariables: map[string]string{ 66 | constants.AppendUserAgentEnvVar: "Env1/1.2 Env2", 67 | }, 68 | ExpectedUserAgent: awsSdkGoUserAgent() + " Env1/1.2 Env2", 69 | }, 70 | 71 | "customized User-Agent TF_APPEND_USER_AGENT with comment": { 72 | Config: &config.Config{ 73 | AccessKey: servicemocks.MockStaticAccessKey, 74 | Region: "us-east-1", 75 | SecretKey: servicemocks.MockStaticSecretKey, 76 | }, 77 | EnvironmentVariables: map[string]string{ 78 | constants.AppendUserAgentEnvVar: "Env1/1.2 (comment) Env2", 79 | }, 80 | ExpectedUserAgent: awsSdkGoUserAgent() + " Env1/1.2 (comment) Env2", 81 | }, 82 | 83 | "APN User-Agent Products": { 84 | Config: &config.Config{ 85 | AccessKey: servicemocks.MockStaticAccessKey, 86 | Region: "us-east-1", 87 | SecretKey: servicemocks.MockStaticSecretKey, 88 | APNInfo: &config.APNInfo{ 89 | PartnerName: "partner", 90 | Products: []config.UserAgentProduct{ 91 | { 92 | Name: "first", 93 | Version: "1.2.3", 94 | }, 95 | { 96 | Name: "second", 97 | Version: "1.0.2", 98 | Comment: "a comment", 99 | }, 100 | }, 101 | }, 102 | }, 103 | ExpectedUserAgent: "APN/1.0 partner/1.0 first/1.2.3 second/1.0.2 (a comment) " + awsSdkGoUserAgent(), 104 | }, 105 | 106 | "APN User-Agent Products and TF_APPEND_USER_AGENT": { 107 | Config: &config.Config{ 108 | AccessKey: servicemocks.MockStaticAccessKey, 109 | Region: "us-east-1", 110 | SecretKey: servicemocks.MockStaticSecretKey, 111 | APNInfo: &config.APNInfo{ 112 | PartnerName: "partner", 113 | Products: []config.UserAgentProduct{ 114 | { 115 | Name: "first", 116 | Version: "1.2.3", 117 | }, 118 | { 119 | Name: "second", 120 | Version: "1.0.2", 121 | }, 122 | }, 123 | }, 124 | }, 125 | EnvironmentVariables: map[string]string{ 126 | constants.AppendUserAgentEnvVar: "Last/9.0.0", 127 | }, 128 | ExpectedUserAgent: "APN/1.0 partner/1.0 first/1.2.3 second/1.0.2 " + awsSdkGoUserAgent() + " Last/9.0.0", 129 | }, 130 | 131 | "User-Agent Products": { 132 | Config: &config.Config{ 133 | AccessKey: servicemocks.MockStaticAccessKey, 134 | Region: "us-east-1", 135 | SecretKey: servicemocks.MockStaticSecretKey, 136 | UserAgent: []config.UserAgentProduct{ 137 | { 138 | Name: "first", 139 | Version: "1.2.3", 140 | }, 141 | { 142 | Name: "second", 143 | Version: "1.0.2", 144 | Comment: "a comment", 145 | }, 146 | }, 147 | }, 148 | ExpectedUserAgent: awsSdkGoUserAgent() + " first/1.2.3 second/1.0.2 (a comment)", 149 | }, 150 | 151 | "APN and User-Agent Products": { 152 | Config: &config.Config{ 153 | AccessKey: servicemocks.MockStaticAccessKey, 154 | Region: "us-east-1", 155 | SecretKey: servicemocks.MockStaticSecretKey, 156 | APNInfo: &config.APNInfo{ 157 | PartnerName: "partner", 158 | Products: []config.UserAgentProduct{ 159 | { 160 | Name: "first", 161 | Version: "1.2.3", 162 | }, 163 | { 164 | Name: "second", 165 | Version: "1.0.2", 166 | Comment: "a comment", 167 | }, 168 | }, 169 | }, 170 | UserAgent: []config.UserAgentProduct{ 171 | { 172 | Name: "third", 173 | Version: "4.5.6", 174 | }, 175 | { 176 | Name: "fourth", 177 | Version: "2.1", 178 | }, 179 | }, 180 | }, 181 | ExpectedUserAgent: "APN/1.0 partner/1.0 first/1.2.3 second/1.0.2 (a comment) " + awsSdkGoUserAgent() + " third/4.5.6 fourth/2.1", 182 | }, 183 | 184 | "context": { 185 | Config: &config.Config{ 186 | AccessKey: servicemocks.MockStaticAccessKey, 187 | Region: "us-east-1", 188 | SecretKey: servicemocks.MockStaticSecretKey, 189 | }, 190 | Context: []config.UserAgentProduct{ 191 | { 192 | Name: "first", 193 | Version: "1.2.3", 194 | }, 195 | { 196 | Name: "second", 197 | Version: "1.0.2", 198 | Comment: "a comment", 199 | }, 200 | }, 201 | ExpectedUserAgent: awsSdkGoUserAgent() + " first/1.2.3 second/1.0.2 (a comment)", 202 | }, 203 | 204 | "User-Agent Products and context": { 205 | Config: &config.Config{ 206 | AccessKey: servicemocks.MockStaticAccessKey, 207 | Region: "us-east-1", 208 | SecretKey: servicemocks.MockStaticSecretKey, 209 | UserAgent: []config.UserAgentProduct{ 210 | { 211 | Name: "first", 212 | Version: "1.2.3", 213 | }, 214 | { 215 | Name: "second", 216 | Version: "1.0.2", 217 | Comment: "a comment", 218 | }, 219 | }, 220 | }, 221 | Context: []config.UserAgentProduct{ 222 | { 223 | Name: "third", 224 | Version: "4.5.6", 225 | }, 226 | { 227 | Name: "fourth", 228 | Version: "2.1", 229 | }, 230 | }, 231 | ExpectedUserAgent: awsSdkGoUserAgent() + " first/1.2.3 second/1.0.2 (a comment) third/4.5.6 fourth/2.1", 232 | }, 233 | } 234 | 235 | for name, testCase := range testCases { 236 | t.Run(name, func(t *testing.T) { 237 | servicemocks.InitSessionTestEnv(t) 238 | 239 | for k, v := range testCase.EnvironmentVariables { 240 | os.Setenv(k, v) 241 | } 242 | 243 | testCase.Config.SkipCredsValidation = true 244 | 245 | testUserAgentProducts(t, testCase) 246 | }) 247 | } 248 | } 249 | -------------------------------------------------------------------------------- /internal/test/validator.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package test 5 | 6 | import ( 7 | "slices" 8 | "testing" 9 | 10 | "github.com/hashicorp/aws-sdk-go-base/v2/diag" 11 | ) 12 | 13 | type DiagsValidator func(*testing.T, diag.Diagnostics) 14 | 15 | type ErrValidator func(error) bool 16 | 17 | type DiagValidator func(diag.Diagnostic) bool 18 | 19 | func ExpectNoDiags(t *testing.T, diags diag.Diagnostics) { 20 | expectDiagsCount(t, diags, 0) 21 | } 22 | 23 | func ExpectErrDiagValidator(msg string, ev ErrValidator) DiagsValidator { 24 | return func(t *testing.T, diags diag.Diagnostics) { 25 | // Check for the correct type of error before checking for single diagnostic 26 | if !expectDiagsContainsErr(diags, ev) { 27 | t.Fatalf("expected %s, got %#v", msg, diags) 28 | } 29 | 30 | expectDiagsCount(t, diags, 1) 31 | } 32 | } 33 | 34 | func ExpectDiagValidator(msg string, dv DiagValidator) DiagsValidator { 35 | return func(t *testing.T, diags diag.Diagnostics) { 36 | // Check for the correct type of error before checking for single diagnostic 37 | if !expectDiagsContainsDiagFunc(diags, dv) { 38 | t.Fatalf("expected %s, got %#v", msg, diags) 39 | } 40 | 41 | expectDiagsCount(t, diags, 1) 42 | } 43 | } 44 | 45 | func expectDiagsCount(t *testing.T, diags diag.Diagnostics, c int) { 46 | if l := diags.Count(); l != c { 47 | t.Fatalf("Diagnostics: expected %d element, got %d\n%#v", c, l, diags) 48 | } 49 | } 50 | 51 | func expectDiagsContainsErr(diags diag.Diagnostics, ev ErrValidator) bool { 52 | for _, d := range diags.Errors() { 53 | if e, ok := d.(diag.DiagnosticWithErr); ok { 54 | if ev(e.Err()) { 55 | return true 56 | } 57 | } 58 | } 59 | return false 60 | } 61 | 62 | func expectDiagsContainsDiagFunc(diags diag.Diagnostics, dv DiagValidator) bool { 63 | return slices.ContainsFunc(diags, dv) 64 | } 65 | -------------------------------------------------------------------------------- /logging/attributes.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | import "go.opentelemetry.io/otel/attribute" 7 | 8 | const ( 9 | AwsSdkKey attribute.Key = "tf_aws.sdk" 10 | SigningRegionKey attribute.Key = "tf_aws.signing_region" 11 | CustomEndpointKey attribute.Key = "tf_aws.custom_endpoint" 12 | ) 13 | 14 | func SigningRegion(region string) attribute.KeyValue { 15 | return SigningRegionKey.String(region) 16 | } 17 | 18 | func CustomEndpoint(custom bool) attribute.KeyValue { 19 | return CustomEndpointKey.Bool(custom) 20 | } 21 | -------------------------------------------------------------------------------- /logging/aws.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | import ( 7 | "regexp" 8 | "unsafe" 9 | ) 10 | 11 | // IAM Unique ID prefixes from 12 | // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_identifiers.html#identifiers-unique-ids 13 | var UniqueIDRegex = regexp.MustCompile(`(A3T[A-Z0-9]` + 14 | `|ABIA` + // STS service bearer token 15 | `|ACCA` + // Context-specific credential 16 | `|AGPA` + // User group 17 | `|AIDA` + // IAM user 18 | `|AIPA` + // EC2 instance profile 19 | `|AKIA` + // Access key 20 | `|ANPA` + // Managed policy 21 | `|ANVA` + // Version in a managed policy 22 | `|APKA` + // Public key 23 | `|AROA` + // Role 24 | `|ASCA` + // Certificate 25 | `|ASIA` + // STS temporary access key 26 | `)[A-Z0-9]{16,}`) 27 | 28 | const ( 29 | unmaskedFirst = 4 30 | unmaskedLast = 4 31 | ) 32 | 33 | func MaskAWSAccessKey(field []byte) []byte { 34 | field = UniqueIDRegex.ReplaceAllFunc(field, func(s []byte) []byte { 35 | return partialMaskString(s, unmaskedFirst, unmaskedLast) 36 | }) 37 | return field 38 | } 39 | 40 | func MaskAWSSensitiveValues(field string) string { 41 | b := unsafe.Slice(unsafe.StringData(field), len(field)) 42 | b = MaskAWSAccessKey(b) 43 | MaskAWSSecretKeys(b) 44 | return unsafe.String(unsafe.SliceData(b), len(b)) 45 | } 46 | 47 | // MaskAWSSecretKeys masks likely AWS secret access keys in the input. 48 | // See https://aws.amazon.com/blogs/security/a-safer-way-to-distribute-aws-credentials-to-ec2/: 49 | // "Find me 40-character, base-64 strings that don’t have any base 64 characters immediately before or after". 50 | func MaskAWSSecretKeys(in []byte) { 51 | const ( 52 | secretKeyLen = 40 53 | ) 54 | len := len(in) 55 | base64Characters := 0 56 | 57 | for i := range len { 58 | b := in[i] 59 | 60 | if (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') || b == '/' || b == '+' || b == '=' { 61 | // base64 character. 62 | base64Characters++ 63 | } else { 64 | if base64Characters == secretKeyLen { 65 | for j := (i - secretKeyLen) + unmaskedFirst; j < i-unmaskedLast; j++ { 66 | in[j] = '*' 67 | } 68 | } 69 | 70 | base64Characters = 0 71 | } 72 | } 73 | 74 | if base64Characters == secretKeyLen { 75 | for j := (len - secretKeyLen) + unmaskedFirst; j < len-unmaskedLast; j++ { 76 | in[j] = '*' 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /logging/aws_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | import ( 7 | "testing" 8 | "unsafe" 9 | ) 10 | 11 | func TestMaskAWSSensitiveValues(t *testing.T) { 12 | t.Parallel() 13 | 14 | type testCase struct { 15 | input string 16 | expected string 17 | } 18 | 19 | tests := map[string]testCase{ 20 | "mask_simple": { 21 | input: "MfP3tIG15gibzIx7CSbhSNkgD5sSV4k2tWXgN8U8", 22 | expected: "MfP3********************************N8U8", 23 | }, 24 | "mask_complex_json": { 25 | input: ` 26 | { 27 | "AWSSecretKey": "LEfH8nZmFN4BGIJnku6lkChHydRN5B/YlWCIjOte", 28 | "BucketName": "test-bucket", 29 | "AWSKeyId": "AIDACKCEVSQ6C2EXAMPLE", 30 | } 31 | `, 32 | expected: ` 33 | { 34 | "AWSSecretKey": "LEfH********************************jOte", 35 | "BucketName": "test-bucket", 36 | "AWSKeyId": "AIDA*************MPLE", 37 | } 38 | `, 39 | }, 40 | "mask_multiple_json": { 41 | input: ` 42 | { 43 | "AWSSecretKey": "LEfH8nZmFN4BGIJnku6lkChHydRN5B/YlWCIjOte", 44 | "BucketName": "test-bucket-1", 45 | "AWSKeyId": "AIDACKCEVSQ6C2EXAMPLE", 46 | }, 47 | { 48 | "Key": "ABCDEFGH!JKLMNOPQRSTUVWXYZ012345678901234567890123456789", 49 | }, 50 | { 51 | "AWSSecretKey": "MfP3tIG15gibzIx7CSbhSNkgD5sSV4k2tWXgN8U8", 52 | "BucketName": "test-bucket-2", 53 | "AWSKeyId": "AKIA5PX2H2S3LHEXAMPLE", 54 | } 55 | `, 56 | expected: ` 57 | { 58 | "AWSSecretKey": "LEfH********************************jOte", 59 | "BucketName": "test-bucket-1", 60 | "AWSKeyId": "AIDA*************MPLE", 61 | }, 62 | { 63 | "Key": "ABCDEFGH!JKLMNOPQRSTUVWXYZ012345678901234567890123456789", 64 | }, 65 | { 66 | "AWSSecretKey": "MfP3********************************N8U8", 67 | "BucketName": "test-bucket-2", 68 | "AWSKeyId": "AKIA*************MPLE", 69 | } 70 | `, 71 | }, 72 | "no_mask": { 73 | input: "test-bucket", 74 | expected: "test-bucket", 75 | }, 76 | "mask_xml": { 77 | input: ` 78 | 8/AiP0ofCD/YOAqXWrungQt/Y4BkTj1UOjZ0MqBs 79 | test-bucket 80 | AIDACKCEVSQ6C2EXAMPLE 81 | `, 82 | expected: ` 83 | 8/Ai********************************MqBs 84 | test-bucket 85 | AIDA*************MPLE 86 | `, 87 | }, 88 | } 89 | 90 | for name, test := range tests { 91 | t.Run(name, func(t *testing.T) { 92 | t.Parallel() 93 | 94 | got := MaskAWSSensitiveValues(test.input) 95 | 96 | if got != test.expected { 97 | t.Errorf("unexpected diff +wanted: %s, -got: %s", test.expected, got) 98 | } 99 | }) 100 | } 101 | } 102 | 103 | func BenchmarkMaskAWSAccessKey(b *testing.B) { 104 | b.ReportAllocs() 105 | for n := 0; n < b.N; n++ { 106 | MaskAWSAccessKey([]byte(` 107 | { 108 | "AWSSecretKey": "LEfH8nZmFN4BGIJnku6lkChHydRN5B/YlWCIjOte", 109 | "BucketName": "test-bucket", 110 | "AWSKeyId": "AIDACKCEVSQ6C2EXAMPLE", 111 | } 112 | `)) 113 | } 114 | } 115 | 116 | func BenchmarkPartialMaskString(b *testing.B) { 117 | var s []byte 118 | b.ReportAllocs() 119 | for n := 0; n < b.N; n++ { 120 | s = partialMaskString([]byte("AIDACKCEVSQ6C2EXAMPLE"), 4, 4) 121 | } 122 | dump = unsafe.String(unsafe.SliceData(s), len(s)) 123 | } 124 | 125 | func BenchmarkMaskAWSSecretKeys(b *testing.B) { 126 | b.ReportAllocs() 127 | for n := 0; n < b.N; n++ { 128 | MaskAWSSecretKeys([]byte(` 129 | { 130 | "AWSSecretKey": "LEfH8nZmFN4BGIJnku6lkChHydRN5B/YlWCIjOte", 131 | "BucketName": "test-bucket", 132 | "AWSKeyId": "AIDACKCEVSQ6C2EXAMPLE", 133 | } 134 | `)) 135 | } 136 | } 137 | 138 | func BenchmarkMaskAWSSensitiveValues(b *testing.B) { 139 | var s string 140 | b.ReportAllocs() 141 | for n := 0; n < b.N; n++ { 142 | s = MaskAWSSensitiveValues(` 143 | { 144 | "AWSSecretKey": "LEfH8nZmFN4BGIJnku6lkChHydRN5B/YlWCIjOte", 145 | "BucketName": "test-bucket", 146 | "AWSKeyId": "AIDACKCEVSQ6C2EXAMPLE", 147 | } 148 | `) 149 | } 150 | dump = s 151 | } 152 | 153 | var dump string 154 | -------------------------------------------------------------------------------- /logging/context.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | import ( 7 | "context" 8 | ) 9 | 10 | type loggerKeyT string 11 | 12 | const loggerKey loggerKeyT = "logger-key" 13 | 14 | func RegisterLogger(ctx context.Context, logger Logger) context.Context { 15 | return context.WithValue(ctx, loggerKey, logger) 16 | } 17 | 18 | func RetrieveLogger(ctx context.Context) Logger { 19 | logger, ok := ctx.Value(loggerKey).(Logger) 20 | if !ok { 21 | return NullLogger{} 22 | } 23 | return logger 24 | } 25 | -------------------------------------------------------------------------------- /logging/hc_logger.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | import ( 7 | "context" 8 | 9 | "github.com/hashicorp/go-hclog" 10 | ) 11 | 12 | type HcLogger struct{} 13 | 14 | var _ Logger = HcLogger{} 15 | 16 | func NewHcLogger(ctx context.Context, logger hclog.Logger) (context.Context, HcLogger) { 17 | ctx = hclog.WithContext(ctx, logger) 18 | 19 | return ctx, HcLogger{} 20 | } 21 | 22 | func (l HcLogger) SubLogger(ctx context.Context, name string) (context.Context, Logger) { 23 | logger := hclog.FromContext(ctx) 24 | logger = logger.Named(name) 25 | ctx = hclog.WithContext(ctx, logger) 26 | 27 | return ctx, HcLogger{} 28 | } 29 | 30 | func (l HcLogger) Warn(ctx context.Context, msg string, fields ...map[string]any) { 31 | logger := hclog.FromContext(ctx) 32 | logger.Warn(msg, flattenFields(fields...)...) 33 | } 34 | 35 | func (l HcLogger) Info(ctx context.Context, msg string, fields ...map[string]any) { 36 | logger := hclog.FromContext(ctx) 37 | logger.Info(msg, flattenFields(fields...)...) 38 | } 39 | 40 | func (l HcLogger) Debug(ctx context.Context, msg string, fields ...map[string]any) { 41 | logger := hclog.FromContext(ctx) 42 | logger.Debug(msg, flattenFields(fields...)...) 43 | } 44 | 45 | func (l HcLogger) Trace(ctx context.Context, msg string, fields ...map[string]any) { 46 | logger := hclog.FromContext(ctx) 47 | logger.Trace(msg, flattenFields(fields...)...) 48 | } 49 | 50 | // TODO: how to handle duplicates 51 | func flattenFields(fields ...map[string]any) []any { 52 | var totalLen int 53 | for _, m := range fields { 54 | totalLen = len(m) 55 | } 56 | f := make([]any, 0, totalLen*2) //nolint:mnd 57 | 58 | for _, m := range fields { 59 | for k, v := range m { 60 | f = append(f, k, v) 61 | } 62 | } 63 | return f 64 | } 65 | 66 | func (l HcLogger) SetField(ctx context.Context, key string, value any) context.Context { 67 | logger := hclog.FromContext(ctx) 68 | logger = logger.With(key, value) 69 | ctx = hclog.WithContext(ctx, logger) 70 | return ctx 71 | } 72 | -------------------------------------------------------------------------------- /logging/hc_logger_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | import ( 7 | "context" 8 | "io" 9 | "testing" 10 | 11 | "github.com/hashicorp/go-hclog" 12 | ) 13 | 14 | const hclogRootName = "hc-log-test" 15 | 16 | func TestHcLoggerWarn(t *testing.T) { 17 | testLoggerWarn(t, hclogRootName, hcLoggerFactory) 18 | } 19 | 20 | func TestHcLoggerSetField(t *testing.T) { 21 | testLoggerSetField(t, hclogRootName, hcLoggerFactory) 22 | } 23 | 24 | func hcLoggerFactory(ctx context.Context, name string, output io.Writer) (context.Context, Logger) { 25 | hclogger := configureHcLogger(output) 26 | 27 | ctx, rootLogger := NewHcLogger(ctx, hclogger) 28 | ctx, logger := rootLogger.SubLogger(ctx, name) 29 | 30 | return ctx, logger 31 | } 32 | 33 | // configureHcLogger configures the default logger with settings suitable for testing: 34 | // 35 | // - Log level set to TRACE 36 | // - Written to the io.Writer passed in, such as a bytes.Buffer 37 | // - Log entries are in JSON format, and can be decoded using multilineJSONDecode 38 | // - Caller information is not included 39 | // - Timestamp is not included 40 | func configureHcLogger(output io.Writer) hclog.Logger { 41 | logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{ 42 | Name: hclogRootName, 43 | Level: hclog.Trace, 44 | Output: output, 45 | IndependentLevels: true, 46 | JSONFormat: true, 47 | IncludeLocation: false, 48 | DisableTime: true, 49 | }) 50 | 51 | return logger 52 | } 53 | -------------------------------------------------------------------------------- /logging/logger.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | import ( 7 | "context" 8 | ) 9 | 10 | type Logger interface { 11 | Warn(ctx context.Context, msg string, fields ...map[string]any) 12 | Info(ctx context.Context, msg string, fields ...map[string]any) 13 | Debug(ctx context.Context, msg string, fields ...map[string]any) 14 | Trace(ctx context.Context, msg string, fields ...map[string]any) 15 | 16 | SetField(ctx context.Context, key string, value any) context.Context 17 | 18 | SubLogger(ctx context.Context, name string) (context.Context, Logger) 19 | } 20 | -------------------------------------------------------------------------------- /logging/logger_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "io" 10 | "testing" 11 | 12 | "github.com/google/go-cmp/cmp" 13 | "github.com/hashicorp/terraform-plugin-log/tflogtest" 14 | ) 15 | 16 | func testLoggerWarn(t *testing.T, rootName string, factory func(ctx context.Context, name string, output io.Writer) (context.Context, Logger)) { 17 | t.Helper() 18 | 19 | loggerName := "test" 20 | expectedModule := rootName + "." + loggerName 21 | 22 | var buf bytes.Buffer 23 | ctx := context.Background() 24 | ctx, logger := factory(ctx, loggerName, &buf) 25 | 26 | logger.Warn(ctx, "message", map[string]any{ 27 | "one": int(1), 28 | "two": "two", 29 | }) 30 | 31 | lines, err := tflogtest.MultilineJSONDecode(&buf) 32 | if err != nil { 33 | t.Fatalf("decoding log lines: %s", err) 34 | } 35 | 36 | expected := []map[string]any{ 37 | { 38 | "@level": "warn", 39 | "@module": expectedModule, 40 | "@message": "message", 41 | "one": float64(1), 42 | "two": "two", 43 | }, 44 | } 45 | 46 | if diff := cmp.Diff(expected, lines); diff != "" { 47 | t.Errorf("unexpected logger output difference: %s", diff) 48 | } 49 | } 50 | 51 | func testLoggerSetField(t *testing.T, rootName string, factory func(ctx context.Context, name string, output io.Writer) (context.Context, Logger)) { 52 | t.Helper() 53 | 54 | loggerName := "test" 55 | expectedModule := rootName + "." + loggerName 56 | 57 | var buf bytes.Buffer 58 | originalCtx := context.Background() 59 | originalCtx, logger := factory(originalCtx, loggerName, &buf) 60 | 61 | newCtx := logger.SetField(originalCtx, "key", "value") 62 | 63 | logger.Warn(newCtx, "new logger") 64 | logger.Warn(newCtx, "new logger", map[string]any{ 65 | "key": "other value", 66 | }) 67 | logger.Warn(originalCtx, "original logger") 68 | 69 | lines, err := tflogtest.MultilineJSONDecode(&buf) 70 | if err != nil { 71 | t.Fatalf("ctxWithField: decoding log lines: %s", err) 72 | } 73 | 74 | expected := []map[string]any{ 75 | { 76 | "@level": "warn", 77 | "@module": expectedModule, 78 | "@message": "new logger", 79 | "key": "value", 80 | }, 81 | { 82 | "@level": "warn", 83 | "@module": expectedModule, 84 | "@message": "new logger", 85 | "key": "other value", 86 | }, 87 | { 88 | "@level": "warn", 89 | "@module": expectedModule, 90 | "@message": "original logger", 91 | }, 92 | } 93 | 94 | if diff := cmp.Diff(expected, lines); diff != "" { 95 | t.Errorf("unexpected logger output difference: %s", diff) 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /logging/mask.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | func partialMaskString(s []byte, first, last int) []byte { 7 | l := len(s) 8 | result := make([]byte, 0, l) 9 | result = append(result, s[0:first]...) 10 | for range l - first - last { 11 | result = append(result, '*') 12 | } 13 | result = append(result, s[l-last:]...) 14 | return result 15 | } 16 | -------------------------------------------------------------------------------- /logging/null_logger.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | import ( 7 | "context" 8 | ) 9 | 10 | type NullLogger struct { 11 | } 12 | 13 | var _ Logger = NullLogger{} 14 | 15 | func (l NullLogger) SubLogger(ctx context.Context, name string) (context.Context, Logger) { 16 | return ctx, l 17 | } 18 | 19 | func (l NullLogger) Warn(ctx context.Context, msg string, fields ...map[string]any) { 20 | } 21 | 22 | func (l NullLogger) Info(ctx context.Context, msg string, fields ...map[string]any) { 23 | } 24 | 25 | func (l NullLogger) Debug(ctx context.Context, msg string, fields ...map[string]any) { 26 | } 27 | 28 | func (l NullLogger) Trace(ctx context.Context, msg string, fields ...map[string]any) { 29 | } 30 | 31 | func (l NullLogger) SetField(ctx context.Context, key string, value any) context.Context { 32 | return ctx 33 | } 34 | -------------------------------------------------------------------------------- /logging/tf_logger.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | import ( 7 | "context" 8 | 9 | "github.com/hashicorp/terraform-plugin-log/tflog" 10 | ) 11 | 12 | type TfLogger string 13 | 14 | var _ Logger = TfLogger("") 15 | 16 | func NewTfLogger(ctx context.Context) (context.Context, TfLogger) { 17 | return ctx, TfLogger("") 18 | } 19 | 20 | func (l TfLogger) SubLogger(ctx context.Context, name string) (context.Context, Logger) { 21 | ctx = tflog.NewSubsystem(ctx, name, tflog.WithRootFields()) 22 | logger := TfLogger(name) 23 | 24 | return ctx, logger 25 | } 26 | 27 | func (l TfLogger) Warn(ctx context.Context, msg string, fields ...map[string]any) { 28 | if l == "" { 29 | tflog.Warn(ctx, msg, fields...) 30 | } else { 31 | tflog.SubsystemWarn(ctx, string(l), msg, fields...) 32 | } 33 | } 34 | 35 | func (l TfLogger) Info(ctx context.Context, msg string, fields ...map[string]any) { 36 | if l == "" { 37 | tflog.Info(ctx, msg, fields...) 38 | } else { 39 | tflog.SubsystemInfo(ctx, string(l), msg, fields...) 40 | } 41 | } 42 | 43 | func (l TfLogger) Debug(ctx context.Context, msg string, fields ...map[string]any) { 44 | if l == "" { 45 | tflog.Debug(ctx, msg, fields...) 46 | } else { 47 | tflog.SubsystemDebug(ctx, string(l), msg, fields...) 48 | } 49 | } 50 | 51 | func (l TfLogger) Trace(ctx context.Context, msg string, fields ...map[string]any) { 52 | if l == "" { 53 | tflog.Trace(ctx, msg, fields...) 54 | } else { 55 | tflog.SubsystemTrace(ctx, string(l), msg, fields...) 56 | } 57 | } 58 | 59 | func (l TfLogger) SetField(ctx context.Context, key string, value any) context.Context { 60 | if l == "" { 61 | return tflog.SetField(ctx, key, value) 62 | } else { 63 | return tflog.SubsystemSetField(ctx, string(l), key, value) 64 | } 65 | } 66 | 67 | // func (l TfLogger) MaskAllFieldValuesRegexes(ctx context.Context, expressions ...*regexp.Regexp) context.Context { 68 | // if l == "" { 69 | // return tflog.MaskAllFieldValuesRegexes(ctx, expressions...) 70 | // } else { 71 | // return tflog.SubsystemMaskAllFieldValuesRegexes(ctx, string(l), expressions...) 72 | // } 73 | // } 74 | -------------------------------------------------------------------------------- /logging/tf_logger_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package logging 5 | 6 | import ( 7 | "context" 8 | "io" 9 | "testing" 10 | 11 | "github.com/hashicorp/terraform-plugin-log/tflogtest" 12 | ) 13 | 14 | const tflogRootName = "provider" 15 | 16 | func TestTfLoggerWarn(t *testing.T) { 17 | testLoggerWarn(t, tflogRootName, tfLoggerFactory) 18 | } 19 | 20 | func TestTfLoggerSetField(t *testing.T) { 21 | testLoggerSetField(t, tflogRootName, tfLoggerFactory) 22 | } 23 | 24 | func tfLoggerFactory(ctx context.Context, name string, output io.Writer) (context.Context, Logger) { 25 | ctx = tflogtest.RootLogger(ctx, output) 26 | 27 | ctx, rootLogger := NewTfLogger(ctx) 28 | ctx, logger := rootLogger.SubLogger(ctx, name) 29 | 30 | return ctx, logger 31 | } 32 | -------------------------------------------------------------------------------- /mockdata/mocks.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package mockdata 5 | 6 | import ( 7 | "github.com/aws/aws-sdk-go-v2/aws" 8 | "github.com/aws/aws-sdk-go-v2/config" 9 | "github.com/aws/aws-sdk-go-v2/credentials" 10 | "github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds" 11 | "github.com/aws/aws-sdk-go-v2/credentials/endpointcreds" 12 | "github.com/aws/aws-sdk-go-v2/credentials/ssocreds" 13 | "github.com/aws/aws-sdk-go-v2/credentials/stscreds" 14 | "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" 15 | ) 16 | 17 | // GetMockedAwsApiSession establishes an AWS session to a simulated AWS API server for a given service and route endpoints. 18 | func GetMockedAwsApiSession(svcName string, endpoints []*servicemocks.MockEndpoint) (func(), aws.Config, string) { 19 | ts := servicemocks.MockAwsApiServer(svcName, endpoints) 20 | 21 | sc := credentials.NewStaticCredentialsProvider("accessKey", "secretKey", "") 22 | 23 | awsConfig := aws.Config{ 24 | Credentials: sc, 25 | Region: "us-east-1", 26 | EndpointResolverWithOptions: aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...any) (aws.Endpoint, error) { 27 | return aws.Endpoint{ 28 | URL: ts.URL, 29 | Source: aws.EndpointSourceCustom, 30 | }, nil 31 | }), 32 | } 33 | 34 | return ts.Close, awsConfig, ts.URL 35 | } 36 | 37 | var ( 38 | MockEc2MetadataCredentials = aws.Credentials{ 39 | AccessKeyID: servicemocks.MockEc2MetadataAccessKey, 40 | Source: ec2rolecreds.ProviderName, 41 | SecretAccessKey: servicemocks.MockEc2MetadataSecretKey, 42 | SessionToken: servicemocks.MockEc2MetadataSessionToken, 43 | CanExpire: true, 44 | } 45 | 46 | MockEcsCredentialsCredentials = aws.Credentials{ 47 | AccessKeyID: servicemocks.MockEcsCredentialsAccessKey, 48 | SecretAccessKey: servicemocks.MockEcsCredentialsSecretKey, 49 | SessionToken: servicemocks.MockEcsCredentialsSessionToken, 50 | CanExpire: true, 51 | Source: endpointcreds.ProviderName, 52 | } 53 | 54 | MockEnvCredentials = aws.Credentials{ 55 | AccessKeyID: servicemocks.MockEnvAccessKey, 56 | SecretAccessKey: servicemocks.MockEnvSecretKey, 57 | Source: config.CredentialsSourceName, 58 | } 59 | 60 | MockEnvCredentialsWithSessionToken = aws.Credentials{ 61 | AccessKeyID: servicemocks.MockEnvAccessKey, 62 | SecretAccessKey: servicemocks.MockEnvSecretKey, 63 | SessionToken: servicemocks.MockEnvSessionToken, 64 | Source: config.CredentialsSourceName, 65 | } 66 | 67 | MockStaticCredentials = aws.Credentials{ 68 | AccessKeyID: servicemocks.MockStaticAccessKey, 69 | SecretAccessKey: servicemocks.MockStaticSecretKey, 70 | Source: credentials.StaticCredentialsName, 71 | } 72 | 73 | MockStsAssumeRoleCredentials = aws.Credentials{ 74 | AccessKeyID: servicemocks.MockStsAssumeRoleAccessKey, 75 | AccountID: "555555555555", 76 | SecretAccessKey: servicemocks.MockStsAssumeRoleSecretKey, 77 | SessionToken: servicemocks.MockStsAssumeRoleSessionToken, 78 | Source: stscreds.ProviderName, 79 | CanExpire: true, 80 | } 81 | 82 | MockStsAssumeRoleWithWebIdentityCredentials = aws.Credentials{ 83 | AccessKeyID: servicemocks.MockStsAssumeRoleWithWebIdentityAccessKey, 84 | AccountID: "666666666666", 85 | SecretAccessKey: servicemocks.MockStsAssumeRoleWithWebIdentitySecretKey, 86 | SessionToken: servicemocks.MockStsAssumeRoleWithWebIdentitySessionToken, 87 | Source: stscreds.WebIdentityProviderName, 88 | CanExpire: true, 89 | } 90 | 91 | MockSsoCredentials = aws.Credentials{ 92 | AccessKeyID: servicemocks.MockSsoAccessKeyID, 93 | AccountID: "123456789012", 94 | SecretAccessKey: servicemocks.MockSsoSecretAccessKey, 95 | SessionToken: servicemocks.MockSsoSessionToken, 96 | Source: ssocreds.ProviderName, 97 | CanExpire: true, 98 | } 99 | ) 100 | -------------------------------------------------------------------------------- /servicemocks/pem_file.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package servicemocks 5 | 6 | import ( 7 | "os" 8 | ) 9 | 10 | func TempPEMFile() (string, error) { 11 | file, err := os.CreateTemp("", "bundle-*.pem") 12 | if err != nil { 13 | return "", err 14 | } 15 | defer file.Close() 16 | 17 | _, err = file.Write(TLSBundleCA) 18 | if err != nil { 19 | return "", err 20 | } 21 | 22 | return file.Name(), nil 23 | } 24 | 25 | var ( 26 | // TLSBundleCA ca.crt 27 | TLSBundleCA = []byte(`-----BEGIN CERTIFICATE----- 28 | MIICiTCCAfKgAwIBAgIJAJ5X1olt05XjMA0GCSqGSIb3DQEBCwUAMDgxCzAJBgNV 29 | BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD 30 | QTAeFw0xNzAzMDkwMDAyMDZaFw0yNzAzMDcwMDAyMDZaMDgxCzAJBgNVBAYTAkdP 31 | MQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBDQTCBnzAN 32 | BgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAw/8DN+t9XQR60jx42rsQ2WE2Dx85rb3n 33 | GQxnKZZLNddsT8rDyxJNP18aFalbRbFlyln5fxWxZIblu9Xkm/HRhOpbSimSqo1y 34 | uDx21NVZ1YsOvXpHby71jx3gPrrhSc/t/zikhi++6D/C6m1CiIGuiJ0GBiJxtrub 35 | UBMXT0QtI2ECAwEAAaOBmjCBlzAdBgNVHQ4EFgQU8XG3X/YHBA6T04kdEkq6+4GV 36 | YykwaAYDVR0jBGEwX4AU8XG3X/YHBA6T04kdEkq6+4GVYymhPKQ6MDgxCzAJBgNV 37 | BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD 38 | QYIJAJ5X1olt05XjMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQELBQADgYEAeILv 39 | z49+uxmPcfOZzonuOloRcpdvyjiXblYxbzz6ch8GsE7Q886FTZbvwbgLhzdwSVgG 40 | G8WHkodDUsymVepdqAamS3f8PdCUk8xIk9mop8LgaB9Ns0/TssxDvMr3sOD2Grb3 41 | xyWymTWMcj6uCiEBKtnUp4rPiefcvCRYZ17/hLE= 42 | -----END CERTIFICATE----- 43 | `) 44 | ) 45 | -------------------------------------------------------------------------------- /servicemocks/setup.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package servicemocks 5 | 6 | import ( 7 | "crypto/sha1" 8 | "encoding/hex" 9 | "fmt" 10 | "net/http" 11 | "net/http/httptest" 12 | "os" 13 | "path/filepath" 14 | "runtime" 15 | "strings" 16 | "testing" 17 | "time" 18 | ) 19 | 20 | func InitSessionTestEnv(t *testing.T) { 21 | StashEnv(t) 22 | t.Setenv("AWS_CONFIG_FILE", "file_not_exists") 23 | t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "file_not_exists") 24 | } 25 | 26 | func StashEnv(t *testing.T) { 27 | env := os.Environ() 28 | os.Clearenv() 29 | 30 | t.Cleanup(func() { 31 | PopEnv(env) 32 | }) 33 | } 34 | 35 | func PopEnv(env []string) { 36 | os.Clearenv() 37 | 38 | for _, e := range env { 39 | p := strings.SplitN(e, "=", 2) 40 | k, v := p[0], "" 41 | if len(p) > 1 { 42 | v = p[1] 43 | } 44 | os.Setenv(k, v) 45 | } 46 | } 47 | 48 | // InvalidEC2MetadataEndpoint establishes a httptest server to simulate behaviour 49 | // when endpoint doesn't respond as expected 50 | func InvalidEC2MetadataEndpoint(t *testing.T) func() { 51 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 52 | t.Logf("Mock server received EC2 IMDS API %q request to %q", r.Method, r.RequestURI) 53 | w.WriteHeader(http.StatusBadRequest) 54 | })) 55 | 56 | t.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", ts.URL+"/latest") 57 | return ts.Close 58 | } 59 | 60 | // UnsetEnv unsets environment variables for testing a "clean slate" with no 61 | // credentials in the environment 62 | func UnsetEnv(t *testing.T) func() { 63 | t.Helper() 64 | 65 | // Grab any existing AWS keys and preserve. In some tests we'll unset these, so 66 | // we need to have them and restore them after 67 | e := getEnv() 68 | if err := os.Unsetenv("AWS_ACCESS_KEY_ID"); err != nil { 69 | t.Fatalf("Error unsetting env var AWS_ACCESS_KEY_ID: %s", err) 70 | } 71 | if err := os.Unsetenv("AWS_SECRET_ACCESS_KEY"); err != nil { 72 | t.Fatalf("Error unsetting env var AWS_SECRET_ACCESS_KEY: %s", err) 73 | } 74 | if err := os.Unsetenv("AWS_SESSION_TOKEN"); err != nil { 75 | t.Fatalf("Error unsetting env var AWS_SESSION_TOKEN: %s", err) 76 | } 77 | if err := os.Unsetenv("AWS_PROFILE"); err != nil { 78 | t.Fatalf("Error unsetting env var AWS_PROFILE: %s", err) 79 | } 80 | if err := os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE"); err != nil { 81 | t.Fatalf("Error unsetting env var AWS_SHARED_CREDENTIALS_FILE: %s", err) 82 | } 83 | // The Shared Credentials Provider has a very reasonable fallback option of 84 | // checking the user's home directory for credentials, which may create 85 | // unexpected results for users running these tests 86 | t.Setenv("HOME", "/dev/null") 87 | 88 | return func() { 89 | // re-set all the envs we unset above 90 | if err := os.Setenv("AWS_ACCESS_KEY_ID", e.Key); err != nil { 91 | t.Fatalf("Error resetting env var AWS_ACCESS_KEY_ID: %s", err) 92 | } 93 | if err := os.Setenv("AWS_SECRET_ACCESS_KEY", e.Secret); err != nil { 94 | t.Fatalf("Error resetting env var AWS_SECRET_ACCESS_KEY: %s", err) 95 | } 96 | if err := os.Setenv("AWS_SESSION_TOKEN", e.Token); err != nil { 97 | t.Fatalf("Error resetting env var AWS_SESSION_TOKEN: %s", err) 98 | } 99 | if err := os.Setenv("AWS_PROFILE", e.Profile); err != nil { 100 | t.Fatalf("Error resetting env var AWS_PROFILE: %s", err) 101 | } 102 | if err := os.Setenv("AWS_SHARED_CREDENTIALS_FILE", e.CredsFilename); err != nil { 103 | t.Fatalf("Error resetting env var AWS_SHARED_CREDENTIALS_FILE: %s", err) 104 | } 105 | if err := os.Setenv("HOME", e.Home); err != nil { 106 | t.Fatalf("Error resetting env var HOME: %s", err) 107 | } 108 | } 109 | } 110 | 111 | func SetEnv(s string, t *testing.T) func() { 112 | e := getEnv() 113 | // Set all the envs to a dummy value 114 | if err := os.Setenv("AWS_ACCESS_KEY_ID", s); err != nil { //nolint:tenv 115 | t.Fatalf("Error setting env var AWS_ACCESS_KEY_ID: %s", err) 116 | } 117 | if err := os.Setenv("AWS_SECRET_ACCESS_KEY", s); err != nil { //nolint:tenv 118 | t.Fatalf("Error setting env var AWS_SECRET_ACCESS_KEY: %s", err) 119 | } 120 | if err := os.Setenv("AWS_SESSION_TOKEN", s); err != nil { //nolint:tenv 121 | t.Fatalf("Error setting env var AWS_SESSION_TOKEN: %s", err) 122 | } 123 | if err := os.Setenv("AWS_PROFILE", s); err != nil { //nolint:tenv 124 | t.Fatalf("Error setting env var AWS_PROFILE: %s", err) 125 | } 126 | if err := os.Setenv("AWS_SHARED_CREDENTIALS_FILE", s); err != nil { //nolint:tenv 127 | t.Fatalf("Error setting env var AWS_SHARED_CREDENTIALS_FLE: %s", err) 128 | } 129 | 130 | return func() { 131 | // re-set all the envs we unset above 132 | if err := os.Setenv("AWS_ACCESS_KEY_ID", e.Key); err != nil { 133 | t.Fatalf("Error resetting env var AWS_ACCESS_KEY_ID: %s", err) 134 | } 135 | if err := os.Setenv("AWS_SECRET_ACCESS_KEY", e.Secret); err != nil { 136 | t.Fatalf("Error resetting env var AWS_SECRET_ACCESS_KEY: %s", err) 137 | } 138 | if err := os.Setenv("AWS_SESSION_TOKEN", e.Token); err != nil { 139 | t.Fatalf("Error resetting env var AWS_SESSION_TOKEN: %s", err) 140 | } 141 | if err := os.Setenv("AWS_PROFILE", e.Profile); err != nil { 142 | t.Fatalf("Error setting env var AWS_PROFILE: %s", err) 143 | } 144 | if err := os.Setenv("AWS_SHARED_CREDENTIALS_FILE", s); err != nil { 145 | t.Fatalf("Error setting env var AWS_SHARED_CREDENTIALS_FLE: %s", err) 146 | } 147 | } 148 | } 149 | 150 | func getEnv() *currentEnv { 151 | // Grab any existing AWS keys and preserve. In some tests we'll unset these, so 152 | // we need to have them and restore them after 153 | return ¤tEnv{ 154 | Key: os.Getenv("AWS_ACCESS_KEY_ID"), 155 | Secret: os.Getenv("AWS_SECRET_ACCESS_KEY"), 156 | Token: os.Getenv("AWS_SESSION_TOKEN"), 157 | Profile: os.Getenv("AWS_PROFILE"), 158 | CredsFilename: os.Getenv("AWS_SHARED_CREDENTIALS_FILE"), 159 | Home: os.Getenv("HOME"), 160 | } 161 | } 162 | 163 | // struct to preserve the current environment 164 | type currentEnv struct { 165 | Key, Secret, Token, Profile, CredsFilename, Home string 166 | } 167 | 168 | // Copied and adapted from https://github.com/aws/aws-sdk-go-v2/blob/ee5e3f05637540596cc7aab1359742000a8d533a/config/resolve_credentials_test.go#L127 169 | func SsoTestSetup(t *testing.T, ssoKey string) (err error) { 170 | t.Helper() 171 | 172 | dir := t.TempDir() 173 | 174 | cacheDir := filepath.Join(dir, ".aws", "sso", "cache") 175 | err = os.MkdirAll(cacheDir, 0750) 176 | if err != nil { 177 | return err 178 | } 179 | 180 | hash := sha1.New() 181 | if _, err := hash.Write([]byte(ssoKey)); err != nil { 182 | t.Fatalf("computing hash: %s", err) 183 | } 184 | 185 | cacheFilename := strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json" 186 | 187 | tokenFile, err := os.Create(filepath.Join(cacheDir, cacheFilename)) 188 | if err != nil { 189 | return err 190 | } 191 | 192 | defer func() { 193 | closeErr := tokenFile.Close() 194 | if err == nil { 195 | err = closeErr 196 | } else if closeErr != nil { 197 | err = fmt.Errorf("close error: %v, original error: %w", closeErr, err) 198 | } 199 | }() 200 | 201 | _, err = tokenFile.WriteString(fmt.Sprintf(ssoTokenCacheFile, time.Now(). 202 | Add(15*time.Minute). //nolint:mnd 203 | Format(time.RFC3339))) 204 | if err != nil { 205 | return err 206 | } 207 | 208 | if runtime.GOOS == "windows" { 209 | t.Setenv("USERPROFILE", dir) 210 | } else { 211 | t.Setenv("HOME", dir) 212 | } 213 | 214 | return nil 215 | } 216 | 217 | const ssoTokenCacheFile = `{ 218 | "accessToken": "ssoAccessToken", 219 | "expiresAt": "%s" 220 | }` 221 | -------------------------------------------------------------------------------- /tfawserr/awserr.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfawserr 5 | 6 | import ( 7 | "slices" 8 | "strings" 9 | 10 | smithy "github.com/aws/smithy-go" 11 | smithyhttp "github.com/aws/smithy-go/transport/http" 12 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/errs" 13 | ) 14 | 15 | // ErrCodeEquals returns true if the error matches all these conditions: 16 | // - err is of type smithy.APIError 17 | // - APIError.ErrorCode() equals one of the passed codes 18 | func ErrCodeEquals(err error, codes ...string) bool { 19 | if apiErr, ok := errs.As[smithy.APIError](err); ok { 20 | if slices.Contains(codes, apiErr.ErrorCode()) { 21 | return true 22 | } 23 | } 24 | return false 25 | } 26 | 27 | // ErrCodeContains returns true if the error matches all these conditions: 28 | // - err is of type smithy.APIError 29 | // - APIError.ErrorCode() contains code 30 | func ErrCodeContains(err error, code string) bool { 31 | if apiErr, ok := errs.As[smithy.APIError](err); ok { 32 | return strings.Contains(apiErr.ErrorCode(), code) 33 | } 34 | return false 35 | } 36 | 37 | // ErrMessageContains returns true if the error matches all these conditions: 38 | // - err is of type smithy.APIError 39 | // - APIError.ErrorCode() equals code 40 | // - APIError.ErrorMessage() contains message 41 | func ErrMessageContains(err error, code string, message string) bool { 42 | if apiErr, ok := errs.As[smithy.APIError](err); ok { 43 | return apiErr.ErrorCode() == code && strings.Contains(apiErr.ErrorMessage(), message) 44 | } 45 | return false 46 | } 47 | 48 | // ErrHTTPStatusCodeEquals returns true if the error matches all these conditions: 49 | // - err is of type smithyhttp.ResponseError 50 | // - ResponseError.HTTPStatusCode() equals one of the passed status codes 51 | func ErrHTTPStatusCodeEquals(err error, statusCodes ...int) bool { 52 | if respErr, ok := errs.As[*smithyhttp.ResponseError](err); ok { 53 | if slices.Contains(statusCodes, respErr.HTTPStatusCode()) { 54 | return true 55 | } 56 | } 57 | return false 58 | } 59 | -------------------------------------------------------------------------------- /tools/main.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | //go:build tools 5 | // +build tools 6 | 7 | package main 8 | 9 | import ( 10 | _ "github.com/golangci/golangci-lint/cmd/golangci-lint" 11 | _ "github.com/pavius/impi/cmd/impi" 12 | ) 13 | -------------------------------------------------------------------------------- /user_agent.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsbase 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | 10 | "github.com/aws/smithy-go/middleware" 11 | smithyhttp "github.com/aws/smithy-go/transport/http" 12 | "github.com/hashicorp/aws-sdk-go-base/v2/useragent" 13 | ) 14 | 15 | func apnUserAgentMiddleware(apn APNInfo) middleware.BuildMiddleware { 16 | return middleware.BuildMiddlewareFunc("tfAPNUserAgent", 17 | func(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (middleware.BuildOutput, middleware.Metadata, error) { 18 | request, ok := in.Request.(*smithyhttp.Request) 19 | if !ok { 20 | return middleware.BuildOutput{}, middleware.Metadata{}, fmt.Errorf("unknown request type %T", in.Request) 21 | } 22 | 23 | prependUserAgentHeader(request, apn.BuildUserAgentString()) 24 | 25 | return next.HandleBuild(ctx, in) 26 | }, 27 | ) 28 | } 29 | 30 | // Because the default User-Agent middleware prepends itself to the contents of the User-Agent header, 31 | // we have to run after it and also prepend our custom User-Agent 32 | func prependUserAgentHeader(request *smithyhttp.Request, value string) { 33 | current := request.Header.Get("User-Agent") 34 | if len(current) > 0 { 35 | current = value + " " + current 36 | } else { 37 | current = value 38 | } 39 | request.Header["User-Agent"] = append(request.Header["User-Agent"][:0], current) 40 | } 41 | 42 | func withUserAgentAppender(ua string) func(*middleware.Stack) error { 43 | return func(stack *middleware.Stack) error { 44 | return stack.Build.Add(userAgentMiddleware(ua), middleware.After) 45 | } 46 | } 47 | 48 | func userAgentMiddleware(ua string) middleware.BuildMiddleware { 49 | return middleware.BuildMiddlewareFunc("tfUserAgentAppender", 50 | func(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (middleware.BuildOutput, middleware.Metadata, error) { 51 | request, ok := in.Request.(*smithyhttp.Request) 52 | if !ok { 53 | return middleware.BuildOutput{}, middleware.Metadata{}, fmt.Errorf("unknown request type %T", in.Request) 54 | } 55 | 56 | appendUserAgentHeader(request, ua) 57 | 58 | return next.HandleBuild(ctx, in) 59 | }, 60 | ) 61 | } 62 | 63 | func userAgentFromContextMiddleware() middleware.BuildMiddleware { 64 | return middleware.BuildMiddlewareFunc("tfCtxUserAgentAppender", 65 | func(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (middleware.BuildOutput, middleware.Metadata, error) { 66 | request, ok := in.Request.(*smithyhttp.Request) 67 | if !ok { 68 | return middleware.BuildOutput{}, middleware.Metadata{}, fmt.Errorf("unknown request type %T", in.Request) 69 | } 70 | 71 | if v := useragent.BuildFromContext(ctx); v != "" { 72 | appendUserAgentHeader(request, v) 73 | } 74 | 75 | return next.HandleBuild(ctx, in) 76 | }, 77 | ) 78 | } 79 | 80 | func appendUserAgentHeader(request *smithyhttp.Request, value string) { 81 | current := request.Header.Get("User-Agent") 82 | if len(current) > 0 { 83 | current = current + " " + value 84 | } else { 85 | current = value 86 | } 87 | request.Header["User-Agent"] = append(request.Header["User-Agent"][:0], current) 88 | } 89 | -------------------------------------------------------------------------------- /user_agent_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsbase 5 | 6 | import ( 7 | "fmt" 8 | "runtime" 9 | "strings" 10 | 11 | "github.com/aws/aws-sdk-go-v2/aws" 12 | ) 13 | 14 | func awsSdkGoUserAgent() string { 15 | // See https://github.com/aws/aws-sdk-go-v2/blob/4051ca807a0308bc9f169ca308262b328c2692a3/aws/middleware/user_agent_test.go#L18C1-L18C1 16 | return fmt.Sprintf("%s/%s os/%s lang/go#%s md/GOOS#%s md/GOARCH#%s", aws.SDKName, aws.SDKVersion, getNormalizedOSName(), strings.TrimPrefix(runtime.Version(), "go"), runtime.GOOS, runtime.GOARCH) 17 | } 18 | 19 | // Copied from https://github.com/aws/aws-sdk-go-v2/blob/main/aws/middleware/osname.go 20 | func getNormalizedOSName() (os string) { 21 | switch runtime.GOOS { 22 | case "android": 23 | os = "android" 24 | case "linux": 25 | os = "linux" 26 | case "windows": 27 | os = "windows" 28 | case "darwin": 29 | os = "macos" 30 | case "ios": 31 | os = "ios" 32 | default: 33 | os = "other" 34 | } 35 | return os 36 | } 37 | 38 | // cleanUserAgent removes: 39 | // * the "api/" product that the AWS SDK adds to the user-agent string 40 | // * the "ua/" product that contains the User-Agent string version 41 | // * the "m/" product that contains the feature flags 42 | func cleanUserAgent(ua string) string { 43 | var parts []string 44 | for _, v := range strings.Split(ua, " ") { 45 | if strings.HasPrefix(v, "api/") { 46 | continue 47 | } 48 | if strings.HasPrefix(v, "ua/") { 49 | continue 50 | } 51 | if strings.HasPrefix(v, "m/") { 52 | continue 53 | } 54 | parts = append(parts, v) 55 | } 56 | return strings.Join(parts, " ") 57 | } 58 | -------------------------------------------------------------------------------- /useragent/context.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package useragent 5 | 6 | import ( 7 | "context" 8 | 9 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" 10 | ) 11 | 12 | type userAgentKey string 13 | 14 | const ( 15 | contextScopedUserAgent userAgentKey = "ContextScopedUserAgent" 16 | ) 17 | 18 | func Context(ctx context.Context, products config.UserAgentProducts) context.Context { 19 | return context.WithValue(ctx, contextScopedUserAgent, products) 20 | } 21 | 22 | func BuildFromContext(ctx context.Context) string { 23 | ps, ok := ctx.Value(contextScopedUserAgent).(config.UserAgentProducts) 24 | if !ok { 25 | return "" 26 | } 27 | 28 | return ps.BuildUserAgentString() 29 | } 30 | -------------------------------------------------------------------------------- /useragent/context_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package useragent 5 | 6 | import ( 7 | "context" 8 | "testing" 9 | 10 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" 11 | ) 12 | 13 | func TestFromContext(t *testing.T) { 14 | testcases := map[string]struct { 15 | setup func() context.Context 16 | expected string 17 | }{ 18 | "empty": { 19 | setup: func() context.Context { 20 | return context.Background() 21 | }, 22 | expected: "", 23 | }, 24 | "UserAgentProducts": { 25 | setup: func() context.Context { 26 | return Context(context.Background(), config.UserAgentProducts{ 27 | { 28 | Name: "first", 29 | Version: "1.2.3", 30 | }, 31 | { 32 | Name: "second", 33 | Version: "1.0.2", 34 | Comment: "a comment", 35 | }, 36 | }) 37 | }, 38 | expected: "first/1.2.3 second/1.0.2 (a comment)", 39 | }, 40 | "[]UserAgentProduct": { 41 | setup: func() context.Context { 42 | return Context(context.Background(), []config.UserAgentProduct{ 43 | { 44 | Name: "first", 45 | Version: "1.2.3", 46 | }, 47 | { 48 | Name: "second", 49 | Version: "1.0.2", 50 | Comment: "a comment", 51 | }, 52 | }) 53 | }, 54 | expected: "first/1.2.3 second/1.0.2 (a comment)", 55 | }, 56 | } 57 | 58 | for name, testcase := range testcases { 59 | t.Run(name, func(t *testing.T) { 60 | ctx := testcase.setup() 61 | 62 | v := BuildFromContext(ctx) 63 | 64 | if v != testcase.expected { 65 | t.Errorf("expected %q, got %q", testcase.expected, v) 66 | } 67 | }) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /v2/awsv1shim/credentials.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsv1shim 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "sync/atomic" 10 | "time" 11 | 12 | awsv2 "github.com/aws/aws-sdk-go-v2/aws" // nosemgrep: no-sdkv2-imports-in-awsv1shim 13 | "github.com/aws/aws-sdk-go/aws/credentials" 14 | ) 15 | 16 | type v2CredentialsProvider struct { 17 | provider awsv2.CredentialsProvider 18 | 19 | v2creds atomic.Value 20 | } 21 | 22 | // This adapter deals with multiple levels of caching and a slight mismatch between the AWS SDK for Go v1 and v2 credentials models. 23 | // In the SDK v1 model has a root `credentials.Credentials` struct that handles caching. The `credentials.Value` contains only keys. 24 | // The `credentials.Credentials` struct handles expiry information by calling the credentials provider. 25 | // In the SDK v2 model, the SDK returns an `aws.CredentialsCache` which handles caching. The `aws.Credentials` value contains keys 26 | // as well as the expiry information. 27 | // 28 | // The `v2CredentialsProvider` will typically be used with the following layout: 29 | // (v1)`credentials.Credentials` ==> `v2CredentialsProvider` ==> (v2)`aws.CredentialsCache` ==> (v2) 30 | // 31 | // Since the SDK v1 `credentials.Credentials` handles expiry, it has an `Expire` function to explicitly expire credentials. This is 32 | // used, for example, in the SDK v1 default retry handler to catch an expired credentials error. Because of this, the result of 33 | // `RetrieveWithContext` cannot be cached in `v2CredentialsProvider`. 34 | // NOTE: Since the `Expire()` call is not passed up the chain, the (v2)`aws.CredentialsCache` will not have its cache cleared. This 35 | // may cause problems if a credential is revoked early. If this becomes a problem, every call to `RetrieveWithContext` may need to 36 | // call `Invalidate()` on the (v2)`aws.CredentialsCache`. In practice, `RetrieveWithContext` is rarely called, so this is not likely 37 | // to have a significant impact. 38 | // 39 | // The expiry information is cached in `v2CredentialsProvider` because the SDK v1 model handles expiry separately from the credential 40 | // information, and otherwise calling `IsExpired()` and `ExpiresAt()` would potentially call the actual credential provider on each call. 41 | 42 | func (p *v2CredentialsProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { 43 | v2creds, err := p.provider.Retrieve(ctx) 44 | if err != nil { 45 | return credentials.Value{}, err 46 | } 47 | p.v2creds.Store(&v2creds) 48 | 49 | return credentials.Value{ 50 | AccessKeyID: v2creds.AccessKeyID, 51 | SecretAccessKey: v2creds.SecretAccessKey, 52 | SessionToken: v2creds.SessionToken, 53 | ProviderName: fmt.Sprintf("v2Credentials(%s)", v2creds.Source), 54 | }, nil 55 | } 56 | 57 | func (p *v2CredentialsProvider) IsExpired() bool { 58 | v2creds := p.credentials() 59 | if v2creds != nil { 60 | return v2creds.Expired() 61 | } 62 | return true 63 | } 64 | 65 | func (p *v2CredentialsProvider) ExpiresAt() time.Time { 66 | v2creds := p.credentials() 67 | if v2creds != nil { 68 | return v2creds.Expires 69 | } 70 | return time.Time{} 71 | } 72 | 73 | func (p *v2CredentialsProvider) Retrieve() (credentials.Value, error) { 74 | return p.RetrieveWithContext(context.Background()) 75 | } 76 | 77 | func (p *v2CredentialsProvider) credentials() *awsv2.Credentials { 78 | v := p.v2creds.Load() 79 | if v == nil { 80 | return nil 81 | } 82 | 83 | c := v.(*awsv2.Credentials) 84 | if c != nil && c.HasKeys() && !c.Expired() { 85 | return c 86 | } 87 | 88 | return nil 89 | } 90 | 91 | func newV2Credentials(v2provider awsv2.CredentialsProvider) *credentials.Credentials { 92 | return credentials.NewCredentials(&v2CredentialsProvider{ 93 | provider: v2provider, 94 | }) 95 | } 96 | -------------------------------------------------------------------------------- /v2/awsv1shim/credentials_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsv1shim 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "testing" 10 | "time" 11 | 12 | awsv2 "github.com/aws/aws-sdk-go-v2/aws" // nosemgrep: no-sdkv2-imports-in-awsv1shim 13 | credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials" // nosemgrep: no-sdkv2-imports-in-awsv1shim 14 | stscredsv2 "github.com/aws/aws-sdk-go-v2/credentials/stscreds" // nosemgrep: no-sdkv2-imports-in-awsv1shim 15 | stsv2 "github.com/aws/aws-sdk-go-v2/service/sts" // nosemgrep: no-sdkv2-imports-in-awsv1shim 16 | ststypesv2 "github.com/aws/aws-sdk-go-v2/service/sts/types" // nosemgrep: no-sdkv2-imports-in-awsv1shim 17 | "github.com/aws/aws-sdk-go/aws" 18 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/test" 19 | ) 20 | 21 | func TestV2CredentialsProviderPassthrough(t *testing.T) { 22 | ctx := test.Context(t) 23 | 24 | v2creds := credentialsv2.NewStaticCredentialsProvider("key", "secret", "session") 25 | 26 | creds := newV2Credentials(v2creds) 27 | 28 | value, err := creds.GetWithContext(ctx) 29 | if err != nil { 30 | t.Fatalf("unexpected error: %s", err) 31 | } 32 | 33 | if a, e := value.AccessKeyID, "key"; a != e { 34 | t.Errorf("AccessKeyID: expected %q, got %q", e, a) 35 | } 36 | if a, e := value.SecretAccessKey, "secret"; a != e { 37 | t.Errorf("SecretAccessKey: expected %q, got %q", e, a) 38 | } 39 | if a, e := value.SessionToken, "session"; a != e { 40 | t.Errorf("SecretAccessKey: expected %q, got %q", e, a) 41 | } 42 | if a, e := value.ProviderName, fmt.Sprintf("v2Credentials(%s)", credentialsv2.StaticCredentialsName); a != e { 43 | t.Errorf("ProviderName: expected %q, got %q", e, a) 44 | } 45 | } 46 | 47 | func TestV2CredentialsProviderExpriry(t *testing.T) { 48 | testcases := map[string]struct { 49 | v2creds awsv2.CredentialsProvider 50 | }{ 51 | credentialsv2.StaticCredentialsName: { 52 | v2creds: credentialsv2.NewStaticCredentialsProvider("key", "secret", "session"), 53 | }, 54 | } 55 | 56 | for name, testcase := range testcases { 57 | t.Run(name, func(t *testing.T) { 58 | ctx := test.Context(t) 59 | 60 | creds := newV2Credentials(testcase.v2creds) 61 | 62 | // Credentials need to be retrieved before we can check 63 | _, err := creds.GetWithContext(ctx) 64 | if err != nil { 65 | t.Fatalf("unexpected error: %s", err) 66 | } 67 | if creds.IsExpired() { 68 | t.Fatalf("did not expect creds to be expired") 69 | } 70 | expiry, err := creds.ExpiresAt() 71 | if err != nil { 72 | t.Fatalf("unexpected error getting expiry: %s", err) 73 | } 74 | if !expiry.Equal(time.Time{}) { 75 | t.Fatalf("expected no expiry time, got %s", expiry) 76 | } 77 | 78 | creds.Expire() 79 | if !creds.IsExpired() { 80 | t.Fatalf("expected creds to be expired") 81 | } 82 | 83 | value, err := creds.GetWithContext(ctx) 84 | if err != nil { 85 | t.Fatalf("unexpected error: %s", err) 86 | } 87 | if value.AccessKeyID == "" { 88 | t.Error("AccessKeyID: expected a value") 89 | } 90 | if value.SecretAccessKey == "" { 91 | t.Error("SecretAccessKey: expected a value") 92 | } 93 | if value.SessionToken == "" { 94 | t.Error("SessionToken: expected a value") 95 | } 96 | }) 97 | } 98 | } 99 | 100 | func TestV2CredentialsProviderExpriry_AssumeRole(t *testing.T) { 101 | ctx := test.Context(t) 102 | 103 | stsClient := &mockAssumeRole{} 104 | v2creds := stscredsv2.NewAssumeRoleProvider(stsClient, "role") 105 | 106 | creds := newV2Credentials(v2creds) 107 | 108 | // Credentials need to be retrieved before we can check expiry information 109 | _, err := creds.GetWithContext(ctx) 110 | if err != nil { 111 | t.Fatalf("unexpected error: %s", err) 112 | } 113 | if creds.IsExpired() { 114 | t.Fatalf("did not expect creds to be expired") 115 | } 116 | expiry, err := creds.ExpiresAt() 117 | if err != nil { 118 | t.Fatalf("unexpected error getting expiry: %s", err) 119 | } 120 | if expiry.Equal(time.Time{}) { 121 | t.Fatal("expected expiry time, got none") 122 | } 123 | 124 | creds.Expire() 125 | if !creds.IsExpired() { 126 | t.Fatalf("expected creds to be expired") 127 | } 128 | 129 | value, err := creds.GetWithContext(ctx) 130 | if err != nil { 131 | t.Fatalf("unexpected error: %s", err) 132 | } 133 | if value.AccessKeyID == "" { 134 | t.Error("AccessKeyID: expected a value") 135 | } 136 | if value.SecretAccessKey == "" { 137 | t.Error("SecretAccessKey: expected a value") 138 | } 139 | if value.SessionToken == "" { 140 | t.Error("SessionToken: expected a value") 141 | } 142 | } 143 | 144 | func TestV2CredentialsProviderCaching(t *testing.T) { 145 | ctx := test.Context(t) 146 | 147 | stsClientCalls := 0 148 | expectedStsClientCalls := 0 149 | stsClient := &mockAssumeRole{ 150 | TestInput: func(in *stsv2.AssumeRoleInput) { 151 | stsClientCalls++ 152 | }, 153 | } 154 | v2creds := stscredsv2.NewAssumeRoleProvider(stsClient, "role") 155 | creds := newV2Credentials(v2creds) 156 | if stsClientCalls != expectedStsClientCalls { 157 | t.Errorf("did not expect call to STS client") 158 | expectedStsClientCalls = stsClientCalls 159 | } 160 | 161 | _, err := creds.GetWithContext(ctx) 162 | if err != nil { 163 | t.Fatalf("unexpected error: %s", err) 164 | } 165 | expectedStsClientCalls++ 166 | if stsClientCalls != expectedStsClientCalls { 167 | t.Errorf("expected call to STS client") 168 | expectedStsClientCalls = stsClientCalls 169 | } 170 | 171 | _, err = creds.GetWithContext(ctx) 172 | if err != nil { 173 | t.Fatalf("unexpected error: %s", err) 174 | } 175 | if stsClientCalls != expectedStsClientCalls { 176 | t.Errorf("did not expect call to STS client") 177 | expectedStsClientCalls = stsClientCalls 178 | } 179 | 180 | creds.IsExpired() 181 | if stsClientCalls != expectedStsClientCalls { 182 | t.Errorf("did not expect call to STS client") 183 | expectedStsClientCalls = stsClientCalls 184 | } 185 | 186 | _, err = creds.ExpiresAt() 187 | if err != nil { 188 | t.Fatalf("unexpected error: %s", err) 189 | } 190 | if stsClientCalls != expectedStsClientCalls { 191 | t.Errorf("did not expect call to STS client") 192 | expectedStsClientCalls = stsClientCalls 193 | } 194 | 195 | creds.Expire() 196 | if stsClientCalls != expectedStsClientCalls { 197 | t.Errorf("did not expect call to STS client") 198 | expectedStsClientCalls = stsClientCalls 199 | } 200 | 201 | creds.IsExpired() 202 | if stsClientCalls != expectedStsClientCalls { 203 | t.Errorf("did not expect call to STS client") 204 | expectedStsClientCalls = stsClientCalls 205 | } 206 | 207 | _, err = creds.ExpiresAt() 208 | if err != nil { 209 | t.Fatalf("unexpected error: %s", err) 210 | } 211 | if stsClientCalls != expectedStsClientCalls { 212 | t.Errorf("did not expect call to STS client") 213 | expectedStsClientCalls = stsClientCalls 214 | } 215 | 216 | _, err = creds.GetWithContext(ctx) 217 | if err != nil { 218 | t.Fatalf("unexpected error: %s", err) 219 | } 220 | expectedStsClientCalls++ 221 | if stsClientCalls != expectedStsClientCalls { 222 | t.Errorf("expected call to STS client") 223 | } 224 | } 225 | 226 | type mockAssumeRole struct { 227 | TestInput func(*stsv2.AssumeRoleInput) 228 | } 229 | 230 | func (s *mockAssumeRole) AssumeRole(ctx context.Context, params *stsv2.AssumeRoleInput, optFns ...func(*stsv2.Options)) (*stsv2.AssumeRoleOutput, error) { 231 | if s.TestInput != nil { 232 | s.TestInput(params) 233 | } 234 | expiry := time.Now().Add(60 * time.Minute) 235 | 236 | return &stsv2.AssumeRoleOutput{ 237 | Credentials: &ststypesv2.Credentials{ 238 | // Just reflect the role arn to the provider. 239 | AccessKeyId: params.RoleArn, 240 | SecretAccessKey: aws.String("assumedSecretAccessKey"), 241 | SessionToken: aws.String("assumedSessionToken"), 242 | Expiration: &expiry, 243 | }, 244 | }, nil 245 | } 246 | -------------------------------------------------------------------------------- /v2/awsv1shim/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/aws-sdk-go-base/v2/awsv1shim/v2 2 | 3 | go 1.23.6 4 | 5 | require ( 6 | github.com/aws/aws-sdk-go v1.55.7 7 | github.com/aws/aws-sdk-go-v2 v1.36.3 8 | github.com/aws/aws-sdk-go-v2/config v1.29.14 9 | github.com/aws/aws-sdk-go-v2/credentials v1.17.67 10 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 11 | github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 12 | github.com/google/go-cmp v0.7.0 13 | github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.64 14 | github.com/hashicorp/go-cleanhttp v0.5.2 15 | github.com/hashicorp/terraform-plugin-log v0.9.0 16 | go.opentelemetry.io/contrib/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws v0.61.0 17 | go.opentelemetry.io/otel v1.36.0 18 | ) 19 | 20 | require ( 21 | github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect 22 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect 23 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect 24 | github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect 25 | github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 // indirect 26 | github.com/aws/aws-sdk-go-v2/service/dynamodb v1.43.1 // indirect 27 | github.com/aws/aws-sdk-go-v2/service/iam v1.42.0 // indirect 28 | github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect 29 | github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.2 // indirect 30 | github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.15 // indirect 31 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect 32 | github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 // indirect 33 | github.com/aws/aws-sdk-go-v2/service/s3 v1.79.4 // indirect 34 | github.com/aws/aws-sdk-go-v2/service/sns v1.34.4 // indirect 35 | github.com/aws/aws-sdk-go-v2/service/sqs v1.38.5 // indirect 36 | github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect 37 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect 38 | github.com/aws/smithy-go v1.22.3 // indirect 39 | github.com/fatih/color v1.18.0 // indirect 40 | github.com/go-logr/logr v1.4.2 // indirect 41 | github.com/go-logr/stdr v1.2.2 // indirect 42 | github.com/hashicorp/errwrap v1.1.0 // indirect 43 | github.com/hashicorp/go-hclog v1.6.3 // indirect 44 | github.com/hashicorp/go-multierror v1.1.1 // indirect 45 | github.com/jmespath/go-jmespath v0.4.0 // indirect 46 | github.com/mattn/go-colorable v0.1.14 // indirect 47 | github.com/mattn/go-isatty v0.0.20 // indirect 48 | github.com/mitchellh/go-homedir v1.1.0 // indirect 49 | github.com/mitchellh/go-testing-interface v1.14.1 // indirect 50 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect 51 | go.opentelemetry.io/otel/metric v1.36.0 // indirect 52 | go.opentelemetry.io/otel/trace v1.36.0 // indirect 53 | golang.org/x/net v0.40.0 // indirect 54 | golang.org/x/sys v0.33.0 // indirect 55 | golang.org/x/text v0.25.0 // indirect 56 | gopkg.in/yaml.v2 v2.4.0 // indirect 57 | ) 58 | 59 | replace github.com/hashicorp/aws-sdk-go-base/v2 => ../.. 60 | -------------------------------------------------------------------------------- /v2/awsv1shim/http_client.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsv1shim 5 | 6 | import ( 7 | "net/http" 8 | 9 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" 10 | "github.com/hashicorp/go-cleanhttp" 11 | ) 12 | 13 | func defaultHttpClient(c *config.Config) (*http.Client, error) { 14 | opts, err := c.HTTPTransportOptions() 15 | if err != nil { 16 | return nil, err 17 | } 18 | 19 | httpClient := cleanhttp.DefaultPooledClient() 20 | opts(httpClient.Transport.(*http.Transport)) 21 | 22 | return httpClient, nil 23 | } 24 | -------------------------------------------------------------------------------- /v2/awsv1shim/http_client_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsv1shim 5 | 6 | import ( 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" 11 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/test" 12 | ) 13 | 14 | func TestHTTPClientConfiguration_basic(t *testing.T) { 15 | client, err := defaultHttpClient(&config.Config{}) 16 | if err != nil { 17 | t.Fatalf("unexpected error: %s", err) 18 | } 19 | 20 | transport, ok := client.Transport.(*http.Transport) 21 | if !ok { 22 | t.Fatalf("Unexpected type for HTTP client transport: %T", client.Transport) 23 | } 24 | 25 | test.HTTPClientConfigurationTest_basic(t, transport) 26 | } 27 | 28 | func TestHTTPClientConfiguration_insecureHTTPS(t *testing.T) { 29 | client, err := defaultHttpClient(&config.Config{ 30 | Insecure: true, 31 | }) 32 | if err != nil { 33 | t.Fatalf("unexpected error: %s", err) 34 | } 35 | 36 | transport, ok := client.Transport.(*http.Transport) 37 | if !ok { 38 | t.Fatalf("Unexpected type for HTTP client transport: %T", client.Transport) 39 | } 40 | 41 | test.HTTPClientConfigurationTest_insecureHTTPS(t, transport) 42 | } 43 | 44 | func TestHTTPClientConfiguration_proxy(t *testing.T) { 45 | test.HTTPClientConfigurationTest_proxy(t, transport) 46 | } 47 | 48 | func transport(t *testing.T, config *config.Config) *http.Transport { 49 | t.Helper() 50 | 51 | client, err := defaultHttpClient(config) 52 | if err != nil { 53 | t.Fatalf("creating client: %s", err) 54 | } 55 | 56 | transport, ok := client.Transport.(*http.Transport) 57 | if !ok { 58 | t.Fatalf("Unexpected type for HTTP client transport: %T", client.Transport) 59 | } 60 | 61 | return transport 62 | } 63 | -------------------------------------------------------------------------------- /v2/awsv1shim/logger.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsv1shim 5 | 6 | import ( 7 | "bufio" 8 | "bytes" 9 | "context" 10 | "fmt" 11 | "io" 12 | "log" 13 | "net/http" 14 | "net/textproto" 15 | "strings" 16 | "time" 17 | 18 | "github.com/aws/aws-sdk-go/aws" 19 | "github.com/aws/aws-sdk-go/aws/request" 20 | "github.com/hashicorp/aws-sdk-go-base/v2/logging" 21 | "github.com/hashicorp/terraform-plugin-log/tflog" 22 | "go.opentelemetry.io/contrib/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws" 23 | "go.opentelemetry.io/otel/attribute" 24 | "go.opentelemetry.io/otel/semconv/v1.17.0/httpconv" 25 | ) 26 | 27 | const ( 28 | responseBufferLen = logging.MaxResponseBodyLen + 1024 29 | ) 30 | 31 | type debugLogger struct{} 32 | 33 | func (l debugLogger) Log(args ...any) { 34 | tokens := make([]string, 0, len(args)) 35 | for _, arg := range args { 36 | if token, ok := arg.(string); ok { 37 | tokens = append(tokens, token) 38 | } 39 | } 40 | s := strings.Join(tokens, " ") 41 | s = strings.ReplaceAll(s, "\r", "") // Works around https://github.com/jen20/teamcity-go-test/pull/2 42 | log.Printf("missing_context: %s "+string(logging.AwsSdkKey)+"="+awsSdkGoV1Val, s) 43 | } 44 | 45 | func setAWSFields(ctx context.Context, r *request.Request) context.Context { 46 | region := aws.StringValue(r.Config.Region) 47 | 48 | attributes := []attribute.KeyValue{ 49 | otelaws.SystemAttr(), 50 | otelaws.ServiceAttr(r.ClientInfo.ServiceID), 51 | otelaws.RegionAttr(region), 52 | otelaws.OperationAttr(r.Operation.Name), 53 | awsSDKv1Attr(), 54 | } 55 | if signingRegion := r.ClientInfo.SigningRegion; signingRegion != region { 56 | attributes = append(attributes, logging.SigningRegion(signingRegion)) 57 | } 58 | 59 | for _, attribute := range attributes { 60 | ctx = tflog.SetField(ctx, string(attribute.Key), attribute.Value.AsInterface()) 61 | } 62 | 63 | return ctx 64 | } 65 | 66 | const awsSdkGoV1Val = "aws-sdk-go" 67 | 68 | func awsSDKv1Attr() attribute.KeyValue { 69 | return logging.AwsSdkKey.String(awsSdkGoV1Val) 70 | } 71 | 72 | type durationKeyT string 73 | 74 | const durationKey durationKeyT = "request-duration" 75 | 76 | // Replaces the built-in logging middleware from https://github.com/aws/aws-sdk-go/blob/main/aws/client/logger.go 77 | // We want access to the request struct, and cannot get it from the built-in. 78 | // The typical route of adding logging to the http.RoundTripper doesn't work for the AWS SDK for Go v1 without forcing us to manually implement 79 | // configuration that the SDK handles for us. 80 | var requestLogger = request.NamedHandler{ 81 | Name: "TF_AWS_RequestLogger", 82 | Fn: logRequest, 83 | } 84 | 85 | func logRequest(r *request.Request) { 86 | ctx := r.Context() 87 | 88 | ctx = setAWSFields(ctx, r) 89 | 90 | bodySeekable := aws.IsReaderSeekable(r.Body) 91 | 92 | requestFields, err := logging.DecomposeHTTPRequest(ctx, r.HTTPRequest) 93 | if err != nil { 94 | tflog.Error(ctx, fmt.Sprintf("decomposing request: %s", err)) 95 | return 96 | } 97 | 98 | if !bodySeekable { 99 | r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body)) 100 | } 101 | // Reset the request body because dumpRequest will re-wrap the 102 | // r.HTTPRequest's Body as a NoOpCloser and will not be reset after 103 | // read by the HTTP client reader. 104 | if err := r.Error; err != nil { 105 | tflog.Error(ctx, fmt.Sprintf("decomposing request: %s", err)) 106 | return 107 | } 108 | 109 | tflog.Debug(ctx, "HTTP Request Sent", requestFields) 110 | 111 | ctx = context.WithValue(ctx, durationKey, time.Now()) 112 | 113 | r.SetContext(ctx) 114 | } 115 | 116 | // Replaces the built-in logging middleware from https://github.com/aws/aws-sdk-go/blob/main/aws/client/logger.go 117 | // We want access to the response struct, and cannot get it from the built-in. 118 | // The typical route of adding logging to the http.RoundTripper doesn't work for the AWS SDK for Go v1 without forcing us to manually implement 119 | // configuration that the SDK handles for us. 120 | var responseLogger = request.NamedHandler{ 121 | Name: "TF_AWS_ResponseLogger", 122 | Fn: logResponse, 123 | } 124 | 125 | func logResponse(r *request.Request) { 126 | ctx := r.Context() 127 | 128 | ctx = setAWSFields(ctx, r) 129 | 130 | if r.HTTPResponse == nil { 131 | tflog.Error(ctx, "HTTP response is nil") 132 | return 133 | } 134 | 135 | bodyBuffer := bytes.NewBuffer(nil) 136 | 137 | r.HTTPResponse.Body = &teeReaderCloser{ 138 | Reader: io.TeeReader(r.HTTPResponse.Body, limitWriter(bodyBuffer, responseBufferLen)), 139 | Source: r.HTTPResponse.Body, 140 | } 141 | 142 | handlerFn := func(req *request.Request) { 143 | ctx := r.Context() 144 | 145 | var elapsed time.Duration 146 | if start, ok := ctx.Value(durationKey).(time.Time); ok { 147 | elapsed = time.Since(start) 148 | } 149 | 150 | ctx = setAWSFields(ctx, r) 151 | 152 | responseFields, err := decomposeHTTPResponse(r.HTTPResponse, bodyBuffer, elapsed) 153 | if err != nil { 154 | tflog.Error(ctx, fmt.Sprintf("decomposing response: %s", err)) 155 | return 156 | } 157 | tflog.Debug(ctx, "HTTP Response Received", responseFields) 158 | } 159 | 160 | const handlerName = "TF_AWS_ResponseBodyLogger" 161 | 162 | r.Handlers.Unmarshal.SetBackNamed(request.NamedHandler{ 163 | Name: handlerName, Fn: handlerFn, 164 | }) 165 | r.Handlers.UnmarshalError.SetBackNamed(request.NamedHandler{ 166 | Name: handlerName, Fn: handlerFn, 167 | }) 168 | } 169 | 170 | type teeReaderCloser struct { 171 | // io.Reader will be a tee reader that is used during logging. 172 | // This structure will read from a body and write the contents to a logger. 173 | io.Reader 174 | // Source is used just to close when we are done reading. 175 | Source io.ReadCloser 176 | } 177 | 178 | func (reader *teeReaderCloser) Close() error { 179 | return reader.Source.Close() 180 | } 181 | 182 | func decomposeHTTPResponse(resp *http.Response, body io.Reader, elapsed time.Duration) (map[string]any, error) { 183 | var attributes []attribute.KeyValue 184 | 185 | attributes = append(attributes, attribute.Int64("http.duration", elapsed.Milliseconds())) 186 | 187 | attributes = append(attributes, httpconv.ClientResponse(resp)...) 188 | 189 | attributes = append(attributes, logging.DecomposeResponseHeaders(resp)...) 190 | 191 | bodyAttribute, err := decomposeResponseBody(body) 192 | if err != nil { 193 | return nil, err 194 | } 195 | attributes = append(attributes, bodyAttribute) 196 | 197 | result := make(map[string]any, len(attributes)) 198 | for _, attribute := range attributes { 199 | result[string(attribute.Key)] = attribute.Value.AsInterface() 200 | } 201 | 202 | return result, nil 203 | } 204 | 205 | func decomposeResponseBody(bodyReader io.Reader) (kv attribute.KeyValue, err error) { 206 | content, err := io.ReadAll(bodyReader) 207 | if err != nil { 208 | return kv, err 209 | } 210 | 211 | reader := textproto.NewReader(bufio.NewReader(bytes.NewReader(content))) 212 | 213 | body, err := logging.ReadTruncatedBody(reader, logging.MaxResponseBodyLen) 214 | if err != nil { 215 | return kv, err 216 | } 217 | 218 | return attribute.String("http.response.body", body), nil 219 | } 220 | 221 | func limitWriter(w io.Writer, n int64) io.Writer { 222 | return &limitedWriter{w, n} 223 | } 224 | 225 | type limitedWriter struct { 226 | W io.Writer // the underlying writer 227 | N int64 // max bytes remaining 228 | } 229 | 230 | // Write writes data into the wrapped Writer up to a limit of N bytes 231 | // Silently stops writing and returns full size of p to allow use with io.TeeReader 232 | func (w *limitedWriter) Write(p []byte) (int, error) { 233 | if w.N <= 0 { 234 | return len(p), nil 235 | } 236 | if int64(len(p)) > w.N { 237 | n, err := w.W.Write(p[0:w.N]) 238 | w.N -= int64(n) 239 | return len(p), err 240 | } else { 241 | n, err := w.W.Write(p) 242 | w.N -= int64(n) 243 | return n, err 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /v2/awsv1shim/mockdata/mocks.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package mockdata 5 | 6 | import ( 7 | "github.com/aws/aws-sdk-go/aws" 8 | "github.com/aws/aws-sdk-go/aws/credentials" 9 | "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" 10 | "github.com/aws/aws-sdk-go/aws/credentials/endpointcreds" 11 | "github.com/aws/aws-sdk-go/aws/credentials/stscreds" 12 | "github.com/aws/aws-sdk-go/aws/session" 13 | "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" 14 | ) 15 | 16 | // GetMockedAwsApiSession establishes an AWS session to a simulated AWS API server for a given service and route endpoints. 17 | func GetMockedAwsApiSession(svcName string, endpoints []*servicemocks.MockEndpoint) (func(), *session.Session, error) { 18 | ts := servicemocks.MockAwsApiServer(svcName, endpoints) 19 | 20 | sc := credentials.NewStaticCredentials("accessKey", "secretKey", "") 21 | 22 | sess, err := session.NewSession(&aws.Config{ 23 | Credentials: sc, 24 | Region: aws.String("us-east-1"), 25 | Endpoint: aws.String(ts.URL), 26 | CredentialsChainVerboseErrors: aws.Bool(true), 27 | }) 28 | 29 | return ts.Close, sess, err 30 | } 31 | 32 | var ( 33 | MockEc2MetadataCredentials = credentials.Value{ 34 | AccessKeyID: servicemocks.MockEc2MetadataAccessKey, 35 | ProviderName: ec2rolecreds.ProviderName, 36 | SecretAccessKey: servicemocks.MockEc2MetadataSecretKey, 37 | SessionToken: servicemocks.MockEc2MetadataSessionToken, 38 | } 39 | 40 | MockEcsCredentialsCredentials = credentials.Value{ 41 | AccessKeyID: servicemocks.MockEcsCredentialsAccessKey, 42 | ProviderName: endpointcreds.ProviderName, 43 | SecretAccessKey: servicemocks.MockEcsCredentialsSecretKey, 44 | SessionToken: servicemocks.MockEcsCredentialsSessionToken, 45 | } 46 | 47 | MockEnvCredentials = credentials.Value{ 48 | AccessKeyID: servicemocks.MockEnvAccessKey, 49 | ProviderName: credentials.EnvProviderName, 50 | SecretAccessKey: servicemocks.MockEnvSecretKey, 51 | } 52 | 53 | MockEnvCredentialsWithSessionToken = credentials.Value{ 54 | AccessKeyID: servicemocks.MockEnvAccessKey, 55 | ProviderName: credentials.EnvProviderName, 56 | SecretAccessKey: servicemocks.MockEnvSecretKey, 57 | SessionToken: servicemocks.MockEnvSessionToken, 58 | } 59 | 60 | MockStaticCredentials = credentials.Value{ 61 | AccessKeyID: servicemocks.MockStaticAccessKey, 62 | ProviderName: credentials.StaticProviderName, 63 | SecretAccessKey: servicemocks.MockStaticSecretKey, 64 | } 65 | 66 | MockStsAssumeRoleCredentials = credentials.Value{ 67 | AccessKeyID: servicemocks.MockStsAssumeRoleAccessKey, 68 | ProviderName: stscreds.ProviderName, 69 | SecretAccessKey: servicemocks.MockStsAssumeRoleSecretKey, 70 | SessionToken: servicemocks.MockStsAssumeRoleSessionToken, 71 | } 72 | 73 | MockStsAssumeRoleWithWebIdentityCredentials = credentials.Value{ 74 | AccessKeyID: servicemocks.MockStsAssumeRoleWithWebIdentityAccessKey, 75 | ProviderName: stscreds.WebIdentityProviderName, 76 | SecretAccessKey: servicemocks.MockStsAssumeRoleWithWebIdentitySecretKey, 77 | SessionToken: servicemocks.MockStsAssumeRoleWithWebIdentitySessionToken, 78 | } 79 | ) 80 | -------------------------------------------------------------------------------- /v2/awsv1shim/resolvers.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsv1shim 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "io" 10 | "os" 11 | 12 | configv2 "github.com/aws/aws-sdk-go-v2/config" 13 | "github.com/hashicorp/aws-sdk-go-base/v2/logging" 14 | ) 15 | 16 | func resolveCustomCABundle(ctx context.Context, configSources []any) (value io.Reader, found bool, err error) { 17 | for _, source := range configSources { 18 | switch cfg := source.(type) { 19 | case configv2.LoadOptions: 20 | value, found, err = loadOptionsGetCustomCABundle(ctx, cfg) 21 | case configv2.EnvConfig: 22 | value, found, err = envConfigGetCustomCABundle(ctx, cfg) 23 | case configv2.SharedConfig: 24 | value, found, err = sharedConfigGetCustomCABundle(ctx, cfg) 25 | default: 26 | logger := logging.RetrieveLogger(ctx) 27 | logger.Warn(ctx, "Unrecognized config source", map[string]any{ 28 | "source": source, 29 | }) 30 | continue 31 | } 32 | if err != nil || found { 33 | break 34 | } 35 | } 36 | 37 | return 38 | } 39 | 40 | // Copied from https://github.com/aws/aws-sdk-go-v2/blob/889e1da2776ae5bd6d056cf44f6ce6d043237769/config/load_options.go#L334-L340 41 | func loadOptionsGetCustomCABundle(_ context.Context, o configv2.LoadOptions) (io.Reader, bool, error) { //nolint:unparam 42 | if o.CustomCABundle == nil { 43 | return nil, false, nil 44 | } 45 | 46 | return o.CustomCABundle, true, nil 47 | } 48 | 49 | // Copied from https://github.com/aws/aws-sdk-go-v2/blob/889e1da2776ae5bd6d056cf44f6ce6d043237769/config/env_config.go#L463-L473 50 | func envConfigGetCustomCABundle(_ context.Context, c configv2.EnvConfig) (io.Reader, bool, error) { 51 | if len(c.CustomCABundle) == 0 { 52 | return nil, false, nil 53 | } 54 | 55 | b, err := os.ReadFile(c.CustomCABundle) 56 | if err != nil { 57 | return nil, false, err 58 | } 59 | return bytes.NewReader(b), true, nil 60 | } 61 | 62 | // Copied from https://github.com/aws/aws-sdk-go-v2/blob/889e1da2776ae5bd6d056cf44f6ce6d043237769/config/shared_config.go#L350-L360 63 | func sharedConfigGetCustomCABundle(_ context.Context, c configv2.SharedConfig) (io.Reader, bool, error) { 64 | if len(c.CustomCABundle) == 0 { 65 | return nil, false, nil 66 | } 67 | 68 | b, err := os.ReadFile(c.CustomCABundle) 69 | if err != nil { 70 | return nil, false, err 71 | } 72 | return bytes.NewReader(b), true, nil 73 | } 74 | -------------------------------------------------------------------------------- /v2/awsv1shim/session.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsv1shim 5 | 6 | import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim 7 | "context" 8 | "fmt" 9 | "os" 10 | 11 | awsv2 "github.com/aws/aws-sdk-go-v2/aws" // nosemgrep: no-sdkv2-imports-in-awsv1shim 12 | "github.com/aws/aws-sdk-go/aws" 13 | "github.com/aws/aws-sdk-go/aws/endpoints" 14 | "github.com/aws/aws-sdk-go/aws/request" 15 | "github.com/aws/aws-sdk-go/aws/session" 16 | awsbase "github.com/hashicorp/aws-sdk-go-base/v2" 17 | "github.com/hashicorp/aws-sdk-go-base/v2/awsv1shim/v2/tfawserr" 18 | "github.com/hashicorp/aws-sdk-go-base/v2/diag" 19 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/awsconfig" 20 | "github.com/hashicorp/aws-sdk-go-base/v2/internal/constants" 21 | "github.com/hashicorp/aws-sdk-go-base/v2/logging" 22 | ) 23 | 24 | // getSessionOptions attempts to return valid AWS Go SDK session authentication 25 | // options based on pre-existing credential provider, configured profile, or 26 | // fallback to automatically a determined session via the AWS Go SDK. 27 | func getSessionOptions(ctx context.Context, awsC *awsv2.Config, c *awsbase.Config) (*session.Options, error) { 28 | useFIPSEndpoint, _, err := awsconfig.ResolveUseFIPSEndpoint(ctx, awsC.ConfigSources) 29 | if err != nil { 30 | return nil, fmt.Errorf("error resolving FIPS endpoint configuration: %w", err) 31 | } 32 | 33 | useDualStackEndpoint, _, err := awsconfig.ResolveUseDualStackEndpoint(ctx, awsC.ConfigSources) 34 | if err != nil { 35 | return nil, fmt.Errorf("error resolving dual-stack endpoint configuration: %w", err) 36 | } 37 | 38 | httpClient := c.HTTPClient 39 | if httpClient == nil { 40 | httpClient, err = defaultHttpClient(c) 41 | if err != nil { 42 | return nil, err 43 | } 44 | } 45 | 46 | options := &session.Options{ 47 | Config: aws.Config{ 48 | Credentials: newV2Credentials(awsC.Credentials), 49 | HTTPClient: httpClient, 50 | MaxRetries: aws.Int(0), 51 | Region: aws.String(awsC.Region), 52 | UseFIPSEndpoint: convertFIPSEndpointState(useFIPSEndpoint), 53 | UseDualStackEndpoint: convertDualStackEndpointState(useDualStackEndpoint), 54 | }, 55 | SharedConfigState: session.SharedConfigEnable, 56 | SharedConfigFiles: append(c.SharedCredentialsFiles, c.SharedConfigFiles...), 57 | } 58 | 59 | if !c.SuppressDebugLog { 60 | options.Config.LogLevel = aws.LogLevel(aws.LogOff) 61 | options.Config.Logger = debugLogger{} 62 | } 63 | 64 | // We can't reuse the io.Reader from the awsv2.Config, because it's already been read. 65 | // Re-create it here from the filename. 66 | if c.CustomCABundle != "" { 67 | reader, err := c.CustomCABundleReader() 68 | if err != nil { 69 | return nil, err 70 | } 71 | options.CustomCABundle = reader 72 | } else if reader, found, err := resolveCustomCABundle(ctx, awsC.ConfigSources); err != nil { 73 | return nil, fmt.Errorf("error resolving custom CA bundle configuration: %w", err) 74 | } else if found { 75 | options.CustomCABundle = reader 76 | } 77 | 78 | return options, nil 79 | } 80 | 81 | const loggerName string = "aws-base-v1" 82 | 83 | // GetSession returns an AWS Go SDK session. 84 | func GetSession(ctx context.Context, awsC *awsv2.Config, c *awsbase.Config) (*session.Session, diag.Diagnostics) { 85 | var diags diag.Diagnostics 86 | 87 | var logger logging.Logger = logging.NullLogger{} 88 | if c.Logger != nil { 89 | logger = c.Logger 90 | } 91 | ctx, logger = logger.SubLogger(ctx, loggerName) 92 | ctx = logging.RegisterLogger(ctx, logger) 93 | 94 | options, err := getSessionOptions(ctx, awsC, c) 95 | if err != nil { 96 | return nil, diags.AddSimpleError(err) 97 | } 98 | 99 | sess, err := session.NewSessionWithOptions(*options) 100 | if err != nil { 101 | if tfawserr.ErrCodeEquals(err, "NoCredentialProviders") { 102 | return nil, diags.Append(c.NewNoValidCredentialSourcesError(err)) 103 | } 104 | return nil, diags.AddSimpleError(fmt.Errorf("creating AWS session: %w", err)) 105 | } 106 | 107 | // Set retries after resolving credentials to prevent retries during resolution 108 | if retryer := awsC.Retryer(); retryer != nil { 109 | sess = sess.Copy(&aws.Config{MaxRetries: aws.Int(retryer.MaxAttempts())}) 110 | } 111 | 112 | SetSessionUserAgent(sess, c.APNInfo, c.UserAgent) 113 | 114 | sess.Handlers.Build.PushBack(userAgentFromContextHandler) 115 | 116 | if !c.SuppressDebugLog { 117 | sess.Handlers.Send.PushFrontNamed(requestLogger) 118 | sess.Handlers.Send.PushBackNamed(responseLogger) 119 | } 120 | 121 | // Add custom input from ENV to the User-Agent request header 122 | // Reference: https://github.com/terraform-providers/terraform-provider-aws/issues/9149 123 | if v := os.Getenv(constants.AppendUserAgentEnvVar); v != "" { 124 | logger.Debug(ctx, "Adding User-Agent info", map[string]any{ 125 | "source": fmt.Sprintf("envvar(%q)", constants.AppendUserAgentEnvVar), 126 | "value": v, 127 | }) 128 | sess.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler(v)) 129 | } 130 | 131 | // Generally, we want to configure a lower retry theshold for networking issues 132 | // as the session retry threshold is very high by default and can mask permanent 133 | // networking failures, such as a non-existent service endpoint. 134 | // MaxRetries will override this logic if it has a lower retry threshold. 135 | // NOTE: This logic can be fooled by other request errors raising the retry count 136 | // before any networking error occurs 137 | sess.Handlers.Retry.PushBack(func(r *request.Request) { 138 | logger := logging.RetrieveLogger(r.Context()) 139 | 140 | if r.IsErrorExpired() { 141 | logger.Warn(ctx, "Disabling retries after next request due to expired credentials", map[string]any{ 142 | "error": r.Error, 143 | }) 144 | r.Retryable = aws.Bool(false) 145 | } 146 | 147 | if r.RetryCount < constants.MaxNetworkRetryCount { 148 | return 149 | } 150 | 151 | // RequestError: send request failed 152 | // caused by: Post https://FQDN/: dial tcp: lookup FQDN: no such host 153 | if tfawserr.ErrMessageAndOrigErrContain(r.Error, request.ErrCodeRequestError, "send request failed", "no such host") { 154 | logger.Warn(ctx, "Disabling retries after next request due to networking error", map[string]any{ 155 | "error": r.Error, 156 | }) 157 | r.Retryable = aws.Bool(false) 158 | } 159 | // RequestError: send request failed 160 | // caused by: Post https://FQDN/: dial tcp IPADDRESS:443: connect: connection refused 161 | if tfawserr.ErrMessageAndOrigErrContain(r.Error, request.ErrCodeRequestError, "send request failed", "connection refused") { 162 | logger.Warn(ctx, "Disabling retries after next request due to networking error", map[string]any{ 163 | "error": r.Error, 164 | }) 165 | r.Retryable = aws.Bool(false) 166 | } 167 | }) 168 | 169 | return sess, nil 170 | } 171 | 172 | func convertFIPSEndpointState(value awsv2.FIPSEndpointState) endpoints.FIPSEndpointState { 173 | switch value { 174 | case awsv2.FIPSEndpointStateEnabled: 175 | return endpoints.FIPSEndpointStateEnabled 176 | case awsv2.FIPSEndpointStateDisabled: 177 | return endpoints.FIPSEndpointStateDisabled 178 | default: 179 | return endpoints.FIPSEndpointStateUnset 180 | } 181 | } 182 | 183 | func convertDualStackEndpointState(value awsv2.DualStackEndpointState) endpoints.DualStackEndpointState { 184 | switch value { 185 | case awsv2.DualStackEndpointStateEnabled: 186 | return endpoints.DualStackEndpointStateEnabled 187 | case awsv2.DualStackEndpointStateDisabled: 188 | return endpoints.DualStackEndpointStateDisabled 189 | default: 190 | return endpoints.DualStackEndpointStateUnset 191 | } 192 | } 193 | 194 | func SetSessionUserAgent(sess *session.Session, apnInfo *awsbase.APNInfo, userAgentProducts awsbase.UserAgentProducts) { 195 | // AWS SDK Go automatically adds a User-Agent product to HTTP requests, 196 | // which contains helpful information about the SDK version and runtime. 197 | // The configuration of additional User-Agent header products should take 198 | // precedence over that product. Since the AWS SDK Go request package 199 | // functions only append, we must PushFront on the build handlers instead 200 | // of PushBack. 201 | if apnInfo != nil { 202 | sess.Handlers.Build.PushFront( 203 | request.MakeAddToUserAgentFreeFormHandler(apnInfo.BuildUserAgentString()), 204 | ) 205 | } 206 | 207 | if len(userAgentProducts) > 0 { 208 | sess.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler(userAgentProducts.BuildUserAgentString())) 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /v2/awsv1shim/tfawserr/awserr.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfawserr 5 | 6 | import ( 7 | "errors" 8 | "strings" 9 | 10 | "github.com/aws/aws-sdk-go/aws/awserr" 11 | ) 12 | 13 | // ErrMessageAndOrigErrContain returns true if the error matches all these conditions: 14 | // - err is of type awserr.Error 15 | // - Error.Code() matches code 16 | // - Error.Message() contains message 17 | // - Error.OrigErr() contains origErrMessage 18 | func ErrMessageAndOrigErrContain(err error, code string, message string, origErrMessage string) bool { 19 | if !ErrMessageContains(err, code, message) { 20 | return false 21 | } 22 | 23 | if origErrMessage == "" { 24 | return true 25 | } 26 | 27 | // Ensure OrigErr() is non-nil, to prevent panics 28 | if origErr := err.(awserr.Error).OrigErr(); origErr != nil { 29 | return strings.Contains(origErr.Error(), origErrMessage) 30 | } 31 | 32 | return false 33 | } 34 | 35 | // ErrCodeEquals returns true if the error matches all these conditions: 36 | // - err is of type awserr.Error 37 | // - Error.Code() equals one of the passed codes 38 | func ErrCodeEquals(err error, codes ...string) bool { 39 | var awsErr awserr.Error 40 | if errors.As(err, &awsErr) { 41 | for _, code := range codes { 42 | if awsErr.Code() == code { 43 | return true 44 | } 45 | } 46 | } 47 | return false 48 | } 49 | 50 | // ErrCodeContains returns true if the error matches all these conditions: 51 | // - err is of type awserr.Error 52 | // - Error.Code() contains code 53 | func ErrCodeContains(err error, code string) bool { 54 | var awsErr awserr.Error 55 | if errors.As(err, &awsErr) { 56 | return strings.Contains(awsErr.Code(), code) 57 | } 58 | return false 59 | } 60 | 61 | // ErrMessageContains returns true if the error matches all these conditions: 62 | // - err is of type awserr.Error 63 | // - Error.Code() equals code 64 | // - Error.Message() contains message 65 | func ErrMessageContains(err error, code string, message string) bool { 66 | var awsErr awserr.Error 67 | if errors.As(err, &awsErr) { 68 | return awsErr.Code() == code && strings.Contains(awsErr.Message(), message) 69 | } 70 | return false 71 | } 72 | 73 | // ErrStatusCodeEquals returns true if the error matches all these conditions: 74 | // - err is of type awserr.RequestFailure 75 | // - RequestFailure.StatusCode() equals statusCode 76 | // 77 | // It is always preferable to use ErrMessageContains() except in older APIs (e.g. S3) 78 | // that sometimes only respond with status codes. 79 | func ErrStatusCodeEquals(err error, statusCode int) bool { 80 | var awsErr awserr.RequestFailure 81 | if errors.As(err, &awsErr) { 82 | return awsErr.StatusCode() == statusCode 83 | } 84 | return false 85 | } 86 | -------------------------------------------------------------------------------- /v2/awsv1shim/user_agent.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package awsv1shim 5 | 6 | import ( 7 | "github.com/aws/aws-sdk-go/aws/request" 8 | "github.com/hashicorp/aws-sdk-go-base/v2/useragent" 9 | ) 10 | 11 | func userAgentFromContextHandler(r *request.Request) { 12 | ctx := r.Context() 13 | 14 | if v := useragent.BuildFromContext(ctx); v != "" { 15 | request.AddToUserAgent(r, v) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /validation/json.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package validation 5 | 6 | import ( 7 | "encoding/json" 8 | "errors" 9 | "fmt" 10 | "strconv" 11 | "strings" 12 | ) 13 | 14 | // DuplicateKeyError is returned when duplicate key names are detected 15 | // inside a JSON object 16 | type DuplicateKeyError struct { 17 | path []string 18 | key string 19 | } 20 | 21 | func (e *DuplicateKeyError) Error() string { 22 | return fmt.Sprintf(`duplicate key "%s"`, strings.Join(append(e.path, e.key), ".")) 23 | } 24 | 25 | // JSONNoDuplicateKeys verifies the provided JSON object contains 26 | // no duplicated keys 27 | // 28 | // The function expects a single JSON object, and will error prior to 29 | // checking for duplicate keys should an invalid input be provided. 30 | func JSONNoDuplicateKeys(s string) error { 31 | var out map[string]any 32 | if err := json.Unmarshal([]byte(s), &out); err != nil { 33 | return fmt.Errorf("unmarshaling input: %w", err) 34 | } 35 | 36 | dec := json.NewDecoder(strings.NewReader(s)) 37 | return checkToken(dec, nil) 38 | } 39 | 40 | // checkToken walks a JSON object checking for duplicated keys 41 | // 42 | // The function is called recursively on the value of each key 43 | // inside and object, or item inside an array. 44 | // 45 | // Adapted from: https://stackoverflow.com/a/50109335 46 | func checkToken(dec *json.Decoder, path []string) error { 47 | t, err := dec.Token() 48 | if err != nil { 49 | return err 50 | } 51 | 52 | delim, ok := t.(json.Delim) 53 | if !ok { 54 | // non-delimiter, nothing to do 55 | return nil 56 | } 57 | 58 | var dupErrs []error 59 | switch delim { 60 | case '{': 61 | keys := make(map[string]bool) 62 | for dec.More() { 63 | // Get the field key 64 | t, err := dec.Token() 65 | if err != nil { 66 | return err 67 | } 68 | key := t.(string) 69 | 70 | if keys[key] { 71 | // Duplicate found 72 | dupErrs = append(dupErrs, &DuplicateKeyError{path: path, key: key}) 73 | } 74 | keys[key] = true 75 | 76 | // Check the keys value 77 | if err := checkToken(dec, append(path, key)); err != nil { 78 | dupErrs = append(dupErrs, err) 79 | } 80 | } 81 | 82 | // consume trailing "}" 83 | _, err := dec.Token() 84 | if err != nil { 85 | return err 86 | } 87 | case '[': 88 | i := 0 89 | for dec.More() { 90 | // Check each items value 91 | if err := checkToken(dec, append(path, strconv.Itoa(i))); err != nil { 92 | dupErrs = append(dupErrs, err) 93 | } 94 | i++ 95 | } 96 | 97 | // consume trailing "]" 98 | _, err := dec.Token() 99 | if err != nil { 100 | return err 101 | } 102 | } 103 | 104 | return errors.Join(dupErrs...) 105 | } 106 | -------------------------------------------------------------------------------- /validation/json_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package validation 5 | 6 | import ( 7 | "testing" 8 | ) 9 | 10 | func TestJSONNoDuplicateKeys(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | s string 14 | wantErr bool 15 | }{ 16 | { 17 | name: "invalid", 18 | s: "{{{", 19 | wantErr: true, 20 | }, 21 | { 22 | name: "valid", 23 | s: `{ 24 | "a": "foo", 25 | "b": { 26 | "c": "bar", 27 | "d": [ 28 | { 29 | "e": "baz" 30 | }, 31 | { 32 | "f": "qux", 33 | "g": "foo" 34 | } 35 | ] 36 | } 37 | }`, 38 | wantErr: false, 39 | }, 40 | { 41 | name: "root", 42 | s: `{ 43 | "a": "foo", 44 | "a": "bar" 45 | }`, 46 | wantErr: true, 47 | }, 48 | { 49 | name: "nested object", 50 | s: `{ 51 | "a": "foo", 52 | "b": { 53 | "c": "bar", 54 | "c": "baz" 55 | } 56 | }`, 57 | wantErr: true, 58 | }, 59 | { 60 | name: "nested array", 61 | s: `{ 62 | "a": "foo", 63 | "b": { 64 | "c": "bar", 65 | "d": [ 66 | { 67 | "e": "foo", 68 | "e": "bar" 69 | }, 70 | { 71 | "f": "baz", 72 | "g": "qux" 73 | } 74 | ] 75 | } 76 | }`, 77 | wantErr: true, 78 | }, 79 | { 80 | name: "multiple", 81 | s: `{ 82 | "a": "foo", 83 | "a": "bar", 84 | "b": { 85 | "c": "baz", 86 | "c": "qux", 87 | "d": [ 88 | { 89 | "e": "foo" 90 | }, 91 | { 92 | "f": "bar", 93 | "f": "baz", 94 | "g": "qux" 95 | } 96 | ] 97 | } 98 | }`, 99 | wantErr: true, 100 | }, 101 | { 102 | name: "aws iam condition keys", 103 | s: `{ 104 | "Version": "2012-10-17", 105 | "Statement": [ 106 | { 107 | "Effect": "Allow", 108 | "Action": "iam:PassRole", 109 | "Resource": "*", 110 | "Condition": { 111 | "StringEquals": { 112 | "iam:PassedToService": "cloudwatch.amazonaws.com" 113 | }, 114 | "StringEquals": { 115 | "iam:PassedToService": "ec2.amazonaws.com" 116 | } 117 | } 118 | } 119 | ] 120 | }`, 121 | 122 | wantErr: true, 123 | }, 124 | } 125 | 126 | for _, tt := range tests { 127 | t.Run(tt.name, func(t *testing.T) { 128 | if err := JSONNoDuplicateKeys(tt.s); (err != nil) != tt.wantErr { 129 | t.Errorf("JSONNoDuplicateKeys() error = %v, wantErr %v", err, tt.wantErr) 130 | } 131 | }) 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /validation/region.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package validation 5 | 6 | import ( 7 | "fmt" 8 | "slices" 9 | 10 | "github.com/hashicorp/aws-sdk-go-base/v2/endpoints" 11 | ) 12 | 13 | type InvalidRegionError struct { 14 | region string 15 | } 16 | 17 | func (e *InvalidRegionError) Error() string { 18 | return fmt.Sprintf("invalid AWS Region: %s", e.region) 19 | } 20 | 21 | // SupportedRegion checks if the given region is a valid AWS region. 22 | func SupportedRegion(region string) error { 23 | if slices.ContainsFunc(endpoints.DefaultPartitions(), func(p endpoints.Partition) bool { 24 | _, ok := p.Regions()[region] 25 | return ok 26 | }) { 27 | return nil 28 | } 29 | 30 | return &InvalidRegionError{ 31 | region: region, 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /validation/region_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package validation 5 | 6 | import ( 7 | "testing" 8 | ) 9 | 10 | func TestSupportedRegion(t *testing.T) { 11 | var testCases = []struct { 12 | Region string 13 | ExpectError bool 14 | }{ 15 | { 16 | Region: "us-east-1", 17 | ExpectError: false, 18 | }, 19 | { 20 | Region: "ap-northeast-3", 21 | ExpectError: false, 22 | }, 23 | { 24 | Region: "us-gov-west-1", 25 | ExpectError: false, 26 | }, 27 | { 28 | Region: "cn-northwest-1", 29 | ExpectError: false, 30 | }, 31 | { 32 | Region: "invalid", 33 | ExpectError: true, 34 | }, 35 | } 36 | 37 | for _, testCase := range testCases { 38 | t.Run(testCase.Region, func(t *testing.T) { 39 | err := SupportedRegion(testCase.Region) 40 | if err != nil && !testCase.ExpectError { 41 | t.Fatalf("Expected no error, received error: %s", err) 42 | } 43 | if err == nil && testCase.ExpectError { 44 | t.Fatal("Expected error, received none") 45 | } 46 | }) 47 | } 48 | } 49 | --------------------------------------------------------------------------------