├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ └── go.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── awsutil ├── LICENSE ├── clients.go ├── clients_test.go ├── error.go ├── error_test.go ├── generate_credentials.go ├── go.mod ├── go.sum ├── mocks.go ├── mocks_test.go ├── options.go ├── options_test.go ├── region.go ├── region_test.go ├── rotate.go └── rotate_test.go ├── base62 ├── LICENSE ├── base62.go ├── base62_test.go ├── go.mod └── go.sum ├── configutil ├── LICENSE ├── Makefile ├── config.go ├── config_test.go ├── config_util.go ├── encrypt_decrypt.go ├── encrypt_decrypt_test.go ├── file_plugin_test.go ├── go.mod ├── go.sum ├── kms.go ├── kms_test.go ├── merge.go ├── options.go ├── options_test.go └── testplugins │ └── aead │ └── main.go ├── fileutil ├── LICENSE ├── caching_file_reader.go ├── caching_file_reader_test.go ├── go.mod └── go.sum ├── gatedwriter ├── LICENSE ├── go.mod ├── writer.go └── writer_test.go ├── kv-builder ├── LICENSE ├── builder.go ├── builder_test.go ├── go.mod └── go.sum ├── listenerutil ├── LICENSE ├── error.go ├── forwarded_for.go ├── forwarded_for_test.go ├── go.mod ├── go.sum ├── listener.go ├── listener_test.go └── parse.go ├── mlock ├── LICENSE ├── go.mod ├── go.sum ├── mlock.go ├── mlock_unavail.go └── mlock_unix.go ├── parseutil ├── LICENSE ├── go.mod ├── go.sum ├── parsepath.go ├── parsepath_test.go ├── parseutil.go └── parseutil_test.go ├── password ├── LICENSE ├── go.mod ├── go.sum ├── password.go ├── password_solaris.go ├── password_test.go ├── password_unix.go └── password_windows.go ├── pluginutil ├── doc.go ├── go.mod ├── go.sum ├── options.go ├── options_test.go └── pluginutil.go ├── reloadutil ├── LICENSE ├── go.mod ├── go.sum ├── reload.go └── reload_test.go ├── strutil ├── LICENSE ├── go.mod ├── go.sum ├── strutil.go ├── strutil_benchmark_test.go └── strutil_test.go └── tlsutil ├── LICENSE ├── go.mod ├── go.sum ├── tlsutil.go └── tlsutil_test.go /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Additional context** 27 | Add any other context about the problem here. 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | 15 | - name: Set up Go 16 | uses: actions/setup-go@v2 17 | with: 18 | go-version: 1.17 19 | 20 | - name: Build 21 | run: find . -name go.mod -execdir go build ./... \; 22 | 23 | - name: Test 24 | run: find . -name go.mod -execdir go test ./... \; -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Folders 2 | _obj 3 | _test 4 | .cover 5 | 6 | # IntelliJ IDEA project files 7 | .idea 8 | *.ipr 9 | *.iml 10 | *.iws 11 | 12 | ### Logs ### 13 | *.log 14 | logs/ 15 | 16 | ### direnv ### 17 | .envrc 18 | .direnv/ 19 | 20 | ### Temp directories ### 21 | tmp/ 22 | temp/ 23 | 24 | ### Visual Studio ### 25 | .vscode/ 26 | 27 | ### macOS ### 28 | # General 29 | .DS_Store 30 | .AppleDouble 31 | .LSOverride 32 | 33 | # Icon must end with two \r 34 | Icon 35 | 36 | 37 | # Thumbnails 38 | ._* 39 | 40 | ### Git ### 41 | # Created by git for backups. To disable backups in Git: 42 | # $ git config --global mergetool.keepBackup false 43 | *.orig 44 | 45 | # Created by git when using merge tools for conflicts 46 | *.BACKUP.* 47 | *.BASE.* 48 | *.LOCAL.* 49 | *.REMOTE.* 50 | *_BACKUP_*.txt 51 | *_BASE_*.txt 52 | *_LOCAL_*.txt 53 | *_REMOTE_*.txt 54 | 55 | ### Go ### 56 | # Binaries for programs and plugins 57 | *.exe 58 | *.exe~ 59 | *.dll 60 | *.so 61 | *.dylib 62 | 63 | # Test binary, built with `go test -c` 64 | *.test 65 | 66 | # Output of the go coverage tool, specifically when used with LiteIDE 67 | *.out 68 | 69 | ### Tags ### 70 | # Ignore tags created by etags, ctags, gtags (GNU global) and cscope 71 | TAGS 72 | .TAGS 73 | !TAGS/ 74 | tags 75 | .tags 76 | !tags/ 77 | gtags.files 78 | GTAGS 79 | GRTAGS 80 | GPATH 81 | GSYMS 82 | cscope.files 83 | cscope.out 84 | cscope.in.out 85 | cscope.po.out 86 | 87 | ### Vagrant ### 88 | # General 89 | .vagrant/ 90 | 91 | # Log files (if you are creating logs in debug mode, uncomment this) 92 | # *.log 93 | 94 | ### Vagrant Patch ### 95 | *.box 96 | 97 | ### Vim ### 98 | # Swap 99 | [._]*.s[a-v][a-z] 100 | [._]*.sw[a-p] 101 | [._]s[a-rt-v][a-z] 102 | [._]ss[a-gi-z] 103 | [._]sw[a-p] 104 | 105 | # Session 106 | Session.vim 107 | Sessionx.vim 108 | 109 | # Temporary 110 | .netrwhist 111 | *~ 112 | 113 | # Auto-generated tag files 114 | # Persistent undo 115 | [._]*.un~ 116 | 117 | # Test config file 118 | test*.hcl 119 | 120 | # vim: set filetype=conf : -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to go-secure-stdlib 2 | 3 | Thank you for contributing! Here you can find common questions around reporting 4 | issues and opening pull requests to our project. 5 | 6 | Please note that these modules are all actively used in multiple 7 | projects/products. Please ensure that engineering leads from affected 8 | projects/products are aware of proposed changes and given a chance to review 9 | them. 10 | 11 | When contributing in any way to the project (new issue, PR, etc), please be 12 | aware that our team identifies with many gender pronouns. Please remember to use 13 | nonbinary pronouns (they/them) and gender neutral language ("Hello folks") when 14 | addressing our team. For more reading on our code of conduct, please see the 15 | [HashiCorp community 16 | guidelines](https://www.hashicorp.com/community-guidelines). 17 | 18 | ## Issue Reporting 19 | ### Reporting Security Related Vulnerabilities 20 | 21 | We take security and our users' trust very seriously. If you believe you have 22 | found a security issue, please responsibly disclose by contacting us at 23 | security@hashicorp.com. Do not open an issue on our GitHub issue tracker if you 24 | believe you've found a security related issue, thank you! 25 | 26 | ### Bug Fixes 27 | 28 | If you believe you found a bug, please: 29 | 30 | 1. Build from the latest `main` HEAD commit to attempt to reproduce the issue. 31 | It's possible we've already fixed the bug, and this is a first good step to 32 | ensuring that's not the case. 33 | 1. Ensure a similar ticket is not already opened by searching our opened issues 34 | on GitHub. 35 | 36 | 37 | Once you've verified the above, feel free to open a bug fix issue template type 38 | from our [issue 39 | selector](https://github.com/hashicorp/go-secure-stdlib/issues/new/choose) and 40 | we'll do our best to triage it as quickly as possible. 41 | 42 | ## Pull Requests 43 | 44 | ### New Features & Improvements 45 | 46 | Before writing a line of code, please ask us about a potential improvement or 47 | feature that you want to write. We may already be working on it; even if we 48 | aren't, we need to ensure that both the feature and its proposed implementation 49 | is aligned with our road map, vision, and standards for the project. We're happy 50 | to help walk through that via a [feature request 51 | issue](https://github.com/hashicorp/go-secure-stdlib/issues/new/choose). 52 | 53 | ### Submitting a New Pull Request 54 | 55 | When submitting a pull request, please ensure: 56 | 57 | 1. You've added a changelog line clearly describing the new addition under the 58 | correct changelog sub-section. 59 | 2. You've followed the above guidelines for contributing. 60 | 61 | Once you open your PR, our auto-labeling will add labels to help us triage and 62 | prioritize your contribution. Please allow us a couple of days to comment, 63 | request changes, or approve your PR. Thank you for your contribution! 64 | 65 | ## Testing 66 | 67 | You can run the GitHub actions locally using 68 | [act](https://github.com/nektos/act). 69 | 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Stdlib for HashiCorp Secure products 2 | ================= 3 | 4 | These libraries are maintained by engineers in the HashiCorp's Secure division 5 | as a stdlib for its projects -- Vault, Vault plugins, Boundary, etc. -- to 6 | reduce code duplication and increase consistency. 7 | 8 | Each library is its own Go module, although some of them may have dependencies 9 | on others within the repo. The libraries follow Go module versioning rules. 10 | 11 | Most of the libraries in here were originally pulled from 12 | vault/helper/metricsutil, vault/sdk/helper, and vault/internalshared; see there 13 | for contribution and change history prior to their move here. 14 | 15 | All modules are licensed according to MPLv2 as contained in the LICENSE file; 16 | this file is duplicated in each module. 17 | -------------------------------------------------------------------------------- /awsutil/clients.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/aws/aws-sdk-go/aws/session" 8 | "github.com/aws/aws-sdk-go/service/iam" 9 | "github.com/aws/aws-sdk-go/service/iam/iamiface" 10 | "github.com/aws/aws-sdk-go/service/sts" 11 | "github.com/aws/aws-sdk-go/service/sts/stsiface" 12 | ) 13 | 14 | // IAMAPIFunc is a factory function for returning an IAM interface, 15 | // useful for supplying mock interfaces for testing IAM. The session 16 | // is passed into the function in the same way as done with the 17 | // standard iam.New() constructor. 18 | type IAMAPIFunc func(sess *session.Session) (iamiface.IAMAPI, error) 19 | 20 | // STSAPIFunc is a factory function for returning a STS interface, 21 | // useful for supplying mock interfaces for testing STS. The session 22 | // is passed into the function in the same way as done with the 23 | // standard sts.New() constructor. 24 | type STSAPIFunc func(sess *session.Session) (stsiface.STSAPI, error) 25 | 26 | // IAMClient returns an IAM client. 27 | // 28 | // Supported options: WithSession, WithIAMAPIFunc. 29 | // 30 | // If WithIAMAPIFunc is supplied, the included function is used as 31 | // the IAM client constructor instead. This can be used for Mocking 32 | // the IAM API. 33 | func (c *CredentialsConfig) IAMClient(opt ...Option) (iamiface.IAMAPI, error) { 34 | opts, err := getOpts(opt...) 35 | if err != nil { 36 | return nil, fmt.Errorf("error reading options: %w", err) 37 | } 38 | 39 | sess := opts.withAwsSession 40 | if sess == nil { 41 | sess, err = c.GetSession(opt...) 42 | if err != nil { 43 | return nil, fmt.Errorf("error calling GetSession: %w", err) 44 | } 45 | } 46 | 47 | if opts.withIAMAPIFunc != nil { 48 | return opts.withIAMAPIFunc(sess) 49 | } 50 | 51 | client := iam.New(sess) 52 | if client == nil { 53 | return nil, errors.New("could not obtain iam client from session") 54 | } 55 | 56 | return client, nil 57 | } 58 | 59 | // STSClient returns a STS client. 60 | // 61 | // Supported options: WithSession, WithSTSAPIFunc. 62 | // 63 | // If WithSTSAPIFunc is supplied, the included function is used as 64 | // the STS client constructor instead. This can be used for Mocking 65 | // the STS API. 66 | func (c *CredentialsConfig) STSClient(opt ...Option) (stsiface.STSAPI, error) { 67 | opts, err := getOpts(opt...) 68 | if err != nil { 69 | return nil, fmt.Errorf("error reading options: %w", err) 70 | } 71 | 72 | sess := opts.withAwsSession 73 | if sess == nil { 74 | sess, err = c.GetSession(opt...) 75 | if err != nil { 76 | return nil, fmt.Errorf("error calling GetSession: %w", err) 77 | } 78 | } 79 | 80 | if opts.withSTSAPIFunc != nil { 81 | return opts.withSTSAPIFunc(sess) 82 | } 83 | 84 | client := sts.New(sess) 85 | if client == nil { 86 | return nil, errors.New("could not obtain sts client from session") 87 | } 88 | 89 | return client, nil 90 | } 91 | -------------------------------------------------------------------------------- /awsutil/clients_test.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/aws/aws-sdk-go/service/iam" 9 | "github.com/aws/aws-sdk-go/service/iam/iamiface" 10 | "github.com/aws/aws-sdk-go/service/sts" 11 | "github.com/aws/aws-sdk-go/service/sts/stsiface" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | const testOptionErr = "test option error" 16 | const testBadClientType = "badclienttype" 17 | 18 | func testWithBadClientType(o *options) error { 19 | o.withClientType = testBadClientType 20 | return nil 21 | } 22 | 23 | func TestCredentialsConfigIAMClient(t *testing.T) { 24 | cases := []struct { 25 | name string 26 | credentialsConfig *CredentialsConfig 27 | opts []Option 28 | require func(t *testing.T, actual iamiface.IAMAPI) 29 | requireErr string 30 | }{ 31 | { 32 | name: "options error", 33 | credentialsConfig: &CredentialsConfig{}, 34 | opts: []Option{MockOptionErr(errors.New(testOptionErr))}, 35 | requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), 36 | }, 37 | { 38 | name: "session error", 39 | credentialsConfig: &CredentialsConfig{}, 40 | opts: []Option{testWithBadClientType}, 41 | requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType), 42 | }, 43 | { 44 | name: "with mock IAM session", 45 | credentialsConfig: &CredentialsConfig{}, 46 | opts: []Option{WithIAMAPIFunc(NewMockIAM())}, 47 | require: func(t *testing.T, actual iamiface.IAMAPI) { 48 | t.Helper() 49 | require := require.New(t) 50 | require.Equal(&MockIAM{}, actual) 51 | }, 52 | }, 53 | { 54 | name: "no mock client", 55 | credentialsConfig: &CredentialsConfig{}, 56 | opts: []Option{}, 57 | require: func(t *testing.T, actual iamiface.IAMAPI) { 58 | t.Helper() 59 | require := require.New(t) 60 | require.IsType(&iam.IAM{}, actual) 61 | }, 62 | }, 63 | } 64 | 65 | for _, tc := range cases { 66 | tc := tc 67 | t.Run(tc.name, func(t *testing.T) { 68 | require := require.New(t) 69 | actual, err := tc.credentialsConfig.IAMClient(tc.opts...) 70 | if tc.requireErr != "" { 71 | require.EqualError(err, tc.requireErr) 72 | return 73 | } 74 | 75 | require.NoError(err) 76 | tc.require(t, actual) 77 | }) 78 | } 79 | } 80 | 81 | func TestCredentialsConfigSTSClient(t *testing.T) { 82 | cases := []struct { 83 | name string 84 | credentialsConfig *CredentialsConfig 85 | opts []Option 86 | require func(t *testing.T, actual stsiface.STSAPI) 87 | requireErr string 88 | }{ 89 | { 90 | name: "options error", 91 | credentialsConfig: &CredentialsConfig{}, 92 | opts: []Option{MockOptionErr(errors.New(testOptionErr))}, 93 | requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), 94 | }, 95 | { 96 | name: "session error", 97 | credentialsConfig: &CredentialsConfig{}, 98 | opts: []Option{testWithBadClientType}, 99 | requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType), 100 | }, 101 | { 102 | name: "with mock STS session", 103 | credentialsConfig: &CredentialsConfig{}, 104 | opts: []Option{WithSTSAPIFunc(NewMockSTS())}, 105 | require: func(t *testing.T, actual stsiface.STSAPI) { 106 | t.Helper() 107 | require := require.New(t) 108 | require.Equal(&MockSTS{}, actual) 109 | }, 110 | }, 111 | { 112 | name: "no mock client", 113 | credentialsConfig: &CredentialsConfig{}, 114 | opts: []Option{}, 115 | require: func(t *testing.T, actual stsiface.STSAPI) { 116 | t.Helper() 117 | require := require.New(t) 118 | require.IsType(&sts.STS{}, actual) 119 | }, 120 | }, 121 | } 122 | 123 | for _, tc := range cases { 124 | tc := tc 125 | t.Run(tc.name, func(t *testing.T) { 126 | require := require.New(t) 127 | actual, err := tc.credentialsConfig.STSClient(tc.opts...) 128 | if tc.requireErr != "" { 129 | require.EqualError(err, tc.requireErr) 130 | return 131 | } 132 | 133 | require.NoError(err) 134 | tc.require(t, actual) 135 | }) 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /awsutil/error.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "errors" 5 | 6 | awsRequest "github.com/aws/aws-sdk-go/aws/request" 7 | multierror "github.com/hashicorp/go-multierror" 8 | ) 9 | 10 | var ErrUpstreamRateLimited = errors.New("upstream rate limited") 11 | 12 | // CheckAWSError will examine an error and convert to a logical error if 13 | // appropriate. If no appropriate error is found, return nil 14 | func CheckAWSError(err error) error { 15 | // IsErrorThrottle will check if the error returned is one that matches 16 | // known request limiting errors: 17 | // https://github.com/aws/aws-sdk-go/blob/488d634b5a699b9118ac2befb5135922b4a77210/aws/request/retryer.go#L35 18 | if awsRequest.IsErrorThrottle(err) { 19 | return ErrUpstreamRateLimited 20 | } 21 | return nil 22 | } 23 | 24 | // AppendAWSError checks if the given error is a known AWS error we modify, 25 | // and if so then returns a go-multierror, appending the original and the 26 | // AWS error. 27 | // If the error is not an AWS error, or not an error we wish to modify, then 28 | // return the original error. 29 | func AppendAWSError(err error) error { 30 | if awserr := CheckAWSError(err); awserr != nil { 31 | err = multierror.Append(err, awserr) 32 | } 33 | return err 34 | } 35 | -------------------------------------------------------------------------------- /awsutil/error_test.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/aws/aws-sdk-go/aws/awserr" 8 | multierror "github.com/hashicorp/go-multierror" 9 | ) 10 | 11 | func Test_CheckAWSError(t *testing.T) { 12 | testCases := []struct { 13 | Name string 14 | Err error 15 | Expected error 16 | }{ 17 | { 18 | Name: "Something not checked", 19 | Err: fmt.Errorf("something"), 20 | }, 21 | { 22 | Name: "Upstream throttle error", 23 | Err: awserr.New("Throttling", "", nil), 24 | Expected: ErrUpstreamRateLimited, 25 | }, 26 | { 27 | Name: "Upstream RequestLimitExceeded", 28 | Err: awserr.New("RequestLimitExceeded", "Request rate limited", nil), 29 | Expected: ErrUpstreamRateLimited, 30 | }, 31 | } 32 | 33 | for _, tc := range testCases { 34 | t.Run(tc.Name, func(t *testing.T) { 35 | err := CheckAWSError(tc.Err) 36 | if err == nil && tc.Expected != nil { 37 | t.Fatalf("expected non-nil error (%#v), got nil", tc.Expected) 38 | } 39 | if err != nil && tc.Expected == nil { 40 | t.Fatalf("expected nil error, got (%#v)", err) 41 | } 42 | if err != tc.Expected { 43 | t.Fatalf("expected error (%#v), got (%#v)", tc.Expected, err) 44 | } 45 | }) 46 | } 47 | } 48 | 49 | func Test_AppendRateLimitedError(t *testing.T) { 50 | awsErr := awserr.New("Throttling", "", nil) 51 | testCases := []struct { 52 | Name string 53 | Err error 54 | Expected error 55 | }{ 56 | { 57 | Name: "Something not checked", 58 | Err: fmt.Errorf("something"), 59 | Expected: fmt.Errorf("something"), 60 | }, 61 | { 62 | Name: "Upstream throttle error", 63 | Err: awsErr, 64 | Expected: multierror.Append(awsErr, ErrUpstreamRateLimited), 65 | }, 66 | { 67 | Name: "Nil", 68 | }, 69 | } 70 | 71 | for _, tc := range testCases { 72 | t.Run(tc.Name, func(t *testing.T) { 73 | err := AppendAWSError(tc.Err) 74 | if err == nil && tc.Expected != nil { 75 | t.Fatalf("expected non-nil error (%#v), got nil", tc.Expected) 76 | } 77 | if err != nil && tc.Expected == nil { 78 | t.Fatalf("expected nil error, got (%#v)", err) 79 | } 80 | if err == nil && tc.Expected == nil { 81 | return 82 | } 83 | if err.Error() != tc.Expected.Error() { 84 | t.Fatalf("expected error (%#v), got (%#v)", tc.Expected.Error(), err.Error()) 85 | } 86 | }) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /awsutil/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/awsutil 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/aws/aws-sdk-go v1.30.27 7 | github.com/hashicorp/errwrap v1.1.0 8 | github.com/hashicorp/go-cleanhttp v0.5.2 9 | github.com/hashicorp/go-hclog v0.16.2 10 | github.com/hashicorp/go-multierror v1.1.1 11 | github.com/kr/pretty v0.3.0 // indirect 12 | github.com/mattn/go-colorable v0.1.6 // indirect 13 | github.com/pkg/errors v0.9.1 14 | github.com/stretchr/testify v1.5.1 15 | golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 // indirect 16 | gopkg.in/yaml.v2 v2.2.8 // indirect 17 | ) 18 | -------------------------------------------------------------------------------- /awsutil/go.sum: -------------------------------------------------------------------------------- 1 | github.com/aws/aws-sdk-go v1.30.27 h1:9gPjZWVDSoQrBO2AvqrWObS6KAZByfEJxQoCYo4ZfK0= 2 | github.com/aws/aws-sdk-go v1.30.27/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= 3 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 6 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= 8 | github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= 9 | github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 10 | github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 11 | github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= 12 | github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 13 | github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= 14 | github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= 15 | github.com/hashicorp/go-hclog v0.16.2 h1:K4ev2ib4LdQETX5cSZBG0DVLk1jwGqSPXBjdah3veNs= 16 | github.com/hashicorp/go-hclog v0.16.2/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= 17 | github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= 18 | github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= 19 | github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= 20 | github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= 21 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 22 | github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= 23 | github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= 24 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 25 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 26 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 27 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 28 | github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= 29 | github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= 30 | github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= 31 | github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= 32 | github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= 33 | github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= 34 | github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= 35 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 36 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 37 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 38 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 39 | github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= 40 | github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= 41 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 42 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 43 | github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= 44 | github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= 45 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 46 | golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 47 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 48 | golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 49 | golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 50 | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 51 | golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 52 | golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 h1:OjiUf46hAmXblsZdnoSXsEUSKU8r1UEzcL5RVZ4gO9Y= 53 | golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 54 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 55 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 56 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 57 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 58 | gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= 59 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 60 | gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= 61 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 62 | -------------------------------------------------------------------------------- /awsutil/mocks.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "github.com/aws/aws-sdk-go/aws/session" 5 | "github.com/aws/aws-sdk-go/service/iam" 6 | "github.com/aws/aws-sdk-go/service/iam/iamiface" 7 | "github.com/aws/aws-sdk-go/service/sts" 8 | "github.com/aws/aws-sdk-go/service/sts/stsiface" 9 | ) 10 | 11 | // MockOptionErr provides a mock option error for use with testing. 12 | func MockOptionErr(withErr error) Option { 13 | return func(_ *options) error { 14 | return withErr 15 | } 16 | } 17 | 18 | // MockIAM provides a way to mock the AWS IAM API. 19 | type MockIAM struct { 20 | iamiface.IAMAPI 21 | 22 | CreateAccessKeyOutput *iam.CreateAccessKeyOutput 23 | CreateAccessKeyError error 24 | DeleteAccessKeyError error 25 | GetUserOutput *iam.GetUserOutput 26 | GetUserError error 27 | } 28 | 29 | // MockIAMOption is a function for setting the various fields on a MockIAM 30 | // object. 31 | type MockIAMOption func(m *MockIAM) error 32 | 33 | // WithCreateAccessKeyOutput sets the output for the CreateAccessKey method. 34 | func WithCreateAccessKeyOutput(o *iam.CreateAccessKeyOutput) MockIAMOption { 35 | return func(m *MockIAM) error { 36 | m.CreateAccessKeyOutput = o 37 | return nil 38 | } 39 | } 40 | 41 | // WithCreateAccessKeyError sets the error output for the CreateAccessKey 42 | // method. 43 | func WithCreateAccessKeyError(e error) MockIAMOption { 44 | return func(m *MockIAM) error { 45 | m.CreateAccessKeyError = e 46 | return nil 47 | } 48 | } 49 | 50 | // WithDeleteAccessKeyError sets the error output for the DeleteAccessKey 51 | // method. 52 | func WithDeleteAccessKeyError(e error) MockIAMOption { 53 | return func(m *MockIAM) error { 54 | m.DeleteAccessKeyError = e 55 | return nil 56 | } 57 | } 58 | 59 | // WithGetUserOutput sets the output for the GetUser method. 60 | func WithGetUserOutput(o *iam.GetUserOutput) MockIAMOption { 61 | return func(m *MockIAM) error { 62 | m.GetUserOutput = o 63 | return nil 64 | } 65 | } 66 | 67 | // WithGetUserError sets the error output for the GetUser method. 68 | func WithGetUserError(e error) MockIAMOption { 69 | return func(m *MockIAM) error { 70 | m.GetUserError = e 71 | return nil 72 | } 73 | } 74 | 75 | // NewMockIAM provides a factory function to use with the WithIAMAPIFunc 76 | // option. 77 | func NewMockIAM(opts ...MockIAMOption) IAMAPIFunc { 78 | return func(_ *session.Session) (iamiface.IAMAPI, error) { 79 | m := new(MockIAM) 80 | for _, opt := range opts { 81 | if err := opt(m); err != nil { 82 | return nil, err 83 | } 84 | } 85 | 86 | return m, nil 87 | } 88 | } 89 | 90 | func (m *MockIAM) CreateAccessKey(*iam.CreateAccessKeyInput) (*iam.CreateAccessKeyOutput, error) { 91 | if m.CreateAccessKeyError != nil { 92 | return nil, m.CreateAccessKeyError 93 | } 94 | 95 | return m.CreateAccessKeyOutput, nil 96 | } 97 | 98 | func (m *MockIAM) DeleteAccessKey(*iam.DeleteAccessKeyInput) (*iam.DeleteAccessKeyOutput, error) { 99 | return &iam.DeleteAccessKeyOutput{}, m.DeleteAccessKeyError 100 | } 101 | 102 | func (m *MockIAM) GetUser(*iam.GetUserInput) (*iam.GetUserOutput, error) { 103 | if m.GetUserError != nil { 104 | return nil, m.GetUserError 105 | } 106 | 107 | return m.GetUserOutput, nil 108 | } 109 | 110 | // MockSTS provides a way to mock the AWS STS API. 111 | type MockSTS struct { 112 | stsiface.STSAPI 113 | 114 | GetCallerIdentityOutput *sts.GetCallerIdentityOutput 115 | GetCallerIdentityError error 116 | } 117 | 118 | // MockSTSOption is a function for setting the various fields on a MockSTS 119 | // object. 120 | type MockSTSOption func(m *MockSTS) error 121 | 122 | // WithGetCallerIdentityOutput sets the output for the GetCallerIdentity 123 | // method. 124 | func WithGetCallerIdentityOutput(o *sts.GetCallerIdentityOutput) MockSTSOption { 125 | return func(m *MockSTS) error { 126 | m.GetCallerIdentityOutput = o 127 | return nil 128 | } 129 | } 130 | 131 | // WithGetCallerIdentityError sets the error output for the GetCallerIdentity 132 | // method. 133 | func WithGetCallerIdentityError(e error) MockSTSOption { 134 | return func(m *MockSTS) error { 135 | m.GetCallerIdentityError = e 136 | return nil 137 | } 138 | } 139 | 140 | // NewMockSTS provides a factory function to use with the WithSTSAPIFunc 141 | // option. 142 | // 143 | // If withGetCallerIdentityError is supplied, calls to GetCallerIdentity will 144 | // return the supplied error. Otherwise, a basic mock API output is returned. 145 | func NewMockSTS(opts ...MockSTSOption) STSAPIFunc { 146 | return func(_ *session.Session) (stsiface.STSAPI, error) { 147 | m := new(MockSTS) 148 | for _, opt := range opts { 149 | if err := opt(m); err != nil { 150 | return nil, err 151 | } 152 | } 153 | 154 | return m, nil 155 | } 156 | } 157 | 158 | func (m *MockSTS) GetCallerIdentity(_ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { 159 | if m.GetCallerIdentityError != nil { 160 | return nil, m.GetCallerIdentityError 161 | } 162 | 163 | return m.GetCallerIdentityOutput, nil 164 | } 165 | -------------------------------------------------------------------------------- /awsutil/mocks_test.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/aws/aws-sdk-go/aws" 8 | "github.com/aws/aws-sdk-go/service/iam" 9 | "github.com/aws/aws-sdk-go/service/sts" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestMockIAM(t *testing.T) { 15 | cases := []struct { 16 | name string 17 | opts []MockIAMOption 18 | expectedCreateAccessKeyOutput *iam.CreateAccessKeyOutput 19 | expectedCreateAccessKeyError error 20 | expectedDeleteAccessKeyError error 21 | expectedGetUserOutput *iam.GetUserOutput 22 | expectedGetUserError error 23 | }{ 24 | { 25 | name: "CreateAccessKeyOutput", 26 | opts: []MockIAMOption{WithCreateAccessKeyOutput( 27 | &iam.CreateAccessKeyOutput{ 28 | AccessKey: &iam.AccessKey{ 29 | AccessKeyId: aws.String("foobar"), 30 | SecretAccessKey: aws.String("bazqux"), 31 | }, 32 | }, 33 | )}, 34 | expectedCreateAccessKeyOutput: &iam.CreateAccessKeyOutput{ 35 | AccessKey: &iam.AccessKey{ 36 | AccessKeyId: aws.String("foobar"), 37 | SecretAccessKey: aws.String("bazqux"), 38 | }, 39 | }, 40 | }, 41 | { 42 | name: "CreateAccessKeyError", 43 | opts: []MockIAMOption{WithCreateAccessKeyError(errors.New("testerr"))}, 44 | expectedCreateAccessKeyError: errors.New("testerr"), 45 | }, 46 | { 47 | name: "DeleteAccessKeyError", 48 | opts: []MockIAMOption{WithDeleteAccessKeyError(errors.New("testerr"))}, 49 | expectedDeleteAccessKeyError: errors.New("testerr"), 50 | }, 51 | { 52 | name: "GetUserOutput", 53 | opts: []MockIAMOption{WithGetUserOutput( 54 | &iam.GetUserOutput{ 55 | User: &iam.User{ 56 | Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), 57 | UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), 58 | UserName: aws.String("JohnDoe"), 59 | }, 60 | }, 61 | )}, 62 | expectedGetUserOutput: &iam.GetUserOutput{ 63 | User: &iam.User{ 64 | Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), 65 | UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), 66 | UserName: aws.String("JohnDoe"), 67 | }, 68 | }, 69 | }, 70 | { 71 | name: "GetUserError", 72 | opts: []MockIAMOption{WithGetUserError(errors.New("testerr"))}, 73 | expectedGetUserError: errors.New("testerr"), 74 | }, 75 | } 76 | 77 | for _, tc := range cases { 78 | tc := tc 79 | t.Run(tc.name, func(t *testing.T) { 80 | assert := assert.New(t) 81 | require := require.New(t) 82 | 83 | f := NewMockIAM(tc.opts...) 84 | m, err := f(nil) 85 | require.NoError(err) // Nothing returns an error right now 86 | actualCreateAccessKeyOutput, actualCreateAccessKeyError := m.CreateAccessKey(nil) 87 | _, actualDeleteAccessKeyError := m.DeleteAccessKey(nil) 88 | actualGetUserOutput, actualGetUserError := m.GetUser(nil) 89 | assert.Equal(tc.expectedCreateAccessKeyOutput, actualCreateAccessKeyOutput) 90 | assert.Equal(tc.expectedCreateAccessKeyError, actualCreateAccessKeyError) 91 | assert.Equal(tc.expectedDeleteAccessKeyError, actualDeleteAccessKeyError) 92 | assert.Equal(tc.expectedGetUserOutput, actualGetUserOutput) 93 | assert.Equal(tc.expectedGetUserError, actualGetUserError) 94 | }) 95 | } 96 | } 97 | 98 | func TestMockSTS(t *testing.T) { 99 | cases := []struct { 100 | name string 101 | opts []MockSTSOption 102 | expectedGetCallerIdentityOutput *sts.GetCallerIdentityOutput 103 | expectedGetCallerIdentityError error 104 | }{ 105 | { 106 | name: "GetCallerIdentityOutput", 107 | opts: []MockSTSOption{WithGetCallerIdentityOutput( 108 | &sts.GetCallerIdentityOutput{ 109 | Account: aws.String("1234567890"), 110 | Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), 111 | UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), 112 | }, 113 | )}, 114 | expectedGetCallerIdentityOutput: &sts.GetCallerIdentityOutput{ 115 | Account: aws.String("1234567890"), 116 | Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), 117 | UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), 118 | }, 119 | }, 120 | { 121 | name: "GetCallerIdentityError", 122 | opts: []MockSTSOption{WithGetCallerIdentityError(errors.New("testerr"))}, 123 | expectedGetCallerIdentityError: errors.New("testerr"), 124 | }, 125 | } 126 | 127 | for _, tc := range cases { 128 | tc := tc 129 | t.Run(tc.name, func(t *testing.T) { 130 | assert := assert.New(t) 131 | require := require.New(t) 132 | 133 | f := NewMockSTS(tc.opts...) 134 | m, err := f(nil) 135 | require.NoError(err) // Nothing returns an error right now 136 | actualGetCallerIdentityOutput, actualGetCallerIdentityError := m.GetCallerIdentity(nil) 137 | assert.Equal(tc.expectedGetCallerIdentityOutput, actualGetCallerIdentityOutput) 138 | assert.Equal(tc.expectedGetCallerIdentityError, actualGetCallerIdentityError) 139 | }) 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /awsutil/options.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/aws/aws-sdk-go/aws/session" 9 | "github.com/hashicorp/go-hclog" 10 | ) 11 | 12 | // getOpts iterates the inbound Options and returns a struct 13 | func getOpts(opt ...Option) (options, error) { 14 | opts := getDefaultOptions() 15 | for _, o := range opt { 16 | if o == nil { 17 | continue 18 | } 19 | if err := o(&opts); err != nil { 20 | return options{}, err 21 | } 22 | } 23 | return opts, nil 24 | } 25 | 26 | // Option - how Options are passed as arguments 27 | type Option func(*options) error 28 | 29 | // options = how options are represented 30 | type options struct { 31 | withEnvironmentCredentials bool 32 | withSharedCredentials bool 33 | withAwsSession *session.Session 34 | withClientType string 35 | withUsername string 36 | withAccessKey string 37 | withSecretKey string 38 | withLogger hclog.Logger 39 | withStsEndpoint string 40 | withIamEndpoint string 41 | withMaxRetries *int 42 | withRegion string 43 | withHttpClient *http.Client 44 | withValidityCheckTimeout time.Duration 45 | withIAMAPIFunc IAMAPIFunc 46 | withSTSAPIFunc STSAPIFunc 47 | } 48 | 49 | func getDefaultOptions() options { 50 | return options{ 51 | withEnvironmentCredentials: true, 52 | withSharedCredentials: true, 53 | withClientType: "iam", 54 | } 55 | } 56 | 57 | // WithEnvironmentCredentials allows controlling whether environment credentials 58 | // are used 59 | func WithEnvironmentCredentials(with bool) Option { 60 | return func(o *options) error { 61 | o.withEnvironmentCredentials = with 62 | return nil 63 | } 64 | } 65 | 66 | // WithSharedCredentials allows controlling whether shared credentials are used 67 | func WithSharedCredentials(with bool) Option { 68 | return func(o *options) error { 69 | o.withSharedCredentials = with 70 | return nil 71 | } 72 | } 73 | 74 | // WithAwsSession allows controlling the session passed into the client 75 | func WithAwsSession(with *session.Session) Option { 76 | return func(o *options) error { 77 | o.withAwsSession = with 78 | return nil 79 | } 80 | } 81 | 82 | // WithClientType allows choosing the client type to use 83 | func WithClientType(with string) Option { 84 | return func(o *options) error { 85 | switch with { 86 | case "iam", "sts": 87 | default: 88 | return fmt.Errorf("unsupported client type %q", with) 89 | } 90 | o.withClientType = with 91 | return nil 92 | } 93 | } 94 | 95 | // WithUsername allows passing the user name to use for an operation 96 | func WithUsername(with string) Option { 97 | return func(o *options) error { 98 | o.withUsername = with 99 | return nil 100 | } 101 | } 102 | 103 | // WithAccessKey allows passing an access key to use for operations 104 | func WithAccessKey(with string) Option { 105 | return func(o *options) error { 106 | o.withAccessKey = with 107 | return nil 108 | } 109 | } 110 | 111 | // WithSecretKey allows passing a secret key to use for operations 112 | func WithSecretKey(with string) Option { 113 | return func(o *options) error { 114 | o.withSecretKey = with 115 | return nil 116 | } 117 | } 118 | 119 | // WithStsEndpoint allows passing a custom STS endpoint 120 | func WithStsEndpoint(with string) Option { 121 | return func(o *options) error { 122 | o.withStsEndpoint = with 123 | return nil 124 | } 125 | } 126 | 127 | // WithIamEndpoint allows passing a custom IAM endpoint 128 | func WithIamEndpoint(with string) Option { 129 | return func(o *options) error { 130 | o.withIamEndpoint = with 131 | return nil 132 | } 133 | } 134 | 135 | // WithRegion allows passing a custom region 136 | func WithRegion(with string) Option { 137 | return func(o *options) error { 138 | o.withRegion = with 139 | return nil 140 | } 141 | } 142 | 143 | // WithLogger allows passing a logger to use 144 | func WithLogger(with hclog.Logger) Option { 145 | return func(o *options) error { 146 | o.withLogger = with 147 | return nil 148 | } 149 | } 150 | 151 | // WithMaxRetries allows passing custom max retries to set 152 | func WithMaxRetries(with *int) Option { 153 | return func(o *options) error { 154 | o.withMaxRetries = with 155 | return nil 156 | } 157 | } 158 | 159 | // WithHttpClient allows passing a custom client to use 160 | func WithHttpClient(with *http.Client) Option { 161 | return func(o *options) error { 162 | o.withHttpClient = with 163 | return nil 164 | } 165 | } 166 | 167 | // WithValidityCheckTimeout allows passing a timeout for operations that can wait 168 | // on success. 169 | func WithValidityCheckTimeout(with time.Duration) Option { 170 | return func(o *options) error { 171 | o.withValidityCheckTimeout = with 172 | return nil 173 | } 174 | } 175 | 176 | // WithIAMAPIFunc allows passing in an IAM interface constructor for mocking 177 | // the AWS IAM API. 178 | func WithIAMAPIFunc(with IAMAPIFunc) Option { 179 | return func(o *options) error { 180 | o.withIAMAPIFunc = with 181 | return nil 182 | } 183 | } 184 | 185 | // WithSTSAPIFunc allows passing in a STS interface constructor for mocking the 186 | // AWS STS API. 187 | func WithSTSAPIFunc(with STSAPIFunc) Option { 188 | return func(o *options) error { 189 | o.withSTSAPIFunc = with 190 | return nil 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /awsutil/options_test.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | "time" 7 | 8 | "github.com/aws/aws-sdk-go/aws" 9 | "github.com/aws/aws-sdk-go/aws/session" 10 | "github.com/hashicorp/go-hclog" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func Test_GetOpts(t *testing.T) { 16 | t.Parallel() 17 | t.Run("default", func(t *testing.T) { 18 | testOpts := getDefaultOptions() 19 | assert.Equal(t, true, testOpts.withEnvironmentCredentials) 20 | assert.Equal(t, true, testOpts.withSharedCredentials) 21 | assert.Nil(t, testOpts.withAwsSession) 22 | assert.Equal(t, "iam", testOpts.withClientType) 23 | }) 24 | t.Run("withEnvironmentCredentials", func(t *testing.T) { 25 | opts, err := getOpts(WithEnvironmentCredentials(false)) 26 | require.NoError(t, err) 27 | testOpts := getDefaultOptions() 28 | testOpts.withEnvironmentCredentials = false 29 | assert.Equal(t, opts, testOpts) 30 | }) 31 | t.Run("withSharedCredentials", func(t *testing.T) { 32 | opts, err := getOpts(WithSharedCredentials(false)) 33 | require.NoError(t, err) 34 | testOpts := getDefaultOptions() 35 | testOpts.withSharedCredentials = false 36 | assert.Equal(t, opts, testOpts) 37 | }) 38 | t.Run("withAwsSession", func(t *testing.T) { 39 | sess := new(session.Session) 40 | opts, err := getOpts(WithAwsSession(sess)) 41 | require.NoError(t, err) 42 | testOpts := getDefaultOptions() 43 | testOpts.withAwsSession = sess 44 | assert.Equal(t, opts, testOpts) 45 | }) 46 | t.Run("withUsername", func(t *testing.T) { 47 | opts, err := getOpts(WithUsername("foobar")) 48 | require.NoError(t, err) 49 | testOpts := getDefaultOptions() 50 | testOpts.withUsername = "foobar" 51 | assert.Equal(t, opts, testOpts) 52 | }) 53 | t.Run("withClientType", func(t *testing.T) { 54 | _, err := getOpts(WithClientType("foobar")) 55 | require.Error(t, err) 56 | opts, err := getOpts(WithClientType("sts")) 57 | require.NoError(t, err) 58 | testOpts := getDefaultOptions() 59 | testOpts.withClientType = "sts" 60 | assert.Equal(t, opts, testOpts) 61 | }) 62 | t.Run("withAccessKey", func(t *testing.T) { 63 | opts, err := getOpts(WithAccessKey("foobar")) 64 | require.NoError(t, err) 65 | testOpts := getDefaultOptions() 66 | testOpts.withAccessKey = "foobar" 67 | assert.Equal(t, opts, testOpts) 68 | }) 69 | t.Run("withSecretKey", func(t *testing.T) { 70 | opts, err := getOpts(WithSecretKey("foobar")) 71 | require.NoError(t, err) 72 | testOpts := getDefaultOptions() 73 | testOpts.withSecretKey = "foobar" 74 | assert.Equal(t, opts, testOpts) 75 | }) 76 | t.Run("withStsEndpoint", func(t *testing.T) { 77 | opts, err := getOpts(WithStsEndpoint("foobar")) 78 | require.NoError(t, err) 79 | testOpts := getDefaultOptions() 80 | testOpts.withStsEndpoint = "foobar" 81 | assert.Equal(t, opts, testOpts) 82 | }) 83 | t.Run("withIamEndpoint", func(t *testing.T) { 84 | opts, err := getOpts(WithIamEndpoint("foobar")) 85 | require.NoError(t, err) 86 | testOpts := getDefaultOptions() 87 | testOpts.withIamEndpoint = "foobar" 88 | assert.Equal(t, opts, testOpts) 89 | }) 90 | t.Run("withLogger", func(t *testing.T) { 91 | logger := hclog.New(nil) 92 | opts, err := getOpts(WithLogger(logger)) 93 | require.NoError(t, err) 94 | assert.Equal(t, &opts.withLogger, &logger) 95 | }) 96 | t.Run("withRegion", func(t *testing.T) { 97 | opts, err := getOpts(WithRegion("foobar")) 98 | require.NoError(t, err) 99 | testOpts := getDefaultOptions() 100 | testOpts.withRegion = "foobar" 101 | assert.Equal(t, opts, testOpts) 102 | }) 103 | t.Run("withMaxRetries", func(t *testing.T) { 104 | opts, err := getOpts(WithMaxRetries(aws.Int(5))) 105 | require.NoError(t, err) 106 | testOpts := getDefaultOptions() 107 | testOpts.withMaxRetries = aws.Int(5) 108 | assert.Equal(t, opts, testOpts) 109 | }) 110 | t.Run("withHttpClient", func(t *testing.T) { 111 | client := &http.Client{} 112 | opts, err := getOpts(WithHttpClient(client)) 113 | require.NoError(t, err) 114 | assert.Equal(t, &opts.withHttpClient, &client) 115 | }) 116 | t.Run("withValidityCheckTimeout", func(t *testing.T) { 117 | opts, err := getOpts(WithValidityCheckTimeout(time.Second)) 118 | require.NoError(t, err) 119 | assert.Equal(t, opts.withValidityCheckTimeout, time.Second) 120 | }) 121 | t.Run("withIAMIface", func(t *testing.T) { 122 | opts, err := getOpts(WithIAMAPIFunc(NewMockIAM())) 123 | require.NoError(t, err) 124 | assert.NotNil(t, opts.withIAMAPIFunc) 125 | }) 126 | t.Run("withSTSIface", func(t *testing.T) { 127 | opts, err := getOpts(WithSTSAPIFunc(NewMockSTS())) 128 | require.NoError(t, err) 129 | assert.NotNil(t, opts.withSTSAPIFunc) 130 | }) 131 | } 132 | -------------------------------------------------------------------------------- /awsutil/region.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | 7 | "github.com/aws/aws-sdk-go/aws" 8 | "github.com/aws/aws-sdk-go/aws/ec2metadata" 9 | "github.com/aws/aws-sdk-go/aws/session" 10 | "github.com/hashicorp/errwrap" 11 | ) 12 | 13 | // "us-east-1 is used because it's where AWS first provides support for new features, 14 | // is a widely used region, and is the most common one for some services like STS. 15 | const DefaultRegion = "us-east-1" 16 | 17 | // This is nil by default, but is exposed in case it needs to be changed for tests. 18 | var ec2Endpoint *string 19 | 20 | /* 21 | It's impossible to mimic "normal" AWS behavior here because it's not consistent 22 | or well-defined. For example, boto3, the Python SDK (which the aws cli uses), 23 | loads `~/.aws/config` by default and only reads the `AWS_DEFAULT_REGION` environment 24 | variable (and not `AWS_REGION`, while the golang SDK does _mostly_ the opposite -- it 25 | reads the region **only** from `AWS_REGION` and not at all `~/.aws/config`, **unless** 26 | the `AWS_SDK_LOAD_CONFIG` environment variable is set. So, we must define our own 27 | approach to walking AWS config and deciding what to use. 28 | 29 | Our chosen approach is: 30 | 31 | "More specific takes precedence over less specific." 32 | 33 | 1. User-provided configuration is the most explicit. 34 | 2. Environment variables are potentially shared across many invocations and so they have less precedence. 35 | 3. Configuration in `~/.aws/config` is shared across all invocations of a given user and so this has even less precedence. 36 | 4. Configuration retrieved from the EC2 instance metadata service is shared by all invocations on a given machine, and so it has the lowest precedence. 37 | 38 | This approach should be used in future updates to this logic. 39 | */ 40 | func GetRegion(configuredRegion string) (string, error) { 41 | if configuredRegion != "" { 42 | return configuredRegion, nil 43 | } 44 | 45 | sess, err := session.NewSessionWithOptions(session.Options{ 46 | SharedConfigState: session.SharedConfigEnable, 47 | }) 48 | if err != nil { 49 | return "", errwrap.Wrapf("got error when starting session: {{err}}", err) 50 | } 51 | 52 | region := aws.StringValue(sess.Config.Region) 53 | if region != "" { 54 | return region, nil 55 | } 56 | 57 | metadata := ec2metadata.New(sess, &aws.Config{ 58 | Endpoint: ec2Endpoint, 59 | EC2MetadataDisableTimeoutOverride: aws.Bool(true), 60 | HTTPClient: &http.Client{ 61 | Timeout: time.Second, 62 | }, 63 | }) 64 | if !metadata.Available() { 65 | return DefaultRegion, nil 66 | } 67 | 68 | region, err = metadata.Region() 69 | if err != nil { 70 | return "", errwrap.Wrapf("unable to retrieve region from instance metadata: {{err}}", err) 71 | } 72 | 73 | return region, nil 74 | } 75 | -------------------------------------------------------------------------------- /awsutil/region_test.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "net/http" 7 | "net/http/httptest" 8 | "os" 9 | "os/user" 10 | "testing" 11 | 12 | "github.com/aws/aws-sdk-go/aws" 13 | ) 14 | 15 | const testConfigFile = `[default] 16 | region=%s 17 | output=json` 18 | 19 | var ( 20 | shouldTestFiles = os.Getenv("VAULT_ACC_AWS_FILES") == "1" 21 | 22 | expectedTestRegion = "us-west-2" 23 | unexpectedTestRegion = "us-east-2" 24 | regionEnvKeys = []string{"AWS_REGION", "AWS_DEFAULT_REGION"} 25 | ) 26 | 27 | func TestGetRegion_UserConfigPreferredFirst(t *testing.T) { 28 | configuredRegion := expectedTestRegion 29 | 30 | cleanupEnv := setEnvRegion(t, unexpectedTestRegion) 31 | defer cleanupEnv() 32 | 33 | cleanupFile := setConfigFileRegion(t, unexpectedTestRegion) 34 | defer cleanupFile() 35 | 36 | cleanupMetadata := setInstanceMetadata(t, unexpectedTestRegion) 37 | defer cleanupMetadata() 38 | 39 | result, err := GetRegion(configuredRegion) 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | if result != expectedTestRegion { 44 | t.Fatalf("expected: %s; actual: %s", expectedTestRegion, result) 45 | } 46 | } 47 | 48 | func TestGetRegion_EnvVarsPreferredSecond(t *testing.T) { 49 | configuredRegion := "" 50 | 51 | cleanupEnv := setEnvRegion(t, expectedTestRegion) 52 | defer cleanupEnv() 53 | 54 | cleanupFile := setConfigFileRegion(t, unexpectedTestRegion) 55 | defer cleanupFile() 56 | 57 | cleanupMetadata := setInstanceMetadata(t, unexpectedTestRegion) 58 | defer cleanupMetadata() 59 | 60 | result, err := GetRegion(configuredRegion) 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | if result != expectedTestRegion { 65 | t.Fatalf("expected: %s; actual: %s", expectedTestRegion, result) 66 | } 67 | } 68 | 69 | func TestGetRegion_ConfigFilesPreferredThird(t *testing.T) { 70 | if !shouldTestFiles { 71 | // In some test environments, like a CI environment, we may not have the 72 | // permissions to write to the ~/.aws/config file. Thus, this test is off 73 | // by default but can be set to on for local development. 74 | t.SkipNow() 75 | } 76 | configuredRegion := "" 77 | 78 | cleanupEnv := setEnvRegion(t, "") 79 | defer cleanupEnv() 80 | 81 | cleanupFile := setConfigFileRegion(t, expectedTestRegion) 82 | defer cleanupFile() 83 | 84 | cleanupMetadata := setInstanceMetadata(t, unexpectedTestRegion) 85 | defer cleanupMetadata() 86 | 87 | result, err := GetRegion(configuredRegion) 88 | if err != nil { 89 | t.Fatal(err) 90 | } 91 | if result != expectedTestRegion { 92 | t.Fatalf("expected: %s; actual: %s", expectedTestRegion, result) 93 | } 94 | } 95 | 96 | func TestGetRegion_ConfigFileUnfound(t *testing.T) { 97 | if enabled := os.Getenv("VAULT_ACC"); enabled == "" { 98 | t.Skip() 99 | } 100 | 101 | configuredRegion := "" 102 | cleanupEnv := setEnvRegion(t, "") 103 | defer cleanupEnv() 104 | 105 | if err := os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "foo"); err != nil { 106 | t.Fatal(err) 107 | } 108 | defer func() { 109 | if err := os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE"); err != nil { 110 | t.Fatal(err) 111 | } 112 | }() 113 | 114 | result, err := GetRegion(configuredRegion) 115 | if err != nil { 116 | t.Fatal(err) 117 | } 118 | if result != DefaultRegion { 119 | t.Fatalf("expected: %s; actual: %s", DefaultRegion, result) 120 | } 121 | } 122 | 123 | func TestGetRegion_EC2InstanceMetadataPreferredFourth(t *testing.T) { 124 | if !shouldTestFiles { 125 | // In some test environments, like a CI environment, we may not have the 126 | // permissions to write to the ~/.aws/config file. Thus, this test is off 127 | // by default but can be set to on for local development. 128 | t.SkipNow() 129 | } 130 | configuredRegion := "" 131 | 132 | cleanupEnv := setEnvRegion(t, "") 133 | defer cleanupEnv() 134 | 135 | cleanupFile := setConfigFileRegion(t, "") 136 | defer cleanupFile() 137 | 138 | cleanupMetadata := setInstanceMetadata(t, expectedTestRegion) 139 | defer cleanupMetadata() 140 | 141 | result, err := GetRegion(configuredRegion) 142 | if err != nil { 143 | t.Fatal(err) 144 | } 145 | if result != expectedTestRegion { 146 | t.Fatalf("expected: %s; actual: %s", expectedTestRegion, result) 147 | } 148 | } 149 | 150 | func TestGetRegion_DefaultsToDefaultRegionWhenRegionUnavailable(t *testing.T) { 151 | if enabled := os.Getenv("VAULT_ACC"); enabled == "" { 152 | t.Skip() 153 | } 154 | 155 | configuredRegion := "" 156 | 157 | cleanupEnv := setEnvRegion(t, "") 158 | defer cleanupEnv() 159 | 160 | cleanupFile := setConfigFileRegion(t, "") 161 | defer cleanupFile() 162 | 163 | result, err := GetRegion(configuredRegion) 164 | if err != nil { 165 | t.Fatal(err) 166 | } 167 | if result != DefaultRegion { 168 | t.Fatalf("expected: %s; actual: %s", DefaultRegion, result) 169 | } 170 | } 171 | 172 | func setEnvRegion(t *testing.T, region string) (cleanup func()) { 173 | for _, envKey := range regionEnvKeys { 174 | if err := os.Setenv(envKey, region); err != nil { 175 | t.Fatal(err) 176 | } 177 | } 178 | cleanup = func() { 179 | for _, envKey := range regionEnvKeys { 180 | if err := os.Unsetenv(envKey); err != nil { 181 | t.Fatal(err) 182 | } 183 | } 184 | } 185 | return 186 | } 187 | 188 | func setConfigFileRegion(t *testing.T, region string) (cleanup func()) { 189 | var cleanupFuncs []func() 190 | 191 | cleanup = func() { 192 | for _, f := range cleanupFuncs { 193 | f() 194 | } 195 | } 196 | 197 | if !shouldTestFiles { 198 | return 199 | } 200 | 201 | usr, err := user.Current() 202 | if err != nil { 203 | t.Fatal(err) 204 | } 205 | 206 | pathToAWSDir := usr.HomeDir + "/.aws" 207 | pathToConfig := pathToAWSDir + "/config" 208 | 209 | preExistingConfig, err := ioutil.ReadFile(pathToConfig) 210 | if err != nil { 211 | // File simply doesn't exist. 212 | if err := os.Mkdir(pathToAWSDir, os.ModeDir); err != nil { 213 | t.Fatal(err) 214 | } 215 | cleanupFuncs = append(cleanupFuncs, func() { 216 | if err := os.RemoveAll(pathToAWSDir); err != nil { 217 | t.Fatal(err) 218 | } 219 | }) 220 | } else { 221 | cleanupFuncs = append(cleanupFuncs, func() { 222 | if err := ioutil.WriteFile(pathToConfig, preExistingConfig, 0o644); err != nil { 223 | t.Fatal(err) 224 | } 225 | }) 226 | } 227 | fileBody := fmt.Sprintf(testConfigFile, region) 228 | if err := ioutil.WriteFile(pathToConfig, []byte(fileBody), 0o644); err != nil { 229 | t.Fatal(err) 230 | } 231 | 232 | if err := os.Setenv("AWS_SHARED_CREDENTIALS_FILE", pathToConfig); err != nil { 233 | t.Fatal(err) 234 | } 235 | cleanupFuncs = append(cleanupFuncs, func() { 236 | if err := os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE"); err != nil { 237 | t.Fatal(err) 238 | } 239 | }) 240 | 241 | return 242 | } 243 | 244 | func setInstanceMetadata(t *testing.T, region string) (cleanup func()) { 245 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 246 | reqPath := r.URL.String() 247 | switch reqPath { 248 | case "/latest/meta-data/instance-id": 249 | w.Write([]byte("i-1234567890abcdef0")) 250 | return 251 | case "/latest/meta-data/placement/availability-zone": 252 | // add a letter suffix, as a normal response is formatted like "us-east-1a" 253 | w.Write([]byte(region + "a")) 254 | return 255 | default: 256 | t.Fatalf("received unexpected request path: %s", reqPath) 257 | } 258 | })) 259 | ec2Endpoint = aws.String(ts.URL) 260 | cleanup = func() { 261 | ts.Close() 262 | ec2Endpoint = nil 263 | } 264 | return 265 | } 266 | -------------------------------------------------------------------------------- /awsutil/rotate.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/aws/aws-sdk-go/aws" 10 | "github.com/aws/aws-sdk-go/aws/session" 11 | "github.com/aws/aws-sdk-go/service/iam" 12 | "github.com/aws/aws-sdk-go/service/sts" 13 | ) 14 | 15 | // RotateKeys takes the access key and secret key from this credentials config 16 | // and first creates a new access/secret key, then deletes the old access key. 17 | // If deletion of the old access key is successful, the new access key/secret 18 | // key are written into the credentials config and nil is returned. On any 19 | // error, the old credentials are not overwritten. This ensures that any 20 | // generated new secret key never leaves this function in case of an error, even 21 | // though it will still result in an extraneous access key existing; we do also 22 | // try to delete the new one to clean up, although it's unlikely that will work 23 | // if the old one could not be deleted. 24 | // 25 | // Supported options: WithEnvironmentCredentials, WithSharedCredentials, 26 | // WithAwsSession, WithUsername, WithValidityCheckTimeout, WithIAMAPIFunc, 27 | // WithSTSAPIFunc 28 | // 29 | // Note that WithValidityCheckTimeout here, when non-zero, controls the 30 | // WithValidityCheckTimeout option on access key creation. See CreateAccessKey 31 | // for more details. 32 | func (c *CredentialsConfig) RotateKeys(opt ...Option) error { 33 | if c.AccessKey == "" || c.SecretKey == "" { 34 | return errors.New("cannot rotate credentials when either access_key or secret_key is empty") 35 | } 36 | 37 | opts, err := getOpts(opt...) 38 | if err != nil { 39 | return fmt.Errorf("error reading options in RotateKeys: %w", err) 40 | } 41 | 42 | sess := opts.withAwsSession 43 | if sess == nil { 44 | sess, err = c.GetSession(opt...) 45 | if err != nil { 46 | return fmt.Errorf("error calling GetSession: %w", err) 47 | } 48 | } 49 | 50 | sessOpt := append(opt, WithAwsSession(sess)) 51 | createAccessKeyRes, err := c.CreateAccessKey(sessOpt...) 52 | if err != nil { 53 | return fmt.Errorf("error calling CreateAccessKey: %w", err) 54 | } 55 | 56 | err = c.DeleteAccessKey(c.AccessKey, append(sessOpt, WithUsername(*createAccessKeyRes.AccessKey.UserName))...) 57 | if err != nil { 58 | return fmt.Errorf("error deleting old access key: %w", err) 59 | } 60 | 61 | c.AccessKey = *createAccessKeyRes.AccessKey.AccessKeyId 62 | c.SecretKey = *createAccessKeyRes.AccessKey.SecretAccessKey 63 | 64 | return nil 65 | } 66 | 67 | // CreateAccessKey creates a new access/secret key pair. 68 | // 69 | // Supported options: WithEnvironmentCredentials, WithSharedCredentials, 70 | // WithAwsSession, WithUsername, WithValidityCheckTimeout, WithIAMAPIFunc, 71 | // WithSTSAPIFunc 72 | // 73 | // When WithValidityCheckTimeout is non-zero, it specifies a timeout to wait on 74 | // the created credentials to be valid and ready for use. 75 | func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKeyOutput, error) { 76 | opts, err := getOpts(opt...) 77 | if err != nil { 78 | return nil, fmt.Errorf("error reading options in CreateAccessKey: %w", err) 79 | } 80 | 81 | client, err := c.IAMClient(opt...) 82 | if err != nil { 83 | return nil, fmt.Errorf("error loading IAM client: %w", err) 84 | } 85 | 86 | var getUserInput iam.GetUserInput 87 | if opts.withUsername != "" { 88 | getUserInput.SetUserName(opts.withUsername) 89 | } // otherwise, empty input means get current user 90 | getUserRes, err := client.GetUser(&getUserInput) 91 | if err != nil { 92 | return nil, fmt.Errorf("error calling aws.GetUser: %w", err) 93 | } 94 | if getUserRes == nil { 95 | return nil, fmt.Errorf("nil response from aws.GetUser") 96 | } 97 | if getUserRes.User == nil { 98 | return nil, fmt.Errorf("nil user returned from aws.GetUser") 99 | } 100 | if getUserRes.User.UserName == nil { 101 | return nil, fmt.Errorf("nil UserName returned from aws.GetUser") 102 | } 103 | 104 | createAccessKeyInput := iam.CreateAccessKeyInput{ 105 | UserName: getUserRes.User.UserName, 106 | } 107 | createAccessKeyRes, err := client.CreateAccessKey(&createAccessKeyInput) 108 | if err != nil { 109 | return nil, fmt.Errorf("error calling aws.CreateAccessKey: %w", err) 110 | } 111 | if createAccessKeyRes == nil { 112 | return nil, fmt.Errorf("nil response from aws.CreateAccessKey") 113 | } 114 | if createAccessKeyRes.AccessKey == nil { 115 | return nil, fmt.Errorf("nil access key in response from aws.CreateAccessKey") 116 | } 117 | if createAccessKeyRes.AccessKey.AccessKeyId == nil || createAccessKeyRes.AccessKey.SecretAccessKey == nil { 118 | return nil, fmt.Errorf("nil AccessKeyId or SecretAccessKey returned from aws.CreateAccessKey") 119 | } 120 | 121 | // Check the credentials to make sure they are usable. We only do 122 | // this if withValidityCheckTimeout is non-zero to ensue that we don't 123 | // immediately fail due to eventual consistency. 124 | if opts.withValidityCheckTimeout != 0 { 125 | newC := &CredentialsConfig{ 126 | AccessKey: *createAccessKeyRes.AccessKey.AccessKeyId, 127 | SecretKey: *createAccessKeyRes.AccessKey.SecretAccessKey, 128 | } 129 | 130 | if _, err := newC.GetCallerIdentity( 131 | WithValidityCheckTimeout(opts.withValidityCheckTimeout), 132 | WithSTSAPIFunc(opts.withSTSAPIFunc), 133 | ); err != nil { 134 | return nil, fmt.Errorf("error verifying new credentials: %w", err) 135 | } 136 | } 137 | 138 | return createAccessKeyRes, nil 139 | } 140 | 141 | // DeleteAccessKey deletes an access key. 142 | // 143 | // Supported options: WithEnvironmentCredentials, WithSharedCredentials, 144 | // WithAwsSession, WithUserName, WithIAMAPIFunc 145 | func (c *CredentialsConfig) DeleteAccessKey(accessKeyId string, opt ...Option) error { 146 | opts, err := getOpts(opt...) 147 | if err != nil { 148 | return fmt.Errorf("error reading options in RotateKeys: %w", err) 149 | } 150 | 151 | client, err := c.IAMClient(opt...) 152 | if err != nil { 153 | return fmt.Errorf("error loading IAM client: %w", err) 154 | } 155 | 156 | deleteAccessKeyInput := iam.DeleteAccessKeyInput{ 157 | AccessKeyId: aws.String(accessKeyId), 158 | } 159 | if opts.withUsername != "" { 160 | deleteAccessKeyInput.SetUserName(opts.withUsername) 161 | } 162 | 163 | _, err = client.DeleteAccessKey(&deleteAccessKeyInput) 164 | if err != nil { 165 | return fmt.Errorf("error deleting old access key: %w", err) 166 | } 167 | 168 | return nil 169 | } 170 | 171 | // GetSession returns an AWS session configured according to the various values 172 | // in the CredentialsConfig object. This can be passed into iam.New or sts.New 173 | // as appropriate. 174 | // 175 | // Supported options: WithEnvironmentCredentials, WithSharedCredentials, 176 | // WithAwsSession, WithClientType 177 | func (c *CredentialsConfig) GetSession(opt ...Option) (*session.Session, error) { 178 | opts, err := getOpts(opt...) 179 | if err != nil { 180 | return nil, fmt.Errorf("error reading options in GetSession: %w", err) 181 | } 182 | 183 | creds, err := c.GenerateCredentialChain(opt...) 184 | if err != nil { 185 | return nil, err 186 | } 187 | 188 | var endpoint string 189 | switch opts.withClientType { 190 | case "sts": 191 | endpoint = c.STSEndpoint 192 | case "iam": 193 | endpoint = c.IAMEndpoint 194 | default: 195 | return nil, fmt.Errorf("unknown client type %q in GetSession", opts.withClientType) 196 | } 197 | 198 | awsConfig := &aws.Config{ 199 | Credentials: creds, 200 | Region: aws.String(c.Region), 201 | Endpoint: aws.String(endpoint), 202 | HTTPClient: c.HTTPClient, 203 | MaxRetries: c.MaxRetries, 204 | } 205 | 206 | sess, err := session.NewSession(awsConfig) 207 | if err != nil { 208 | return nil, fmt.Errorf("error getting new session: %w", err) 209 | } 210 | 211 | return sess, nil 212 | } 213 | 214 | // GetCallerIdentity runs sts.GetCallerIdentity for the current set 215 | // credentials. This can be used to check that credentials are valid, 216 | // in addition to checking details about the effective logged in 217 | // account and user ID. 218 | // 219 | // Supported options: WithEnvironmentCredentials, 220 | // WithSharedCredentials, WithAwsSession, WithValidityCheckTimeout 221 | func (c *CredentialsConfig) GetCallerIdentity(opt ...Option) (*sts.GetCallerIdentityOutput, error) { 222 | opts, err := getOpts(opt...) 223 | if err != nil { 224 | return nil, fmt.Errorf("error reading options in GetCallerIdentity: %w", err) 225 | } 226 | 227 | client, err := c.STSClient(opt...) 228 | if err != nil { 229 | return nil, fmt.Errorf("error loading STS client: %w", err) 230 | } 231 | 232 | delay := time.Second 233 | timeoutCtx, cancel := context.WithTimeout(context.Background(), opts.withValidityCheckTimeout) 234 | defer cancel() 235 | for { 236 | cid, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{}) 237 | if err == nil { 238 | return cid, nil 239 | } 240 | 241 | // TODO: can add a context here for external cancellation in the future 242 | select { 243 | case <-time.After(delay): 244 | // pass 245 | 246 | case <-timeoutCtx.Done(): 247 | // Format our error based on how we were called. 248 | if opts.withValidityCheckTimeout == 0 { 249 | // There was no timeout, just return the error unwrapped. 250 | return nil, err 251 | } 252 | 253 | // Otherwise, return the error wrapped in a timeout error. 254 | return nil, fmt.Errorf("timeout after %s waiting for success: %w", opts.withValidityCheckTimeout, err) 255 | } 256 | } 257 | } 258 | -------------------------------------------------------------------------------- /awsutil/rotate_test.go: -------------------------------------------------------------------------------- 1 | package awsutil 2 | 3 | import ( 4 | "errors" 5 | "os" 6 | "strings" 7 | "testing" 8 | "time" 9 | 10 | "github.com/aws/aws-sdk-go/aws" 11 | "github.com/aws/aws-sdk-go/aws/awserr" 12 | "github.com/aws/aws-sdk-go/service/iam" 13 | "github.com/aws/aws-sdk-go/service/sts" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | const testRotationWaitTimeout = time.Second * 30 19 | 20 | func TestRotation(t *testing.T) { 21 | require, assert := require.New(t), assert.New(t) 22 | 23 | rootKey, rootSecretKey, sessionToken := os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), os.Getenv("AWS_SESSION_TOKEN") 24 | if rootKey == "" || rootSecretKey == "" { 25 | t.Skip("missing AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY") 26 | } 27 | 28 | credsConfig := &CredentialsConfig{ 29 | AccessKey: rootKey, 30 | SecretKey: rootSecretKey, 31 | SessionToken: sessionToken, 32 | } 33 | 34 | username := os.Getenv("AWS_USERNAME") 35 | if username == "" { 36 | username = "aws-iam-kms-testing" 37 | } 38 | 39 | // Create an initial key 40 | out, err := credsConfig.CreateAccessKey(WithUsername(username), WithValidityCheckTimeout(testRotationWaitTimeout)) 41 | require.NoError(err) 42 | require.NotNil(out) 43 | 44 | cleanupKey := out.AccessKey.AccessKeyId 45 | 46 | defer func() { 47 | assert.NoError(credsConfig.DeleteAccessKey(*cleanupKey, WithUsername(username))) 48 | }() 49 | 50 | // Run rotation 51 | accessKey, secretKey := *out.AccessKey.AccessKeyId, *out.AccessKey.SecretAccessKey 52 | c, err := NewCredentialsConfig( 53 | WithAccessKey(accessKey), 54 | WithSecretKey(secretKey), 55 | ) 56 | require.NoError(err) 57 | require.NoError(c.RotateKeys(WithValidityCheckTimeout(testRotationWaitTimeout))) 58 | assert.NotEqual(accessKey, c.AccessKey) 59 | assert.NotEqual(secretKey, c.SecretKey) 60 | cleanupKey = &c.AccessKey 61 | } 62 | 63 | func TestCallerIdentity(t *testing.T) { 64 | require, assert := require.New(t), assert.New(t) 65 | 66 | key, secretKey, sessionToken := os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), os.Getenv("AWS_SESSION_TOKEN") 67 | if key == "" || secretKey == "" { 68 | t.Skip("missing AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY") 69 | } 70 | 71 | c := &CredentialsConfig{ 72 | AccessKey: key, 73 | SecretKey: secretKey, 74 | SessionToken: sessionToken, 75 | } 76 | 77 | cid, err := c.GetCallerIdentity() 78 | require.NoError(err) 79 | assert.NotEmpty(cid.Account) 80 | assert.NotEmpty(cid.Arn) 81 | assert.NotEmpty(cid.UserId) 82 | } 83 | 84 | func TestCallerIdentityWithSession(t *testing.T) { 85 | require, assert := require.New(t), assert.New(t) 86 | 87 | key, secretKey, sessionToken := os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), os.Getenv("AWS_SESSION_TOKEN") 88 | if key == "" || secretKey == "" { 89 | t.Skip("missing AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY") 90 | } 91 | 92 | c := &CredentialsConfig{ 93 | AccessKey: key, 94 | SecretKey: secretKey, 95 | SessionToken: sessionToken, 96 | } 97 | 98 | sess, err := c.GetSession() 99 | require.NoError(err) 100 | require.NotNil(sess) 101 | 102 | cid, err := c.GetCallerIdentity(WithAwsSession(sess)) 103 | require.NoError(err) 104 | assert.NotEmpty(cid.Account) 105 | assert.NotEmpty(cid.Arn) 106 | assert.NotEmpty(cid.UserId) 107 | } 108 | 109 | func TestCallerIdentityErrorNoTimeout(t *testing.T) { 110 | require := require.New(t) 111 | 112 | c := &CredentialsConfig{ 113 | AccessKey: "bad", 114 | SecretKey: "badagain", 115 | } 116 | 117 | _, err := c.GetCallerIdentity() 118 | require.NotNil(err) 119 | require.Implements((*awserr.Error)(nil), err) 120 | } 121 | 122 | func TestCallerIdentityErrorWithValidityCheckTimeout(t *testing.T) { 123 | require := require.New(t) 124 | 125 | c := &CredentialsConfig{ 126 | AccessKey: "bad", 127 | SecretKey: "badagain", 128 | } 129 | 130 | _, err := c.GetCallerIdentity(WithValidityCheckTimeout(time.Second * 10)) 131 | require.NotNil(err) 132 | require.True(strings.HasPrefix(err.Error(), "timeout after 10s waiting for success")) 133 | err = errors.Unwrap(err) 134 | require.NotNil(err) 135 | require.Implements((*awserr.Error)(nil), err) 136 | } 137 | 138 | func TestCallerIdentityWithSTSMockError(t *testing.T) { 139 | require := require.New(t) 140 | 141 | expectedErr := errors.New("this is the expected error") 142 | c, err := NewCredentialsConfig() 143 | require.NoError(err) 144 | _, err = c.GetCallerIdentity(WithSTSAPIFunc(NewMockSTS(WithGetCallerIdentityError(expectedErr)))) 145 | require.EqualError(err, expectedErr.Error()) 146 | } 147 | 148 | func TestCallerIdentityWithSTSMockNoErorr(t *testing.T) { 149 | require := require.New(t) 150 | 151 | expectedOut := &sts.GetCallerIdentityOutput{ 152 | Account: aws.String("1234567890"), 153 | Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), 154 | UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), 155 | } 156 | 157 | c, err := NewCredentialsConfig() 158 | require.NoError(err) 159 | out, err := c.GetCallerIdentity(WithSTSAPIFunc(NewMockSTS(WithGetCallerIdentityOutput(expectedOut)))) 160 | require.NoError(err) 161 | require.Equal(expectedOut, out) 162 | } 163 | 164 | func TestDeleteAccessKeyWithIAMMock(t *testing.T) { 165 | require := require.New(t) 166 | 167 | mockErr := errors.New("this is the expected error") 168 | expectedErr := "error deleting old access key: this is the expected error" 169 | c, err := NewCredentialsConfig() 170 | require.NoError(err) 171 | err = c.DeleteAccessKey("foobar", WithIAMAPIFunc(NewMockIAM(WithDeleteAccessKeyError(mockErr)))) 172 | require.EqualError(err, expectedErr) 173 | } 174 | 175 | func TestCreateAccessKeyWithIAMMockGetUserError(t *testing.T) { 176 | require := require.New(t) 177 | 178 | mockErr := errors.New("this is the expected error") 179 | expectedErr := "error calling aws.GetUser: this is the expected error" 180 | c, err := NewCredentialsConfig() 181 | require.NoError(err) 182 | _, err = c.CreateAccessKey(WithIAMAPIFunc(NewMockIAM(WithGetUserError(mockErr)))) 183 | require.EqualError(err, expectedErr) 184 | } 185 | 186 | func TestCreateAccessKeyWithIAMMockCreateAccessKeyError(t *testing.T) { 187 | require := require.New(t) 188 | 189 | mockErr := errors.New("this is the expected error") 190 | expectedErr := "error calling aws.CreateAccessKey: this is the expected error" 191 | c, err := NewCredentialsConfig() 192 | require.NoError(err) 193 | _, err = c.CreateAccessKey(WithIAMAPIFunc(NewMockIAM( 194 | WithGetUserOutput(&iam.GetUserOutput{ 195 | User: &iam.User{ 196 | UserName: aws.String("foobar"), 197 | }, 198 | }), 199 | WithCreateAccessKeyError(mockErr), 200 | ))) 201 | require.EqualError(err, expectedErr) 202 | } 203 | 204 | func TestCreateAccessKeyWithIAMAndSTSMockGetCallerIdentityError(t *testing.T) { 205 | require := require.New(t) 206 | 207 | mockErr := errors.New("this is the expected error") 208 | expectedErr := "error verifying new credentials: timeout after 1ns waiting for success: this is the expected error" 209 | c, err := NewCredentialsConfig() 210 | require.NoError(err) 211 | _, err = c.CreateAccessKey( 212 | WithValidityCheckTimeout(time.Nanosecond), 213 | WithIAMAPIFunc(NewMockIAM( 214 | WithGetUserOutput(&iam.GetUserOutput{ 215 | User: &iam.User{ 216 | UserName: aws.String("foobar"), 217 | }, 218 | }), 219 | WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ 220 | AccessKey: &iam.AccessKey{ 221 | AccessKeyId: aws.String("foobar"), 222 | SecretAccessKey: aws.String("bazqux"), 223 | }, 224 | }), 225 | )), 226 | WithSTSAPIFunc(NewMockSTS( 227 | WithGetCallerIdentityError(mockErr), 228 | )), 229 | ) 230 | require.EqualError(err, expectedErr) 231 | } 232 | 233 | func TestCreateAccessKeyNilResponse(t *testing.T) { 234 | require := require.New(t) 235 | 236 | expectedErr := "nil response from aws.CreateAccessKey" 237 | c, err := NewCredentialsConfig() 238 | require.NoError(err) 239 | _, err = c.CreateAccessKey( 240 | WithValidityCheckTimeout(time.Nanosecond), 241 | WithIAMAPIFunc(NewMockIAM( 242 | WithGetUserOutput(&iam.GetUserOutput{ 243 | User: &iam.User{ 244 | UserName: aws.String("foobar"), 245 | }, 246 | }), 247 | )), 248 | ) 249 | require.EqualError(err, expectedErr) 250 | } 251 | 252 | func TestRotateKeysWithMocks(t *testing.T) { 253 | mockErr := errors.New("this is the expected error") 254 | cases := []struct { 255 | name string 256 | mockIAMOpts []MockIAMOption 257 | mockSTSOpts []MockSTSOption 258 | require func(t *testing.T, actual *CredentialsConfig) 259 | requireErr string 260 | }{ 261 | { 262 | name: "CreateAccessKey IAM error", 263 | mockIAMOpts: []MockIAMOption{WithGetUserError(mockErr)}, 264 | requireErr: "error calling CreateAccessKey: error calling aws.GetUser: this is the expected error", 265 | }, 266 | { 267 | name: "CreateAccessKey STS error", 268 | mockIAMOpts: []MockIAMOption{ 269 | WithGetUserOutput(&iam.GetUserOutput{ 270 | User: &iam.User{ 271 | UserName: aws.String("foobar"), 272 | }, 273 | }), 274 | WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ 275 | AccessKey: &iam.AccessKey{ 276 | AccessKeyId: aws.String("foobar"), 277 | SecretAccessKey: aws.String("bazqux"), 278 | }, 279 | }), 280 | }, 281 | mockSTSOpts: []MockSTSOption{WithGetCallerIdentityError(mockErr)}, 282 | requireErr: "error calling CreateAccessKey: error verifying new credentials: timeout after 1ns waiting for success: this is the expected error", 283 | }, 284 | { 285 | name: "DeleteAccessKey IAM error", 286 | mockIAMOpts: []MockIAMOption{ 287 | WithGetUserOutput(&iam.GetUserOutput{ 288 | User: &iam.User{ 289 | UserName: aws.String("foobar"), 290 | }, 291 | }), 292 | WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ 293 | AccessKey: &iam.AccessKey{ 294 | AccessKeyId: aws.String("foobar"), 295 | SecretAccessKey: aws.String("bazqux"), 296 | UserName: aws.String("foouser"), 297 | }, 298 | }), 299 | // DeleteAccessKeyOutput w/o error is a no-op in the mock and 300 | // will return without additional stubbing 301 | }, 302 | mockSTSOpts: []MockSTSOption{WithGetCallerIdentityOutput(&sts.GetCallerIdentityOutput{})}, 303 | require: func(t *testing.T, actual *CredentialsConfig) { 304 | t.Helper() 305 | require := require.New(t) 306 | 307 | require.Equal("foobar", actual.AccessKey) 308 | require.Equal("bazqux", actual.SecretKey) 309 | }, 310 | }, 311 | } 312 | 313 | for _, tc := range cases { 314 | tc := tc 315 | t.Run(tc.name, func(t *testing.T) { 316 | require := require.New(t) 317 | c, err := NewCredentialsConfig( 318 | WithAccessKey("foo"), 319 | WithSecretKey("bar"), 320 | ) 321 | require.NoError(err) 322 | err = c.RotateKeys( 323 | WithIAMAPIFunc(NewMockIAM(tc.mockIAMOpts...)), 324 | WithSTSAPIFunc(NewMockSTS(tc.mockSTSOpts...)), 325 | WithValidityCheckTimeout(time.Nanosecond), 326 | ) 327 | if tc.requireErr != "" { 328 | require.EqualError(err, tc.requireErr) 329 | return 330 | } 331 | 332 | require.NoError(err) 333 | tc.require(t, c) 334 | }) 335 | } 336 | } 337 | -------------------------------------------------------------------------------- /base62/base62.go: -------------------------------------------------------------------------------- 1 | // Package base62 provides utilities for working with base62 strings. 2 | // base62 strings will only contain characters: 0-9, a-z, A-Z 3 | package base62 4 | 5 | import ( 6 | "crypto/rand" 7 | "io" 8 | 9 | uuid "github.com/hashicorp/go-uuid" 10 | ) 11 | 12 | const ( 13 | charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" 14 | csLen = byte(len(charset)) 15 | ) 16 | 17 | // MustRandom generates a random string using base-62 characters. Resulting 18 | // entropy is ~5.95 bits/character. If an error is encountered, MustRandom 19 | // panics. 20 | func MustRandom(length int) string { 21 | out, err := RandomWithReader(length, rand.Reader) 22 | if err != nil { 23 | panic(err) 24 | } 25 | return out 26 | } 27 | 28 | // Random generates a random string using base-62 characters. 29 | // Resulting entropy is ~5.95 bits/character. 30 | func Random(length int) (string, error) { 31 | return RandomWithReader(length, rand.Reader) 32 | } 33 | 34 | // RandomWithReader generates a random string using base-62 characters and a given reader. 35 | // Resulting entropy is ~5.95 bits/character. 36 | func RandomWithReader(length int, reader io.Reader) (string, error) { 37 | if length == 0 { 38 | return "", nil 39 | } 40 | output := make([]byte, 0, length) 41 | 42 | // Request a bit more than length to reduce the chance 43 | // of needing more than one batch of random bytes 44 | batchSize := length + length/4 45 | 46 | for { 47 | buf, err := uuid.GenerateRandomBytesWithReader(batchSize, reader) 48 | if err != nil { 49 | return "", err 50 | } 51 | 52 | for _, b := range buf { 53 | // Avoid bias by using a value range that's a multiple of 62 54 | if b < (csLen * 4) { 55 | output = append(output, charset[b%csLen]) 56 | 57 | if len(output) == length { 58 | return string(output), nil 59 | } 60 | } 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /base62/base62_test.go: -------------------------------------------------------------------------------- 1 | package base62 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestRandom(t *testing.T) { 8 | strings := make(map[string]struct{}) 9 | 10 | for i := 0; i < 100000; i++ { 11 | c, err := Random(16) 12 | if err != nil { 13 | t.Fatal(err) 14 | } 15 | if _, ok := strings[c]; ok { 16 | t.Fatalf("Unexpected duplicate string: %s", c) 17 | } 18 | strings[c] = struct{}{} 19 | 20 | } 21 | 22 | for i := 0; i < 3000; i++ { 23 | c, err := Random(i) 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | if len(c) != i { 28 | t.Fatalf("Expected length %d, got: %d", i, len(c)) 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /base62/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/base62 2 | 3 | go 1.16 4 | 5 | require github.com/hashicorp/go-uuid v1.0.2 6 | -------------------------------------------------------------------------------- /base62/go.sum: -------------------------------------------------------------------------------- 1 | github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= 2 | github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= 3 | -------------------------------------------------------------------------------- /configutil/Makefile: -------------------------------------------------------------------------------- 1 | PLUGIN_TMP_DIR := $(shell mktemp -d) 2 | 3 | test-plugin: 4 | go build -o "${PLUGIN_TMP_DIR}/aeadplugin" testplugins/aead/main.go 5 | PLUGIN_PATH="${PLUGIN_TMP_DIR}/aeadplugin" go test -v -run 'TestFilePlugin|TestConfigureWrapperPropagatesOptions' 6 | 7 | .PHONY: test-plugin 8 | -------------------------------------------------------------------------------- /configutil/config.go: -------------------------------------------------------------------------------- 1 | package configutil 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/ioutil" 7 | "time" 8 | 9 | "github.com/hashicorp/go-secure-stdlib/listenerutil" 10 | "github.com/hashicorp/go-secure-stdlib/parseutil" 11 | "github.com/hashicorp/hcl" 12 | "github.com/hashicorp/hcl/hcl/ast" 13 | ) 14 | 15 | // These two functions are overridden if metricsutil is invoked, but keep this 16 | // module from needing to depend on metricsutil and its various deps otherwise. 17 | // Import the metricsutil module, e.g. 18 | // 19 | // _ "github.com/hashicorp/go-secure-stdlib/metricsutil" 20 | // 21 | // in order to have telemetry be parsed. 22 | var ( 23 | ParseTelemetry = func(*ast.ObjectList) (interface{}, error) { return nil, nil } 24 | SanitizeTelemetry = func(interface{}) map[string]interface{} { return nil } 25 | ) 26 | 27 | // SharedConfig contains some shared values 28 | type SharedConfig struct { 29 | EntSharedConfig 30 | 31 | Listeners []*listenerutil.ListenerConfig `hcl:"-"` 32 | 33 | Seals []*KMS `hcl:"-"` 34 | Entropy *Entropy `hcl:"-"` 35 | 36 | DisableMlock bool `hcl:"-"` 37 | DisableMlockRaw interface{} `hcl:"disable_mlock"` 38 | 39 | Telemetry interface{} `hcl:"telemetry"` 40 | 41 | DefaultMaxRequestDuration time.Duration `hcl:"-"` 42 | DefaultMaxRequestDurationRaw interface{} `hcl:"default_max_request_duration"` 43 | 44 | // LogFormat specifies the log format. Valid values are "standard" and 45 | // "json". The values are case-insenstive. If no log format is specified, 46 | // then standard format will be used. 47 | LogFormat string `hcl:"log_format"` 48 | LogLevel string `hcl:"log_level"` 49 | 50 | PidFile string `hcl:"pid_file"` 51 | 52 | ClusterName string `hcl:"cluster_name"` 53 | } 54 | 55 | // LoadConfigFile loads the configuration from the given file. 56 | func LoadConfigFile(path string, opt ...Option) (*SharedConfig, error) { 57 | // Read the file 58 | d, err := ioutil.ReadFile(path) 59 | if err != nil { 60 | return nil, err 61 | } 62 | return ParseConfig(string(d), opt...) 63 | } 64 | 65 | func LoadConfigKMSes(path string, opt ...Option) ([]*KMS, error) { 66 | // Read the file 67 | d, err := ioutil.ReadFile(path) 68 | if err != nil { 69 | return nil, err 70 | } 71 | return ParseKMSes(string(d), opt...) 72 | } 73 | 74 | func ParseConfig(d string, opt ...Option) (*SharedConfig, error) { 75 | // Parse! 76 | obj, err := hcl.Parse(d) 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | // Start building the result 82 | var result SharedConfig 83 | if err := hcl.DecodeObject(&result, obj); err != nil { 84 | return nil, err 85 | } 86 | 87 | if result.DefaultMaxRequestDurationRaw != nil { 88 | if result.DefaultMaxRequestDuration, err = parseutil.ParseDurationSecond(result.DefaultMaxRequestDurationRaw); err != nil { 89 | return nil, err 90 | } 91 | result.DefaultMaxRequestDurationRaw = nil 92 | } 93 | 94 | if result.DisableMlockRaw != nil { 95 | if result.DisableMlock, err = parseutil.ParseBool(result.DisableMlockRaw); err != nil { 96 | return nil, err 97 | } 98 | result.DisableMlockRaw = nil 99 | } 100 | 101 | result.ClusterName, err = parseutil.ParsePath(result.ClusterName) 102 | if err != nil && !errors.Is(err, parseutil.ErrNotAUrl) { 103 | return nil, fmt.Errorf("error parsing cluster name: %w", err) 104 | } 105 | 106 | list, ok := obj.Node.(*ast.ObjectList) 107 | if !ok { 108 | return nil, fmt.Errorf("error parsing: file doesn't contain a root object") 109 | } 110 | 111 | if result.Seals, err = filterKMSes(list, opt...); err != nil { 112 | return nil, fmt.Errorf("error parsing kms information: %w", err) 113 | } 114 | 115 | if o := list.Filter("entropy"); len(o.Items) > 0 { 116 | if err := ParseEntropy(&result, o, "entropy"); err != nil { 117 | return nil, fmt.Errorf("error parsing 'entropy': %w", err) 118 | } 119 | } 120 | 121 | if o := list.Filter("listener"); len(o.Items) > 0 { 122 | l, err := listenerutil.ParseListeners(o) 123 | if err != nil { 124 | return nil, fmt.Errorf("error parsing 'listener': %w", err) 125 | } 126 | result.Listeners = l 127 | } 128 | 129 | if o := list.Filter("telemetry"); len(o.Items) > 0 { 130 | t, err := ParseTelemetry(o) 131 | if err != nil { 132 | return nil, fmt.Errorf("error parsing 'telemetry': %w", err) 133 | } 134 | result.Telemetry = t 135 | } 136 | 137 | entConfig := &(result.EntSharedConfig) 138 | if err := entConfig.ParseConfig(list); err != nil { 139 | return nil, fmt.Errorf("error parsing enterprise config: %w", err) 140 | } 141 | 142 | return &result, nil 143 | } 144 | 145 | // Sanitized returns a copy of the config with all values that are considered 146 | // sensitive stripped. It also strips all `*Raw` values that are mainly 147 | // used for parsing. 148 | // 149 | // Specifically, the fields that this method strips are: 150 | // - KMS.Config 151 | // - Telemetry.CirconusAPIToken 152 | func (c *SharedConfig) Sanitized() map[string]interface{} { 153 | if c == nil { 154 | return nil 155 | } 156 | 157 | result := map[string]interface{}{ 158 | "disable_mlock": c.DisableMlock, 159 | 160 | "default_max_request_duration": c.DefaultMaxRequestDuration, 161 | 162 | "log_level": c.LogLevel, 163 | "log_format": c.LogFormat, 164 | 165 | "pid_file": c.PidFile, 166 | 167 | "cluster_name": c.ClusterName, 168 | } 169 | 170 | // Sanitize listeners 171 | if len(c.Listeners) != 0 { 172 | var sanitizedListeners []interface{} 173 | for _, ln := range c.Listeners { 174 | cleanLn := map[string]interface{}{ 175 | "type": ln.Type, 176 | "config": ln.RawConfig, 177 | } 178 | sanitizedListeners = append(sanitizedListeners, cleanLn) 179 | } 180 | result["listeners"] = sanitizedListeners 181 | } 182 | 183 | // Sanitize seals stanza 184 | if len(c.Seals) != 0 { 185 | var sanitizedSeals []interface{} 186 | for _, s := range c.Seals { 187 | cleanSeal := map[string]interface{}{ 188 | "type": s.Type, 189 | "disabled": s.Disabled, 190 | } 191 | sanitizedSeals = append(sanitizedSeals, cleanSeal) 192 | } 193 | result["seals"] = sanitizedSeals 194 | } 195 | 196 | // Sanitize telemetry stanza 197 | if c.Telemetry != nil { 198 | result["telemetry"] = SanitizeTelemetry(c.Telemetry) 199 | } 200 | 201 | return result 202 | } 203 | -------------------------------------------------------------------------------- /configutil/config_test.go: -------------------------------------------------------------------------------- 1 | package configutil 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestParseConfig(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | in string 14 | stateFn func(t *testing.T) 15 | expSharedConfig *SharedConfig 16 | expErr bool 17 | expErrIs error 18 | expErrStr string 19 | }{ 20 | { 21 | name: "cluster name set directly", 22 | in: `cluster_name = "test-cluster"`, 23 | expSharedConfig: &SharedConfig{ClusterName: "test-cluster"}, 24 | expErr: false, 25 | }, 26 | { 27 | name: "cluster name set to environment variable", 28 | in: `cluster_name = "env://SHARED_CFG_CLUSTER_NAME"`, 29 | stateFn: func(t *testing.T) { 30 | t.Setenv("SHARED_CFG_CLUSTER_NAME", "test-cluster") 31 | }, 32 | expSharedConfig: &SharedConfig{ClusterName: "test-cluster"}, 33 | expErr: false, 34 | }, 35 | { 36 | name: "cluster name set to something that isn't a URL", 37 | in: `cluster_name = "test\x00cluster"`, 38 | expSharedConfig: &SharedConfig{ClusterName: "test\x00cluster"}, 39 | expErr: false, 40 | }, 41 | { 42 | name: "cluster name ParsePath fail (missing file)", 43 | in: `cluster_name = "file://doesnt_exist_ck3iop2w"`, 44 | expSharedConfig: nil, 45 | expErr: true, 46 | expErrIs: os.ErrNotExist, 47 | }, 48 | } 49 | 50 | for _, tt := range tests { 51 | t.Run(tt.name, func(t *testing.T) { 52 | if tt.stateFn != nil { 53 | tt.stateFn(t) 54 | } 55 | 56 | sc, err := ParseConfig(tt.in) 57 | if tt.expErr { 58 | if tt.expErrIs != nil { 59 | require.ErrorIs(t, err, tt.expErrIs) 60 | } else { 61 | require.EqualError(t, err, tt.expErrStr) 62 | } 63 | require.Nil(t, sc) 64 | return 65 | } 66 | 67 | require.NoError(t, err) 68 | require.EqualValues(t, tt.expSharedConfig, sc) 69 | }) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /configutil/config_util.go: -------------------------------------------------------------------------------- 1 | // +build !enterprise 2 | 3 | package configutil 4 | 5 | import ( 6 | "github.com/hashicorp/hcl/hcl/ast" 7 | ) 8 | 9 | type EntSharedConfig struct { 10 | } 11 | 12 | func (ec *EntSharedConfig) ParseConfig(list *ast.ObjectList) error { 13 | return nil 14 | } 15 | 16 | func ParseEntropy(result *SharedConfig, list *ast.ObjectList, blockName string) error { 17 | return nil 18 | } 19 | -------------------------------------------------------------------------------- /configutil/encrypt_decrypt.go: -------------------------------------------------------------------------------- 1 | package configutil 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/base64" 7 | "errors" 8 | "fmt" 9 | "regexp" 10 | 11 | wrapping "github.com/hashicorp/go-kms-wrapping/v2" 12 | "google.golang.org/protobuf/proto" 13 | ) 14 | 15 | var ( 16 | encryptRegex = regexp.MustCompile(`{{encrypt\(.*\)}}`) 17 | decryptRegex = regexp.MustCompile(`{{decrypt\(.*\)}}`) 18 | ) 19 | 20 | func EncryptDecrypt(rawStr string, decrypt, strip bool, wrapper wrapping.Wrapper) (string, error) { 21 | var locs [][]int 22 | raw := []byte(rawStr) 23 | searchVal := "{{encrypt(" 24 | replaceVal := "{{decrypt(" 25 | suffixVal := ")}}" 26 | if decrypt { 27 | searchVal = "{{decrypt(" 28 | replaceVal = "{{encrypt(" 29 | locs = decryptRegex.FindAllIndex(raw, -1) 30 | } else { 31 | locs = encryptRegex.FindAllIndex(raw, -1) 32 | } 33 | if strip { 34 | replaceVal = "" 35 | suffixVal = "" 36 | } 37 | 38 | out := make([]byte, 0, len(rawStr)*2) 39 | var prevMaxLoc int 40 | for _, match := range locs { 41 | if len(match) != 2 { 42 | return "", fmt.Errorf("expected two values for match, got %d", len(match)) 43 | } 44 | 45 | // Append everything from the end of the last match to the beginning of this one 46 | out = append(out, raw[prevMaxLoc:match[0]]...) 47 | 48 | // Transform. First pull off the suffix/prefix 49 | matchBytes := raw[match[0]:match[1]] 50 | matchBytes = bytes.TrimSuffix(bytes.TrimPrefix(matchBytes, []byte(searchVal)), []byte(")}}")) 51 | var finalVal string 52 | 53 | // Now encrypt or decrypt 54 | switch decrypt { 55 | case false: 56 | outBlob, err := wrapper.Encrypt(context.Background(), matchBytes, nil) 57 | if err != nil { 58 | return "", fmt.Errorf("error encrypting parameter: %w", err) 59 | } 60 | if outBlob == nil { 61 | return "", errors.New("nil value returned from encrypting parameter") 62 | } 63 | outMsg, err := proto.Marshal(outBlob) 64 | if err != nil { 65 | return "", fmt.Errorf("error marshaling encrypted parameter: %w", err) 66 | } 67 | finalVal = base64.RawURLEncoding.EncodeToString(outMsg) 68 | 69 | default: 70 | inMsg, err := base64.RawURLEncoding.DecodeString(string(matchBytes)) 71 | if err != nil { 72 | return "", fmt.Errorf("error decoding encrypted parameter: %w", err) 73 | } 74 | inBlob := new(wrapping.BlobInfo) 75 | if err := proto.Unmarshal(inMsg, inBlob); err != nil { 76 | return "", fmt.Errorf("error unmarshaling encrypted parameter: %w", err) 77 | } 78 | dec, err := wrapper.Decrypt(context.Background(), inBlob, nil) 79 | if err != nil { 80 | return "", fmt.Errorf("error decrypting encrypted parameter: %w", err) 81 | } 82 | finalVal = string(dec) 83 | } 84 | 85 | // Append new value 86 | out = append(out, []byte(fmt.Sprintf("%s%s%s", replaceVal, finalVal, suffixVal))...) 87 | prevMaxLoc = match[1] 88 | } 89 | // At the end, append the rest 90 | out = append(out, raw[prevMaxLoc:]...) 91 | return string(out), nil 92 | } 93 | -------------------------------------------------------------------------------- /configutil/encrypt_decrypt_test.go: -------------------------------------------------------------------------------- 1 | package configutil 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/base64" 7 | "testing" 8 | 9 | wrapping "github.com/hashicorp/go-kms-wrapping/v2" 10 | "google.golang.org/protobuf/proto" 11 | ) 12 | 13 | func TestEncryptParams(t *testing.T) { 14 | rawStr := ` 15 | storage "consul" { 16 | api_key = "{{encrypt(foobar)}}" 17 | } 18 | 19 | telemetry { 20 | some_param = "something" 21 | circonus_api_key = "{{encrypt(barfoo)}}" 22 | } 23 | ` 24 | 25 | finalStr := ` 26 | storage "consul" { 27 | api_key = "foobar" 28 | } 29 | 30 | telemetry { 31 | some_param = "something" 32 | circonus_api_key = "barfoo" 33 | } 34 | ` 35 | 36 | reverser := new(reversingWrapper) 37 | out, err := EncryptDecrypt(rawStr, false, false, reverser) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | 42 | first := true 43 | locs := decryptRegex.FindAllIndex([]byte(out), -1) 44 | for _, match := range locs { 45 | matchBytes := []byte(out)[match[0]:match[1]] 46 | matchBytes = bytes.TrimSuffix(bytes.TrimPrefix(matchBytes, []byte("{{decrypt(")), []byte(")}}")) 47 | inMsg, err := base64.RawURLEncoding.DecodeString(string(matchBytes)) 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | inBlob := new(wrapping.BlobInfo) 52 | if err := proto.Unmarshal(inMsg, inBlob); err != nil { 53 | t.Fatal(err) 54 | } 55 | ct := string(inBlob.Ciphertext) 56 | if first { 57 | if ct != "raboof" { 58 | t.Fatal(ct) 59 | } 60 | first = false 61 | } else { 62 | if ct != "oofrab" { 63 | t.Fatal(ct) 64 | } 65 | } 66 | } 67 | 68 | decOut, err := EncryptDecrypt(out, true, false, reverser) 69 | if err != nil { 70 | t.Fatal(err) 71 | } 72 | 73 | if decOut != rawStr { 74 | t.Fatal(decOut) 75 | } 76 | 77 | decOut, err = EncryptDecrypt(out, true, true, reverser) 78 | if err != nil { 79 | t.Fatal(err) 80 | } 81 | 82 | if decOut != finalStr { 83 | t.Fatal(decOut) 84 | } 85 | } 86 | 87 | type reversingWrapper struct{} 88 | 89 | func (r *reversingWrapper) SetConfig(_ context.Context, _ ...wrapping.Option) (*wrapping.WrapperConfig, error) { 90 | return nil, nil 91 | } 92 | 93 | func (r *reversingWrapper) Type(_ context.Context) (wrapping.WrapperType, error) { 94 | return wrapping.WrapperType("reversing"), nil 95 | } 96 | func (r *reversingWrapper) KeyId(_ context.Context) (string, error) { return "reverser", nil } 97 | func (r *reversingWrapper) HmacKeyId() string { return "" } 98 | func (r *reversingWrapper) Init(_ context.Context) error { return nil } 99 | func (r *reversingWrapper) Finalize(_ context.Context) error { return nil } 100 | func (r *reversingWrapper) Encrypt(_ context.Context, input []byte, _ ...wrapping.Option) (*wrapping.BlobInfo, error) { 101 | return &wrapping.BlobInfo{ 102 | Ciphertext: r.reverse(input), 103 | }, nil 104 | } 105 | 106 | func (r *reversingWrapper) Decrypt(_ context.Context, input *wrapping.BlobInfo, _ ...wrapping.Option) ([]byte, error) { 107 | return r.reverse(input.Ciphertext), nil 108 | } 109 | 110 | func (r *reversingWrapper) reverse(input []byte) []byte { 111 | output := make([]byte, len(input)) 112 | for i, j := 0, len(input)-1; i < j; i, j = i+1, j-1 { 113 | output[i], output[j] = input[j], input[i] 114 | } 115 | return output 116 | } 117 | -------------------------------------------------------------------------------- /configutil/file_plugin_test.go: -------------------------------------------------------------------------------- 1 | package configutil 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "encoding/hex" 7 | "fmt" 8 | "os" 9 | "testing" 10 | 11 | wrapping "github.com/hashicorp/go-kms-wrapping/v2" 12 | "github.com/hashicorp/go-secure-stdlib/pluginutil/v2" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | "golang.org/x/crypto/sha3" 16 | ) 17 | 18 | func TestFilePlugin(t *testing.T) { 19 | ctx := context.Background() 20 | 21 | pluginPath := os.Getenv("PLUGIN_PATH") 22 | if pluginPath == "" { 23 | t.Skipf("skipping plugin test as no PLUGIN_PATH specified") 24 | } 25 | 26 | pluginBytes, err := os.ReadFile(pluginPath) 27 | require.NoError(t, err) 28 | 29 | sha2256Bytes := sha256.Sum256(pluginBytes) 30 | modifiedSha2 := sha256.Sum256(pluginBytes) 31 | modifiedSha2[0] = '0' 32 | modifiedSha2[1] = '0' 33 | sha3384Hash := sha3.New384() 34 | _, err = sha3384Hash.Write(pluginBytes) 35 | require.NoError(t, err) 36 | sha3384Bytes := sha3384Hash.Sum(nil) 37 | 38 | testCases := []struct { 39 | name string // name of the test 40 | pluginChecksum []byte // checksum to use 41 | pluginHashMethod pluginutil.HashMethod // hash method to use 42 | wantErrContains string // Error from the plugin process 43 | hacheSeeEll string // If set, will be parsed and used to populate values 44 | wantConfigErrContains string // Error from any set config 45 | }{ 46 | { 47 | name: "valid checksum", 48 | pluginChecksum: sha2256Bytes[:], 49 | pluginHashMethod: pluginutil.HashMethodSha2256, 50 | }, 51 | { 52 | name: "invalid checksum", 53 | pluginChecksum: modifiedSha2[:], 54 | pluginHashMethod: pluginutil.HashMethodSha2256, 55 | wantErrContains: "checksums did not match", 56 | }, 57 | { 58 | name: "valid checksum, other type", 59 | pluginChecksum: sha3384Bytes[:], 60 | pluginHashMethod: pluginutil.HashMethodSha3384, 61 | }, 62 | { 63 | name: "invalid hcl no checksum", 64 | hacheSeeEll: fmt.Sprintf(` 65 | kms "aead" { 66 | purpose = "root" 67 | aead_type = "aes-gcm" 68 | plugin_path = "%s" 69 | } 70 | `, pluginPath), 71 | wantConfigErrContains: "plugin_path specified but plugin_checksum empty", 72 | }, 73 | { 74 | name: "invalid hcl no path", 75 | hacheSeeEll: fmt.Sprintf(` 76 | kms "aead" { 77 | purpose = "root" 78 | aead_type = "aes-gcm" 79 | plugin_checksum = "%s" 80 | } 81 | `, hex.EncodeToString(sha2256Bytes[:])), 82 | wantConfigErrContains: "plugin_checksum specified but plugin_path empty", 83 | }, 84 | { 85 | name: "invalid hcl unknown hash method", 86 | hacheSeeEll: fmt.Sprintf(` 87 | kms "aead" { 88 | purpose = "root" 89 | aead_type = "aes-gcm" 90 | plugin_path = "%s" 91 | plugin_checksum = "%s" 92 | plugin_hash_method = "foobar" 93 | } 94 | `, pluginPath, hex.EncodeToString(sha2256Bytes[:])), 95 | wantErrContains: "unsupported hash method", 96 | }, 97 | { 98 | name: "valid hcl", 99 | hacheSeeEll: fmt.Sprintf(` 100 | kms "aead" { 101 | purpose = "root" 102 | aead_type = "aes-gcm" 103 | plugin_path = "%s" 104 | plugin_checksum = "%s" 105 | } 106 | `, pluginPath, hex.EncodeToString(sha2256Bytes[:])), 107 | }, 108 | { 109 | name: "valid hcl alternate checksum", 110 | hacheSeeEll: fmt.Sprintf(` 111 | kms "aead" { 112 | purpose = "root" 113 | aead_type = "aes-gcm" 114 | plugin_path = "%s" 115 | plugin_checksum = "%s" 116 | plugin_hash_method = "%s" 117 | } 118 | `, pluginPath, hex.EncodeToString(sha3384Bytes[:]), pluginutil.HashMethodSha3384), 119 | }, 120 | } 121 | for _, tc := range testCases { 122 | t.Run(tc.name, func(t *testing.T) { 123 | assert, require := assert.New(t), require.New(t) 124 | var kms *KMS 125 | var pluginOpts []pluginutil.Option 126 | switch tc.hacheSeeEll == "" { 127 | case true: 128 | kms = &KMS{ 129 | Type: string(wrapping.WrapperTypeAead), 130 | Purpose: []string{"foobar"}, 131 | } 132 | pluginOpts = append(pluginOpts, pluginutil.WithPluginFile( 133 | pluginutil.PluginFileInfo{ 134 | Name: "aead", 135 | Path: pluginPath, 136 | Checksum: tc.pluginChecksum, 137 | HashMethod: tc.pluginHashMethod, 138 | }), 139 | ) 140 | default: 141 | conf, err := ParseConfig(tc.hacheSeeEll) 142 | if tc.wantConfigErrContains != "" { 143 | require.Error(err) 144 | assert.Contains(err.Error(), tc.wantConfigErrContains) 145 | return 146 | } 147 | require.NoError(err) 148 | require.Len(conf.Seals, 1) 149 | kms = conf.Seals[0] 150 | } 151 | wrapper, cleanup, err := configureWrapper( 152 | ctx, 153 | kms, 154 | nil, 155 | nil, 156 | WithPluginOptions(pluginOpts...), 157 | ) 158 | if tc.wantErrContains != "" { 159 | require.Error(err) 160 | assert.Contains(err.Error(), tc.wantErrContains) 161 | return 162 | } 163 | require.NoError(err) 164 | assert.NotNil(wrapper) 165 | assert.NoError(cleanup()) 166 | }) 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /configutil/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/configutil/v2 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/hashicorp/go-hclog v1.1.0 7 | github.com/hashicorp/go-kms-wrapping/plugin/v2 v2.0.2 8 | github.com/hashicorp/go-kms-wrapping/v2 v2.0.4 9 | github.com/hashicorp/go-multierror v1.1.1 10 | github.com/hashicorp/go-plugin v1.4.3 11 | github.com/hashicorp/go-secure-stdlib/listenerutil v0.1.4 12 | github.com/hashicorp/go-secure-stdlib/parseutil v0.1.2 13 | github.com/hashicorp/go-secure-stdlib/pluginutil/v2 v2.0.2 14 | github.com/hashicorp/hcl v1.0.0 15 | github.com/stretchr/testify v1.7.0 16 | golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 17 | google.golang.org/protobuf v1.27.1 18 | ) 19 | 20 | require ( 21 | github.com/Masterminds/goutils v1.1.0 // indirect 22 | github.com/Masterminds/semver v1.5.0 // indirect 23 | github.com/Masterminds/sprig v2.22.0+incompatible // indirect 24 | github.com/armon/go-radix v1.0.0 // indirect 25 | github.com/bgentry/speakeasy v0.1.0 // indirect 26 | github.com/davecgh/go-spew v1.1.1 // indirect 27 | github.com/fatih/color v1.7.0 // indirect 28 | github.com/golang/protobuf v1.5.2 // indirect 29 | github.com/google/uuid v1.1.2 // indirect 30 | github.com/hashicorp/errwrap v1.1.0 // indirect 31 | github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 // indirect 32 | github.com/hashicorp/go-secure-stdlib/reloadutil v0.1.1 // indirect 33 | github.com/hashicorp/go-secure-stdlib/strutil v0.1.1 // indirect 34 | github.com/hashicorp/go-secure-stdlib/tlsutil v0.1.1 // indirect 35 | github.com/hashicorp/go-sockaddr v1.0.2 // indirect 36 | github.com/hashicorp/go-uuid v1.0.2 // indirect 37 | github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb // indirect 38 | github.com/huandu/xstrings v1.3.2 // indirect 39 | github.com/imdario/mergo v0.3.11 // indirect 40 | github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f // indirect 41 | github.com/mattn/go-colorable v0.1.6 // indirect 42 | github.com/mattn/go-isatty v0.0.12 // indirect 43 | github.com/mitchellh/cli v1.1.2 // indirect 44 | github.com/mitchellh/copystructure v1.0.0 // indirect 45 | github.com/mitchellh/go-testing-interface v1.0.0 // indirect 46 | github.com/mitchellh/mapstructure v1.4.1 // indirect 47 | github.com/mitchellh/reflectwalk v1.0.0 // indirect 48 | github.com/oklog/run v1.0.0 // indirect 49 | github.com/pmezard/go-difflib v1.0.0 // indirect 50 | github.com/posener/complete v1.1.1 // indirect 51 | github.com/rogpeppe/go-internal v1.8.1 // indirect 52 | github.com/ryanuber/go-glob v1.0.0 // indirect 53 | golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect 54 | golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a // indirect 55 | golang.org/x/text v0.3.7 // indirect 56 | google.golang.org/genproto v0.0.0-20220208230804-65c12eb4c068 // indirect 57 | google.golang.org/grpc v1.44.0 // indirect 58 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect 59 | ) 60 | -------------------------------------------------------------------------------- /configutil/kms_test.go: -------------------------------------------------------------------------------- 1 | package configutil 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "os" 7 | "testing" 8 | 9 | wrapping "github.com/hashicorp/go-kms-wrapping/v2" 10 | "github.com/hashicorp/go-secure-stdlib/pluginutil/v2" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestConfigureWrapperPropagatesOptions(t *testing.T) { 16 | pluginPath := os.Getenv("PLUGIN_PATH") 17 | if pluginPath == "" { 18 | t.Skipf("skipping plugin test as no PLUGIN_PATH specified") 19 | } 20 | assert, require := assert.New(t), require.New(t) 21 | ctx := context.Background() 22 | 23 | pluginBytes, err := os.ReadFile(pluginPath) 24 | require.NoError(err) 25 | sha2256Bytes := sha256.Sum256(pluginBytes) 26 | kms := &KMS{ 27 | Type: string(wrapping.WrapperTypeAead), 28 | Purpose: []string{"foobar"}, 29 | } 30 | tmpDir := t.TempDir() 31 | pluginOptions := []pluginutil.Option{ 32 | pluginutil.WithPluginExecutionDirectory(tmpDir), 33 | pluginutil.WithPluginFile( 34 | pluginutil.PluginFileInfo{ 35 | Name: "aead", 36 | Path: pluginPath, 37 | Checksum: sha2256Bytes[:], 38 | HashMethod: pluginutil.HashMethodSha2256, 39 | }), 40 | } 41 | wrapper, cleanup, err := configureWrapper(ctx, kms, nil, nil, WithPluginOptions(pluginOptions...)) 42 | require.NoError(err) 43 | require.NotNil(wrapper) 44 | require.NotNil(cleanup) 45 | t.Cleanup(func() { 46 | err := cleanup() 47 | require.NoError(err) 48 | }) 49 | files, err := os.ReadDir(tmpDir) 50 | require.NoError(err) 51 | require.Len(files, 1) 52 | assert.Equal("aeadplugin", files[0].Name()) 53 | blob, err := wrapper.Encrypt(ctx, []byte("secret")) 54 | require.NoError(err) 55 | decrypted, err := wrapper.Decrypt(ctx, blob) 56 | require.NoError(err) 57 | assert.EqualValues("secret", decrypted) 58 | } 59 | -------------------------------------------------------------------------------- /configutil/merge.go: -------------------------------------------------------------------------------- 1 | package configutil 2 | 3 | func (c *SharedConfig) Merge(c2 *SharedConfig) *SharedConfig { 4 | if c2 == nil { 5 | return c 6 | } 7 | 8 | result := new(SharedConfig) 9 | 10 | result.Listeners = append(result.Listeners, c.Listeners...) 11 | result.Listeners = append(result.Listeners, c2.Listeners...) 12 | 13 | result.Entropy = c.Entropy 14 | if c2.Entropy != nil { 15 | result.Entropy = c2.Entropy 16 | } 17 | 18 | result.Seals = append(result.Seals, c.Seals...) 19 | result.Seals = append(result.Seals, c2.Seals...) 20 | 21 | result.Telemetry = c.Telemetry 22 | if c2.Telemetry != nil { 23 | result.Telemetry = c2.Telemetry 24 | } 25 | 26 | result.DisableMlock = c.DisableMlock 27 | if c2.DisableMlock { 28 | result.DisableMlock = c2.DisableMlock 29 | } 30 | 31 | result.DefaultMaxRequestDuration = c.DefaultMaxRequestDuration 32 | if c2.DefaultMaxRequestDuration > result.DefaultMaxRequestDuration { 33 | result.DefaultMaxRequestDuration = c2.DefaultMaxRequestDuration 34 | } 35 | 36 | result.LogLevel = c.LogLevel 37 | if c2.LogLevel != "" { 38 | result.LogLevel = c2.LogLevel 39 | } 40 | 41 | result.LogFormat = c.LogFormat 42 | if c2.LogFormat != "" { 43 | result.LogFormat = c2.LogFormat 44 | } 45 | 46 | result.PidFile = c.PidFile 47 | if c2.PidFile != "" { 48 | result.PidFile = c2.PidFile 49 | } 50 | 51 | result.ClusterName = c.ClusterName 52 | if c2.ClusterName != "" { 53 | result.ClusterName = c2.ClusterName 54 | } 55 | 56 | return result 57 | } 58 | -------------------------------------------------------------------------------- /configutil/options.go: -------------------------------------------------------------------------------- 1 | package configutil 2 | 3 | import ( 4 | "github.com/hashicorp/go-hclog" 5 | "github.com/hashicorp/go-secure-stdlib/pluginutil/v2" 6 | ) 7 | 8 | // getOpts - iterate the inbound Options and return a struct 9 | func getOpts(opt ...Option) (*options, error) { 10 | opts := getDefaultOptions() 11 | for _, o := range opt { 12 | if o != nil { 13 | if err := o(&opts); err != nil { 14 | return nil, err 15 | } 16 | } 17 | } 18 | return &opts, nil 19 | } 20 | 21 | // Option - how Options are passed as arguments 22 | type Option func(*options) error 23 | 24 | // options = how options are represented 25 | type options struct { 26 | withPluginOptions []pluginutil.Option 27 | withMaxKmsBlocks int 28 | withLogger hclog.Logger 29 | } 30 | 31 | func getDefaultOptions() options { 32 | return options{} 33 | } 34 | 35 | // WithMaxKmsBlocks provides a maximum number of allowed kms(/seal/hsm) blocks. 36 | // Set negative for unlimited. 0 uses the lib default, which is currently 37 | // unlimited. 38 | func WithMaxKmsBlocks(blocks int) Option { 39 | return func(o *options) error { 40 | o.withMaxKmsBlocks = blocks 41 | return nil 42 | } 43 | } 44 | 45 | // WithPluginOptions allows providing plugin-related (as opposed to 46 | // configutil-related) options 47 | func WithPluginOptions(opts ...pluginutil.Option) Option { 48 | return func(o *options) error { 49 | o.withPluginOptions = append(o.withPluginOptions, opts...) 50 | return nil 51 | } 52 | } 53 | 54 | // WithLogger provides a way to override default logger for some purposes (e.g. 55 | // kms plugins) 56 | func WithLogger(logger hclog.Logger) Option { 57 | return func(o *options) error { 58 | o.withLogger = logger 59 | return nil 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /configutil/options_test.go: -------------------------------------------------------------------------------- 1 | package configutil 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/hashicorp/go-hclog" 7 | "github.com/hashicorp/go-secure-stdlib/pluginutil/v2" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func Test_GetOpts(t *testing.T) { 13 | t.Parallel() 14 | t.Run("nil", func(t *testing.T) { 15 | assert := assert.New(t) 16 | opts, err := getOpts(nil) 17 | assert.NoError(err) 18 | assert.NotNil(opts) 19 | }) 20 | t.Run("with-plugin-options", func(t *testing.T) { 21 | assert, require := assert.New(t), require.New(t) 22 | opts, err := getOpts() 23 | require.NoError(err) 24 | assert.Nil(opts.withPluginOptions) 25 | opts, err = getOpts( 26 | WithPluginOptions(pluginutil.WithPluginsMap(nil), pluginutil.WithSecureConfig(nil)), 27 | ) 28 | require.NoError(err) 29 | require.NotNil(opts) 30 | assert.Len(opts.withPluginOptions, 2) 31 | }) 32 | t.Run("with-max-kms-blocks", func(t *testing.T) { 33 | assert, require := assert.New(t), require.New(t) 34 | opts, err := getOpts() 35 | require.NoError(err) 36 | assert.Zero(opts.withMaxKmsBlocks) 37 | opts, err = getOpts(WithMaxKmsBlocks(2)) 38 | require.NoError(err) 39 | require.NotNil(opts) 40 | assert.Equal(2, opts.withMaxKmsBlocks) 41 | }) 42 | t.Run("with-logger", func(t *testing.T) { 43 | assert, require := assert.New(t), require.New(t) 44 | opts, err := getOpts() 45 | require.NoError(err) 46 | assert.Nil(opts.withLogger) 47 | logger := hclog.Default() 48 | opts, err = getOpts(WithLogger(logger)) 49 | require.NoError(err) 50 | require.NotNil(opts) 51 | assert.Equal(logger, opts.withLogger) 52 | }) 53 | } 54 | -------------------------------------------------------------------------------- /configutil/testplugins/aead/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "fmt" 7 | "os" 8 | 9 | gkwp "github.com/hashicorp/go-kms-wrapping/plugin/v2" 10 | aead "github.com/hashicorp/go-kms-wrapping/v2/aead" 11 | ) 12 | 13 | func main() { 14 | block, err := aes.NewCipher([]byte("1234567890123456")) 15 | if err != nil { 16 | fmt.Println("Error creating AES block", err) 17 | os.Exit(1) 18 | } 19 | aeadCipher, err := cipher.NewGCM(block) 20 | if err != nil { 21 | fmt.Println("Error creating GCM cipher", err) 22 | os.Exit(1) 23 | } 24 | wrapper := aead.NewWrapper() 25 | wrapper.SetAead(aeadCipher) 26 | if err := gkwp.ServePlugin(wrapper); err != nil { 27 | fmt.Println("Error serving plugin", err) 28 | os.Exit(1) 29 | } 30 | os.Exit(0) 31 | } 32 | -------------------------------------------------------------------------------- /fileutil/caching_file_reader.go: -------------------------------------------------------------------------------- 1 | package fileutil 2 | 3 | import ( 4 | "os" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // CachingFileReader reads a file and keeps an in-memory copy of it, until the 10 | // copy is considered stale. Next ReadFile() after expiry will re-read the file from disk. 11 | type CachingFileReader struct { 12 | // path is the file path to the cached file. 13 | path string 14 | 15 | // ttl is the time-to-live duration when cached file is considered stale 16 | ttl time.Duration 17 | 18 | // cache is the buffer holding the in-memory copy of the file. 19 | cache cachedFile 20 | 21 | l sync.RWMutex 22 | 23 | // currentTime is a function that returns the current local time. 24 | // Normally set to time.Now but it can be overwritten by test cases to manipulate time. 25 | currentTime func() time.Time 26 | } 27 | 28 | type cachedFile struct { 29 | // buf is the buffer holding the in-memory copy of the file. 30 | buf []byte 31 | 32 | // expiry is the time when the cached copy is considered stale and must be re-read. 33 | expiry time.Time 34 | } 35 | 36 | func NewCachingFileReader(path string, ttl time.Duration) *CachingFileReader { 37 | return &CachingFileReader{ 38 | path: path, 39 | ttl: ttl, 40 | currentTime: time.Now, 41 | } 42 | } 43 | 44 | func (r *CachingFileReader) ReadFile() ([]byte, error) { 45 | // Fast path requiring read lock only: file is already in memory and not stale. 46 | r.l.RLock() 47 | now := r.currentTime() 48 | cache := r.cache 49 | r.l.RUnlock() 50 | if now.Before(cache.expiry) { 51 | newBuf := make([]byte, len(cache.buf)) 52 | copy(newBuf, cache.buf) 53 | return newBuf, nil 54 | } 55 | 56 | // Slow path: read the file from disk. 57 | r.l.Lock() 58 | defer r.l.Unlock() 59 | 60 | buf, err := os.ReadFile(r.path) 61 | if err != nil { 62 | return nil, err 63 | } 64 | r.cache = cachedFile{ 65 | buf: buf, 66 | expiry: r.currentTime().Add(r.ttl), 67 | } 68 | 69 | newBuf := make([]byte, len(r.cache.buf)) 70 | copy(newBuf, r.cache.buf) 71 | return newBuf, nil 72 | } 73 | 74 | func (r *CachingFileReader) setStaticTime(staticTime time.Time) { 75 | r.l.Lock() 76 | defer r.l.Unlock() 77 | r.currentTime = func() time.Time { 78 | return staticTime 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /fileutil/caching_file_reader_test.go: -------------------------------------------------------------------------------- 1 | package fileutil 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestCachingFileReader(t *testing.T) { 10 | content1 := []byte("before") 11 | content2 := []byte("after") 12 | 13 | // Create temporary file. 14 | f, err := os.CreateTemp("", "testfile") 15 | if err != nil { 16 | t.Error(err) 17 | } 18 | f.Close() 19 | defer os.Remove(f.Name()) 20 | 21 | r := NewCachingFileReader(f.Name(), 1*time.Minute) 22 | currentTime := time.Now() 23 | r.setStaticTime(currentTime) 24 | 25 | // Write initial content to file and check that we can read it. 26 | os.WriteFile(f.Name(), []byte(content1), 0o644) 27 | got, err := r.ReadFile() 28 | if err != nil { 29 | t.Error(err) 30 | } 31 | if string(got) != string(content1) { 32 | t.Errorf("got '%s', expected '%s'", got, content1) 33 | } 34 | 35 | // Write new content to the file. 36 | os.WriteFile(f.Name(), []byte(content2), 0o644) 37 | 38 | // Advance simulated time, but not enough for cache to expire. 39 | currentTime = currentTime.Add(30 * time.Second) 40 | r.setStaticTime(currentTime) 41 | 42 | // Read again and check we still got the old cached content. 43 | got, err = r.ReadFile() 44 | if err != nil { 45 | t.Error(err) 46 | } 47 | if string(got) != string(content1) { 48 | t.Errorf("got '%s', expected '%s'", got, content1) 49 | } 50 | 51 | // Advance simulated time for cache to expire. 52 | currentTime = currentTime.Add(30 * time.Second) 53 | r.setStaticTime(currentTime) 54 | 55 | // Read again and check that we got the new content. 56 | got, err = r.ReadFile() 57 | if err != nil { 58 | t.Error(err) 59 | } 60 | if string(got) != string(content2) { 61 | t.Errorf("got '%s', expected '%s'", got, content2) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /fileutil/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/fileutil 2 | 3 | go 1.16 4 | -------------------------------------------------------------------------------- /fileutil/go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Flyingon/go-secure-stdlib/7849be51188ffe09900bf3232e94695389cfc8aa/fileutil/go.sum -------------------------------------------------------------------------------- /gatedwriter/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/gatedwriter 2 | 3 | go 1.16 4 | -------------------------------------------------------------------------------- /gatedwriter/writer.go: -------------------------------------------------------------------------------- 1 | package gatedwriter 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "sync" 7 | ) 8 | 9 | // Writer is an io.Writer implementation that buffers all of its 10 | // data into an internal buffer until it is told to let data through. 11 | type Writer struct { 12 | writer io.Writer 13 | 14 | buf bytes.Buffer 15 | flush bool 16 | lock sync.Mutex 17 | } 18 | 19 | func NewWriter(underlying io.Writer) *Writer { 20 | return &Writer{writer: underlying} 21 | } 22 | 23 | // Flush tells the Writer to flush any buffered data and to stop 24 | // buffering. 25 | func (w *Writer) Flush() error { 26 | w.lock.Lock() 27 | defer w.lock.Unlock() 28 | 29 | w.flush = true 30 | _, err := w.buf.WriteTo(w.writer) 31 | return err 32 | } 33 | 34 | func (w *Writer) Write(p []byte) (n int, err error) { 35 | w.lock.Lock() 36 | defer w.lock.Unlock() 37 | 38 | if w.flush { 39 | return w.writer.Write(p) 40 | } 41 | 42 | return w.buf.Write(p) 43 | } 44 | -------------------------------------------------------------------------------- /gatedwriter/writer_test.go: -------------------------------------------------------------------------------- 1 | package gatedwriter 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "testing" 7 | ) 8 | 9 | func TestWriter_impl(t *testing.T) { 10 | var _ io.Writer = new(Writer) 11 | } 12 | 13 | func TestWriter(t *testing.T) { 14 | buf := new(bytes.Buffer) 15 | w := NewWriter(buf) 16 | w.Write([]byte("foo\n")) 17 | w.Write([]byte("bar\n")) 18 | 19 | if buf.String() != "" { 20 | t.Fatalf("bad: %s", buf.String()) 21 | } 22 | 23 | w.Flush() 24 | 25 | if buf.String() != "foo\nbar\n" { 26 | t.Fatalf("bad: %s", buf.String()) 27 | } 28 | 29 | w.Write([]byte("baz\n")) 30 | 31 | if buf.String() != "foo\nbar\nbaz\n" { 32 | t.Fatalf("bad: %s", buf.String()) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /kv-builder/builder.go: -------------------------------------------------------------------------------- 1 | package kvbuilder 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "os" 10 | "strings" 11 | 12 | "github.com/mitchellh/mapstructure" 13 | ) 14 | 15 | // Builder is a struct to build a key/value mapping based on a list 16 | // of "k=v" pairs, where the value might come from stdin, a file, etc. 17 | type Builder struct { 18 | Stdin io.Reader 19 | 20 | result map[string]interface{} 21 | stdin bool 22 | } 23 | 24 | // Map returns the built map. 25 | func (b *Builder) Map() map[string]interface{} { 26 | return b.result 27 | } 28 | 29 | // Add adds to the mapping with the given args. 30 | func (b *Builder) Add(args ...string) error { 31 | for _, a := range args { 32 | if err := b.add(a); err != nil { 33 | return fmt.Errorf("invalid key/value pair %q: %w", a, err) 34 | } 35 | } 36 | 37 | return nil 38 | } 39 | 40 | func (b *Builder) add(raw string) error { 41 | // Regardless of validity, make sure we make our result 42 | if b.result == nil { 43 | b.result = make(map[string]interface{}) 44 | } 45 | 46 | // Empty strings are fine, just ignored 47 | if raw == "" { 48 | return nil 49 | } 50 | 51 | // Split into key/value 52 | parts := strings.SplitN(raw, "=", 2) 53 | 54 | // If the arg is exactly "-", then we need to read from stdin 55 | // and merge the results into the resulting structure. 56 | if len(parts) == 1 { 57 | if raw == "-" { 58 | if b.Stdin == nil { 59 | return fmt.Errorf("stdin is not supported") 60 | } 61 | if b.stdin { 62 | return fmt.Errorf("stdin already consumed") 63 | } 64 | 65 | b.stdin = true 66 | return b.addReader(b.Stdin) 67 | } 68 | 69 | // If the arg begins with "@" then we need to read a file directly 70 | if raw[0] == '@' { 71 | f, err := os.Open(raw[1:]) 72 | if err != nil { 73 | return err 74 | } 75 | defer f.Close() 76 | 77 | return b.addReader(f) 78 | } 79 | } 80 | 81 | if len(parts) != 2 { 82 | return fmt.Errorf("format must be key=value") 83 | } 84 | key, value := parts[0], parts[1] 85 | 86 | if len(value) > 0 { 87 | if value[0] == '@' { 88 | contents, err := ioutil.ReadFile(value[1:]) 89 | if err != nil { 90 | return fmt.Errorf("error reading file: %w", err) 91 | } 92 | 93 | value = string(contents) 94 | } else if len(value) >= 2 && value[0] == '\\' && value[1] == '@' { 95 | value = value[1:] 96 | } else if value == "-" { 97 | if b.Stdin == nil { 98 | return fmt.Errorf("stdin is not supported") 99 | } 100 | if b.stdin { 101 | return fmt.Errorf("stdin already consumed") 102 | } 103 | b.stdin = true 104 | 105 | var buf bytes.Buffer 106 | if _, err := io.Copy(&buf, b.Stdin); err != nil { 107 | return err 108 | } 109 | 110 | value = buf.String() 111 | } 112 | } 113 | 114 | // Repeated keys will be converted into a slice 115 | if existingValue, ok := b.result[key]; ok { 116 | var sliceValue []interface{} 117 | if err := mapstructure.WeakDecode(existingValue, &sliceValue); err != nil { 118 | return err 119 | } 120 | sliceValue = append(sliceValue, value) 121 | b.result[key] = sliceValue 122 | return nil 123 | } 124 | 125 | b.result[key] = value 126 | return nil 127 | } 128 | 129 | func (b *Builder) addReader(r io.Reader) error { 130 | if r == nil { 131 | return fmt.Errorf("'io.Reader' being decoded is nil") 132 | } 133 | 134 | dec := json.NewDecoder(r) 135 | // While decoding JSON values, interpret the integer values as 136 | // `json.Number`s instead of `float64`. 137 | dec.UseNumber() 138 | 139 | return dec.Decode(&b.result) 140 | } 141 | -------------------------------------------------------------------------------- /kv-builder/builder_test.go: -------------------------------------------------------------------------------- 1 | package kvbuilder 2 | 3 | import ( 4 | "bytes" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestBuilder_basic(t *testing.T) { 10 | var b Builder 11 | err := b.Add("foo=bar", "bar=baz", "baz=") 12 | if err != nil { 13 | t.Fatalf("err: %s", err) 14 | } 15 | 16 | expected := map[string]interface{}{ 17 | "foo": "bar", 18 | "bar": "baz", 19 | "baz": "", 20 | } 21 | actual := b.Map() 22 | if !reflect.DeepEqual(actual, expected) { 23 | t.Fatalf("bad: %#v", actual) 24 | } 25 | } 26 | 27 | func TestBuilder_escapedAt(t *testing.T) { 28 | var b Builder 29 | err := b.Add("foo=bar", "bar=\\@baz") 30 | if err != nil { 31 | t.Fatalf("err: %s", err) 32 | } 33 | 34 | expected := map[string]interface{}{ 35 | "foo": "bar", 36 | "bar": "@baz", 37 | } 38 | actual := b.Map() 39 | if !reflect.DeepEqual(actual, expected) { 40 | t.Fatalf("bad: %#v", actual) 41 | } 42 | } 43 | 44 | func TestBuilder_singleBackslash(t *testing.T) { 45 | var b Builder 46 | err := b.Add("foo=bar", "bar=\\") 47 | if err != nil { 48 | t.Fatalf("err: %s", err) 49 | } 50 | 51 | expected := map[string]interface{}{ 52 | "foo": "bar", 53 | "bar": "\\", 54 | } 55 | actual := b.Map() 56 | if !reflect.DeepEqual(actual, expected) { 57 | t.Fatalf("bad: %#v", actual) 58 | } 59 | } 60 | 61 | func TestBuilder_stdin(t *testing.T) { 62 | var b Builder 63 | b.Stdin = bytes.NewBufferString("baz") 64 | err := b.Add("foo=bar", "bar=-") 65 | if err != nil { 66 | t.Fatalf("err: %s", err) 67 | } 68 | 69 | expected := map[string]interface{}{ 70 | "foo": "bar", 71 | "bar": "baz", 72 | } 73 | actual := b.Map() 74 | if !reflect.DeepEqual(actual, expected) { 75 | t.Fatalf("bad: %#v", actual) 76 | } 77 | } 78 | 79 | func TestBuilder_stdinMap(t *testing.T) { 80 | var b Builder 81 | b.Stdin = bytes.NewBufferString(`{"foo": "bar"}`) 82 | err := b.Add("-", "bar=baz") 83 | if err != nil { 84 | t.Fatalf("err: %s", err) 85 | } 86 | 87 | expected := map[string]interface{}{ 88 | "foo": "bar", 89 | "bar": "baz", 90 | } 91 | actual := b.Map() 92 | if !reflect.DeepEqual(actual, expected) { 93 | t.Fatalf("bad: %#v", actual) 94 | } 95 | } 96 | 97 | func TestBuilder_stdinTwice(t *testing.T) { 98 | var b Builder 99 | b.Stdin = bytes.NewBufferString(`{"foo": "bar"}`) 100 | err := b.Add("-", "-") 101 | if err == nil { 102 | t.Fatal("should error") 103 | } 104 | } 105 | 106 | func TestBuilder_sameKeyTwice(t *testing.T) { 107 | var b Builder 108 | err := b.Add("foo=bar", "foo=baz") 109 | if err != nil { 110 | t.Fatalf("err: %s", err) 111 | } 112 | 113 | expected := map[string]interface{}{ 114 | "foo": []interface{}{"bar", "baz"}, 115 | } 116 | actual := b.Map() 117 | if !reflect.DeepEqual(actual, expected) { 118 | t.Fatalf("bad: %#v", actual) 119 | } 120 | } 121 | 122 | func TestBuilder_sameKeyMultipleTimes(t *testing.T) { 123 | var b Builder 124 | err := b.Add("foo=bar", "foo=baz", "foo=bay", "foo=bax", "bar=baz") 125 | if err != nil { 126 | t.Fatalf("err: %s", err) 127 | } 128 | 129 | expected := map[string]interface{}{ 130 | "foo": []interface{}{"bar", "baz", "bay", "bax"}, 131 | "bar": "baz", 132 | } 133 | actual := b.Map() 134 | if !reflect.DeepEqual(actual, expected) { 135 | t.Fatalf("bad: %#v", actual) 136 | } 137 | } 138 | 139 | func TestBuilder_specialCharactersInKey(t *testing.T) { 140 | var b Builder 141 | b.Stdin = bytes.NewBufferString("{\"foo\": \"bay\"}") 142 | err := b.Add("@foo=bar", "-foo=baz", "-") 143 | if err != nil { 144 | t.Fatalf("err: %s", err) 145 | } 146 | 147 | expected := map[string]interface{}{ 148 | "@foo": "bar", 149 | "-foo": "baz", 150 | "foo": "bay", 151 | } 152 | actual := b.Map() 153 | if !reflect.DeepEqual(actual, expected) { 154 | t.Fatalf("bad: %#v", actual) 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /kv-builder/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/kv-builder 2 | 3 | go 1.16 4 | 5 | require github.com/mitchellh/mapstructure v1.4.1 6 | -------------------------------------------------------------------------------- /kv-builder/go.sum: -------------------------------------------------------------------------------- 1 | github.com/mitchellh/mapstructure v1.4.1 h1:CpVNEelQCZBooIPDn+AR3NpivK/TIKU8bDxdASFVQag= 2 | github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= 3 | -------------------------------------------------------------------------------- /listenerutil/error.go: -------------------------------------------------------------------------------- 1 | package listenerutil 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | var ( 8 | ErrInvalidParameter = errors.New("invalid parameter") 9 | ) 10 | -------------------------------------------------------------------------------- /listenerutil/forwarded_for.go: -------------------------------------------------------------------------------- 1 | package listenerutil 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | "net/textproto" 9 | "strings" 10 | 11 | "github.com/hashicorp/go-sockaddr" 12 | ) 13 | 14 | type key int 15 | 16 | const ( 17 | remoteAddrKey key = iota 18 | 19 | missingPortErrStr = "missing port in address" 20 | ) 21 | 22 | // ErrResponseFn provides a func to call whenever WrapForwardedForHandler 23 | // encounters an error 24 | type ErrResponseFn func(w http.ResponseWriter, status int, err error) 25 | 26 | // WrapForwaredForHandler is an http middleware handler which uses the 27 | // XForwardedFor* listener config settings to determine how/if X-Forwarded-For 28 | // are trusted/allowed for an inbound request. In the end, if a "trusted" 29 | // X-Forwarded-For header is found, then the request RemoteAddr will be 30 | // overwritten with it before the request is served. 31 | func WrapForwardedForHandler(h http.Handler, l *ListenerConfig, respErrFn ErrResponseFn) (http.Handler, error) { 32 | if h == nil { 33 | return nil, fmt.Errorf("missing http handler: %w", ErrInvalidParameter) 34 | } 35 | if l == nil { 36 | return nil, fmt.Errorf("missing listener config: %w", ErrInvalidParameter) 37 | } 38 | if respErrFn == nil { 39 | return nil, fmt.Errorf("missing response error function: %w", ErrInvalidParameter) 40 | } 41 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 42 | 43 | trusted, remoteAddr, err := TrustedFromXForwardedFor(r, l) 44 | if err != nil { 45 | respErrFn(w, http.StatusBadRequest, err) 46 | return 47 | } 48 | if trusted == nil || remoteAddr == nil { 49 | h.ServeHTTP(w, r) 50 | return 51 | } 52 | newCtx, err := newOrigRemoteAddrCtx(r.Context(), r.RemoteAddr) 53 | if err != nil { 54 | respErrFn(w, http.StatusBadRequest, fmt.Errorf("error setting orig remote header ctx: %w", err)) 55 | return 56 | } 57 | r = r.WithContext(newCtx) 58 | switch { 59 | case trusted.Port != "": 60 | r.RemoteAddr = net.JoinHostPort(trusted.Host, trusted.Port) 61 | default: 62 | // setting remote address to a combination is a bit different, but 63 | // it's needed to satisfies the requirement that remote addr always 64 | // have a port which is likely relied upon by downstream callers in 65 | // the call chain. 66 | // 67 | // this is intentionally the default since it's very likely the 68 | // "trusted" address will not have a port making this the most 69 | // likely execution path 70 | r.RemoteAddr = net.JoinHostPort(trusted.Host, remoteAddr.Port) 71 | } 72 | h.ServeHTTP(w, r) 73 | return 74 | }), nil 75 | } 76 | 77 | // Addr represents only the Host and Port of a TCP address. 78 | type Addr struct { 79 | Host string 80 | Port string 81 | } 82 | 83 | // TrustedFromXForwardedFor will use the XForwardedFor* listener config settings 84 | // to determine how/if X-Forwarded-For are trusted/allowed for an inbound 85 | // request. Important: return values of nil, nil, nil are valid and simply 86 | // means that no "trusted" header was found and no error was raised as well. 87 | // Errors can be raised for a number of conditions based on the listener config 88 | // settings, especially when the config setting for 89 | // XForwardedForRejectNotPresent is set to true which means if a "trusted" 90 | // header can't be found the request should be rejected. 91 | func TrustedFromXForwardedFor(r *http.Request, l *ListenerConfig) (trustedAddress *Addr, remoteAddress *Addr, e error) { 92 | if r == nil { 93 | return nil, nil, fmt.Errorf("missing http request: %w", ErrInvalidParameter) 94 | } 95 | if l == nil { 96 | return nil, nil, fmt.Errorf("missing listener config: %w", ErrInvalidParameter) 97 | } 98 | rejectNotPresent := l.XForwardedForRejectNotPresent 99 | hopSkips := l.XForwardedForHopSkips 100 | authorizedAddrs := l.XForwardedForAuthorizedAddrs 101 | rejectNotAuthz := l.XForwardedForRejectNotAuthorized 102 | 103 | headers, headersOK := r.Header[textproto.CanonicalMIMEHeaderKey("X-Forwarded-For")] 104 | if !headersOK || len(headers) == 0 { 105 | if !rejectNotPresent { 106 | return nil, nil, nil 107 | } 108 | return nil, nil, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present") 109 | } 110 | 111 | // http request remote address will always have a remoteAddrHost:port 112 | // (see: 113 | // https://cs.opensource.google/go/go/+/refs/tags/go1.17.3:src/net/http/request.go;l=279-286) 114 | var remoteAddr Addr 115 | var err error 116 | remoteAddr.Host, remoteAddr.Port, err = net.SplitHostPort(r.RemoteAddr) 117 | if err != nil { 118 | // If not rejecting treat it like we just don't have a valid 119 | // header because we can't do a comparison against an address we 120 | // can't understand 121 | if !rejectNotPresent { 122 | return nil, nil, nil 123 | } 124 | return nil, nil, fmt.Errorf("error parsing client hostport: %w", err) 125 | } 126 | 127 | addr, err := sockaddr.NewIPAddr(remoteAddr.Host) 128 | if err != nil { 129 | // We treat this the same as the case above 130 | if !rejectNotPresent { 131 | return nil, nil, nil 132 | } 133 | return nil, nil, fmt.Errorf("error parsing client address: %w", err) 134 | } 135 | 136 | var found bool 137 | for _, authz := range authorizedAddrs { 138 | if authz.Contains(addr) { 139 | found = true 140 | break 141 | } 142 | } 143 | if !found { 144 | // If we didn't find it and aren't configured to reject, simply 145 | // don't trust it 146 | if !rejectNotAuthz { 147 | return nil, nil, nil 148 | } 149 | return nil, nil, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection") 150 | } 151 | 152 | // At this point we have at least one value and it's authorized 153 | 154 | // Split comma separated ones, which are common. This brings it in line 155 | // to the multiple-header case. 156 | var acc []*Addr 157 | for _, header := range headers { 158 | vals := strings.Split(header, ",") 159 | for _, v := range vals { 160 | // validate the header contains a valid IP 161 | v = strings.TrimSpace(v) 162 | h, p, err := net.SplitHostPort(v) 163 | switch { 164 | case err != nil && strings.Contains(err.Error(), missingPortErrStr): 165 | h = v 166 | case err != nil && !strings.Contains(err.Error(), missingPortErrStr): 167 | if !rejectNotPresent { 168 | return nil, nil, nil 169 | } 170 | return nil, nil, fmt.Errorf("error parsing client address host/port (%s) from header", v) 171 | } 172 | ip := net.ParseIP(h) 173 | if ip == nil { 174 | if !rejectNotPresent { 175 | return nil, nil, nil 176 | } 177 | return nil, nil, fmt.Errorf("error parsing client address (%s) from header", v) 178 | } 179 | acc = append(acc, &Addr{Host: h, Port: p}) 180 | } 181 | } 182 | 183 | indexToUse := int64(len(acc)) - 1 - hopSkips 184 | if indexToUse < 0 { 185 | // This is likely an error in either configuration or other 186 | // infrastructure. We could either deny the request, or we 187 | // could simply not trust the value. Denying the request is 188 | // "safer" since if this logic is configured at all there may 189 | // be an assumption it can always be trusted. Given that we can 190 | // deny accepting the request at all if it's not from an 191 | // authorized address, if we're at this point the address is 192 | // authorized (or we've turned off explicit rejection) and we 193 | // should assume that what comes in should be properly 194 | // formatted. 195 | return nil, nil, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers)) 196 | } 197 | 198 | return acc[indexToUse], &remoteAddr, nil 199 | } 200 | 201 | // newOrigRemoteAddrCtx will return a context containing a value for the 202 | // provided original remote address 203 | func newOrigRemoteAddrCtx(ctx context.Context, origRemoteAddr string) (context.Context, error) { 204 | const op = "event.NewRequestInfoContext" 205 | if ctx == nil { 206 | return nil, fmt.Errorf("%s: missing context: %w", op, ErrInvalidParameter) 207 | } 208 | if origRemoteAddr == "" { 209 | return nil, fmt.Errorf("%s: missing original remote address: %w", op, ErrInvalidParameter) 210 | } 211 | return context.WithValue(ctx, remoteAddrKey, origRemoteAddr), nil 212 | } 213 | 214 | // OrigRemoteAddrFromCtx attempts to get the original remote address value from 215 | // the context provided 216 | func OrigRemoteAddrFromCtx(ctx context.Context) (string, bool) { 217 | if ctx == nil { 218 | return "", false 219 | } 220 | orig, ok := ctx.Value(remoteAddrKey).(string) 221 | return orig, ok 222 | } 223 | -------------------------------------------------------------------------------- /listenerutil/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/listenerutil 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/armon/go-radix v1.0.0 // indirect 7 | github.com/hashicorp/errwrap v1.1.0 // indirect 8 | github.com/hashicorp/go-multierror v1.1.1 9 | github.com/hashicorp/go-secure-stdlib/parseutil v0.1.1 10 | github.com/hashicorp/go-secure-stdlib/reloadutil v0.1.1 11 | github.com/hashicorp/go-secure-stdlib/strutil v0.1.1 12 | github.com/hashicorp/go-secure-stdlib/tlsutil v0.1.1 13 | github.com/hashicorp/go-sockaddr v1.0.2 14 | github.com/hashicorp/hcl v1.0.0 15 | github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f 16 | github.com/mattn/go-colorable v0.1.6 // indirect 17 | github.com/mitchellh/cli v1.1.2 18 | github.com/stretchr/testify v1.7.0 19 | golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /listenerutil/go.sum: -------------------------------------------------------------------------------- 1 | github.com/Masterminds/goutils v1.1.0 h1:zukEsf/1JZwCMgHiK3GZftabmxiCw4apj3a28RPBiVg= 2 | github.com/Masterminds/goutils v1.1.0/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= 3 | github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= 4 | github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= 5 | github.com/Masterminds/sprig v2.22.0+incompatible h1:z4yfnGrZ7netVz+0EDJ0Wi+5VZCSYp4Z0m2dk6cEM60= 6 | github.com/Masterminds/sprig v2.22.0+incompatible/go.mod h1:y6hNFY5UBTIWBxnzTeuNhlNS5hqE0NB0E6fgfo2Br3o= 7 | github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= 8 | github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI= 9 | github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= 10 | github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY= 11 | github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= 12 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 13 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 14 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 15 | github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= 16 | github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= 17 | github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= 18 | github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 19 | github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 20 | github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= 21 | github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 22 | github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= 23 | github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= 24 | github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= 25 | github.com/hashicorp/go-secure-stdlib/parseutil v0.1.1 h1:78ki3QBevHwYrVxnyVeaEz+7WtifHhauYF23es/0KlI= 26 | github.com/hashicorp/go-secure-stdlib/parseutil v0.1.1/go.mod h1:QmrqtbKuxxSWTN3ETMPuB+VtEiBJ/A9XhoYGv8E1uD8= 27 | github.com/hashicorp/go-secure-stdlib/reloadutil v0.1.1 h1:SMGUnbpAcat8rIKHkBPjfv81yC46a8eCNZ2hsR2l1EI= 28 | github.com/hashicorp/go-secure-stdlib/reloadutil v0.1.1/go.mod h1:Ch/bf00Qnx77MZd49JRgHYqHQjtEmTgGU2faufpVZb0= 29 | github.com/hashicorp/go-secure-stdlib/strutil v0.1.1 h1:nd0HIW15E6FG1MsnArYaHfuw9C2zgzM8LxkG5Ty/788= 30 | github.com/hashicorp/go-secure-stdlib/strutil v0.1.1/go.mod h1:gKOamz3EwoIoJq7mlMIRBpVTAUn8qPCrEclOKKWhD3U= 31 | github.com/hashicorp/go-secure-stdlib/tlsutil v0.1.1 h1:Yc026VyMyIpq1UWRnakHRG01U8fJm+nEfEmjoAb00n8= 32 | github.com/hashicorp/go-secure-stdlib/tlsutil v0.1.1/go.mod h1:l8slYwnJA26yBz+ErHpp2IRCLr0vuOMGBORIz4rRiAs= 33 | github.com/hashicorp/go-sockaddr v1.0.2 h1:ztczhD1jLxIRjVejw8gFomI1BQZOe2WoVOu0SyteCQc= 34 | github.com/hashicorp/go-sockaddr v1.0.2/go.mod h1:rB4wwRAUzs07qva3c5SdrY/NEtAUjGlgmH/UkBUC97A= 35 | github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= 36 | github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= 37 | github.com/huandu/xstrings v1.3.2 h1:L18LIDzqlW6xN2rEkpdV8+oL/IXWJ1APd+vsdYy4Wdw= 38 | github.com/huandu/xstrings v1.3.2/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= 39 | github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA= 40 | github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= 41 | github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f h1:E87tDTVS5W65euzixn7clSzK66puSt1H4I5SC0EmHH4= 42 | github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f/go.mod h1:3J2qVK16Lq8V+wfiL2lPeDZ7UWMxk5LemerHa1p6N00= 43 | github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= 44 | github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= 45 | github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= 46 | github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= 47 | github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= 48 | github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= 49 | github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= 50 | github.com/mitchellh/cli v1.1.2 h1:PvH+lL2B7IQ101xQL63Of8yFS2y+aDlsFcsqNc+u/Kw= 51 | github.com/mitchellh/cli v1.1.2/go.mod h1:6iaV0fGdElS6dPBx0EApTxHrcWvmJphyh2n8YBLPPZ4= 52 | github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= 53 | github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= 54 | github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= 55 | github.com/mitchellh/mapstructure v1.4.1 h1:CpVNEelQCZBooIPDn+AR3NpivK/TIKU8bDxdASFVQag= 56 | github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= 57 | github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY= 58 | github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= 59 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 60 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 61 | github.com/posener/complete v1.1.1 h1:ccV59UEOTzVDnDUEFdT95ZzHVZ+5+158q8+SJb2QV5w= 62 | github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= 63 | github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= 64 | github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= 65 | github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= 66 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 67 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 68 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 69 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 70 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 71 | golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= 72 | golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 73 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 74 | golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 75 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 76 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 77 | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 78 | golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 79 | golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 h1:OjiUf46hAmXblsZdnoSXsEUSKU8r1UEzcL5RVZ4gO9Y= 80 | golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 81 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 82 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 83 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 84 | gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= 85 | gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 86 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 87 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 88 | -------------------------------------------------------------------------------- /listenerutil/listener.go: -------------------------------------------------------------------------------- 1 | package listenerutil 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "errors" 7 | "fmt" 8 | "io/ioutil" 9 | "net" 10 | "os" 11 | osuser "os/user" 12 | "strconv" 13 | 14 | "github.com/hashicorp/go-secure-stdlib/reloadutil" 15 | "github.com/hashicorp/go-secure-stdlib/tlsutil" 16 | "github.com/jefferai/isbadcipher" 17 | "github.com/mitchellh/cli" 18 | ) 19 | 20 | type Listener struct { 21 | net.Listener 22 | Config ListenerConfig 23 | } 24 | 25 | type UnixSocketsConfig struct { 26 | User string `hcl:"user"` 27 | Mode string `hcl:"mode"` 28 | Group string `hcl:"group"` 29 | } 30 | 31 | // rmListener is an implementation of net.Listener that forwards most 32 | // calls to the listener but also removes a file as part of the close. We 33 | // use this to cleanup the unix domain socket on close. 34 | type rmListener struct { 35 | net.Listener 36 | Path string 37 | } 38 | 39 | func (l *rmListener) Close() error { 40 | // Close the listener itself 41 | if err := l.Listener.Close(); err != nil { 42 | return err 43 | } 44 | 45 | // Remove the file 46 | return os.Remove(l.Path) 47 | } 48 | 49 | func UnixSocketListener(path string, unixSocketsConfig *UnixSocketsConfig) (net.Listener, error) { 50 | if err := os.Remove(path); err != nil && !os.IsNotExist(err) { 51 | return nil, fmt.Errorf("failed to remove socket file: %v", err) 52 | } 53 | 54 | ln, err := net.Listen("unix", path) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | if unixSocketsConfig != nil { 60 | err = setFilePermissions(path, unixSocketsConfig.User, unixSocketsConfig.Group, unixSocketsConfig.Mode) 61 | if err != nil { 62 | return nil, fmt.Errorf("failed to set file system permissions on the socket file: %s", err) 63 | } 64 | } 65 | 66 | // Wrap the listener in rmListener so that the Unix domain socket file is 67 | // removed on close. 68 | return &rmListener{ 69 | Listener: ln, 70 | Path: path, 71 | }, nil 72 | } 73 | 74 | func TLSConfig( 75 | l *ListenerConfig, 76 | props map[string]string, 77 | ui cli.Ui) (*tls.Config, reloadutil.ReloadFunc, error) { 78 | props["tls"] = "disabled" 79 | 80 | if l.TLSDisable { 81 | return nil, nil, nil 82 | } 83 | 84 | cg := reloadutil.NewCertificateGetter(l.TLSCertFile, l.TLSKeyFile, "") 85 | if err := cg.Reload(); err != nil { 86 | // We try the key without a passphrase first and if we get an incorrect 87 | // passphrase response, try again after prompting for a passphrase 88 | if errors.As(err, &x509.IncorrectPasswordError) { 89 | var passphrase string 90 | passphrase, err = ui.AskSecret(fmt.Sprintf("Enter passphrase for %s:", l.TLSKeyFile)) 91 | if err == nil { 92 | cg = reloadutil.NewCertificateGetter(l.TLSCertFile, l.TLSKeyFile, passphrase) 93 | if err = cg.Reload(); err == nil { 94 | goto PASSPHRASECORRECT 95 | } 96 | } 97 | } 98 | return nil, nil, fmt.Errorf("error loading TLS cert: %w", err) 99 | } 100 | 101 | PASSPHRASECORRECT: 102 | tlsConf := &tls.Config{ 103 | GetCertificate: cg.GetCertificate, 104 | NextProtos: []string{"h2", "http/1.1"}, 105 | ClientAuth: tls.RequestClientCert, 106 | PreferServerCipherSuites: l.TLSPreferServerCipherSuites, 107 | } 108 | 109 | if l.TLSMinVersion == "" { 110 | l.TLSMinVersion = "tls12" 111 | } 112 | 113 | if l.TLSMaxVersion == "" { 114 | l.TLSMaxVersion = "tls13" 115 | } 116 | 117 | var ok bool 118 | tlsConf.MinVersion, ok = tlsutil.TLSLookup[l.TLSMinVersion] 119 | if !ok { 120 | return nil, nil, fmt.Errorf("'tls_min_version' value %q not supported, please specify one of [tls10,tls11,tls12,tls13]", l.TLSMinVersion) 121 | } 122 | 123 | tlsConf.MaxVersion, ok = tlsutil.TLSLookup[l.TLSMaxVersion] 124 | if !ok { 125 | return nil, nil, fmt.Errorf("'tls_max_version' value %q not supported, please specify one of [tls10,tls11,tls12,tls13]", l.TLSMaxVersion) 126 | } 127 | 128 | if tlsConf.MaxVersion < tlsConf.MinVersion { 129 | return nil, nil, fmt.Errorf("'tls_max_version' must be greater than or equal to 'tls_min_version'") 130 | } 131 | 132 | if len(l.TLSCipherSuites) > 0 { 133 | // HTTP/2 with TLS 1.2 blacklists several cipher suites. 134 | // https://tools.ietf.org/html/rfc7540#appendix-A 135 | // 136 | // Since the CLI (net/http) automatically uses HTTP/2 with TLS 1.2, 137 | // we check here if all or some specified cipher suites are blacklisted. 138 | badCiphers := []string{} 139 | for _, cipher := range l.TLSCipherSuites { 140 | if isbadcipher.IsBadCipher(cipher) { 141 | // Get the name of the current cipher. 142 | cipherStr, err := tlsutil.GetCipherName(cipher) 143 | if err != nil { 144 | return nil, nil, fmt.Errorf("invalid value for 'tls_cipher_suites': %w", err) 145 | } 146 | badCiphers = append(badCiphers, cipherStr) 147 | } 148 | } 149 | if len(badCiphers) == len(l.TLSCipherSuites) { 150 | ui.Warn(`WARNING! All cipher suites defined by 'tls_cipher_suites' are blacklisted by the 151 | HTTP/2 specification. HTTP/2 communication with TLS 1.2 will not work as intended 152 | and Vault will be unavailable via the CLI. 153 | Please see https://tools.ietf.org/html/rfc7540#appendix-A for further information.`) 154 | } else if len(badCiphers) > 0 { 155 | ui.Warn(fmt.Sprintf(`WARNING! The following cipher suites defined by 'tls_cipher_suites' are 156 | blacklisted by the HTTP/2 specification: 157 | %v 158 | Please see https://tools.ietf.org/html/rfc7540#appendix-A for further information.`, badCiphers)) 159 | } 160 | tlsConf.CipherSuites = l.TLSCipherSuites 161 | } 162 | 163 | if l.TLSRequireAndVerifyClientCert { 164 | tlsConf.ClientAuth = tls.RequireAndVerifyClientCert 165 | if l.TLSClientCAFile != "" { 166 | caPool := x509.NewCertPool() 167 | data, err := ioutil.ReadFile(l.TLSClientCAFile) 168 | if err != nil { 169 | return nil, nil, fmt.Errorf("failed to read tls_client_ca_file: %w", err) 170 | } 171 | 172 | if !caPool.AppendCertsFromPEM(data) { 173 | return nil, nil, fmt.Errorf("failed to parse CA certificate in tls_client_ca_file") 174 | } 175 | tlsConf.ClientCAs = caPool 176 | } 177 | } 178 | 179 | if l.TLSDisableClientCerts { 180 | if l.TLSRequireAndVerifyClientCert { 181 | return nil, nil, fmt.Errorf("'tls_disable_client_certs' and 'tls_require_and_verify_client_cert' are mutually exclusive") 182 | } 183 | tlsConf.ClientAuth = tls.NoClientCert 184 | } 185 | 186 | props["tls"] = "enabled" 187 | return tlsConf, cg.Reload, nil 188 | } 189 | 190 | // setFilePermissions handles configuring ownership and permissions 191 | // settings on a given file. All permission/ownership settings are 192 | // optional. If no user or group is specified, the current user/group 193 | // will be used. Mode is optional, and has no default (the operation is 194 | // not performed if absent). User may be specified by name or ID, but 195 | // group may only be specified by ID. 196 | func setFilePermissions(path string, user, group, mode string) error { 197 | var err error 198 | uid, gid := os.Getuid(), os.Getgid() 199 | 200 | if user != "" { 201 | if uid, err = strconv.Atoi(user); err == nil { 202 | goto GROUP 203 | } 204 | 205 | // Try looking up the user by name 206 | u, err := osuser.Lookup(user) 207 | if err != nil { 208 | return fmt.Errorf("failed to look up user %q: %v", user, err) 209 | } 210 | uid, _ = strconv.Atoi(u.Uid) 211 | } 212 | 213 | GROUP: 214 | if group != "" { 215 | if gid, err = strconv.Atoi(group); err == nil { 216 | goto OWN 217 | } 218 | 219 | // Try looking up the user by name 220 | g, err := osuser.LookupGroup(group) 221 | if err != nil { 222 | return fmt.Errorf("failed to look up group %q: %v", user, err) 223 | } 224 | gid, _ = strconv.Atoi(g.Gid) 225 | } 226 | 227 | OWN: 228 | if err := os.Chown(path, uid, gid); err != nil { 229 | return fmt.Errorf("failed setting ownership to %d:%d on %q: %v", 230 | uid, gid, path, err) 231 | } 232 | 233 | if mode != "" { 234 | mode, err := strconv.ParseUint(mode, 8, 32) 235 | if err != nil { 236 | return fmt.Errorf("invalid mode specified: %v", mode) 237 | } 238 | if err := os.Chmod(path, os.FileMode(mode)); err != nil { 239 | return fmt.Errorf("failed setting permissions to %d on %q: %v", 240 | mode, path, err) 241 | } 242 | } 243 | 244 | return nil 245 | } 246 | -------------------------------------------------------------------------------- /listenerutil/listener_test.go: -------------------------------------------------------------------------------- 1 | package listenerutil 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | osuser "os/user" 7 | "strconv" 8 | "testing" 9 | ) 10 | 11 | func TestUnixSocketListener(t *testing.T) { 12 | t.Run("ids", func(t *testing.T) { 13 | socket, err := ioutil.TempFile("", "socket") 14 | if err != nil { 15 | t.Fatal(err) 16 | } 17 | defer os.Remove(socket.Name()) 18 | 19 | uid, gid := os.Getuid(), os.Getgid() 20 | 21 | u, err := osuser.LookupId(strconv.Itoa(uid)) 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | user := u.Username 26 | 27 | g, err := osuser.LookupGroupId(strconv.Itoa(gid)) 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | group := g.Name 32 | 33 | l, err := UnixSocketListener(socket.Name(), &UnixSocketsConfig{ 34 | User: user, 35 | Group: group, 36 | Mode: "644", 37 | }) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | defer l.Close() 42 | 43 | fi, err := os.Stat(socket.Name()) 44 | if err != nil { 45 | t.Fatal(err) 46 | } 47 | 48 | mode, err := strconv.ParseUint("644", 8, 32) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | if fi.Mode().Perm() != os.FileMode(mode) { 53 | t.Fatalf("failed to set permissions on the socket file") 54 | } 55 | }) 56 | t.Run("names", func(t *testing.T) { 57 | socket, err := ioutil.TempFile("", "socket") 58 | if err != nil { 59 | t.Fatal(err) 60 | } 61 | defer os.Remove(socket.Name()) 62 | 63 | uid, gid := os.Getuid(), os.Getgid() 64 | l, err := UnixSocketListener(socket.Name(), &UnixSocketsConfig{ 65 | User: strconv.Itoa(uid), 66 | Group: strconv.Itoa(gid), 67 | Mode: "644", 68 | }) 69 | if err != nil { 70 | t.Fatal(err) 71 | } 72 | defer l.Close() 73 | 74 | fi, err := os.Stat(socket.Name()) 75 | if err != nil { 76 | t.Fatal(err) 77 | } 78 | 79 | mode, err := strconv.ParseUint("644", 8, 32) 80 | if err != nil { 81 | t.Fatal(err) 82 | } 83 | if fi.Mode().Perm() != os.FileMode(mode) { 84 | t.Fatalf("failed to set permissions on the socket file") 85 | } 86 | }) 87 | 88 | } 89 | -------------------------------------------------------------------------------- /mlock/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/mlock 2 | 3 | go 1.16 4 | 5 | require golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c 6 | -------------------------------------------------------------------------------- /mlock/go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I= 2 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 3 | -------------------------------------------------------------------------------- /mlock/mlock.go: -------------------------------------------------------------------------------- 1 | package mlock 2 | 3 | // This should be set by the OS-specific packages to tell whether LockMemory 4 | // is supported or not. 5 | var supported bool 6 | 7 | // Supported returns true if LockMemory is functional on this system. 8 | func Supported() bool { 9 | return supported 10 | } 11 | 12 | // LockMemory prevents any memory from being swapped to disk. 13 | func LockMemory() error { 14 | return lockMemory() 15 | } 16 | -------------------------------------------------------------------------------- /mlock/mlock_unavail.go: -------------------------------------------------------------------------------- 1 | // +build darwin nacl netbsd plan9 windows 2 | 3 | package mlock 4 | 5 | func init() { 6 | supported = false 7 | } 8 | 9 | func lockMemory() error { 10 | // XXX: No good way to do this on Windows. There is the VirtualLock 11 | // method, but it requires a specific address and offset. 12 | return nil 13 | } 14 | -------------------------------------------------------------------------------- /mlock/mlock_unix.go: -------------------------------------------------------------------------------- 1 | // +build dragonfly freebsd linux openbsd solaris 2 | 3 | package mlock 4 | 5 | import ( 6 | "syscall" 7 | 8 | "golang.org/x/sys/unix" 9 | ) 10 | 11 | func init() { 12 | supported = true 13 | } 14 | 15 | func lockMemory() error { 16 | // Mlockall prevents all current and future pages from being swapped out. 17 | return unix.Mlockall(syscall.MCL_CURRENT | syscall.MCL_FUTURE) 18 | } 19 | -------------------------------------------------------------------------------- /parseutil/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/parseutil 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/hashicorp/go-secure-stdlib/strutil v0.1.1 7 | github.com/hashicorp/go-sockaddr v1.0.2 8 | github.com/mitchellh/mapstructure v1.4.1 9 | github.com/stretchr/testify v1.7.0 10 | ) 11 | -------------------------------------------------------------------------------- /parseutil/go.sum: -------------------------------------------------------------------------------- 1 | github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= 2 | github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= 3 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= 6 | github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 7 | github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= 8 | github.com/hashicorp/go-secure-stdlib/strutil v0.1.1 h1:nd0HIW15E6FG1MsnArYaHfuw9C2zgzM8LxkG5Ty/788= 9 | github.com/hashicorp/go-secure-stdlib/strutil v0.1.1/go.mod h1:gKOamz3EwoIoJq7mlMIRBpVTAUn8qPCrEclOKKWhD3U= 10 | github.com/hashicorp/go-sockaddr v1.0.2 h1:ztczhD1jLxIRjVejw8gFomI1BQZOe2WoVOu0SyteCQc= 11 | github.com/hashicorp/go-sockaddr v1.0.2/go.mod h1:rB4wwRAUzs07qva3c5SdrY/NEtAUjGlgmH/UkBUC97A= 12 | github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= 13 | github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= 14 | github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= 15 | github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= 16 | github.com/mitchellh/mapstructure v1.4.1 h1:CpVNEelQCZBooIPDn+AR3NpivK/TIKU8bDxdASFVQag= 17 | github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= 18 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 19 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 20 | github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= 21 | github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= 22 | github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= 23 | github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= 24 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 25 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 26 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 27 | golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 28 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 29 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 30 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 31 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 32 | -------------------------------------------------------------------------------- /parseutil/parsepath.go: -------------------------------------------------------------------------------- 1 | package parseutil 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/ioutil" 7 | "net/url" 8 | "os" 9 | "strings" 10 | ) 11 | 12 | var ( 13 | ErrNotAUrl = errors.New("not a url") 14 | ErrNotParsed = errors.New("not a parsed value") 15 | ) 16 | 17 | // ParsePath parses a URL with schemes file://, env://, or any other. Depending 18 | // on the scheme it will return specific types of data: 19 | // 20 | // * file:// will return a string with the file's contents 21 | // 22 | // * env:// will return a string with the env var's contents 23 | // 24 | // * Anything else will return the string as it was. Functionally this means 25 | // anything for which Go's `url.Parse` function does not throw an error. If you 26 | // want to ensure that this function errors if a known scheme is not found, use 27 | // MustParsePath. 28 | // 29 | // On error, we return the original string along with the error. The caller can 30 | // switch on errors.Is(err, ErrNotAUrl) to understand whether it was the parsing 31 | // step that errored or something else (such as a file not found). This is 32 | // useful to attempt to read a non-URL string from some resource, but where the 33 | // original input may simply be a valid string of that type. 34 | func ParsePath(path string) (string, error) { 35 | return parsePath(path, false) 36 | } 37 | 38 | // MustParsePath behaves like ParsePath but will return ErrNotAUrl if the value 39 | // is not a URL with a scheme that can be parsed by this function. 40 | func MustParsePath(path string) (string, error) { 41 | return parsePath(path, true) 42 | } 43 | 44 | func parsePath(path string, mustParse bool) (string, error) { 45 | path = strings.TrimSpace(path) 46 | parsed, err := url.Parse(path) 47 | if err != nil { 48 | return path, fmt.Errorf("error parsing url (%q): %w", err.Error(), ErrNotAUrl) 49 | } 50 | switch parsed.Scheme { 51 | case "file": 52 | contents, err := ioutil.ReadFile(strings.TrimPrefix(path, "file://")) 53 | if err != nil { 54 | return path, fmt.Errorf("error reading file at %s: %w", path, err) 55 | } 56 | return strings.TrimSpace(string(contents)), nil 57 | case "env": 58 | return strings.TrimSpace(os.Getenv(strings.TrimPrefix(path, "env://"))), nil 59 | default: 60 | if mustParse { 61 | return "", ErrNotParsed 62 | } 63 | return path, nil 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /parseutil/parsepath_test.go: -------------------------------------------------------------------------------- 1 | package parseutil 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestParsePath(t *testing.T) { 14 | t.Parallel() 15 | 16 | file, err := os.CreateTemp("", "") 17 | require.NoError(t, err) 18 | _, err = file.WriteString("foo") 19 | require.NoError(t, err) 20 | require.NoError(t, file.Close()) 21 | defer os.Remove(file.Name()) 22 | 23 | require.NoError(t, os.Setenv("PATHTEST", "bar")) 24 | 25 | cases := []struct { 26 | name string 27 | inPath string 28 | outStr string 29 | notAUrl bool 30 | must bool 31 | notParsed bool 32 | expErrorContains string 33 | }{ 34 | { 35 | name: "file", 36 | inPath: fmt.Sprintf("file://%s", file.Name()), 37 | outStr: "foo", 38 | }, 39 | { 40 | name: "file-mustparse", 41 | inPath: fmt.Sprintf("file://%s", file.Name()), 42 | outStr: "foo", 43 | must: true, 44 | }, 45 | { 46 | name: "env", 47 | inPath: "env://PATHTEST", 48 | outStr: "bar", 49 | }, 50 | { 51 | name: "env-mustparse", 52 | inPath: "env://PATHTEST", 53 | outStr: "bar", 54 | must: true, 55 | }, 56 | { 57 | name: "plain", 58 | inPath: "zipzap", 59 | outStr: "zipzap", 60 | }, 61 | { 62 | name: "plain-mustparse", 63 | inPath: "zipzap", 64 | outStr: "zipzap", 65 | must: true, 66 | notParsed: true, 67 | }, 68 | { 69 | name: "no file", 70 | inPath: "file:///dev/nullface", 71 | outStr: "file:///dev/nullface", 72 | expErrorContains: "no such file or directory", 73 | }, 74 | { 75 | name: "not a url", 76 | inPath: "http://" + string([]byte{0x00}), 77 | outStr: "http://" + string([]byte{0x00}), 78 | notAUrl: true, 79 | }, 80 | } 81 | for _, tt := range cases { 82 | t.Run(tt.name, func(t *testing.T) { 83 | assert, require := assert.New(t), require.New(t) 84 | var out string 85 | var err error 86 | switch tt.must { 87 | case false: 88 | out, err = ParsePath(tt.inPath) 89 | default: 90 | out, err = MustParsePath(tt.inPath) 91 | } 92 | if tt.expErrorContains != "" { 93 | require.Error(err) 94 | assert.Contains(err.Error(), tt.expErrorContains) 95 | return 96 | } 97 | if tt.notAUrl { 98 | require.Error(err) 99 | assert.True(errors.Is(err, ErrNotAUrl)) 100 | assert.Equal(tt.inPath, out) 101 | return 102 | } 103 | if tt.notParsed { 104 | require.Error(err) 105 | assert.True(errors.Is(err, ErrNotParsed)) 106 | assert.Empty(out) 107 | return 108 | } 109 | require.NoError(err) 110 | assert.Equal(tt.outStr, out) 111 | }) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /password/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/password 2 | 3 | go 1.16 4 | 5 | require ( 6 | golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 7 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c 8 | ) 9 | -------------------------------------------------------------------------------- /password/go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= 2 | golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 3 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 4 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 5 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 6 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I= 7 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 8 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= 9 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 10 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 11 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 12 | -------------------------------------------------------------------------------- /password/password.go: -------------------------------------------------------------------------------- 1 | // password is a package for reading a password securely from a terminal. 2 | // The code in this package disables echo in the terminal so that the 3 | // password is not echoed back in plaintext to the user. 4 | package password 5 | 6 | import ( 7 | "errors" 8 | "io" 9 | "os" 10 | "os/signal" 11 | "strings" 12 | ) 13 | 14 | var ErrInterrupted = errors.New("interrupted") 15 | 16 | // Read reads the password from the given os.File. The password 17 | // will not be echoed back to the user. Ctrl-C will automatically return 18 | // from this function with a blank string and an ErrInterrupted. 19 | func Read(f *os.File) (string, error) { 20 | ch := make(chan os.Signal, 1) 21 | signal.Notify(ch, os.Interrupt) 22 | defer signal.Stop(ch) 23 | 24 | // Run the actual read in a go-routine so that we can still detect signals 25 | var result string 26 | var resultErr error 27 | doneCh := make(chan struct{}) 28 | go func() { 29 | defer close(doneCh) 30 | result, resultErr = read(f) 31 | }() 32 | 33 | // Wait on either the read to finish or the signal to come through 34 | select { 35 | case <-ch: 36 | return "", ErrInterrupted 37 | case <-doneCh: 38 | return removeiTermDelete(result), resultErr 39 | } 40 | } 41 | 42 | func readline(f *os.File) (string, error) { 43 | var buf [1]byte 44 | resultBuf := make([]byte, 0, 64) 45 | for { 46 | n, err := f.Read(buf[:]) 47 | if err != nil && err != io.EOF { 48 | return "", err 49 | } 50 | if n == 0 || buf[0] == '\n' || buf[0] == '\r' { 51 | break 52 | } 53 | 54 | // ASCII code 3 is what is sent for a Ctrl-C while reading raw. 55 | // If we see that, then get the interrupt. We have to do this here 56 | // because terminals in raw mode won't catch it at the shell level. 57 | if buf[0] == 3 { 58 | return "", ErrInterrupted 59 | } 60 | 61 | resultBuf = append(resultBuf, buf[0]) 62 | } 63 | 64 | return string(resultBuf), nil 65 | } 66 | 67 | func removeiTermDelete(input string) string { 68 | return strings.TrimPrefix(input, "\x20\x7f") 69 | } 70 | -------------------------------------------------------------------------------- /password/password_solaris.go: -------------------------------------------------------------------------------- 1 | // +build solaris 2 | 3 | package password 4 | 5 | import ( 6 | "fmt" 7 | "os" 8 | "syscall" 9 | 10 | "golang.org/x/sys/unix" 11 | ) 12 | 13 | func read(f *os.File) (string, error) { 14 | fd := int(f.Fd()) 15 | if !isTerminal(fd) { 16 | return "", fmt.Errorf("file descriptor %d is not a terminal", fd) 17 | } 18 | 19 | oldState, err := makeRaw(fd) 20 | if err != nil { 21 | return "", err 22 | } 23 | defer unix.IoctlSetTermios(fd, unix.TCSETS, oldState) 24 | 25 | return readline(f) 26 | } 27 | 28 | // isTerminal returns true if there is a terminal attached to the given 29 | // file descriptor. 30 | // Source: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libbc/libc/gen/common/isatty.c 31 | func isTerminal(fd int) bool { 32 | var termio unix.Termio 33 | err := unix.IoctlSetTermio(fd, unix.TCGETA, &termio) 34 | return err == nil 35 | } 36 | 37 | // makeRaw puts the terminal connected to the given file descriptor into raw 38 | // mode and returns the previous state of the terminal so that it can be 39 | // restored. 40 | // Source: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libast/common/uwin/getpass.c 41 | func makeRaw(fd int) (*unix.Termios, error) { 42 | oldTermiosPtr, err := unix.IoctlGetTermios(int(fd), unix.TCGETS) 43 | if err != nil { 44 | return nil, err 45 | } 46 | oldTermios := *oldTermiosPtr 47 | 48 | newTermios := oldTermios 49 | newTermios.Lflag &^= syscall.ECHO | syscall.ECHOE | syscall.ECHOK | syscall.ECHONL 50 | if err := unix.IoctlSetTermios(fd, unix.TCSETS, &newTermios); err != nil { 51 | return nil, err 52 | } 53 | 54 | return oldTermiosPtr, nil 55 | } 56 | -------------------------------------------------------------------------------- /password/password_test.go: -------------------------------------------------------------------------------- 1 | package password 2 | 3 | import "testing" 4 | 5 | type testCase struct { 6 | name string 7 | input string 8 | expected string 9 | } 10 | 11 | func TestRemoveiTermDelete(t *testing.T) { 12 | tests := []testCase{ 13 | {"NoDelete", "TestingStuff", "TestingStuff"}, 14 | {"SingleDelete", "Testing\x7fStuff", "Testing\x7fStuff"}, 15 | {"DeleteFirst", "\x7fTestingStuff", "\x7fTestingStuff"}, 16 | {"DoubleDelete", "\x7f\x7fTestingStuff", "\x7f\x7fTestingStuff"}, 17 | {"SpaceFirst", "\x20TestingStuff", "\x20TestingStuff"}, 18 | {"iTermDelete", "\x20\x7fTestingStuff", "TestingStuff"}, 19 | } 20 | 21 | for _, test := range tests { 22 | result := removeiTermDelete(test.input) 23 | if result != test.expected { 24 | t.Errorf("Test %s failed, input: '%s', expected: '%s', output: '%s'", test.name, test.input, test.expected, result) 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /password/password_unix.go: -------------------------------------------------------------------------------- 1 | // +build linux darwin freebsd netbsd openbsd dragonfly 2 | 3 | package password 4 | 5 | import ( 6 | "fmt" 7 | "os" 8 | 9 | "golang.org/x/crypto/ssh/terminal" 10 | ) 11 | 12 | func read(f *os.File) (string, error) { 13 | fd := int(f.Fd()) 14 | if !terminal.IsTerminal(fd) { 15 | return "", fmt.Errorf("file descriptor %d is not a terminal", fd) 16 | } 17 | 18 | oldState, err := terminal.MakeRaw(fd) 19 | if err != nil { 20 | return "", err 21 | } 22 | defer terminal.Restore(fd, oldState) 23 | 24 | return readline(f) 25 | } 26 | -------------------------------------------------------------------------------- /password/password_windows.go: -------------------------------------------------------------------------------- 1 | // +build windows 2 | 3 | package password 4 | 5 | import ( 6 | "os" 7 | "syscall" 8 | ) 9 | 10 | var ( 11 | kernel32 = syscall.MustLoadDLL("kernel32.dll") 12 | setConsoleModeProc = kernel32.MustFindProc("SetConsoleMode") 13 | ) 14 | 15 | // Magic constant from MSDN to control whether characters read are 16 | // repeated back on the console. 17 | // 18 | // http://msdn.microsoft.com/en-us/library/windows/desktop/ms686033(v=vs.85).aspx 19 | const ENABLE_ECHO_INPUT = 0x0004 20 | 21 | func read(f *os.File) (string, error) { 22 | handle := syscall.Handle(f.Fd()) 23 | 24 | // Grab the old console mode so we can reset it. We defer the reset 25 | // right away because it doesn't matter (it is idempotent). 26 | var oldMode uint32 27 | if err := syscall.GetConsoleMode(handle, &oldMode); err != nil { 28 | return "", err 29 | } 30 | defer setConsoleMode(handle, oldMode) 31 | 32 | // The new mode is the old mode WITHOUT the echo input flag set. 33 | var newMode uint32 = uint32(int(oldMode) & ^ENABLE_ECHO_INPUT) 34 | if err := setConsoleMode(handle, newMode); err != nil { 35 | return "", err 36 | } 37 | 38 | return readline(f) 39 | } 40 | 41 | func setConsoleMode(console syscall.Handle, mode uint32) error { 42 | r, _, err := setConsoleModeProc.Call(uintptr(console), uintptr(mode)) 43 | if r == 0 { 44 | return err 45 | } 46 | 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /pluginutil/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package pluginutil provides common functions to make it easier to load plugins, 3 | especially if they can be either instantiated in memory or implemented as 4 | go-plugin plugins. 5 | 6 | The package takes care of the actual building of the plugin map and execution of 7 | the plugins. 8 | 9 | The general flow is that BuildPluginMap is called with the various plugin 10 | sources, which gives back a map of plugin information. Program-side validation 11 | logic can then be used to decide whether or not to proceed, e.g. "if a certain 12 | plugin is not available after parsing sources, quit". 13 | 14 | The desired plugin information can then be sent to the CreatePlugin function, 15 | along with potentially additional options, such as a SecureConfig section. This 16 | function returns an interface that either represents a go-plugin client or a 17 | direct Go interface. The calling code can do a type switch to figure out which 18 | it is, dispense the plugin if needed, and return the interface back to the 19 | caller. 20 | 21 | For an example of usage, see the kms.go file in the configutil 22 | package in this repository. 23 | */ 24 | 25 | package pluginutil 26 | -------------------------------------------------------------------------------- /pluginutil/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/pluginutil/v2 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/hashicorp/go-plugin v1.4.3 7 | github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 8 | github.com/stretchr/testify v1.7.0 9 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 10 | ) 11 | 12 | require ( 13 | github.com/davecgh/go-spew v1.1.1 // indirect 14 | github.com/fatih/color v1.7.0 // indirect 15 | github.com/golang/protobuf v1.5.2 // indirect 16 | github.com/google/go-cmp v0.5.7 // indirect 17 | github.com/hashicorp/go-hclog v1.1.0 // indirect 18 | github.com/hashicorp/go-uuid v1.0.2 // indirect 19 | github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb // indirect 20 | github.com/mattn/go-colorable v0.1.6 // indirect 21 | github.com/mattn/go-isatty v0.0.12 // indirect 22 | github.com/mitchellh/go-testing-interface v1.0.0 // indirect 23 | github.com/oklog/run v1.0.0 // indirect 24 | github.com/pmezard/go-difflib v1.0.0 // indirect 25 | golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect 26 | golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a // indirect 27 | golang.org/x/text v0.3.7 // indirect 28 | google.golang.org/genproto v0.0.0-20220208230804-65c12eb4c068 // indirect 29 | google.golang.org/grpc v1.44.0 // indirect 30 | google.golang.org/protobuf v1.27.1 // indirect 31 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect 32 | ) 33 | -------------------------------------------------------------------------------- /pluginutil/options.go: -------------------------------------------------------------------------------- 1 | package pluginutil 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/fs" 7 | "os" 8 | 9 | gp "github.com/hashicorp/go-plugin" 10 | ) 11 | 12 | // GetOpts - iterate the inbound Options and return a struct 13 | func GetOpts(opt ...Option) (*options, error) { 14 | opts := getDefaultOptions() 15 | for _, o := range opt { 16 | if o != nil { 17 | if err := o(&opts); err != nil { 18 | return nil, err 19 | } 20 | } 21 | } 22 | return &opts, nil 23 | } 24 | 25 | // Option - how Options are passed as arguments 26 | type Option func(*options) error 27 | 28 | // pluginSourceInfo contains possibilities for plugin creation -- a map that can 29 | // be used to directly create instances, or an FS that can be used to source 30 | // plugin instances. 31 | type pluginSourceInfo struct { 32 | pluginMap map[string]InmemCreationFunc 33 | 34 | pluginFs fs.FS 35 | pluginFsPrefix string 36 | 37 | pluginFileInfo *PluginFileInfo 38 | } 39 | 40 | // options = how options are represented 41 | type options struct { 42 | withPluginSources []pluginSourceInfo 43 | withPluginExecutionDirectory string 44 | withPluginClientCreationFunc PluginClientCreationFunc 45 | WithSecureConfig *gp.SecureConfig 46 | } 47 | 48 | func getDefaultOptions() options { 49 | return options{} 50 | } 51 | 52 | // WithPluginsFilesystem provides an fs.FS containing plugins that can be 53 | // executed to provide functionality. This can be specified multiple times; all 54 | // FSes will be scanned. Any conflicts will be resolved later (e.g. in 55 | // BuildPluginsMap, the behavior will be last scanned plugin with the same name 56 | // wins).If there are conflicts, the last one wins, a property shared with 57 | // WithPluginsMap and WithPluginFile). The prefix will be stripped from each 58 | // entry when determining the plugin type. 59 | // 60 | // This doesn't currently support any kind of secure config and is meant for 61 | // cases where you can build up this FS securely. See WithPluginFile for adding 62 | // individual files with checksumming. 63 | func WithPluginsFilesystem(withPrefix string, withPlugins fs.FS) Option { 64 | return func(o *options) error { 65 | if withPlugins == nil { 66 | return errors.New("nil plugin filesystem passed into option") 67 | } 68 | o.withPluginSources = append(o.withPluginSources, 69 | pluginSourceInfo{ 70 | pluginFs: withPlugins, 71 | pluginFsPrefix: withPrefix, 72 | }, 73 | ) 74 | return nil 75 | } 76 | } 77 | 78 | // WithPluginsMap provides a map containing functions that can be called to 79 | // instantiate plugins directly. This can be specified multiple times; all maps 80 | // will be scanned. Any conflicts will be resolved later (e.g. in 81 | // BuildPluginsMap, the behavior will be last scanned plugin with the same name 82 | // wins).If there are conflicts, the last one wins, a property shared with 83 | // WithPluginsFilesystem and WithPluginFile). 84 | func WithPluginsMap(with map[string]InmemCreationFunc) Option { 85 | return func(o *options) error { 86 | if len(with) == 0 { 87 | return errors.New("no entries in plugins map passed into option") 88 | } 89 | o.withPluginSources = append(o.withPluginSources, 90 | pluginSourceInfo{ 91 | pluginMap: with, 92 | }, 93 | ) 94 | return nil 95 | } 96 | } 97 | 98 | // WithPluginFile provides source information for a file on disk (rather than an 99 | // fs.FS abstraction or an in-memory function). Secure hash info _must_ be 100 | // provided in this case. If there are conflicts with the name, the last one 101 | // wins, a property shared with WithPluginsFilesystem and WithPluginsMap). 102 | func WithPluginFile(with PluginFileInfo) Option { 103 | return func(o *options) error { 104 | // Start with validating that the file exists 105 | switch { 106 | case with.Name == "": 107 | return errors.New("plugin file name is empty") 108 | case with.Path == "": 109 | return errors.New("plugin file path is empty") 110 | case len(with.Checksum) == 0: 111 | return errors.New("plugin file checksum is empty") 112 | } 113 | 114 | switch with.HashMethod { 115 | case HashMethodUnspecified: 116 | with.HashMethod = HashMethodSha2256 117 | case HashMethodSha2256, 118 | HashMethodSha2384, 119 | HashMethodSha2512, 120 | HashMethodSha3256, 121 | HashMethodSha3384, 122 | HashMethodSha3512: 123 | default: 124 | return fmt.Errorf("unsupported hash method %q", string(with.HashMethod)) 125 | } 126 | info, err := os.Stat(with.Path) 127 | if err != nil { 128 | return fmt.Errorf("plugin at %q not found on filesystem: %w", with.Path, err) 129 | } 130 | if info.IsDir() { 131 | return fmt.Errorf("plugin at path %q is a directory", with.Path) 132 | } 133 | 134 | o.withPluginSources = append(o.withPluginSources, 135 | pluginSourceInfo{ 136 | pluginFileInfo: &with, 137 | }, 138 | ) 139 | return nil 140 | } 141 | } 142 | 143 | // WithPluginExecutionDirectory allows setting a specific directory for writing 144 | // out and executing plugins; if not set, os.TempDir will be used to create a 145 | // suitable directory 146 | func WithPluginExecutionDirectory(with string) Option { 147 | return func(o *options) error { 148 | o.withPluginExecutionDirectory = with 149 | return nil 150 | } 151 | } 152 | 153 | // WithPluginClientCreationFunc allows passing in the func to use to create a plugin 154 | // client on the host side. Not necessary if only inmem functions are used, but 155 | // required otherwise. 156 | func WithPluginClientCreationFunc(with PluginClientCreationFunc) Option { 157 | return func(o *options) error { 158 | o.withPluginClientCreationFunc = with 159 | return nil 160 | } 161 | } 162 | 163 | // WithSecureConfig allows passing in the go-plugin secure config struct for 164 | // validating a plugin prior to execution. Generally not needed if the plugin is 165 | // being spun out of the binary at runtime. 166 | func WithSecureConfig(with *gp.SecureConfig) Option { 167 | return func(o *options) error { 168 | o.WithSecureConfig = with 169 | return nil 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /pluginutil/options_test.go: -------------------------------------------------------------------------------- 1 | package pluginutil 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "testing/fstest" 7 | 8 | gp "github.com/hashicorp/go-plugin" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func Test_GetOpts(t *testing.T) { 14 | t.Parallel() 15 | t.Run("nil", func(t *testing.T) { 16 | assert := assert.New(t) 17 | opts, err := GetOpts(nil) 18 | assert.NoError(err) 19 | assert.NotNil(opts) 20 | }) 21 | t.Run("with-plugins-filesystem", func(t *testing.T) { 22 | assert, require := assert.New(t), require.New(t) 23 | opts, err := GetOpts() 24 | require.NoError(err) 25 | assert.Nil(opts.withPluginSources) 26 | opts, err = GetOpts(WithPluginsFilesystem("foo", nil)) 27 | require.Error(err) 28 | assert.Nil(opts) 29 | opts, err = GetOpts(WithPluginsFilesystem("foo", make(fstest.MapFS))) 30 | require.NoError(err) 31 | require.NotNil(opts) 32 | assert.NotNil(opts.withPluginSources) 33 | }) 34 | t.Run("with-plugins-map", func(t *testing.T) { 35 | assert, require := assert.New(t), require.New(t) 36 | opts, err := GetOpts() 37 | require.NoError(err) 38 | assert.Nil(opts.withPluginSources) 39 | opts, err = GetOpts(WithPluginsMap( 40 | map[string]InmemCreationFunc{ 41 | "foo": nil, 42 | }, 43 | )) 44 | require.NoError(err) 45 | require.NotNil(opts) 46 | assert.NotNil(opts.withPluginSources) 47 | }) 48 | t.Run("with-multiple-calls", func(t *testing.T) { 49 | assert, require := assert.New(t), require.New(t) 50 | opts, err := GetOpts() 51 | require.NoError(err) 52 | assert.Nil(opts.withPluginSources) 53 | opts, err = GetOpts( 54 | WithPluginsMap( 55 | map[string]InmemCreationFunc{ 56 | "foo": nil, 57 | }, 58 | ), 59 | WithPluginsMap( 60 | map[string]InmemCreationFunc{ 61 | "bar": nil, 62 | }, 63 | ), 64 | ) 65 | require.NoError(err) 66 | require.NotNil(opts) 67 | assert.NotNil(opts.withPluginSources) 68 | assert.Len(opts.withPluginSources, 2) 69 | }) 70 | t.Run("with-plugins-execution-directory", func(t *testing.T) { 71 | assert, require := assert.New(t), require.New(t) 72 | opts, err := GetOpts(WithPluginExecutionDirectory("foo")) 73 | require.NoError(err) 74 | require.NotNil(opts) 75 | assert.Equal("foo", opts.withPluginExecutionDirectory) 76 | }) 77 | t.Run("with-plugin-client-creation-func", func(t *testing.T) { 78 | assert, require := assert.New(t), require.New(t) 79 | opts, err := GetOpts() 80 | require.NoError(err) 81 | assert.Nil(opts.withPluginClientCreationFunc) 82 | opts, err = GetOpts(WithPluginClientCreationFunc( 83 | func(string, ...Option) (*gp.Client, error) { 84 | return new(gp.Client), nil 85 | }, 86 | )) 87 | require.NoError(err) 88 | require.NotNil(opts) 89 | client, err := opts.withPluginClientCreationFunc("") 90 | assert.NoError(err) 91 | assert.NotNil(client) 92 | }) 93 | t.Run("with-secure-config", func(t *testing.T) { 94 | assert, require := assert.New(t), require.New(t) 95 | opts, err := GetOpts() 96 | require.NoError(err) 97 | assert.Nil(opts.WithSecureConfig) 98 | opts, err = GetOpts(WithSecureConfig(new(gp.SecureConfig))) 99 | require.NoError(err) 100 | require.NotNil(opts.WithSecureConfig) 101 | }) 102 | t.Run("with-plugin-file", func(t *testing.T) { 103 | file, err := os.CreateTemp("", "") 104 | require.NoError(t, err) 105 | t.Cleanup(func() { 106 | os.Remove(file.Name()) 107 | }) 108 | currDir, err := os.Getwd() 109 | require.NoError(t, err) 110 | testCases := []struct { 111 | name string 112 | plugin PluginFileInfo 113 | wantErrContains string 114 | wantHashMethod HashMethod 115 | }{ 116 | { 117 | name: "no name", 118 | plugin: PluginFileInfo{}, 119 | wantErrContains: "name is empty", 120 | }, 121 | { 122 | name: "no path", 123 | plugin: PluginFileInfo{ 124 | Name: "testing", 125 | }, 126 | wantErrContains: "path is empty", 127 | }, 128 | { 129 | name: "no checksum", 130 | plugin: PluginFileInfo{ 131 | Name: "testing", 132 | Path: file.Name(), 133 | }, 134 | wantErrContains: "checksum is empty", 135 | }, 136 | { 137 | name: "bad hash type", 138 | plugin: PluginFileInfo{ 139 | Name: "testing", 140 | Path: file.Name(), 141 | Checksum: []byte("foobar"), 142 | HashMethod: "foobar", 143 | }, 144 | wantErrContains: "unsupported hash method", 145 | }, 146 | { 147 | name: "invalid path - missing", 148 | plugin: PluginFileInfo{ 149 | Name: "testing", 150 | Path: file.Name() + ".foobar", 151 | Checksum: []byte("foobar"), 152 | HashMethod: HashMethodSha2384, 153 | }, 154 | wantErrContains: "not found on filesystem", 155 | }, 156 | { 157 | name: "invalid path - dir", 158 | plugin: PluginFileInfo{ 159 | Name: "testing", 160 | Path: currDir, 161 | Checksum: []byte("foobar"), 162 | HashMethod: HashMethodSha2384, 163 | }, 164 | wantErrContains: "is a directory", 165 | }, 166 | { 167 | name: "unspecified hash type", 168 | plugin: PluginFileInfo{ 169 | Name: "testing", 170 | Path: file.Name(), 171 | Checksum: []byte("foobar"), 172 | HashMethod: HashMethodSha2384, 173 | }, 174 | wantHashMethod: HashMethodSha2256, 175 | }, 176 | { 177 | name: "specified hash type", 178 | plugin: PluginFileInfo{ 179 | Name: "testing", 180 | Path: file.Name(), 181 | Checksum: []byte("foobar"), 182 | HashMethod: HashMethodSha3384, 183 | }, 184 | wantHashMethod: HashMethodSha3384, 185 | }, 186 | } 187 | for _, tc := range testCases { 188 | t.Run(tc.name, func(t *testing.T) { 189 | assert, require := assert.New(t), require.New(t) 190 | opts, err := GetOpts(WithPluginFile(tc.plugin)) 191 | if tc.wantErrContains != "" { 192 | assert.Contains(err.Error(), tc.wantErrContains) 193 | return 194 | } 195 | require.NoError(err) 196 | require.NotNil(opts) 197 | assert.NotNil(opts.withPluginSources) 198 | }) 199 | } 200 | }) 201 | } 202 | -------------------------------------------------------------------------------- /pluginutil/pluginutil.go: -------------------------------------------------------------------------------- 1 | package pluginutil 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "crypto/sha256" 7 | "crypto/sha512" 8 | "fmt" 9 | "hash" 10 | "io/fs" 11 | "io/ioutil" 12 | "os" 13 | "path/filepath" 14 | "runtime" 15 | "strings" 16 | 17 | gp "github.com/hashicorp/go-plugin" 18 | "github.com/hashicorp/go-secure-stdlib/base62" 19 | "golang.org/x/crypto/sha3" 20 | ) 21 | 22 | // HashMethod is a string representation of a hash method 23 | type HashMethod string 24 | 25 | const ( 26 | HashMethodUnspecified HashMethod = "" 27 | HashMethodSha2256 HashMethod = "sha2-256" 28 | HashMethodSha2384 HashMethod = "sha2-384" 29 | HashMethodSha2512 HashMethod = "sha2-512" 30 | HashMethodSha3256 HashMethod = "sha3-256" 31 | HashMethodSha3384 HashMethod = "sha3-384" 32 | HashMethodSha3512 HashMethod = "sha3-512" 33 | ) 34 | 35 | // PluginFileInfo represents user-specified on-disk file information. Note that 36 | // testing for how this works in go-plugin, e.g. passing it into SecureConfig, 37 | // is in configutil to avoid pulling in go-kms-wrapping as a dep of this 38 | // package. 39 | type PluginFileInfo struct { 40 | Name string 41 | Path string 42 | Checksum []byte 43 | HashMethod HashMethod 44 | } 45 | 46 | type ( 47 | // InmemCreationFunc is a function that, when run, returns the thing you 48 | // want created (almost certainly an interface that is also supported by a 49 | // go-plugin plugin implementation) 50 | InmemCreationFunc func() (interface{}, error) 51 | 52 | // PluginClientCreationFunc is a function that, when run, returns a client 53 | // corresponding to a spun out go-plugin plugin. The string argument is the 54 | // filename. WithSecureConfig is supported as an option that will be round 55 | // tripped to the given function if provided to this package so that it can 56 | // be given to go-plugin. 57 | PluginClientCreationFunc func(string, ...Option) (*gp.Client, error) 58 | ) 59 | 60 | // PluginInfo contains plugin instantiation information for a single plugin, 61 | // parsed from the various maps and FSes that can be input to the BuildPluginMap 62 | // function. 63 | type PluginInfo struct { 64 | ContainerFs fs.FS 65 | Path string 66 | SecureConfig *gp.SecureConfig 67 | InmemCreationFunc InmemCreationFunc 68 | PluginClientCreationFunc PluginClientCreationFunc 69 | } 70 | 71 | // BuildPluginMap takes in options that contain one or more sets of plugin maps 72 | // or filesystems and builds an overall mapping of a plugin name to its 73 | // information. The desired plugin can then be sent to CreatePlugin to actually 74 | // instantiate it. If a plugin is specified by name multiple times in option, 75 | // the last one wins. 76 | func BuildPluginMap(opt ...Option) (map[string]*PluginInfo, error) { 77 | opts, err := GetOpts(opt...) 78 | if err != nil { 79 | return nil, fmt.Errorf("error parsing plugin options: %w", err) 80 | } 81 | 82 | if len(opts.withPluginSources) == 0 { 83 | return nil, fmt.Errorf("no plugins available") 84 | } 85 | 86 | pluginMap := map[string]*PluginInfo{} 87 | for _, sourceInfo := range opts.withPluginSources { 88 | switch { 89 | case sourceInfo.pluginFs != nil: 90 | if opts.withPluginClientCreationFunc == nil { 91 | return nil, fmt.Errorf("non-in-memory plugin found but no creation func provided") 92 | } 93 | dirs, err := fs.ReadDir(sourceInfo.pluginFs, ".") 94 | if err != nil { 95 | return nil, fmt.Errorf("error scanning plugins: %w", err) 96 | } 97 | // Store a match between the config type string and the expected plugin name 98 | for _, entry := range dirs { 99 | pluginType := strings.TrimSuffix(strings.TrimPrefix(entry.Name(), sourceInfo.pluginFsPrefix), ".gz") 100 | if runtime.GOOS == "windows" { 101 | pluginType = strings.TrimSuffix(pluginType, ".exe") 102 | } 103 | pluginMap[pluginType] = &PluginInfo{ 104 | ContainerFs: sourceInfo.pluginFs, 105 | Path: entry.Name(), 106 | PluginClientCreationFunc: opts.withPluginClientCreationFunc, 107 | } 108 | } 109 | case sourceInfo.pluginMap != nil: 110 | for k, creationFunc := range sourceInfo.pluginMap { 111 | pluginMap[k] = &PluginInfo{InmemCreationFunc: creationFunc} 112 | } 113 | 114 | case sourceInfo.pluginFileInfo != nil: 115 | fileInfo := sourceInfo.pluginFileInfo 116 | var h hash.Hash 117 | switch fileInfo.HashMethod { 118 | case HashMethodSha2256: 119 | h = sha256.New() 120 | case HashMethodSha2384: 121 | h = sha512.New384() 122 | case HashMethodSha2512: 123 | h = sha512.New() 124 | case HashMethodSha3256: 125 | h = sha3.New256() 126 | case HashMethodSha3384: 127 | h = sha3.New384() 128 | case HashMethodSha3512: 129 | h = sha3.New512() 130 | } 131 | pluginMap[fileInfo.Name] = &PluginInfo{ 132 | Path: fileInfo.Path, 133 | PluginClientCreationFunc: opts.withPluginClientCreationFunc, 134 | SecureConfig: &gp.SecureConfig{ 135 | Checksum: fileInfo.Checksum, 136 | Hash: h, 137 | }, 138 | } 139 | } 140 | } 141 | 142 | return pluginMap, nil 143 | } 144 | 145 | // CreatePlugin instantiates a given plugin either via an in-memory function or 146 | // by executing a go-plugin plugin. The interface returned will either be a 147 | // *.Client or the value returned from an in-memory function. A type 148 | // switch should be used by the calling code to determine this, and the 149 | // appropriate service should be Dispensed if what is returned is a go-plugin 150 | // plugin. 151 | // 152 | // If the WithSecureConfig option is passed, this will be round-tripped into the 153 | // PluginClientCreationFunction from the given *PluginInfo, where it can be sent 154 | // into the go-plugin client configuration. 155 | // 156 | // The caller should ensure that cleanup() is executed when they are done using 157 | // the plugin. In the case of an in-memory plugin it will be nil, however, if 158 | // the plugin is via RPC it will ensure that it is torn down properly. 159 | func CreatePlugin(plugin *PluginInfo, opt ...Option) (interface{}, func() error, error) { 160 | opts, err := GetOpts(opt...) 161 | if err != nil { 162 | return nil, nil, fmt.Errorf("error parsing plugin options: %w", err) 163 | } 164 | 165 | var file fs.File 166 | var name string 167 | 168 | switch { 169 | case plugin == nil: 170 | return nil, nil, fmt.Errorf("plugin is nil") 171 | 172 | // Prioritize in-memory functions 173 | case plugin.InmemCreationFunc != nil: 174 | raw, err := plugin.InmemCreationFunc() 175 | return raw, nil, err 176 | 177 | // If not in-memory we need a filename, whether direct on disk or from a container FS 178 | case plugin.Path == "": 179 | return nil, nil, fmt.Errorf("no inmem creation func and file path not provided") 180 | 181 | // We need the client creation func to use once we've spun out the plugin 182 | case plugin.PluginClientCreationFunc == nil: 183 | return nil, nil, fmt.Errorf("plugin creation func not provided") 184 | 185 | // Either we need to have a validated FS to read from or a secure config 186 | case plugin.ContainerFs == nil && plugin.SecureConfig == nil: 187 | return nil, nil, fmt.Errorf("plugin container filesystem and secure config are both nil") 188 | 189 | // If we have a constructed filesystem, read from there 190 | case plugin.ContainerFs != nil: 191 | file, err = plugin.ContainerFs.Open(plugin.Path) 192 | name = plugin.Path 193 | 194 | // If we have secure config, read from disk 195 | case plugin.SecureConfig != nil: 196 | file, err = os.Open(plugin.Path) 197 | name = filepath.Base(plugin.Path) 198 | 199 | default: 200 | return nil, nil, fmt.Errorf("unhandled path in create plugin switch") 201 | } 202 | 203 | // This is the error from opening the file 204 | if err != nil { 205 | return nil, nil, err 206 | } 207 | defer file.Close() 208 | stat, err := file.Stat() 209 | if err != nil { 210 | return nil, nil, fmt.Errorf("error discovering plugin information: %w", err) 211 | } 212 | if stat.IsDir() { 213 | return nil, nil, fmt.Errorf("plugin is a directory, not a file") 214 | } 215 | 216 | // Read in plugin bytes 217 | expLen := stat.Size() 218 | buf := make([]byte, expLen) 219 | readLen, err := file.Read(buf) 220 | if err != nil { 221 | return nil, nil, fmt.Errorf("error reading plugin bytes: %w", err) 222 | } 223 | if int64(readLen) != expLen { 224 | return nil, nil, fmt.Errorf("reading plugin, expected %d bytes, read %d", expLen, readLen) 225 | } 226 | 227 | // If it's compressed, uncompress it 228 | if strings.HasSuffix(name, ".gz") { 229 | gzipReader, err := gzip.NewReader(bytes.NewReader(buf)) 230 | if err != nil { 231 | return nil, nil, fmt.Errorf("error creating gzip decompression reader: %w", err) 232 | } 233 | uncompBuf := new(bytes.Buffer) 234 | _, err = uncompBuf.ReadFrom(gzipReader) 235 | gzipReader.Close() 236 | if err != nil { 237 | return nil, nil, fmt.Errorf("error reading gzip compressed data from reader: %w", err) 238 | } 239 | buf = uncompBuf.Bytes() 240 | name = strings.TrimSuffix(name, ".gz") 241 | } 242 | 243 | cleanup := func() error { 244 | return nil 245 | } 246 | 247 | // Now, create a temp dir and write out the plugin bytes 248 | dir := opts.withPluginExecutionDirectory 249 | if dir == "" { 250 | tmpDir, err := ioutil.TempDir("", "*") 251 | if err != nil { 252 | return nil, nil, fmt.Errorf("error creating tmp dir for plugin execution: %w", err) 253 | } 254 | cleanup = func() error { 255 | return os.RemoveAll(tmpDir) 256 | } 257 | dir = tmpDir 258 | } 259 | pluginPath := filepath.Join(dir, name) 260 | randSuffix, err := base62.Random(5) 261 | if err != nil { 262 | return nil, nil, fmt.Errorf("error generating random suffix for plugin execution: %w", err) 263 | } 264 | pluginPath = fmt.Sprintf("%s-%s", pluginPath, randSuffix) 265 | if runtime.GOOS == "windows" { 266 | pluginPath = fmt.Sprintf("%s.exe", pluginPath) 267 | } 268 | if err := ioutil.WriteFile(pluginPath, buf, fs.FileMode(0o700)); err != nil { 269 | return nil, cleanup, fmt.Errorf("error writing out plugin for execution: %w", err) 270 | } 271 | 272 | // Execute the plugin, passing in secure config if available 273 | creationFuncOpts := opt 274 | if plugin.SecureConfig != nil { 275 | creationFuncOpts = append(creationFuncOpts, WithSecureConfig(plugin.SecureConfig)) 276 | } 277 | client, err := plugin.PluginClientCreationFunc(pluginPath, creationFuncOpts...) 278 | if err != nil { 279 | return nil, cleanup, fmt.Errorf("error fetching kms plugin client: %w", err) 280 | } 281 | origCleanup := cleanup 282 | cleanup = func() error { 283 | client.Kill() 284 | return origCleanup() 285 | } 286 | rpcClient, err := client.Client() 287 | if err != nil { 288 | return nil, cleanup, fmt.Errorf("error fetching kms plugin rpc client: %w", err) 289 | } 290 | 291 | return rpcClient, cleanup, nil 292 | } 293 | -------------------------------------------------------------------------------- /reloadutil/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/reloadutil 2 | 3 | go 1.16 4 | -------------------------------------------------------------------------------- /reloadutil/go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Flyingon/go-secure-stdlib/7849be51188ffe09900bf3232e94695389cfc8aa/reloadutil/go.sum -------------------------------------------------------------------------------- /reloadutil/reload.go: -------------------------------------------------------------------------------- 1 | package reloadutil 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "encoding/pem" 7 | "errors" 8 | "fmt" 9 | "io/ioutil" 10 | "sync" 11 | ) 12 | 13 | // ReloadFunc are functions that are called when a reload is requested 14 | type ReloadFunc func() error 15 | 16 | // CertificateGetter satisfies ReloadFunc and its GetCertificate method 17 | // satisfies the tls.GetCertificate function signature. Currently it does not 18 | // allow changing paths after the fact. 19 | type CertificateGetter struct { 20 | sync.RWMutex 21 | 22 | cert *tls.Certificate 23 | 24 | certFile string 25 | keyFile string 26 | passphrase string 27 | } 28 | 29 | func NewCertificateGetter(certFile, keyFile, passphrase string) *CertificateGetter { 30 | return &CertificateGetter{ 31 | certFile: certFile, 32 | keyFile: keyFile, 33 | passphrase: passphrase, 34 | } 35 | } 36 | 37 | func (cg *CertificateGetter) Reload() error { 38 | certPEMBlock, err := ioutil.ReadFile(cg.certFile) 39 | if err != nil { 40 | return err 41 | } 42 | keyPEMBlock, err := ioutil.ReadFile(cg.keyFile) 43 | if err != nil { 44 | return err 45 | } 46 | 47 | // Check for encrypted pem block 48 | keyBlock, _ := pem.Decode(keyPEMBlock) 49 | if keyBlock == nil { 50 | return errors.New("decoded PEM is blank") 51 | } 52 | 53 | if x509.IsEncryptedPEMBlock(keyBlock) { 54 | keyBlock.Bytes, err = x509.DecryptPEMBlock(keyBlock, []byte(cg.passphrase)) 55 | if err != nil { 56 | return fmt.Errorf("Decrypting PEM block failed: %w", err) 57 | } 58 | keyPEMBlock = pem.EncodeToMemory(keyBlock) 59 | } 60 | 61 | cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) 62 | if err != nil { 63 | return err 64 | } 65 | 66 | cg.Lock() 67 | defer cg.Unlock() 68 | 69 | cg.cert = &cert 70 | 71 | return nil 72 | } 73 | 74 | func (cg *CertificateGetter) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { 75 | cg.RLock() 76 | defer cg.RUnlock() 77 | 78 | if cg.cert == nil { 79 | return nil, fmt.Errorf("nil certificate") 80 | } 81 | 82 | return cg.cert, nil 83 | } 84 | -------------------------------------------------------------------------------- /reloadutil/reload_test.go: -------------------------------------------------------------------------------- 1 | package reloadutil 2 | 3 | import ( 4 | "crypto/x509" 5 | "errors" 6 | "io/ioutil" 7 | "testing" 8 | ) 9 | 10 | func TestReload_KeyWithPassphrase(t *testing.T) { 11 | password := "password" 12 | cert := []byte(`-----BEGIN CERTIFICATE----- 13 | MIICLzCCAZgCCQCq27CeP4WhlDANBgkqhkiG9w0BAQUFADBcMQswCQYDVQQGEwJV 14 | UzELMAkGA1UECAwCQ0ExFjAUBgNVBAcMDVNhbiBGcmFuY2lzY28xEjAQBgNVBAoM 15 | CUhhc2hpQ29ycDEUMBIGA1UEAwwLbXl2YXVsdC5jb20wHhcNMTcxMjEzMjEzNTM3 16 | WhcNMTgxMjEzMjEzNTM3WjBcMQswCQYDVQQGEwJVUzELMAkGA1UECAwCQ0ExFjAU 17 | BgNVBAcMDVNhbiBGcmFuY2lzY28xEjAQBgNVBAoMCUhhc2hpQ29ycDEUMBIGA1UE 18 | AwwLbXl2YXVsdC5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMvsz/9l 19 | EJIlRG6DOw4fXdB/aJgJk2rR8cU0D8+vECIzb+MdDK0cBHtLiVpZC/RnZMdMzjGn 20 | Z++Fp3dEnT6CD0IjKdJcD+qSyZSjHIuYpHjnjrVlM/Le0xST7egoG+fXkSt4myzG 21 | ec2WK1jcZefRRGPycvMqx1yUWU76jDdFZSL5AgMBAAEwDQYJKoZIhvcNAQEFBQAD 22 | gYEAQfYE26FLZ9SPPU8bHNDxoxDmGrn8yJ78C490Qpix/w6gdLaBtILenrZbhpnB 23 | 3L3okraM8mplaN2KdAcpnsr4wPv9hbYkam0coxCQEKs8ltHSBaXT6uKRWb00nkGu 24 | yAXDRpuPdFRqbXW3ZFC5broUrz4ujxTDKfVeIn0zpPZkv24= 25 | -----END CERTIFICATE-----`) 26 | key := []byte(`-----BEGIN RSA PRIVATE KEY----- 27 | Proc-Type: 4,ENCRYPTED 28 | DEK-Info: DES-EDE3-CBC,64B032D83BD6A6DC 29 | 30 | qVJ+mXEBKMkUPrQ8odHunMpPgChQUny4CX73/dAcm7O9iXIv9eXQSxj2qfgCOloj 31 | vthg7jYNwtRb0ydzCEnEud35zWw38K/l19/pe4ULfNXlOddlsk4XIHarBiz+KUaX 32 | WTbNk0H+DwdcEwhprPgpTk8gp88lZBiHCnTG/s8v/JNt+wkdqjfAp0Xbm9m+OZ7s 33 | hlNxZin1OuBdprBqfKWBltUALZYiIBhspMTmh+jGQSyEKNTAIBejIiRH5+xYWuOy 34 | xKencq8UpQMOMPR2ZiSw42dU9j8HHMgldI7KszU2FDIEFXG7aSjcxNyyybeBT+Uz 35 | YPoxGxSdUYWqaz50UszvHg/QWR8NlPlQc3nFAUVpGKUF9MEQCIAK8HjcpMP+IAVO 36 | ertp4cTa2Rpm9YeoFrY6tabvmXApXlQPw6rBn6o5KpceWG3ceOsDOsT+e3edHu9g 37 | SGO4hjggbRpO+dBOuwfw4rMn9X1BbqXKJcREAmrgVVSf9/s942E4YOQ+IGJPdtmY 38 | WHAFk8hiJepsVCA2NpwVlAD+QbPPaR2RtvYOtq3IKlWRuVQ+6dpxDsz5FlJhs2L+ 39 | HsX6XqtwuQM8kk1hO8Gm3VeV7+b64r9kfbO8jCM18GexCYiCtig51mJW6IO42d1K 40 | bS1axMx/KeDc/sy7LKEbHnjnYanpGz2Wa2EWhnWAeNXD1nUfUNFPp2SsIGbCMnat 41 | mC4O4cO7YRl3+iJg3kHtTPGtgtCjrZcjlyBtxT2VC7SsTcTXZBWovczMIstyr4Ka 42 | opM24uvQT3Bc0UM0WNh3tdRFuboxDeBDh7PX/2RIoiaMuCCiRZ3O0A== 43 | -----END RSA PRIVATE KEY-----`) 44 | tempDir, err := ioutil.TempDir("", "vault-test") 45 | if err != nil { 46 | t.Fatalf("Error creating temporary directory: %s", err) 47 | } 48 | keyFile := tempDir + "/server.key" 49 | certFile := tempDir + "/server.crt" 50 | 51 | err = ioutil.WriteFile(certFile, cert, 0755) 52 | if err != nil { 53 | t.Fatalf("Error writing to temp file: %s", err) 54 | } 55 | err = ioutil.WriteFile(keyFile, key, 0755) 56 | if err != nil { 57 | t.Fatalf("Error writing to temp file: %s", err) 58 | } 59 | 60 | cg := NewCertificateGetter(certFile, keyFile, "") 61 | err = cg.Reload() 62 | if err == nil { 63 | t.Fatal("error expected") 64 | } 65 | if !errors.As(err, &x509.IncorrectPasswordError) { 66 | t.Fatalf("expected incorrect password error, got %v", err) 67 | } 68 | 69 | cg = NewCertificateGetter(certFile, keyFile, password) 70 | if err := cg.Reload(); err != nil { 71 | t.Fatalf("err: %v", err) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /strutil/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/strutil 2 | 3 | go 1.16 4 | 5 | require github.com/ryanuber/go-glob v1.0.0 6 | -------------------------------------------------------------------------------- /strutil/go.sum: -------------------------------------------------------------------------------- 1 | github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= 2 | github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= 3 | -------------------------------------------------------------------------------- /strutil/strutil_benchmark_test.go: -------------------------------------------------------------------------------- 1 | package strutil 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func BenchmarkRemoveDuplicates(b *testing.B) { 9 | a := make([]string, 1_000_000) 10 | for i := 0; i < len(a); i++ { 11 | a[i] = fmt.Sprintf("test.%d", i) 12 | } 13 | b.ResetTimer() 14 | 15 | for i := 0; i < b.N; i++ { 16 | RemoveDuplicates(a, true) 17 | } 18 | } 19 | 20 | func BenchmarkRemoveDuplicatesStable(b *testing.B) { 21 | a := make([]string, 1_000_000) 22 | for i := 0; i < len(a); i++ { 23 | a[i] = fmt.Sprintf("test.%d", i) 24 | } 25 | b.ResetTimer() 26 | 27 | for i := 0; i < b.N; i++ { 28 | RemoveDuplicatesStable(a, true) 29 | } 30 | } 31 | 32 | func BenchmarkEquivalentSlices(b *testing.B) { 33 | x := make([]string, 1_000_000) 34 | y := make([]string, len(x)) 35 | for i := 0; i < len(x); i++ { 36 | x[i] = fmt.Sprintf("test.%d", i) 37 | y[i] = fmt.Sprintf("test.%d", i) 38 | } 39 | b.ResetTimer() 40 | 41 | for i := 0; i < b.N; i++ { 42 | EquivalentSlices(x, y) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /tlsutil/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-secure-stdlib/tlsutil 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/hashicorp/go-secure-stdlib/parseutil v0.1.1 7 | github.com/hashicorp/go-secure-stdlib/strutil v0.1.1 8 | ) 9 | -------------------------------------------------------------------------------- /tlsutil/go.sum: -------------------------------------------------------------------------------- 1 | github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= 2 | github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= 3 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= 6 | github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 7 | github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= 8 | github.com/hashicorp/go-secure-stdlib/parseutil v0.1.1 h1:78ki3QBevHwYrVxnyVeaEz+7WtifHhauYF23es/0KlI= 9 | github.com/hashicorp/go-secure-stdlib/parseutil v0.1.1/go.mod h1:QmrqtbKuxxSWTN3ETMPuB+VtEiBJ/A9XhoYGv8E1uD8= 10 | github.com/hashicorp/go-secure-stdlib/strutil v0.1.1 h1:nd0HIW15E6FG1MsnArYaHfuw9C2zgzM8LxkG5Ty/788= 11 | github.com/hashicorp/go-secure-stdlib/strutil v0.1.1/go.mod h1:gKOamz3EwoIoJq7mlMIRBpVTAUn8qPCrEclOKKWhD3U= 12 | github.com/hashicorp/go-sockaddr v1.0.2 h1:ztczhD1jLxIRjVejw8gFomI1BQZOe2WoVOu0SyteCQc= 13 | github.com/hashicorp/go-sockaddr v1.0.2/go.mod h1:rB4wwRAUzs07qva3c5SdrY/NEtAUjGlgmH/UkBUC97A= 14 | github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= 15 | github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= 16 | github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= 17 | github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= 18 | github.com/mitchellh/mapstructure v1.4.1 h1:CpVNEelQCZBooIPDn+AR3NpivK/TIKU8bDxdASFVQag= 19 | github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= 20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 22 | github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= 23 | github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= 24 | github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= 25 | github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= 26 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 27 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 28 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 29 | golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 30 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 31 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 32 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 33 | -------------------------------------------------------------------------------- /tlsutil/tlsutil.go: -------------------------------------------------------------------------------- 1 | package tlsutil 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "errors" 7 | "fmt" 8 | "io/ioutil" 9 | "net" 10 | "strings" 11 | 12 | "github.com/hashicorp/go-secure-stdlib/parseutil" 13 | "github.com/hashicorp/go-secure-stdlib/strutil" 14 | ) 15 | 16 | var ErrInvalidCertParams = errors.New("invalid certificate parameters") 17 | 18 | // TLSLookup maps the tls_min_version configuration to the internal value 19 | var TLSLookup = map[string]uint16{ 20 | "tls10": tls.VersionTLS10, 21 | "tls11": tls.VersionTLS11, 22 | "tls12": tls.VersionTLS12, 23 | "tls13": tls.VersionTLS13, 24 | } 25 | 26 | // cipherMap maps the cipher suite names to the internal cipher suite code. 27 | var cipherMap = map[string]uint16{ 28 | "TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA, 29 | "TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, 30 | "TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, 31 | "TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, 32 | "TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256, 33 | "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, 34 | "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, 35 | "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 36 | "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 37 | "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 38 | "TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, 39 | "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 40 | "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 41 | "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 42 | "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 43 | "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 44 | "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 45 | "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 46 | "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 47 | "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 48 | "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 49 | "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 50 | "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, 51 | "TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384, 52 | "TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256, 53 | } 54 | 55 | // ParseCiphers parse ciphersuites from the comma-separated string into recognized slice 56 | func ParseCiphers(cipherStr string) ([]uint16, error) { 57 | suites := []uint16{} 58 | ciphers := strutil.ParseStringSlice(cipherStr, ",") 59 | for _, cipher := range ciphers { 60 | if v, ok := cipherMap[cipher]; ok { 61 | suites = append(suites, v) 62 | } else { 63 | return suites, fmt.Errorf("unsupported cipher %q", cipher) 64 | } 65 | } 66 | 67 | return suites, nil 68 | } 69 | 70 | // GetCipherName returns the name of a given cipher suite code or an error if the 71 | // given cipher is unsupported. 72 | func GetCipherName(cipher uint16) (string, error) { 73 | for cipherStr, cipherCode := range cipherMap { 74 | if cipherCode == cipher { 75 | return cipherStr, nil 76 | } 77 | } 78 | return "", fmt.Errorf("unsupported cipher %d", cipher) 79 | } 80 | 81 | // ClientTLSConfig parses the CA certificate, and optionally a public/private 82 | // client certificate key pair. The certificates must be in PEM encoded format. 83 | func ClientTLSConfig(caCert []byte, clientCert []byte, clientKey []byte) (*tls.Config, error) { 84 | var tlsConfig *tls.Config 85 | var pool *x509.CertPool 86 | 87 | switch { 88 | case len(caCert) != 0: 89 | // Valid 90 | case len(clientCert) != 0 && len(clientKey) != 0: 91 | // Valid 92 | default: 93 | return nil, ErrInvalidCertParams 94 | } 95 | 96 | if len(caCert) != 0 { 97 | pool = x509.NewCertPool() 98 | pool.AppendCertsFromPEM(caCert) 99 | } 100 | 101 | tlsConfig = &tls.Config{ 102 | RootCAs: pool, 103 | ClientAuth: tls.RequireAndVerifyClientCert, 104 | MinVersion: tls.VersionTLS12, 105 | } 106 | 107 | var cert tls.Certificate 108 | var err error 109 | if len(clientCert) != 0 && len(clientKey) != 0 { 110 | cert, err = tls.X509KeyPair(clientCert, clientKey) 111 | if err != nil { 112 | return nil, err 113 | } 114 | tlsConfig.Certificates = []tls.Certificate{cert} 115 | } 116 | tlsConfig.BuildNameToCertificate() 117 | 118 | return tlsConfig, nil 119 | } 120 | 121 | // LoadClientTLSConfig loads and parse the CA certificate, and optionally a 122 | // public/private client certificate key pair. The certificates must be in PEM 123 | // encoded format. 124 | func LoadClientTLSConfig(caCert, clientCert, clientKey string) (*tls.Config, error) { 125 | var tlsConfig *tls.Config 126 | var pool *x509.CertPool 127 | 128 | switch { 129 | case len(caCert) != 0: 130 | // Valid 131 | case len(clientCert) != 0 && len(clientKey) != 0: 132 | // Valid 133 | default: 134 | return nil, ErrInvalidCertParams 135 | } 136 | 137 | if len(caCert) != 0 { 138 | pool = x509.NewCertPool() 139 | 140 | data, err := ioutil.ReadFile(caCert) 141 | if err != nil { 142 | return nil, fmt.Errorf("failed to read CA file: %w", err) 143 | } 144 | 145 | if !pool.AppendCertsFromPEM(data) { 146 | return nil, fmt.Errorf("failed to parse CA certificate") 147 | } 148 | } 149 | 150 | tlsConfig = &tls.Config{ 151 | RootCAs: pool, 152 | ClientAuth: tls.RequireAndVerifyClientCert, 153 | MinVersion: tls.VersionTLS12, 154 | } 155 | 156 | var cert tls.Certificate 157 | var err error 158 | if len(clientCert) != 0 && len(clientKey) != 0 { 159 | cert, err = tls.LoadX509KeyPair(clientCert, clientKey) 160 | if err != nil { 161 | return nil, err 162 | } 163 | tlsConfig.Certificates = []tls.Certificate{cert} 164 | } 165 | tlsConfig.BuildNameToCertificate() 166 | 167 | return tlsConfig, nil 168 | } 169 | 170 | func SetupTLSConfig(conf map[string]string, address string) (*tls.Config, error) { 171 | serverName, _, err := net.SplitHostPort(address) 172 | switch { 173 | case err == nil: 174 | case strings.Contains(err.Error(), "missing port"): 175 | serverName = conf["address"] 176 | default: 177 | return nil, err 178 | } 179 | 180 | insecureSkipVerify := false 181 | tlsSkipVerify := conf["tls_skip_verify"] 182 | 183 | if tlsSkipVerify != "" { 184 | b, err := parseutil.ParseBool(tlsSkipVerify) 185 | if err != nil { 186 | return nil, fmt.Errorf("failed parsing tls_skip_verify parameter: %w", err) 187 | } 188 | insecureSkipVerify = b 189 | } 190 | 191 | tlsMinVersionStr, ok := conf["tls_min_version"] 192 | if !ok { 193 | // Set the default value 194 | tlsMinVersionStr = "tls12" 195 | } 196 | 197 | tlsMinVersion, ok := TLSLookup[tlsMinVersionStr] 198 | if !ok { 199 | return nil, fmt.Errorf("invalid 'tls_min_version'") 200 | } 201 | 202 | tlsClientConfig := &tls.Config{ 203 | MinVersion: tlsMinVersion, 204 | InsecureSkipVerify: insecureSkipVerify, 205 | ServerName: serverName, 206 | } 207 | 208 | _, okCert := conf["tls_cert_file"] 209 | _, okKey := conf["tls_key_file"] 210 | 211 | if okCert && okKey { 212 | tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"]) 213 | if err != nil { 214 | return nil, fmt.Errorf("client tls setup failed: %w", err) 215 | } 216 | 217 | tlsClientConfig.Certificates = []tls.Certificate{tlsCert} 218 | } else if okCert || okKey { 219 | return nil, fmt.Errorf("both tls_cert_file and tls_key_file must be provided") 220 | } 221 | 222 | if tlsCaFile, ok := conf["tls_ca_file"]; ok { 223 | caPool := x509.NewCertPool() 224 | 225 | data, err := ioutil.ReadFile(tlsCaFile) 226 | if err != nil { 227 | return nil, fmt.Errorf("failed to read CA file: %w", err) 228 | } 229 | 230 | if !caPool.AppendCertsFromPEM(data) { 231 | return nil, fmt.Errorf("failed to parse CA certificate") 232 | } 233 | 234 | tlsClientConfig.RootCAs = caPool 235 | } 236 | return tlsClientConfig, nil 237 | } 238 | -------------------------------------------------------------------------------- /tlsutil/tlsutil_test.go: -------------------------------------------------------------------------------- 1 | package tlsutil 2 | 3 | import ( 4 | "crypto/tls" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestParseCiphers(t *testing.T) { 10 | testOk := "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_128_GCM_SHA256,TLS_RSA_WITH_AES_256_CBC_SHA,TLS_RSA_WITH_AES_256_GCM_SHA384,TLS_RSA_WITH_AES_128_CBC_SHA256,TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305" 11 | v, err := ParseCiphers(testOk) 12 | if err != nil { 13 | t.Fatal(err) 14 | } 15 | if len(v) != 17 { 16 | t.Fatal("missed ciphers after parse") 17 | } 18 | 19 | testBad := "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,cipherX" 20 | if _, err := ParseCiphers(testBad); err == nil { 21 | t.Fatal("should fail on unsupported cipherX") 22 | } 23 | 24 | testOrder := "TLS_RSA_WITH_AES_256_GCM_SHA384,TLS_RSA_WITH_AES_128_GCM_SHA256" 25 | v, _ = ParseCiphers(testOrder) 26 | expected := []uint16{tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_128_GCM_SHA256} 27 | if !reflect.DeepEqual(expected, v) { 28 | t.Fatal("cipher order is not preserved") 29 | } 30 | } 31 | 32 | func TestGetCipherName(t *testing.T) { 33 | testOkCipherStr := "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA" 34 | testOkCipher := tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA 35 | cipherStr, err := GetCipherName(testOkCipher) 36 | if err != nil { 37 | t.Fatal(err) 38 | } 39 | if cipherStr != testOkCipherStr { 40 | t.Fatalf("cipher string should be %s but is %s", testOkCipherStr, cipherStr) 41 | } 42 | 43 | var testBadCipher uint16 = 0xC022 44 | cipherStr, err = GetCipherName(testBadCipher) 45 | if err == nil { 46 | t.Fatal("should fail on unsupported cipher 0xC022") 47 | } 48 | } 49 | --------------------------------------------------------------------------------