├── .github ├── CODEOWNERS ├── dependabot.yml └── workflows │ ├── ci.yml │ └── codeql-analysis.yml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── buffer_period.go ├── buffer_period_test.go ├── dep ├── dependency.go ├── doc.go ├── errors.go └── template_function_types.go ├── doc.go ├── doc_test.go ├── events ├── events.go └── events_test.go ├── file_perms.go ├── file_perms_windows.go ├── go.mod ├── go.sum ├── internal ├── dependency │ ├── catalog_datacenters.go │ ├── catalog_datacenters_test.go │ ├── catalog_node.go │ ├── catalog_node_test.go │ ├── catalog_nodes.go │ ├── catalog_nodes_test.go │ ├── catalog_service.go │ ├── catalog_service_test.go │ ├── catalog_services.go │ ├── catalog_services_test.go │ ├── client_set.go │ ├── client_set_test.go │ ├── connect_ca.go │ ├── connect_ca_test.go │ ├── connect_leaf.go │ ├── connect_leaf_test.go │ ├── consul_common_test.go │ ├── dependency.go │ ├── dependency_test.go │ ├── errors.go │ ├── fakedep.go │ ├── file.go │ ├── file_test.go │ ├── health_service.go │ ├── health_service_test.go │ ├── kv_exists.go │ ├── kv_exists_get.go │ ├── kv_exists_get_test.go │ ├── kv_exists_test.go │ ├── kv_get.go │ ├── kv_get_test.go │ ├── kv_keys.go │ ├── kv_keys_test.go │ ├── kv_list.go │ ├── kv_list_test.go │ ├── testdata │ │ ├── cert.pem │ │ └── key.pem │ ├── vault_agent_token.go │ ├── vault_agent_token_test.go │ ├── vault_common.go │ ├── vault_common_test.go │ ├── vault_list.go │ ├── vault_list_test.go │ ├── vault_read.go │ ├── vault_read_test.go │ ├── vault_token.go │ ├── vault_token_test.go │ ├── vault_write.go │ └── vault_write_test.go └── test │ └── helpers.go ├── looker.go ├── looker_test.go ├── main_test.go ├── renderer.go ├── renderer_test.go ├── resolver.go ├── resolver_test.go ├── sets.go ├── store.go ├── store_test.go ├── template.go ├── template_test.go ├── testdata └── sandbox │ └── path │ └── to │ ├── bad-symlink │ ├── file │ └── ok-symlink ├── tfunc ├── consul_filter.go ├── consul_filter_test.go ├── consul_v0.go ├── consul_v0_test.go ├── consul_v1.go ├── consul_v1_test.go ├── contains.go ├── contains_test.go ├── deny.go ├── deny_test.go ├── env.go ├── env_test.go ├── file.go ├── file_test.go ├── loop.go ├── loop_test.go ├── maps.go ├── maps_test.go ├── math.go ├── math_test.go ├── parse.go ├── parse_test.go ├── sockaddr.go ├── sockaddr_test.go ├── string.go ├── string_test.go ├── tfunc.go ├── tfunc_test.go ├── time.go ├── time_test.go ├── transform.go ├── transform_test.go ├── vault.go └── vault_test.go ├── vaulttoken ├── main_test.go ├── notifier.go ├── vault_agent_token.go ├── vault_agent_token_test.go ├── vault_token.go ├── vault_token_test.go ├── watcher.go └── watcher_test.go ├── view.go ├── view_test.go ├── watcher.go └── watcher_test.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # default PR reviews to the team 2 | * @hashicorp/consul-selfmanage-maintainers 3 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "gomod" 4 | directory: "/" 5 | # 0 here disables PRs, 5 is default 6 | open-pull-requests-limit: 5 7 | schedule: 8 | interval: "daily" 9 | time: "06:05" 10 | - package-ecosystem: "github-actions" 11 | directory: "/" 12 | schedule: 13 | interval: "daily" 14 | time: "06:05" 15 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | run-tests: 9 | name: Run test cases 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | os: [ubuntu-latest] 14 | go: [^1] 15 | 16 | steps: 17 | - uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac # v4.0.0 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@93397bea11091df50f3d7e59dc26a7711a8bcfbe # v4.1.0 21 | with: 22 | go-version: ${{ matrix.go }} 23 | 24 | - name: Install Consul and Vault for integration testing 25 | run: | 26 | curl -fsSL https://apt.releases.hashicorp.com/gpg | sudo apt-key add - 27 | sudo apt-add-repository "deb [arch=amd64] https://apt.releases.hashicorp.com $(lsb_release -cs) main" 28 | sudo apt-get update && sudo apt-get install consul vault nomad 29 | 30 | - name: Run tests 31 | run: | 32 | go test -race ./... 33 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | name: "CodeQL" 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | # The branches below must be a subset of the branches above 8 | branches: [ main ] 9 | schedule: 10 | - cron: '42 01 * * 6' 11 | 12 | jobs: 13 | analyze: 14 | name: Analyze 15 | runs-on: ubuntu-latest 16 | permissions: 17 | actions: read 18 | contents: read 19 | security-events: write 20 | 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | language: [ 'go' ] 25 | # More: https://aka.ms/codeql-docs/language-support 26 | 27 | steps: 28 | - name: Checkout repository 29 | uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac # v4.0.0 30 | 31 | - name: Initialize CodeQL 32 | uses: github/codeql-action/init@e4262713b504983e61c7728f5452be240d9385a7 # codeql-bundle-v2.14.3 33 | with: 34 | languages: ${{ matrix.language }} 35 | # If you wish to specify custom queries, you can do so here or in 36 | # a config file. By default, queries listed here will override any 37 | # specified in a config file. Prefix the list here with "+" to use 38 | # these queries and those in the config file. 39 | 40 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 41 | # queries: security-extended,security-and-quality 42 | 43 | 44 | # compile? 45 | - name: Autobuild 46 | uses: github/codeql-action/autobuild@e4262713b504983e61c7728f5452be240d9385a7 # codeql-bundle-v2.14.3 47 | 48 | - name: Perform CodeQL Analysis 49 | uses: github/codeql-action/analyze@e4262713b504983e61c7728f5452be240d9385a7 # codeql-bundle-v2.14.3 50 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Hashicat CHANGELOG 2 | 3 | ## v0.2.0 (Dec 07, 2021) 4 | 5 | IMPROVEMENTS: 6 | 7 | * Move all template functions into template library [[GH-82](https://github.com/hashicorp/hcat/pull/82), [GH-36](https://github.com/hashicorp/hcat/issues/36)] 8 | * Port all relevant consul-template updates [[GH-81](https://github.com/hashicorp/hcat/pull/81), [GH-69](https://github.com/hashicorp/hcat/issues/69)] 9 | 10 | ## v0.1.0 (Nov 11, 2021) 11 | 12 | IMPROVEMENTS: 13 | 14 | * Enable application side logging of runtime [[GH-77](https://github.com/hashicorp/hcat/pull/77), [GH-68](https://github.com/hashicorp/hcat/issues/68)] 15 | 16 | -------------------------------------------------------------------------------- /dep/dependency.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dep 5 | 6 | import ( 7 | "fmt" 8 | "time" 9 | 10 | consulapi "github.com/hashicorp/consul/api" 11 | vaultapi "github.com/hashicorp/vault/api" 12 | ) 13 | 14 | // Dependency is an interface for an external dependency to be monitored. 15 | type Dependency interface { 16 | Fetch(Clients) (interface{}, *ResponseMetadata, error) 17 | ID() string 18 | Stop() 19 | fmt.Stringer 20 | } 21 | 22 | // Clients interface for the API clients used for external dependency calls. 23 | type Clients interface { 24 | Consul() *consulapi.Client 25 | Vault() *vaultapi.Client 26 | } 27 | 28 | // Metadata returned by external dependency Fetch-ing. 29 | // LastIndex is used with the Consul backend. Needed to track changes. 30 | // LastContact is used to help calculate staleness of records. 31 | type ResponseMetadata struct { 32 | LastIndex uint64 33 | LastContact time.Duration 34 | } 35 | -------------------------------------------------------------------------------- /dep/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | /* 5 | 6 | Public Dependency type information. 7 | 8 | This sub-package contains all the required types needed to implement an 9 | external dependency as used by this library. 10 | 11 | It is a sub-package as all the current dependency implentations are contained 12 | in an internal/ package as it needs some significant refactoring that shouldn't 13 | interfere with the initial release. 14 | 15 | */ 16 | package dep 17 | -------------------------------------------------------------------------------- /dep/errors.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dep 5 | 6 | import "errors" 7 | 8 | // ErrStopped is a special error that is returned when a dependency is 9 | // prematurely stopped, usually due to a configuration reload or a process 10 | // interrupt. 11 | var ErrStopped = errors.New("dependency stopped") 12 | 13 | // ErrContinue is a special error which says to continue (retry) on error. 14 | var ErrContinue = errors.New("dependency continue") 15 | 16 | var ErrLeaseExpired = errors.New("lease expired or is not renewable") 17 | -------------------------------------------------------------------------------- /dep/template_function_types.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dep 5 | 6 | import ( 7 | "time" 8 | 9 | "github.com/hashicorp/consul/api" 10 | ) 11 | 12 | // Node is a node entry in Consul 13 | type Node struct { 14 | ID string 15 | Node string 16 | Address string 17 | Datacenter string 18 | TaggedAddresses map[string]string 19 | Meta map[string]string 20 | } 21 | 22 | // CatalogNode is a wrapper around the node and its services. 23 | type CatalogNode struct { 24 | Node *Node 25 | Services []*CatalogNodeService 26 | } 27 | 28 | // ServiceTags is a slice of tags assigned to a Service 29 | type ServiceTags []string 30 | 31 | // CatalogNodeService is a service on a single node. 32 | type CatalogNodeService struct { 33 | ID string 34 | Service string 35 | Tags ServiceTags 36 | Meta map[string]string 37 | Port int 38 | Address string 39 | EnableTagOverride bool 40 | } 41 | 42 | // CatalogSnippet is a catalog entry in Consul. 43 | type CatalogSnippet struct { 44 | Name string 45 | Tags ServiceTags 46 | } 47 | 48 | // HealthService is a service entry in Consul. 49 | type HealthService struct { 50 | Node string 51 | NodeID string 52 | NodeAddress string 53 | NodeDatacenter string 54 | NodeTaggedAddresses map[string]string 55 | NodeMeta map[string]string 56 | ServiceMeta map[string]string 57 | Address string 58 | ServiceTaggedAddresses map[string]api.ServiceAddress 59 | ID string 60 | Name string 61 | Kind string 62 | Tags ServiceTags 63 | Checks api.HealthChecks 64 | Status string 65 | Port int 66 | Weights api.AgentWeights 67 | Namespace string 68 | Proxy *api.AgentServiceConnectProxyConfig 69 | } 70 | 71 | // KvValue is here to type the KV return string 72 | type KvValue string 73 | 74 | type KVExists bool 75 | 76 | // KeyPair is a simple Key-Value pair 77 | type KeyPair struct { 78 | Path string 79 | Key string 80 | Value string 81 | Exists bool 82 | 83 | // Lesser-used, but still valuable keys from api.KV 84 | CreateIndex uint64 85 | ModifyIndex uint64 86 | LockIndex uint64 87 | Flags uint64 88 | Session string 89 | } 90 | 91 | // Secret is the structure returned for every secret within Vault. 92 | type Secret struct { 93 | // The request ID that generated this response 94 | RequestID string 95 | 96 | LeaseID string 97 | LeaseDuration int 98 | Renewable bool 99 | 100 | // Data is the actual contents of the secret. The format of the data 101 | // is arbitrary and up to the secret backend. 102 | Data map[string]interface{} 103 | 104 | // Warnings contains any warnings related to the operation. These 105 | // are not issues that caused the command to fail, but that the 106 | // client should be aware of. 107 | Warnings []string 108 | 109 | // Auth, if non-nil, means that there was authentication information 110 | // attached to this response. 111 | Auth *SecretAuth 112 | 113 | // WrapInfo, if non-nil, means that the initial response was wrapped in the 114 | // cubbyhole of the given token (which has a TTL of the given number of 115 | // seconds) 116 | WrapInfo *SecretWrapInfo 117 | } 118 | 119 | // SecretAuth is the structure containing auth information if we have it. 120 | type SecretAuth struct { 121 | ClientToken string 122 | Accessor string 123 | Policies []string 124 | Metadata map[string]string 125 | 126 | LeaseDuration int 127 | Renewable bool 128 | } 129 | 130 | // SecretWrapInfo contains wrapping information if we have it. If what is 131 | // contained is an authentication token, the accessor for the token will be 132 | // available in WrappedAccessor. 133 | type SecretWrapInfo struct { 134 | Token string 135 | TTL int 136 | CreationTime time.Time 137 | WrappedAccessor string 138 | } 139 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | /* 5 | The Hashicat library. 6 | 7 | This library provides a means to fetch data managed by external services and 8 | render templates using that data. It also enables monitoring those services for 9 | data changes to trigger updates to the templates. 10 | 11 | A simple example of how you might use this library to generate the contents of 12 | a single template, waiting for all its dependencies (external data) to be 13 | fetched and filled in, then have that content returned. 14 | */ 15 | package hcat 16 | -------------------------------------------------------------------------------- /events/events.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package events 5 | 6 | import "time" 7 | 8 | // EventHandler is the interface of the call back function for receiveing events. 9 | type EventHandler func(Event) 10 | 11 | // Event is used to type restrict the Events 12 | type Event interface { 13 | isEvent() 14 | } 15 | 16 | // Trace is useful to see some details of what's going on 17 | type Trace struct { 18 | event 19 | ID string 20 | Message string 21 | } 22 | 23 | // BlockingWait means a blocking query was made 24 | type BlockingWait struct { 25 | event 26 | ID string 27 | } 28 | 29 | // ServerContacted indicates that the tracked service has been successfully 30 | // contacted (received a non-error response). 31 | type ServerContacted struct { 32 | event 33 | ID string 34 | } 35 | 36 | // ServerError indicates that an tracked service has been contacted but with 37 | // an error returned. 38 | type ServerError struct { 39 | event 40 | Error error 41 | ID string 42 | } 43 | 44 | // ServerTimeout indicates that a call to the server timed out. 45 | type ServerTimeout struct { 46 | event 47 | ID string 48 | } 49 | 50 | // RetryAttempt indicates that a tracked call is being retried. 51 | type RetryAttempt struct { 52 | event 53 | Error error 54 | ID string 55 | Attempt int 56 | Sleep time.Duration 57 | } 58 | 59 | // MaxRetries indicates that the maximum number of retries has been reached 60 | // (and failed). 61 | type MaxRetries struct { 62 | event 63 | ID string 64 | Count int 65 | } 66 | 67 | // NewData indicates that fresh/new data has been retrieved from the service. 68 | type NewData struct { 69 | event 70 | Data interface{} 71 | ID string 72 | } 73 | 74 | // StaleData indicates that the service returned stale (possibly old) data. 75 | type StaleData struct { 76 | event 77 | ID string 78 | LastContant time.Duration 79 | } 80 | 81 | // NoNewData indicates that data was retrieved from the service, but that it 82 | // matches the current data so no change would be triggered. 83 | type NoNewData struct { 84 | event 85 | ID string 86 | } 87 | 88 | // TrackStart indicates that a new data point is being tracked. 89 | type TrackStart struct { 90 | event 91 | ID string 92 | } 93 | 94 | // TrackStop indicates that a data point is no longer being tracked. 95 | type TrackStop struct { 96 | event 97 | ID string 98 | } 99 | 100 | // Not used yet, need an PolllingQuery interface to match on 101 | // see BlockingQuery for how it should work 102 | type PollingWait struct { 103 | event 104 | ID string 105 | Duration time.Duration 106 | } 107 | 108 | // Event interface type fulfillment 109 | type event struct{} 110 | 111 | func (event) isEvent() {} 112 | -------------------------------------------------------------------------------- /events/events_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package events 5 | 6 | import ( 7 | "testing" 8 | ) 9 | 10 | var ( 11 | _ Event = (*Trace)(nil) 12 | _ Event = (*BlockingWait)(nil) 13 | _ Event = (*ServerContacted)(nil) 14 | _ Event = (*ServerError)(nil) 15 | _ Event = (*ServerTimeout)(nil) 16 | _ Event = (*RetryAttempt)(nil) 17 | _ Event = (*MaxRetries)(nil) 18 | _ Event = (*NewData)(nil) 19 | _ Event = (*StaleData)(nil) 20 | _ Event = (*NoNewData)(nil) 21 | _ Event = (*TrackStart)(nil) 22 | _ Event = (*TrackStop)(nil) 23 | _ Event = (*PollingWait)(nil) 24 | ) 25 | 26 | func TestEvents(t *testing.T) { 27 | var event EventHandler 28 | event = func(e Event) { 29 | switch e.(type) { 30 | case Trace, BlockingWait, ServerContacted, ServerError, 31 | ServerTimeout, RetryAttempt, MaxRetries, NewData, StaleData, 32 | NoNewData, TrackStart, TrackStop, PollingWait: 33 | default: 34 | t.Errorf("Bad event type: %T", e) 35 | } 36 | } 37 | event(Trace{}) 38 | event(MaxRetries{}) 39 | event(TrackStop{}) 40 | } 41 | -------------------------------------------------------------------------------- /file_perms.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | //+build !windows 5 | 6 | package hcat 7 | 8 | import ( 9 | "os" 10 | "syscall" 11 | ) 12 | 13 | func preserveFilePermissions(path string, fileInfo os.FileInfo) error { 14 | sysInfo := fileInfo.Sys() 15 | if sysInfo != nil { 16 | stat, ok := sysInfo.(*syscall.Stat_t) 17 | if ok { 18 | if err := os.Chown(path, int(stat.Uid), int(stat.Gid)); err != nil { 19 | return err 20 | } 21 | } 22 | } 23 | 24 | return nil 25 | } 26 | -------------------------------------------------------------------------------- /file_perms_windows.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | //+build windows 5 | 6 | package hcat 7 | 8 | import "os" 9 | 10 | func preserveFilePermissions(path string, fileInfo os.FileInfo) error { 11 | return nil 12 | } 13 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/hcat 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/BurntSushi/toml v1.2.1 7 | github.com/hashicorp/consul/api v1.17.0 8 | github.com/hashicorp/consul/sdk v0.13.0 9 | github.com/hashicorp/go-bexpr v0.1.11 10 | github.com/hashicorp/go-rootcerts v1.0.2 11 | github.com/hashicorp/go-sockaddr v1.0.2 12 | github.com/hashicorp/vault/api v1.0.5-0.20190730042357-746c0b111519 13 | github.com/imdario/mergo v0.3.13 14 | github.com/pkg/errors v0.9.1 15 | github.com/stretchr/testify v1.8.1 16 | gopkg.in/yaml.v2 v2.4.0 17 | ) 18 | 19 | require ( 20 | github.com/armon/go-metrics v0.3.10 // indirect 21 | github.com/davecgh/go-spew v1.1.1 // indirect 22 | github.com/fatih/color v1.9.0 // indirect 23 | github.com/frankban/quicktest v1.4.0 // indirect 24 | github.com/golang/snappy v0.0.1 // indirect 25 | github.com/hashicorp/errwrap v1.0.0 // indirect 26 | github.com/hashicorp/go-cleanhttp v0.5.1 // indirect 27 | github.com/hashicorp/go-hclog v0.14.1 // indirect 28 | github.com/hashicorp/go-immutable-radix v1.3.0 // indirect 29 | github.com/hashicorp/go-multierror v1.1.1 // indirect 30 | github.com/hashicorp/go-retryablehttp v0.6.6 // indirect 31 | github.com/hashicorp/go-uuid v1.0.2 // indirect 32 | github.com/hashicorp/go-version v1.2.1 // indirect 33 | github.com/hashicorp/golang-lru v0.5.4 // indirect 34 | github.com/hashicorp/hcl v1.0.0 // indirect 35 | github.com/hashicorp/serf v0.10.1 // indirect 36 | github.com/hashicorp/vault/sdk v0.1.14-0.20190730042320-0dc007d98cc8 // indirect 37 | github.com/mattn/go-colorable v0.1.6 // indirect 38 | github.com/mattn/go-isatty v0.0.12 // indirect 39 | github.com/mitchellh/go-homedir v1.1.0 // indirect 40 | github.com/mitchellh/mapstructure v1.4.1 // indirect 41 | github.com/mitchellh/pointerstructure v1.2.1 // indirect 42 | github.com/pierrec/lz4 v2.5.2+incompatible // indirect 43 | github.com/pmezard/go-difflib v1.0.0 // indirect 44 | github.com/ryanuber/go-glob v1.0.0 // indirect 45 | golang.org/x/crypto v0.0.0-20200429183012-4b2356b1ed79 // indirect 46 | golang.org/x/net v0.0.0-20211216030914-fe4d6282115f // indirect 47 | golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect 48 | golang.org/x/text v0.3.8 // indirect 49 | golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 // indirect 50 | gopkg.in/square/go-jose.v2 v2.5.1 // indirect 51 | gopkg.in/yaml.v3 v3.0.1 // indirect 52 | ) 53 | -------------------------------------------------------------------------------- /internal/dependency/catalog_datacenters.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "sort" 8 | "time" 9 | 10 | "github.com/hashicorp/consul/api" 11 | "github.com/hashicorp/hcat/dep" 12 | "github.com/pkg/errors" 13 | ) 14 | 15 | var ( 16 | // Ensure implements 17 | _ isDependency = (*CatalogDatacentersQuery)(nil) 18 | 19 | // CatalogDatacentersQuerySleepTime is the amount of time to sleep between 20 | // queries, since the endpoint does not support blocking queries. 21 | CatalogDatacentersQuerySleepTime = 15 * time.Second 22 | ) 23 | 24 | // CatalogDatacentersQuery is the dependency to query all datacenters 25 | type CatalogDatacentersQuery struct { 26 | isConsul 27 | ignoreFailing bool 28 | stopCh chan struct{} 29 | opts QueryOptions 30 | } 31 | 32 | // NewCatalogDatacentersQuery creates a new datacenter dependency. 33 | func NewCatalogDatacentersQuery(ignoreFailing bool) (*CatalogDatacentersQuery, error) { 34 | return &CatalogDatacentersQuery{ 35 | ignoreFailing: ignoreFailing, 36 | stopCh: make(chan struct{}, 1), 37 | }, nil 38 | } 39 | 40 | // Fetch queries the Consul API defined by the given client and returns a slice 41 | // of strings representing the datacenters 42 | func (d *CatalogDatacentersQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 43 | opts := d.opts.Merge(&QueryOptions{}) 44 | 45 | // This is pretty ghetto, but the datacenters endpoint does not support 46 | // blocking queries, so we are going to "fake it until we make it". When we 47 | // first query, the LastIndex will be "0", meaning we should immediately 48 | // return data, but future calls will include a LastIndex. If we have a 49 | // LastIndex in the query metadata, sleep for 15 seconds before asking Consul 50 | // again. 51 | // 52 | // This is probably okay given the frequency in which datacenters actually 53 | // change, but is technically not edge-triggering. 54 | if opts.WaitIndex != 0 { 55 | select { 56 | case <-d.stopCh: 57 | return nil, nil, ErrStopped 58 | case <-time.After(CatalogDatacentersQuerySleepTime): 59 | } 60 | } 61 | 62 | result, err := clients.Consul().Catalog().Datacenters() 63 | if err != nil { 64 | return nil, nil, errors.Wrapf(err, d.ID()) 65 | } 66 | 67 | // If the user opted in for skipping "down" datacenters, figure out which 68 | // datacenters are down. 69 | if d.ignoreFailing { 70 | dcs := make([]string, 0, len(result)) 71 | for _, dc := range result { 72 | if _, _, err := clients.Consul().Catalog().Services(&api.QueryOptions{ 73 | Datacenter: dc, 74 | AllowStale: false, 75 | RequireConsistent: true, 76 | }); err == nil { 77 | dcs = append(dcs, dc) 78 | } 79 | } 80 | result = dcs 81 | } 82 | 83 | sort.Strings(result) 84 | 85 | return respWithMetadata(result) 86 | } 87 | 88 | // CanShare returns if this dependency is shareable. 89 | func (d *CatalogDatacentersQuery) CanShare() bool { 90 | return true 91 | } 92 | 93 | // ID returns the human-friendly version of this dependency. 94 | func (d *CatalogDatacentersQuery) ID() string { 95 | return "catalog.datacenters" 96 | } 97 | 98 | // Stringer interface reuses ID 99 | func (d *CatalogDatacentersQuery) String() string { 100 | return d.ID() 101 | } 102 | 103 | // Stop terminates this dependency's fetch. 104 | func (d *CatalogDatacentersQuery) Stop() { 105 | close(d.stopCh) 106 | } 107 | 108 | func (d *CatalogDatacentersQuery) SetOptions(opts QueryOptions) { 109 | d.opts = opts 110 | } 111 | -------------------------------------------------------------------------------- /internal/dependency/catalog_datacenters_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func init() { 15 | CatalogDatacentersQuerySleepTime = 50 * time.Millisecond 16 | } 17 | 18 | func TestNewCatalogDatacentersQuery(t *testing.T) { 19 | t.Parallel() 20 | 21 | cases := []struct { 22 | name string 23 | exp *CatalogDatacentersQuery 24 | err bool 25 | }{ 26 | { 27 | "empty", 28 | &CatalogDatacentersQuery{}, 29 | false, 30 | }, 31 | } 32 | 33 | for i, tc := range cases { 34 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 35 | act, err := NewCatalogDatacentersQuery(false) 36 | if (err != nil) != tc.err { 37 | t.Fatal(err) 38 | } 39 | 40 | if act != nil { 41 | act.stopCh = nil 42 | } 43 | 44 | assert.Equal(t, tc.exp, act) 45 | }) 46 | } 47 | } 48 | 49 | func TestCatalogDatacentersQuery_Fetch(t *testing.T) { 50 | t.Parallel() 51 | 52 | cases := []struct { 53 | name string 54 | exp []string 55 | }{ 56 | { 57 | "default", 58 | []string{"dc1"}, 59 | }, 60 | } 61 | 62 | for i, tc := range cases { 63 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 64 | d, err := NewCatalogDatacentersQuery(false) 65 | if err != nil { 66 | t.Fatal(err) 67 | } 68 | 69 | act, _, err := d.Fetch(testClients) 70 | if err != nil { 71 | t.Fatal(err) 72 | } 73 | 74 | assert.Equal(t, tc.exp, act) 75 | }) 76 | } 77 | 78 | t.Run("stops", func(t *testing.T) { 79 | d, err := NewCatalogDatacentersQuery(false) 80 | if err != nil { 81 | t.Fatal(err) 82 | } 83 | 84 | dataCh := make(chan interface{}, 1) 85 | errCh := make(chan error, 1) 86 | go func() { 87 | for { 88 | d.SetOptions(QueryOptions{WaitIndex: 10}) 89 | data, _, err := d.Fetch(testClients) 90 | if err != nil { 91 | errCh <- err 92 | return 93 | } 94 | dataCh <- data 95 | } 96 | }() 97 | 98 | select { 99 | case err := <-errCh: 100 | t.Fatal(err) 101 | case <-dataCh: 102 | } 103 | 104 | d.Stop() 105 | 106 | select { 107 | case err := <-errCh: 108 | if err != ErrStopped { 109 | t.Fatal(err) 110 | } 111 | case <-time.After(100 * time.Millisecond): 112 | t.Errorf("did not stop") 113 | } 114 | }) 115 | 116 | t.Run("fires_changes", func(t *testing.T) { 117 | d, err := NewCatalogDatacentersQuery(false) 118 | if err != nil { 119 | t.Fatal(err) 120 | } 121 | 122 | //_, qm, err := d.Fetch(testClients) 123 | _, _, err = d.Fetch(testClients) 124 | if err != nil { 125 | t.Fatal(err) 126 | } 127 | 128 | dataCh := make(chan interface{}, 1) 129 | errCh := make(chan error, 1) 130 | go func() { 131 | for { 132 | //data, _, err := d.Fetch(testClients, &QueryOptions{WaitIndex: qm.LastIndex}) 133 | data, _, err := d.Fetch(testClients) 134 | if err != nil { 135 | errCh <- err 136 | return 137 | } 138 | dataCh <- data 139 | return 140 | } 141 | }() 142 | 143 | select { 144 | case err := <-errCh: 145 | t.Fatal(err) 146 | case <-dataCh: 147 | } 148 | }) 149 | } 150 | 151 | func TestCatalogDatacentersQuery_String(t *testing.T) { 152 | t.Parallel() 153 | 154 | cases := []struct { 155 | name string 156 | exp string 157 | }{ 158 | { 159 | "empty", 160 | "catalog.datacenters", 161 | }, 162 | } 163 | 164 | for i, tc := range cases { 165 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 166 | d, err := NewCatalogDatacentersQuery(false) 167 | if err != nil { 168 | t.Fatal(err) 169 | } 170 | assert.Equal(t, tc.exp, d.ID()) 171 | }) 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /internal/dependency/catalog_node.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "encoding/gob" 8 | "fmt" 9 | "regexp" 10 | "sort" 11 | 12 | "github.com/hashicorp/hcat/dep" 13 | "github.com/pkg/errors" 14 | ) 15 | 16 | var ( 17 | // Ensure implements 18 | _ isDependency = (*CatalogNodeQuery)(nil) 19 | 20 | // CatalogNodeQueryRe is the regular expression to use. 21 | CatalogNodeQueryRe = regexp.MustCompile(`\A` + nodeNameRe + dcRe + `\z`) 22 | ) 23 | 24 | func init() { 25 | gob.Register([]*dep.CatalogNode{}) 26 | gob.Register([]*dep.CatalogNodeService{}) 27 | } 28 | 29 | // CatalogNodeQuery represents a single node from the Consul catalog. 30 | type CatalogNodeQuery struct { 31 | isConsul 32 | stopCh chan struct{} 33 | 34 | dc string 35 | name string 36 | opts QueryOptions 37 | } 38 | 39 | // NewCatalogNodeQuery parses the given string into a dependency. If the name is 40 | // empty then the name of the local agent is used. 41 | func NewCatalogNodeQuery(s string) (*CatalogNodeQuery, error) { 42 | if s != "" && !CatalogNodeQueryRe.MatchString(s) { 43 | return nil, fmt.Errorf("catalog.node: invalid format: %q", s) 44 | } 45 | 46 | m := regexpMatch(CatalogNodeQueryRe, s) 47 | return &CatalogNodeQuery{ 48 | dc: m["dc"], 49 | name: m["name"], 50 | stopCh: make(chan struct{}, 1), 51 | }, nil 52 | } 53 | 54 | // Fetch queries the Consul API defined by the given client and returns a 55 | // of CatalogNode object. 56 | func (d *CatalogNodeQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 57 | select { 58 | case <-d.stopCh: 59 | return nil, nil, ErrStopped 60 | default: 61 | } 62 | 63 | opts := d.opts.Merge(&QueryOptions{ 64 | Datacenter: d.dc, 65 | }) 66 | 67 | // Grab the name 68 | name := d.name 69 | 70 | if name == "" { 71 | var err error 72 | name, err = clients.Consul().Agent().NodeName() 73 | if err != nil { 74 | return nil, nil, errors.Wrapf(err, d.ID()) 75 | } 76 | } 77 | 78 | node, qm, err := clients.Consul().Catalog().Node(name, opts.ToConsulOpts()) 79 | if err != nil { 80 | return nil, nil, errors.Wrap(err, d.ID()) 81 | } 82 | 83 | rm := &dep.ResponseMetadata{ 84 | LastIndex: qm.LastIndex, 85 | LastContact: qm.LastContact, 86 | } 87 | 88 | if node == nil { 89 | var node dep.CatalogNode 90 | return &node, rm, nil 91 | } 92 | 93 | services := make([]*dep.CatalogNodeService, 0, len(node.Services)) 94 | for _, v := range node.Services { 95 | services = append(services, &dep.CatalogNodeService{ 96 | ID: v.ID, 97 | Service: v.Service, 98 | Tags: dep.ServiceTags(deepCopyAndSortTags(v.Tags)), 99 | Meta: v.Meta, 100 | Port: v.Port, 101 | Address: v.Address, 102 | EnableTagOverride: v.EnableTagOverride, 103 | }) 104 | } 105 | sort.SliceStable(services, 106 | func(i, j int) bool { 107 | if services[i].Service == services[j].Service { 108 | return services[i].ID < services[j].ID 109 | } 110 | return services[i].Service < services[j].Service 111 | }) 112 | 113 | detail := &dep.CatalogNode{ 114 | Node: &dep.Node{ 115 | ID: node.Node.ID, 116 | Node: node.Node.Node, 117 | Address: node.Node.Address, 118 | Datacenter: node.Node.Datacenter, 119 | TaggedAddresses: node.Node.TaggedAddresses, 120 | Meta: node.Node.Meta, 121 | }, 122 | Services: services, 123 | } 124 | 125 | return detail, rm, nil 126 | } 127 | 128 | // CanShare returns a boolean if this dependency is shareable. 129 | func (d *CatalogNodeQuery) CanShare() bool { 130 | return false 131 | } 132 | 133 | // ID returns the human-friendly version of this dependency. 134 | func (d *CatalogNodeQuery) ID() string { 135 | name := d.name 136 | if d.dc != "" { 137 | name = name + "@" + d.dc 138 | } 139 | 140 | if name == "" { 141 | return "catalog.node" 142 | } 143 | return fmt.Sprintf("catalog.node(%s)", name) 144 | } 145 | 146 | // Stringer interface reuses ID 147 | func (d *CatalogNodeQuery) String() string { 148 | return d.ID() 149 | } 150 | 151 | // Stop halts the dependency's fetch function. 152 | func (d *CatalogNodeQuery) Stop() { 153 | close(d.stopCh) 154 | } 155 | 156 | func (d *CatalogNodeQuery) SetOptions(opts QueryOptions) { 157 | d.opts = opts 158 | } 159 | -------------------------------------------------------------------------------- /internal/dependency/catalog_node_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/hashicorp/hcat/dep" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestNewCatalogNodeQuery(t *testing.T) { 15 | t.Parallel() 16 | 17 | cases := []struct { 18 | name string 19 | i string 20 | exp *CatalogNodeQuery 21 | err bool 22 | }{ 23 | { 24 | "empty", 25 | "", 26 | &CatalogNodeQuery{}, 27 | false, 28 | }, 29 | { 30 | "bad", 31 | "!4d", 32 | nil, 33 | true, 34 | }, 35 | { 36 | "dc_only", 37 | "@dc1", 38 | nil, 39 | true, 40 | }, 41 | { 42 | "node", 43 | "node", 44 | &CatalogNodeQuery{ 45 | name: "node", 46 | }, 47 | false, 48 | }, 49 | { 50 | "dc", 51 | "node@dc1", 52 | &CatalogNodeQuery{ 53 | name: "node", 54 | dc: "dc1", 55 | }, 56 | false, 57 | }, 58 | { 59 | "periods", 60 | "node.bar.com@dc1", 61 | &CatalogNodeQuery{ 62 | name: "node.bar.com", 63 | dc: "dc1", 64 | }, 65 | false, 66 | }, 67 | } 68 | 69 | for i, tc := range cases { 70 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 71 | act, err := NewCatalogNodeQuery(tc.i) 72 | if (err != nil) != tc.err { 73 | t.Fatal(err) 74 | } 75 | 76 | if act != nil { 77 | act.stopCh = nil 78 | } 79 | 80 | assert.Equal(t, tc.exp, act) 81 | }) 82 | } 83 | } 84 | 85 | func TestCatalogNodeQuery_Fetch(t *testing.T) { 86 | t.Parallel() 87 | 88 | cases := []struct { 89 | name string 90 | i string 91 | exp *dep.CatalogNode 92 | }{ 93 | { 94 | "local", 95 | "", 96 | &dep.CatalogNode{ 97 | Node: &dep.Node{ 98 | Node: testConsul.Config.NodeName, 99 | Address: testConsul.Config.Bind, 100 | Datacenter: "dc1", 101 | TaggedAddresses: map[string]string{ 102 | "lan": "127.0.0.1", 103 | "wan": "127.0.0.1", 104 | }, 105 | Meta: map[string]string{ 106 | "consul-network-segment": "", 107 | }, 108 | }, 109 | Services: []*dep.CatalogNodeService{ 110 | { 111 | ID: "consul", 112 | Service: "consul", 113 | Port: testConsul.Config.Ports.Server, 114 | Tags: dep.ServiceTags([]string{}), 115 | Meta: map[string]string{}, 116 | }, 117 | { 118 | ID: "critical-service", 119 | Service: "critical-service", 120 | Tags: dep.ServiceTags([]string{}), 121 | Meta: map[string]string{}, 122 | }, 123 | { 124 | ID: "foo", 125 | Service: "foo-sidecar-proxy", 126 | Tags: dep.ServiceTags([]string{}), 127 | Meta: map[string]string{}, 128 | Port: 21999, 129 | }, 130 | { 131 | ID: "service-meta", 132 | Service: "service-meta", 133 | Tags: dep.ServiceTags([]string{"tag1"}), 134 | Meta: map[string]string{ 135 | "meta1": "value1", 136 | }, 137 | }, 138 | { 139 | ID: "service-taggedAddresses", 140 | Service: "service-taggedAddresses", 141 | Tags: dep.ServiceTags([]string{}), 142 | Meta: map[string]string{}, 143 | }, 144 | }, 145 | }, 146 | }, 147 | { 148 | "unknown", 149 | "not_a_real_node", 150 | &dep.CatalogNode{}, 151 | }, 152 | } 153 | 154 | for i, tc := range cases { 155 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 156 | d, err := NewCatalogNodeQuery(tc.i) 157 | if err != nil { 158 | t.Fatal(err) 159 | } 160 | 161 | act, _, err := d.Fetch(testClients) 162 | if err != nil { 163 | t.Fatal(err) 164 | } 165 | 166 | if act != nil { 167 | if n := act.(*dep.CatalogNode).Node; n != nil { 168 | n.ID = "" 169 | n.TaggedAddresses = filterAddresses(n.TaggedAddresses) 170 | } 171 | // delete any version data from ServiceMeta 172 | services := act.(*dep.CatalogNode).Services 173 | for i := range services { 174 | services[i].Meta = filterMeta(services[i].Meta) 175 | } 176 | } 177 | 178 | assert.Equal(t, tc.exp, act) 179 | }) 180 | } 181 | } 182 | 183 | func TestCatalogNodeQuery_String(t *testing.T) { 184 | t.Parallel() 185 | 186 | cases := []struct { 187 | name string 188 | i string 189 | exp string 190 | }{ 191 | { 192 | "empty", 193 | "", 194 | "catalog.node", 195 | }, 196 | { 197 | "node", 198 | "node1", 199 | "catalog.node(node1)", 200 | }, 201 | { 202 | "datacenter", 203 | "node1@dc1", 204 | "catalog.node(node1@dc1)", 205 | }, 206 | } 207 | 208 | for i, tc := range cases { 209 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 210 | d, err := NewCatalogNodeQuery(tc.i) 211 | if err != nil { 212 | t.Fatal(err) 213 | } 214 | assert.Equal(t, tc.exp, d.ID()) 215 | }) 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /internal/dependency/catalog_nodes.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "encoding/gob" 8 | "fmt" 9 | "regexp" 10 | "sort" 11 | 12 | "github.com/hashicorp/hcat/dep" 13 | "github.com/pkg/errors" 14 | ) 15 | 16 | var ( 17 | // Ensure implements 18 | _ isDependency = (*CatalogNodesQuery)(nil) 19 | 20 | // CatalogNodesQueryRe is the regular expression to use. 21 | CatalogNodesQueryRe = regexp.MustCompile(`\A` + dcRe + nearRe + `\z`) 22 | ) 23 | 24 | func init() { 25 | gob.Register([]*dep.Node{}) 26 | } 27 | 28 | // CatalogNodesQuery is the representation of all registered nodes in Consul. 29 | type CatalogNodesQuery struct { 30 | isConsul 31 | stopCh chan struct{} 32 | 33 | dc string 34 | near string 35 | opts QueryOptions 36 | } 37 | 38 | // NewCatalogNodesQuery parses the given string into a dependency. If the name is 39 | // empty then the name of the local agent is used. 40 | func NewCatalogNodesQuery(s string) (*CatalogNodesQuery, error) { 41 | if !CatalogNodesQueryRe.MatchString(s) { 42 | return nil, fmt.Errorf("catalog.nodes: invalid format: %q", s) 43 | } 44 | 45 | m := regexpMatch(CatalogNodesQueryRe, s) 46 | return &CatalogNodesQuery{ 47 | dc: m["dc"], 48 | near: m["near"], 49 | stopCh: make(chan struct{}, 1), 50 | }, nil 51 | } 52 | 53 | // Fetch queries the Consul API defined by the given client and returns a slice 54 | // of Node objects 55 | func (d *CatalogNodesQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 56 | select { 57 | case <-d.stopCh: 58 | return nil, nil, ErrStopped 59 | default: 60 | } 61 | 62 | opts := d.opts.Merge(&QueryOptions{ 63 | Datacenter: d.dc, 64 | Near: d.near, 65 | }) 66 | 67 | n, qm, err := clients.Consul().Catalog().Nodes(opts.ToConsulOpts()) 68 | if err != nil { 69 | return nil, nil, errors.Wrap(err, d.ID()) 70 | } 71 | 72 | nodes := make([]*dep.Node, 0, len(n)) 73 | for _, node := range n { 74 | nodes = append(nodes, &dep.Node{ 75 | ID: node.ID, 76 | Node: node.Node, 77 | Address: node.Address, 78 | Datacenter: node.Datacenter, 79 | TaggedAddresses: node.TaggedAddresses, 80 | Meta: node.Meta, 81 | }) 82 | } 83 | 84 | // Sort unless the user explicitly asked for nearness 85 | if d.near == "" { 86 | sort.SliceStable(nodes, 87 | func(i, j int) bool { 88 | if nodes[i].Node == nodes[j].Node { 89 | return nodes[i].Address < nodes[j].Address 90 | } 91 | return nodes[i].Node < nodes[j].Node 92 | }) 93 | } 94 | 95 | rm := &dep.ResponseMetadata{ 96 | LastIndex: qm.LastIndex, 97 | LastContact: qm.LastContact, 98 | } 99 | 100 | return nodes, rm, nil 101 | } 102 | 103 | // CanShare returns a boolean if this dependency is shareable. 104 | func (d *CatalogNodesQuery) CanShare() bool { 105 | return true 106 | } 107 | 108 | // ID returns the human-friendly version of this dependency. 109 | func (d *CatalogNodesQuery) ID() string { 110 | name := "" 111 | if d.dc != "" { 112 | name = name + "@" + d.dc 113 | } 114 | if d.near != "" { 115 | name = name + "~" + d.near 116 | } 117 | 118 | if name == "" { 119 | return "catalog.nodes" 120 | } 121 | return fmt.Sprintf("catalog.nodes(%s)", name) 122 | } 123 | 124 | // Stringer interface reuses ID 125 | func (d *CatalogNodesQuery) String() string { 126 | return d.ID() 127 | } 128 | 129 | // Stop halts the dependency's fetch function. 130 | func (d *CatalogNodesQuery) Stop() { 131 | close(d.stopCh) 132 | } 133 | 134 | func (d *CatalogNodesQuery) SetOptions(opts QueryOptions) { 135 | d.opts = opts 136 | } 137 | -------------------------------------------------------------------------------- /internal/dependency/catalog_nodes_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/hashicorp/hcat/dep" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestNewCatalogNodesQuery(t *testing.T) { 15 | t.Parallel() 16 | 17 | cases := []struct { 18 | name string 19 | i string 20 | exp *CatalogNodesQuery 21 | err bool 22 | }{ 23 | { 24 | "empty", 25 | "", 26 | &CatalogNodesQuery{}, 27 | false, 28 | }, 29 | { 30 | "node", 31 | "node", 32 | nil, 33 | true, 34 | }, 35 | { 36 | "dc", 37 | "@dc1", 38 | &CatalogNodesQuery{ 39 | dc: "dc1", 40 | }, 41 | false, 42 | }, 43 | { 44 | "near", 45 | "~node1", 46 | &CatalogNodesQuery{ 47 | near: "node1", 48 | }, 49 | false, 50 | }, 51 | { 52 | "dc_near", 53 | "@dc1~node1", 54 | &CatalogNodesQuery{ 55 | dc: "dc1", 56 | near: "node1", 57 | }, 58 | false, 59 | }, 60 | } 61 | 62 | for i, tc := range cases { 63 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 64 | act, err := NewCatalogNodesQuery(tc.i) 65 | if (err != nil) != tc.err { 66 | t.Fatal(err) 67 | } 68 | 69 | if act != nil { 70 | act.stopCh = nil 71 | } 72 | 73 | assert.Equal(t, tc.exp, act) 74 | }) 75 | } 76 | } 77 | 78 | func TestCatalogNodesQuery_Fetch(t *testing.T) { 79 | t.Parallel() 80 | 81 | cases := []struct { 82 | name string 83 | i string 84 | exp []*dep.Node 85 | }{ 86 | { 87 | "all", 88 | "", 89 | []*dep.Node{ 90 | { 91 | Node: testConsul.Config.NodeName, 92 | Address: testConsul.Config.Bind, 93 | Datacenter: "dc1", 94 | TaggedAddresses: map[string]string{ 95 | "lan": "127.0.0.1", 96 | "wan": "127.0.0.1", 97 | }, 98 | Meta: map[string]string{ 99 | "consul-network-segment": "", 100 | }, 101 | }, 102 | }, 103 | }, 104 | } 105 | 106 | for i, tc := range cases { 107 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 108 | d, err := NewCatalogNodesQuery(tc.i) 109 | if err != nil { 110 | t.Fatal(err) 111 | } 112 | 113 | act, _, err := d.Fetch(testClients) 114 | if err != nil { 115 | t.Fatal(err) 116 | } 117 | 118 | if act != nil { 119 | for _, n := range act.([]*dep.Node) { 120 | n.ID = "" 121 | n.TaggedAddresses = filterAddresses(n.TaggedAddresses) 122 | } 123 | } 124 | 125 | assert.Equal(t, tc.exp, act) 126 | }) 127 | } 128 | } 129 | 130 | func TestCatalogNodesQuery_String(t *testing.T) { 131 | t.Parallel() 132 | 133 | cases := []struct { 134 | name string 135 | i string 136 | exp string 137 | }{ 138 | { 139 | "empty", 140 | "", 141 | "catalog.nodes", 142 | }, 143 | { 144 | "datacenter", 145 | "@dc1", 146 | "catalog.nodes(@dc1)", 147 | }, 148 | { 149 | "near", 150 | "~node1", 151 | "catalog.nodes(~node1)", 152 | }, 153 | { 154 | "datacenter_near", 155 | "@dc1~node1", 156 | "catalog.nodes(@dc1~node1)", 157 | }, 158 | } 159 | 160 | for i, tc := range cases { 161 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 162 | d, err := NewCatalogNodesQuery(tc.i) 163 | if err != nil { 164 | t.Fatal(err) 165 | } 166 | assert.Equal(t, tc.exp, d.ID()) 167 | }) 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /internal/dependency/catalog_service.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "encoding/gob" 8 | "fmt" 9 | "net/url" 10 | "regexp" 11 | 12 | "github.com/hashicorp/hcat/dep" 13 | "github.com/pkg/errors" 14 | ) 15 | 16 | var ( 17 | // Ensure implements 18 | _ isDependency = (*CatalogServiceQuery)(nil) 19 | 20 | // CatalogServiceQueryRe is the regular expression to use. 21 | CatalogServiceQueryRe = regexp.MustCompile(`\A` + tagRe + serviceNameRe + dcRe + nearRe + `\z`) 22 | ) 23 | 24 | func init() { 25 | gob.Register([]*dep.CatalogSnippet{}) 26 | } 27 | 28 | // CatalogService is a catalog entry in Consul. 29 | type CatalogService struct { 30 | ID string 31 | Node string 32 | Address string 33 | Datacenter string 34 | TaggedAddresses map[string]string 35 | NodeMeta map[string]string 36 | ServiceID string 37 | ServiceName string 38 | ServiceAddress string 39 | ServiceTags dep.ServiceTags 40 | ServiceMeta map[string]string 41 | ServicePort int 42 | Namespace string 43 | } 44 | 45 | // CatalogServiceQuery is the representation of a requested catalog services 46 | // dependency from inside a template. 47 | type CatalogServiceQuery struct { 48 | isConsul 49 | stopCh chan struct{} 50 | 51 | dc string 52 | name string 53 | near string 54 | tag string 55 | opts QueryOptions 56 | } 57 | 58 | // NewCatalogServiceQuery parses a string into a CatalogServiceQuery. 59 | func NewCatalogServiceQuery(s string) (*CatalogServiceQuery, error) { 60 | if !CatalogServiceQueryRe.MatchString(s) { 61 | return nil, fmt.Errorf("catalog.service: invalid format: %q", s) 62 | } 63 | 64 | m := regexpMatch(CatalogServiceQueryRe, s) 65 | return &CatalogServiceQuery{ 66 | stopCh: make(chan struct{}, 1), 67 | dc: m["dc"], 68 | name: m["name"], 69 | near: m["near"], 70 | tag: m["tag"], 71 | }, nil 72 | } 73 | 74 | // Fetch queries the Consul API defined by the given client and returns a slice 75 | // of CatalogService objects. 76 | func (d *CatalogServiceQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 77 | select { 78 | case <-d.stopCh: 79 | return nil, nil, ErrStopped 80 | default: 81 | } 82 | 83 | opts := d.opts.Merge(&QueryOptions{ 84 | Datacenter: d.dc, 85 | Near: d.near, 86 | }) 87 | 88 | u := &url.URL{ 89 | Path: "/v1/catalog/service/" + d.name, 90 | RawQuery: opts.String(), 91 | } 92 | if d.tag != "" { 93 | q := u.Query() 94 | q.Set("tag", d.tag) 95 | u.RawQuery = q.Encode() 96 | } 97 | 98 | entries, qm, err := clients.Consul().Catalog().Service(d.name, d.tag, opts.ToConsulOpts()) 99 | if err != nil { 100 | return nil, nil, errors.Wrap(err, d.ID()) 101 | } 102 | 103 | var list []*CatalogService 104 | for _, s := range entries { 105 | list = append(list, &CatalogService{ 106 | ID: s.ID, 107 | Node: s.Node, 108 | Address: s.Address, 109 | Datacenter: s.Datacenter, 110 | TaggedAddresses: s.TaggedAddresses, 111 | NodeMeta: s.NodeMeta, 112 | ServiceID: s.ServiceID, 113 | ServiceName: s.ServiceName, 114 | ServiceAddress: s.ServiceAddress, 115 | ServiceTags: dep.ServiceTags(deepCopyAndSortTags(s.ServiceTags)), 116 | ServiceMeta: s.ServiceMeta, 117 | ServicePort: s.ServicePort, 118 | Namespace: s.Namespace, 119 | }) 120 | } 121 | 122 | rm := &dep.ResponseMetadata{ 123 | LastIndex: qm.LastIndex, 124 | LastContact: qm.LastContact, 125 | } 126 | 127 | return list, rm, nil 128 | } 129 | 130 | // CanShare returns a boolean if this dependency is shareable. 131 | func (d *CatalogServiceQuery) CanShare() bool { 132 | return true 133 | } 134 | 135 | // ID returns the human-friendly version of this dependency. 136 | func (d *CatalogServiceQuery) ID() string { 137 | name := d.name 138 | if d.tag != "" { 139 | name = d.tag + "." + name 140 | } 141 | if d.dc != "" { 142 | name = name + "@" + d.dc 143 | } 144 | if d.near != "" { 145 | name = name + "~" + d.near 146 | } 147 | return fmt.Sprintf("catalog.service(%s)", name) 148 | } 149 | 150 | // Stringer interface reuses ID 151 | func (d *CatalogServiceQuery) String() string { 152 | return d.ID() 153 | } 154 | 155 | // Stop halts the dependency's fetch function. 156 | func (d *CatalogServiceQuery) Stop() { 157 | close(d.stopCh) 158 | } 159 | func (d *CatalogServiceQuery) SetOptions(opts QueryOptions) { 160 | d.opts = opts 161 | } 162 | -------------------------------------------------------------------------------- /internal/dependency/catalog_services.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "encoding/gob" 8 | "fmt" 9 | "regexp" 10 | "sort" 11 | "strings" 12 | 13 | "github.com/hashicorp/hcat/dep" 14 | "github.com/pkg/errors" 15 | ) 16 | 17 | var ( 18 | // Ensure implements 19 | _ isDependency = (*CatalogServicesQuery)(nil) 20 | 21 | // CatalogServicesQueryRe is the regular expression to use for CatalogNodesQuery. 22 | CatalogServicesQueryRe = regexp.MustCompile(`\A` + dcRe + `\z`) 23 | ) 24 | 25 | func init() { 26 | gob.Register([]*dep.CatalogSnippet{}) 27 | } 28 | 29 | // CatalogServicesQuery is the representation of a requested catalog service 30 | // dependency from inside a template. 31 | type CatalogServicesQuery struct { 32 | isConsul 33 | stopCh chan struct{} 34 | 35 | dc string 36 | ns string 37 | nodeMeta map[string]string 38 | opts QueryOptions 39 | } 40 | 41 | // NewCatalogServicesQueryV1 processes options in the format of "key=value" 42 | // e.g. "dc=dc1" 43 | func NewCatalogServicesQueryV1(opts []string) (*CatalogServicesQuery, error) { 44 | catalogServicesQuery := CatalogServicesQuery{ 45 | stopCh: make(chan struct{}, 1), 46 | } 47 | 48 | for _, opt := range opts { 49 | if strings.TrimSpace(opt) == "" { 50 | continue 51 | } 52 | 53 | query, value, err := stringsSplit2(opt, "=") 54 | if err != nil { 55 | return nil, fmt.Errorf( 56 | "catalog.services: invalid query parameter format: %q", opt) 57 | } 58 | switch query { 59 | case "dc", "datacenter": 60 | catalogServicesQuery.dc = value 61 | case "ns", "namespace": 62 | catalogServicesQuery.ns = value 63 | case "node-meta": 64 | if catalogServicesQuery.nodeMeta == nil { 65 | catalogServicesQuery.nodeMeta = make(map[string]string) 66 | } 67 | k, v, err := stringsSplit2(value, ":") 68 | if err != nil { 69 | return nil, fmt.Errorf( 70 | "catalog.services: invalid format for query parameter %q: %s", 71 | query, value, 72 | ) 73 | } 74 | catalogServicesQuery.nodeMeta[k] = v 75 | default: 76 | return nil, fmt.Errorf( 77 | "catalog.services: invalid query parameter: %q", opt) 78 | } 79 | } 80 | 81 | return &catalogServicesQuery, nil 82 | } 83 | 84 | // NewCatalogServicesQuery parses a string of the format @dc. 85 | func NewCatalogServicesQuery(s string) (*CatalogServicesQuery, error) { 86 | if !CatalogServicesQueryRe.MatchString(s) { 87 | return nil, fmt.Errorf("catalog.services: invalid format: %q", s) 88 | } 89 | 90 | m := regexpMatch(CatalogServicesQueryRe, s) 91 | return &CatalogServicesQuery{ 92 | stopCh: make(chan struct{}, 1), 93 | dc: m["dc"], 94 | }, nil 95 | } 96 | 97 | // Fetch queries the Consul API defined by the given client and returns a slice 98 | // of CatalogService objects. 99 | func (d *CatalogServicesQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 100 | select { 101 | case <-d.stopCh: 102 | return nil, nil, ErrStopped 103 | default: 104 | } 105 | 106 | opts := d.opts.Merge(&QueryOptions{ 107 | Datacenter: d.dc, 108 | Namespace: d.ns, 109 | }).ToConsulOpts() 110 | // node-meta is handled specifically for /v1/catalog/services endpoint since 111 | // it does not support the preferred filter option. 112 | opts.NodeMeta = d.nodeMeta 113 | 114 | entries, qm, err := clients.Consul().Catalog().Services(opts) 115 | if err != nil { 116 | return nil, nil, errors.Wrap(err, d.ID()) 117 | } 118 | 119 | var catalogServices []*dep.CatalogSnippet 120 | for name, tags := range entries { 121 | catalogServices = append(catalogServices, &dep.CatalogSnippet{ 122 | Name: name, 123 | Tags: dep.ServiceTags(deepCopyAndSortTags(tags)), 124 | }) 125 | } 126 | 127 | sort.SliceStable(catalogServices, 128 | func(i, j int) bool { 129 | if catalogServices[i].Name < catalogServices[j].Name { 130 | return true 131 | } 132 | return false 133 | }) 134 | 135 | rm := &dep.ResponseMetadata{ 136 | LastIndex: qm.LastIndex, 137 | LastContact: qm.LastContact, 138 | } 139 | 140 | return catalogServices, rm, nil 141 | } 142 | 143 | // CanShare returns a boolean if this dependency is shareable. 144 | func (d *CatalogServicesQuery) CanShare() bool { 145 | return true 146 | } 147 | 148 | // ID returns the human-friendly version of this dependency. 149 | func (d *CatalogServicesQuery) ID() string { 150 | var opts []string 151 | if d.dc != "" { 152 | opts = append(opts, fmt.Sprintf("@%s", d.dc)) 153 | } 154 | if d.ns != "" { 155 | opts = append(opts, fmt.Sprintf("ns=%s", d.ns)) 156 | } 157 | for k, v := range d.nodeMeta { 158 | opts = append(opts, fmt.Sprintf("node-meta=%s:%s", k, v)) 159 | } 160 | if len(opts) > 0 { 161 | sort.Strings(opts) 162 | return fmt.Sprintf("catalog.services(%s)", strings.Join(opts, "&")) 163 | } 164 | return "catalog.services" 165 | } 166 | 167 | // Stringer interface reuses ID 168 | func (d *CatalogServicesQuery) String() string { 169 | return d.ID() 170 | } 171 | 172 | // Stop halts the dependency's fetch function. 173 | func (d *CatalogServicesQuery) Stop() { 174 | close(d.stopCh) 175 | } 176 | 177 | func (d *CatalogServicesQuery) SetOptions(opts QueryOptions) { 178 | d.opts = opts 179 | } 180 | 181 | // stringsSplit2 splits a string 182 | func stringsSplit2(s string, sep string) (string, string, error) { 183 | split := strings.Split(s, sep) 184 | if len(split) != 2 { 185 | return "", "", fmt.Errorf("unexpected split on separator %q: %s", sep, s) 186 | } 187 | return strings.TrimSpace(split[0]), strings.TrimSpace(split[1]), nil 188 | } 189 | -------------------------------------------------------------------------------- /internal/dependency/client_set_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | 12 | capi "github.com/hashicorp/consul/api" 13 | vapi "github.com/hashicorp/vault/api" 14 | ) 15 | 16 | func TestClientSet_unwrapVaultToken(t *testing.T) { 17 | // Don't use t.Parallel() here as the SetWrappingLookupFunc is a global 18 | // setting and breaks other tests if run in parallel 19 | 20 | vault := testClients.Vault() 21 | 22 | // Create a wrapped token 23 | vault.SetWrappingLookupFunc(func(operation, path string) string { 24 | return "30s" 25 | }) 26 | defer vault.SetWrappingLookupFunc(nil) 27 | 28 | wrappedToken, err := vault.Auth().Token().Create(&vapi.TokenCreateRequest{ 29 | Lease: "1h", 30 | }) 31 | if err != nil { 32 | t.Fatal(err) 33 | } 34 | 35 | token := vault.Token() 36 | 37 | if token == wrappedToken.WrapInfo.Token { 38 | t.Errorf("expected %q to not be %q", token, 39 | wrappedToken.WrapInfo.Token) 40 | } 41 | 42 | if _, err := vault.Auth().Token().LookupSelf(); err != nil { 43 | t.Fatal(err) 44 | } 45 | } 46 | 47 | func TestClientSet_hasLeader(t *testing.T) { 48 | t.Parallel() 49 | 50 | t.Run("success", func(t *testing.T) { 51 | t.Parallel() 52 | 53 | var err error 54 | client := testClients.Consul() 55 | if err = hasLeader(client, time.Minute); err != nil { 56 | t.Fatal("unexpected hasLeader error:", err) 57 | } 58 | }) 59 | 60 | t.Run("non temp error", func(t *testing.T) { 61 | t.Parallel() 62 | 63 | cconf := capi.DefaultConfig() 64 | cconf.Address = "bad.host:8500" 65 | client, err := capi.NewClient(cconf) 66 | if err != nil { 67 | t.Fatal("client create error:", err) 68 | } 69 | if err = hasLeader(client, time.Minute); err == nil { 70 | t.Fatal("hasLeader should have returned an error") 71 | } 72 | }) 73 | 74 | t.Run("retry exceeds", func(t *testing.T) { 75 | t.Parallel() 76 | 77 | testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 | // Force timeout by setting client transport timeout to a shorter duration 79 | // than the delayed response 80 | time.Sleep(20 * time.Millisecond) 81 | w.WriteHeader(http.StatusOK) 82 | w.Write([]byte("leader.address:8500")) 83 | })) 84 | defer testServer.Close() 85 | 86 | transport := http.DefaultTransport.(*http.Transport).Clone() 87 | transport.ResponseHeaderTimeout = 10 * time.Millisecond 88 | cconf := capi.Config{ 89 | Address: testServer.URL, 90 | HttpClient: testServer.Client(), 91 | Transport: transport, 92 | } 93 | client, err := capi.NewClient(&cconf) 94 | if err != nil { 95 | t.Fatal("client create error:", err) 96 | } 97 | 98 | startTime := time.Now() 99 | if err = hasLeader(client, 3*time.Second); err == nil { 100 | t.Fatal("hasLeader should have returned an error") 101 | } 102 | 103 | // Test retry logic reaches the maxRetryWait 104 | // retries once and exists before the next retry with delay 4s 105 | elapsed := time.Now().Sub(startTime) 106 | expected := 2 * time.Second 107 | if elapsed < expected { 108 | t.Fatal("hasLeader should have exceeded retry duration but returned early", elapsed, expected) 109 | } 110 | }) 111 | } 112 | -------------------------------------------------------------------------------- /internal/dependency/connect_ca.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "github.com/hashicorp/hcat/dep" 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | var ( 12 | // Ensure implements 13 | _ isDependency = (*ConnectCAQuery)(nil) 14 | _ BlockingQuery = (*ConnectCAQuery)(nil) 15 | ) 16 | 17 | type ConnectCAQuery struct { 18 | isConsul 19 | isBlocking 20 | stopCh chan struct{} 21 | opts QueryOptions 22 | } 23 | 24 | func NewConnectCAQuery() *ConnectCAQuery { 25 | return &ConnectCAQuery{ 26 | stopCh: make(chan struct{}, 1), 27 | } 28 | } 29 | 30 | func (d *ConnectCAQuery) Fetch(clients dep.Clients) ( 31 | interface{}, *dep.ResponseMetadata, error, 32 | ) { 33 | select { 34 | case <-d.stopCh: 35 | return nil, nil, ErrStopped 36 | default: 37 | } 38 | 39 | opts := d.opts.Merge(nil) 40 | certs, md, err := clients.Consul().Agent().ConnectCARoots( 41 | opts.ToConsulOpts()) 42 | if err != nil { 43 | return nil, nil, errors.Wrap(err, d.ID()) 44 | } 45 | 46 | rm := &dep.ResponseMetadata{ 47 | LastIndex: md.LastIndex, 48 | LastContact: md.LastContact, 49 | } 50 | 51 | return certs.Roots, rm, nil 52 | } 53 | 54 | func (d *ConnectCAQuery) Stop() { 55 | close(d.stopCh) 56 | } 57 | 58 | func (d *ConnectCAQuery) CanShare() bool { 59 | return false 60 | } 61 | 62 | // ID returns the human-friendly version of this dependency. 63 | func (d *ConnectCAQuery) ID() string { 64 | return "connect.caroots" 65 | } 66 | 67 | // Stringer interface reuses ID 68 | func (d *ConnectCAQuery) String() string { 69 | return d.ID() 70 | } 71 | 72 | func (d *ConnectCAQuery) SetOptions(opts QueryOptions) { 73 | d.opts = opts 74 | } 75 | -------------------------------------------------------------------------------- /internal/dependency/connect_ca_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/hashicorp/consul/api" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestConnectCAQuery_Fetch(t *testing.T) { 14 | t.Parallel() 15 | 16 | d := NewConnectCAQuery() 17 | raw, _, err := d.Fetch(testClients) 18 | assert.NoError(t, err) 19 | act := raw.([]*api.CARoot) 20 | if assert.Len(t, act, 1) { 21 | ca := act[0] 22 | // Root CA name can vary 23 | valid := []string{"Consul CA Root Cert", "Consul CA Primary Cert"} 24 | assert.Contains(t, valid, ca.Name) 25 | assert.True(t, ca.Active) 26 | assert.NotEmpty(t, ca.RootCertPEM) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /internal/dependency/connect_leaf.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | 9 | "github.com/hashicorp/hcat/dep" 10 | "github.com/pkg/errors" 11 | ) 12 | 13 | var ( 14 | // Ensure implements 15 | _ isDependency = (*ConnectLeafQuery)(nil) 16 | _ BlockingQuery = (*ConnectLeafQuery)(nil) 17 | ) 18 | 19 | type ConnectLeafQuery struct { 20 | isConsul 21 | isBlocking 22 | stopCh chan struct{} 23 | 24 | service string 25 | opts QueryOptions 26 | } 27 | 28 | func NewConnectLeafQuery(service string) *ConnectLeafQuery { 29 | return &ConnectLeafQuery{ 30 | stopCh: make(chan struct{}, 1), 31 | service: service, 32 | } 33 | } 34 | 35 | func (d *ConnectLeafQuery) Fetch(clients dep.Clients) ( 36 | interface{}, *dep.ResponseMetadata, error, 37 | ) { 38 | select { 39 | case <-d.stopCh: 40 | return nil, nil, ErrStopped 41 | default: 42 | } 43 | opts := d.opts.Merge(nil) 44 | 45 | cert, md, err := clients.Consul().Agent().ConnectCALeaf(d.service, 46 | opts.ToConsulOpts()) 47 | if err != nil { 48 | return nil, nil, errors.Wrap(err, d.ID()) 49 | } 50 | 51 | rm := &dep.ResponseMetadata{ 52 | LastIndex: md.LastIndex, 53 | LastContact: md.LastContact, 54 | } 55 | 56 | return cert, rm, nil 57 | } 58 | 59 | func (d *ConnectLeafQuery) Stop() { 60 | close(d.stopCh) 61 | } 62 | 63 | func (d *ConnectLeafQuery) CanShare() bool { 64 | return false 65 | } 66 | 67 | // ID returns the human-friendly version of this dependency. 68 | func (d *ConnectLeafQuery) ID() string { 69 | if d.service != "" { 70 | return fmt.Sprintf("connect.caleaf(%s)", d.service) 71 | } 72 | return "connect.caleaf" 73 | } 74 | 75 | // Stringer interface reuses ID 76 | func (d *ConnectLeafQuery) String() string { 77 | return d.ID() 78 | } 79 | 80 | func (d *ConnectLeafQuery) SetOptions(opts QueryOptions) { 81 | d.opts = opts 82 | } 83 | -------------------------------------------------------------------------------- /internal/dependency/connect_leaf_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "math/rand" 9 | "strings" 10 | "testing" 11 | "time" 12 | 13 | "github.com/hashicorp/consul/api" 14 | 15 | "github.com/pkg/errors" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | func TestNewConnectLeafQuery(t *testing.T) { 20 | t.Parallel() 21 | 22 | act := NewConnectLeafQuery("foo") 23 | act.stopCh = nil 24 | exp := &ConnectLeafQuery{service: "foo"} 25 | assert.Equal(t, exp, act) 26 | } 27 | 28 | func TestConnectLeafQuery_Fetch(t *testing.T) { 29 | t.Parallel() 30 | 31 | // leaf tests require new/unique names to generate the certs correctly 32 | uniqueName := func(name string) string { 33 | return fmt.Sprintf("%s_%d", name, rand.Int31()) 34 | } 35 | 36 | t.Run("empty-service", func(t *testing.T) { 37 | d := NewConnectLeafQuery("") 38 | 39 | _, _, err := d.Fetch(testClients) 40 | expPrefix := "Unexpected response code: 500 (URI must be either" 41 | if !strings.HasPrefix(errors.Cause(err).Error(), expPrefix) { 42 | t.Fatalf("Unexpected error: %v", err) 43 | } 44 | }) 45 | t.Run("with-service", func(t *testing.T) { 46 | name := uniqueName("foo") 47 | d := NewConnectLeafQuery(name) 48 | raw, _, err := d.Fetch(testClients) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | cert := raw.(*api.LeafCert) 53 | if cert.Service != name { 54 | t.Fatalf("Unexpected service: %v", cert.Service) 55 | } 56 | if cert.CertPEM == "" { 57 | t.Fatal("Empty cert PEM") 58 | } 59 | if cert.ValidAfter.After(time.Now()) { 60 | t.Fatalf("Bad cert: (bad ValidAfter: %v)", cert.ValidAfter) 61 | } 62 | if cert.ValidBefore.Before(time.Now()) { 63 | t.Fatalf("Bad cert: (bad ValidBefore: %v)", cert.ValidBefore) 64 | } 65 | }) 66 | t.Run("double-check", func(t *testing.T) { 67 | name := uniqueName("foo") 68 | d1 := NewConnectLeafQuery(name) 69 | raw1, _, err := d1.Fetch(testClients) 70 | if err != nil { 71 | t.Fatal(err) 72 | } 73 | cert1 := raw1.(*api.LeafCert) 74 | d2 := NewConnectLeafQuery(name) 75 | raw2, _, err := d2.Fetch(testClients) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | cert2 := raw2.(*api.LeafCert) 80 | if cert1.CertPEM != cert2.CertPEM { 81 | t.Fatalf("Certs should match:\n%v\n%v", 82 | cert1.CertPEM, cert2.CertPEM) 83 | } 84 | }) 85 | } 86 | 87 | func TestConnectLeafQuery_String(t *testing.T) { 88 | t.Parallel() 89 | 90 | cases := []struct { 91 | name string 92 | service string 93 | exp string 94 | }{ 95 | { 96 | "empty", 97 | "", 98 | "connect.caleaf", 99 | }, 100 | { 101 | "service", 102 | "foo", 103 | "connect.caleaf(foo)", 104 | }, 105 | } 106 | 107 | for i, tc := range cases { 108 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 109 | d := NewConnectLeafQuery(tc.service) 110 | assert.Equal(t, tc.exp, d.ID()) 111 | }) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /internal/dependency/consul_common_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | // filter is used as a helper for filtering values out of maps. 7 | func filter(data map[string]string, remove []string) map[string]string { 8 | if data == nil { 9 | return make(map[string]string) 10 | } 11 | for _, k := range remove { 12 | delete(data, k) 13 | } 14 | return data 15 | } 16 | 17 | func filterMeta(meta map[string]string) map[string]string { 18 | return filterVersionMeta(filterEnterprise(meta)) 19 | } 20 | 21 | // filterVersionMeta filters out all version information from the returned 22 | // metadata. It allocates the meta map if it is nil to make the tests backward 23 | // compatible with older versions. 24 | func filterVersionMeta(meta map[string]string) map[string]string { 25 | filteredMeta := []string{ 26 | "raft_version", "version", 27 | "serf_protocol_current", "serf_protocol_min", "serf_protocol_max", 28 | "grpc_port", "grpc_tls_port", 29 | } 30 | return filter(meta, filteredMeta) 31 | } 32 | 33 | // filterEnterprise filters out enterprise service metadata default values. 34 | func filterEnterprise(meta map[string]string) map[string]string { 35 | filtered := []string{"non_voter", "read_replica"} 36 | return filter(meta, filtered) 37 | } 38 | 39 | // filterAddresses filters out consul >1.7 ipv4/ipv6 specific entries 40 | // from TaggedAddresses entries on nodes, catlog and health services. 41 | func filterAddresses(addrs map[string]string) map[string]string { 42 | ipvKeys := []string{"lan_ipv4", "wan_ipv4", "lan_ipv6", "wan_ipv6"} 43 | return filter(addrs, ipvKeys) 44 | } 45 | -------------------------------------------------------------------------------- /internal/dependency/errors.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import "errors" 7 | 8 | // ErrStopped is a special error that is returned when a dependency is 9 | // prematurely stopped, usually due to a configuration reload or a process 10 | // interrupt. 11 | var ErrStopped = errors.New("dependency stopped") 12 | 13 | // ErrContinue is a special error which says to continue (retry) on error. 14 | var ErrContinue = errors.New("dependency continue") 15 | 16 | var ErrLeaseExpired = errors.New("lease expired or is not renewable") 17 | -------------------------------------------------------------------------------- /internal/dependency/file.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "io/ioutil" 9 | "os" 10 | "strings" 11 | "time" 12 | 13 | "github.com/hashicorp/hcat/dep" 14 | "github.com/pkg/errors" 15 | ) 16 | 17 | var ( 18 | // Ensure implements 19 | _ isDependency = (*FileQuery)(nil) 20 | 21 | // FileQuerySleepTime is the amount of time to sleep between queries, since 22 | // the fsnotify library is not compatible with solaris and other OSes yet. 23 | FileQuerySleepTime = 2 * time.Second 24 | ) 25 | 26 | // FileQuery represents a local file dependency. 27 | type FileQuery struct { 28 | stopCh chan struct{} 29 | 30 | path string 31 | stat os.FileInfo 32 | } 33 | 34 | // NewFileQuery creates a file dependency from the given path. 35 | func NewFileQuery(s string) (*FileQuery, error) { 36 | s = strings.TrimSpace(s) 37 | if s == "" { 38 | return nil, fmt.Errorf("file: invalid format: %q", s) 39 | } 40 | 41 | return &FileQuery{ 42 | stopCh: make(chan struct{}, 1), 43 | path: s, 44 | }, nil 45 | } 46 | 47 | // Fetch retrieves this dependency and returns the result or any errors that 48 | // occur in the process. 49 | func (d *FileQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 50 | select { 51 | case <-d.stopCh: 52 | return "", nil, ErrStopped 53 | case r := <-d.watch(d.stat): 54 | if r.err != nil { 55 | return "", nil, errors.Wrap(r.err, d.ID()) 56 | } 57 | 58 | data, err := ioutil.ReadFile(d.path) 59 | if err != nil { 60 | return "", nil, errors.Wrap(err, d.ID()) 61 | } 62 | 63 | d.stat = r.stat 64 | return respWithMetadata(string(data)) 65 | } 66 | } 67 | 68 | // CanShare returns a boolean if this dependency is shareable. 69 | func (d *FileQuery) CanShare() bool { 70 | return false 71 | } 72 | 73 | // Stop halts the dependency's fetch function. 74 | func (d *FileQuery) Stop() { 75 | close(d.stopCh) 76 | } 77 | 78 | // ID returns the human-friendly version of this dependency. 79 | func (d *FileQuery) ID() string { 80 | return fmt.Sprintf("file(%s)", d.path) 81 | } 82 | 83 | // Stringer interface reuses ID 84 | func (d *FileQuery) String() string { 85 | return d.ID() 86 | } 87 | 88 | func (d *FileQuery) SetOptions(opts QueryOptions) {} 89 | 90 | type watchResult struct { 91 | stat os.FileInfo 92 | err error 93 | } 94 | 95 | // watch watchers the file for changes 96 | func (d *FileQuery) watch(lastStat os.FileInfo) <-chan *watchResult { 97 | ch := make(chan *watchResult, 1) 98 | 99 | go func(lastStat os.FileInfo) { 100 | for { 101 | stat, err := os.Stat(d.path) 102 | if err != nil { 103 | select { 104 | case <-d.stopCh: 105 | return 106 | case ch <- &watchResult{err: err}: 107 | return 108 | } 109 | } 110 | 111 | changed := lastStat == nil || 112 | lastStat.Size() != stat.Size() || 113 | lastStat.ModTime() != stat.ModTime() 114 | 115 | if changed { 116 | select { 117 | case <-d.stopCh: 118 | return 119 | case ch <- &watchResult{stat: stat}: 120 | return 121 | } 122 | } 123 | select { 124 | case <-d.stopCh: 125 | return 126 | case <-time.After(FileQuerySleepTime): 127 | } 128 | } 129 | }(lastStat) 130 | 131 | return ch 132 | } 133 | -------------------------------------------------------------------------------- /internal/dependency/file_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "io/ioutil" 9 | "os" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func init() { 17 | FileQuerySleepTime = 50 * time.Millisecond 18 | } 19 | 20 | func TestNewFileQuery(t *testing.T) { 21 | t.Parallel() 22 | 23 | cases := []struct { 24 | name string 25 | i string 26 | exp *FileQuery 27 | err bool 28 | }{ 29 | { 30 | "empty", 31 | "", 32 | nil, 33 | true, 34 | }, 35 | { 36 | "path", 37 | "path", 38 | &FileQuery{ 39 | path: "path", 40 | }, 41 | false, 42 | }, 43 | } 44 | 45 | for i, tc := range cases { 46 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 47 | act, err := NewFileQuery(tc.i) 48 | if (err != nil) != tc.err { 49 | t.Fatal(err) 50 | } 51 | 52 | if act != nil { 53 | act.stopCh = nil 54 | } 55 | 56 | assert.Equal(t, tc.exp, act) 57 | }) 58 | } 59 | } 60 | 61 | func TestFileQuery_Fetch(t *testing.T) { 62 | t.Parallel() 63 | 64 | f, err := ioutil.TempFile("", "") 65 | if err != nil { 66 | t.Fatal(err) 67 | } 68 | defer os.Remove(f.Name()) 69 | if _, err := f.WriteString("hello world"); err != nil { 70 | t.Fatal(err) 71 | } 72 | 73 | cases := []struct { 74 | name string 75 | i string 76 | exp string 77 | err bool 78 | }{ 79 | { 80 | "non_existent", 81 | "/not/a/real/path/ever", 82 | "", 83 | true, 84 | }, 85 | { 86 | "contents", 87 | f.Name(), 88 | "hello world", 89 | false, 90 | }, 91 | } 92 | 93 | for i, tc := range cases { 94 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 95 | d, err := NewFileQuery(tc.i) 96 | if err != nil { 97 | t.Fatal(err) 98 | } 99 | 100 | act, _, err := d.Fetch(nil) 101 | if (err != nil) != tc.err { 102 | t.Fatal(err) 103 | } 104 | 105 | assert.Equal(t, tc.exp, act) 106 | }) 107 | } 108 | 109 | t.Run("stops", func(t *testing.T) { 110 | f, err := ioutil.TempFile("", "") 111 | if err != nil { 112 | t.Fatal(err) 113 | } 114 | defer os.Remove(f.Name()) 115 | 116 | d, err := NewFileQuery(f.Name()) 117 | if err != nil { 118 | t.Fatal(err) 119 | } 120 | 121 | errCh := make(chan error, 1) 122 | go func() { 123 | for { 124 | _, _, err := d.Fetch(nil) 125 | if err != nil { 126 | errCh <- err 127 | return 128 | } 129 | } 130 | }() 131 | 132 | d.Stop() 133 | 134 | select { 135 | case err := <-errCh: 136 | if err != ErrStopped { 137 | t.Fatal(err) 138 | } 139 | case <-time.After(100 * time.Millisecond): 140 | t.Errorf("did not stop") 141 | } 142 | }) 143 | 144 | t.Run("fires_changes", func(t *testing.T) { 145 | f, err := os.CreateTemp("", "") 146 | if err != nil { 147 | t.Fatal(err) 148 | } 149 | if err := os.WriteFile(f.Name(), []byte("hello"), 0644); err != nil { 150 | t.Fatal(err) 151 | } 152 | defer os.Remove(f.Name()) 153 | 154 | d, err := NewFileQuery(f.Name()) 155 | if err != nil { 156 | t.Fatal(err) 157 | } 158 | 159 | dataCh := make(chan interface{}, 1) 160 | errCh := make(chan error, 1) 161 | go func() { 162 | for { 163 | data, _, err := d.Fetch(nil) 164 | if err != nil { 165 | errCh <- err 166 | return 167 | } 168 | dataCh <- data 169 | } 170 | }() 171 | defer d.Stop() 172 | 173 | select { 174 | case err := <-errCh: 175 | t.Fatal(err) 176 | case <-dataCh: 177 | } 178 | 179 | tmp, err := os.CreateTemp("", "") 180 | if err != nil { 181 | t.Fatal(err) 182 | } 183 | defer os.Remove(tmp.Name()) 184 | 185 | if err := os.WriteFile(tmp.Name(), []byte("goodbye"), 0644); err != nil { 186 | t.Fatal(err) 187 | } 188 | if err := f.Sync(); err != nil { 189 | t.Fatal(err) 190 | } 191 | if err := os.Rename(tmp.Name(), f.Name()); err != nil { 192 | t.Fatal(err) 193 | } 194 | 195 | select { 196 | case err := <-errCh: 197 | t.Fatal(err) 198 | case data := <-dataCh: 199 | assert.Equal(t, "goodbye", data) 200 | } 201 | }) 202 | } 203 | 204 | func TestFileQuery_String(t *testing.T) { 205 | t.Parallel() 206 | 207 | cases := []struct { 208 | name string 209 | i string 210 | exp string 211 | }{ 212 | { 213 | "path", 214 | "path", 215 | "file(path)", 216 | }, 217 | } 218 | 219 | for i, tc := range cases { 220 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 221 | d, err := NewFileQuery(tc.i) 222 | if err != nil { 223 | t.Fatal(err) 224 | } 225 | assert.Equal(t, tc.exp, d.ID()) 226 | }) 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /internal/dependency/kv_exists.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "regexp" 9 | "strings" 10 | 11 | "github.com/hashicorp/hcat/dep" 12 | "github.com/pkg/errors" 13 | ) 14 | 15 | var ( 16 | // Ensure implements 17 | _ isDependency = (*KVExistsQuery)(nil) 18 | 19 | // KVExistsQueryRe is the regular expression to use. 20 | KVExistsQueryRe = regexp.MustCompile(`\A` + keyRe + dcRe + `\z`) 21 | ) 22 | 23 | // KVExistsQuery uses a non-blocking query with the KV store for key lookup. 24 | type KVExistsQuery struct { 25 | isConsul 26 | stopCh chan struct{} 27 | 28 | dc string 29 | key string 30 | ns string 31 | opts QueryOptions 32 | } 33 | 34 | func (d *KVExistsQuery) SetOptions(opts QueryOptions) { 35 | opts.WaitIndex = 0 36 | opts.WaitTime = 0 37 | d.opts = opts 38 | } 39 | 40 | // ID returns the human-friendly version of this dependency. 41 | func (d *KVExistsQuery) ID() string { 42 | key := d.key 43 | if d.dc != "" { 44 | key = key + "@" + d.dc 45 | } 46 | return fmt.Sprintf("kv.exists(%s)", key) 47 | } 48 | 49 | // Stringer interface reuses ID 50 | func (d *KVExistsQuery) String() string { 51 | return d.ID() 52 | } 53 | 54 | // NewKVExistsQueryV1 processes options in the format of "key key=value" 55 | // e.g. "my/key dc=dc1" 56 | func NewKVExistsQueryV1(key string, opts []string) (*KVExistsQuery, error) { 57 | if key == "" || key == "/" { 58 | return nil, fmt.Errorf("kv.exists: key required") 59 | } 60 | 61 | q := KVExistsQuery{ 62 | stopCh: make(chan struct{}, 1), 63 | key: strings.TrimPrefix(key, "/"), 64 | } 65 | for _, opt := range opts { 66 | if strings.TrimSpace(opt) == "" { 67 | continue 68 | } 69 | query, value, err := stringsSplit2(opt, "=") 70 | if err != nil { 71 | return nil, fmt.Errorf( 72 | "kv.exists: invalid query parameter format: %q", opt) 73 | } 74 | switch query { 75 | case "dc", "datacenter": 76 | q.dc = value 77 | case "ns", "namespace": 78 | q.ns = value 79 | default: 80 | return nil, fmt.Errorf( 81 | "kv.exists: invalid query parameter: %q", opt) 82 | } 83 | } 84 | 85 | return &q, nil 86 | } 87 | 88 | // NewKVExistsQuery parses a string into a KV lookup. 89 | func NewKVExistsQuery(s string) (*KVExistsQuery, error) { 90 | if !KVExistsQueryRe.MatchString(s) { 91 | return nil, fmt.Errorf("kv.exists: invalid format: %q", s) 92 | } 93 | 94 | m := regexpMatch(KVExistsQueryRe, s) 95 | return &KVExistsQuery{ 96 | stopCh: make(chan struct{}, 1), 97 | dc: m["dc"], 98 | key: m["key"], 99 | ns: "", 100 | }, nil 101 | } 102 | 103 | // Fetch queries the Consul API defined by the given client. 104 | func (d *KVExistsQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 105 | select { 106 | case <-d.stopCh: 107 | return nil, nil, ErrStopped 108 | default: 109 | } 110 | 111 | opts := d.opts.Merge(&QueryOptions{ 112 | Datacenter: d.dc, 113 | Namespace: d.ns, 114 | }) 115 | 116 | pair, qm, err := clients.Consul().KV().Get(d.key, opts.ToConsulOpts()) 117 | if err != nil { 118 | return nil, nil, errors.Wrap(err, d.ID()) 119 | } 120 | 121 | rm := &dep.ResponseMetadata{ 122 | LastIndex: qm.LastIndex, 123 | LastContact: qm.LastContact, 124 | } 125 | 126 | if pair == nil { 127 | return dep.KVExists(false), rm, nil 128 | } 129 | 130 | return dep.KVExists(true), rm, nil 131 | } 132 | 133 | // Stop halts the dependency's fetch function. 134 | func (d *KVExistsQuery) Stop() { 135 | close(d.stopCh) 136 | } 137 | 138 | // CanShare returns a boolean if this dependency is shareable. 139 | func (d *KVExistsQuery) CanShare() bool { 140 | return true 141 | } 142 | -------------------------------------------------------------------------------- /internal/dependency/kv_exists_get.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "strings" 9 | 10 | "github.com/hashicorp/hcat/dep" 11 | "github.com/pkg/errors" 12 | ) 13 | 14 | var ( 15 | // Ensure implements 16 | _ isDependency = (*KVExistsGetQuery)(nil) 17 | ) 18 | 19 | // KVExistsGetQuery uses a non-blocking query to lookup a single key in the KV store. 20 | // The query returns whether the key exists and the value of the key if it exists. 21 | type KVExistsGetQuery struct { 22 | BlockingQuery 23 | KVExistsQuery 24 | } 25 | 26 | // NewKVExistsGetQueryV1 processes options in the format of "key key=value" 27 | // e.g. "my/key dc=dc1" 28 | func NewKVExistsGetQueryV1(key string, opts []string) (*KVExistsGetQuery, error) { 29 | if key == "" || key == "/" { 30 | return nil, fmt.Errorf("kv.exists.get: key required") 31 | } 32 | 33 | q, err := NewKVExistsQueryV1(key, opts) 34 | if err != nil { 35 | return nil, err 36 | } 37 | return &KVExistsGetQuery{KVExistsQuery: *q}, nil 38 | } 39 | 40 | // CanShare returns a boolean if this dependency is shareable. 41 | func (d *KVExistsGetQuery) CanShare() bool { 42 | return true 43 | } 44 | 45 | // ID returns the human-friendly version of this dependency. 46 | func (d *KVExistsGetQuery) ID() string { 47 | key := d.key 48 | var opts []string 49 | if d.dc != "" { 50 | opts = append(opts, "dc="+d.dc) 51 | } 52 | if d.ns != "" { 53 | opts = append(opts, "ns="+d.ns) 54 | } 55 | if len(opts) > 0 { 56 | key = fmt.Sprintf("%s?%s", key, strings.Join(opts, "&")) 57 | } 58 | return fmt.Sprintf("kv.exists.get(%s)", key) 59 | } 60 | 61 | // Stringer interface reuses ID 62 | func (d *KVExistsGetQuery) String() string { 63 | return d.ID() 64 | } 65 | 66 | // Stop halts the dependency's fetch function. 67 | func (d *KVExistsGetQuery) Stop() { 68 | close(d.stopCh) 69 | } 70 | 71 | func (d *KVExistsGetQuery) SetOptions(opts QueryOptions) { 72 | d.opts = opts 73 | } 74 | 75 | // Fetch queries the Consul API defined by the given client. 76 | func (d *KVExistsGetQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 77 | select { 78 | case <-d.stopCh: 79 | return nil, nil, ErrStopped 80 | default: 81 | } 82 | 83 | opts := d.opts.Merge(&QueryOptions{ 84 | Datacenter: d.dc, 85 | Namespace: d.ns, 86 | }) 87 | 88 | pair, qm, err := clients.Consul().KV().Get(d.key, opts.ToConsulOpts()) 89 | if err != nil { 90 | return nil, nil, errors.Wrap(err, d.ID()) 91 | } 92 | 93 | rm := &dep.ResponseMetadata{ 94 | LastIndex: qm.LastIndex, 95 | LastContact: qm.LastContact, 96 | } 97 | 98 | if pair == nil { 99 | return &dep.KeyPair{ 100 | Path: d.key, 101 | Key: d.key, 102 | Exists: false, 103 | }, rm, nil 104 | } 105 | 106 | return &dep.KeyPair{ 107 | Path: pair.Key, 108 | Key: pair.Key, 109 | Value: string(pair.Value), 110 | Exists: true, 111 | CreateIndex: pair.CreateIndex, 112 | ModifyIndex: pair.ModifyIndex, 113 | LockIndex: pair.LockIndex, 114 | Flags: pair.Flags, 115 | Session: pair.Session, 116 | }, rm, nil 117 | } 118 | -------------------------------------------------------------------------------- /internal/dependency/kv_get.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "regexp" 9 | 10 | "github.com/hashicorp/hcat/dep" 11 | "github.com/pkg/errors" 12 | ) 13 | 14 | var ( 15 | // Ensure implements 16 | _ isDependency = (*KVGetQuery)(nil) 17 | _ BlockingQuery = (*KVGetQuery)(nil) 18 | 19 | // KVGetQueryRe is the regular expression to use. 20 | KVGetQueryRe = regexp.MustCompile(`\A` + keyRe + dcRe + `\z`) 21 | ) 22 | 23 | // KVGetQuery queries the KV store for a single key. 24 | type KVGetQuery struct { 25 | KVExistsQuery 26 | isBlocking 27 | } 28 | 29 | // NewKVGetQueryV1 processes options in the format of "key key=value" 30 | // e.g. "my/key dc=dc1" 31 | func NewKVGetQueryV1(key string, opts []string) (*KVGetQuery, error) { 32 | if key == "" || key == "/" { 33 | return nil, fmt.Errorf("kv.get: key required") 34 | } 35 | 36 | q, err := NewKVExistsQueryV1(key, opts) 37 | if err != nil { 38 | return nil, err 39 | } 40 | return &KVGetQuery{KVExistsQuery: *q}, nil 41 | } 42 | 43 | // NewKVGetQuery parses a string into a (non-blocking) KV lookup. 44 | func NewKVGetQuery(s string) (*KVGetQuery, error) { 45 | if !KVGetQueryRe.MatchString(s) { 46 | return nil, fmt.Errorf("kv.get: invalid format: %q", s) 47 | } 48 | 49 | q, err := NewKVExistsQuery(s) 50 | if err != nil { 51 | return nil, err 52 | } 53 | return &KVGetQuery{KVExistsQuery: *q}, nil 54 | } 55 | 56 | // Fetch queries the Consul API defined by the given client. 57 | func (d *KVGetQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 58 | select { 59 | case <-d.stopCh: 60 | return nil, nil, ErrStopped 61 | default: 62 | } 63 | 64 | opts := d.opts.Merge(&QueryOptions{ 65 | Datacenter: d.dc, 66 | Namespace: d.ns, 67 | }) 68 | 69 | pair, qm, err := clients.Consul().KV().Get(d.key, opts.ToConsulOpts()) 70 | if err != nil { 71 | return nil, nil, errors.Wrap(err, d.ID()) 72 | } 73 | 74 | rm := &dep.ResponseMetadata{ 75 | LastIndex: qm.LastIndex, 76 | LastContact: qm.LastContact, 77 | } 78 | 79 | if pair == nil { 80 | return nil, rm, nil 81 | } 82 | 83 | value := dep.KvValue(pair.Value) 84 | return value, rm, nil 85 | } 86 | 87 | // CanShare returns a boolean if this dependency is shareable. 88 | func (d *KVGetQuery) CanShare() bool { 89 | return true 90 | } 91 | 92 | // ID returns the human-friendly version of this dependency. 93 | func (d *KVGetQuery) ID() string { 94 | key := d.key 95 | if d.dc != "" { 96 | key = key + "@" + d.dc 97 | } 98 | 99 | return fmt.Sprintf("kv.get(%s)", key) 100 | } 101 | 102 | // Stringer interface reuses ID 103 | func (d *KVGetQuery) String() string { 104 | return d.ID() 105 | } 106 | 107 | // Stop halts the dependency's fetch function. 108 | func (d *KVGetQuery) Stop() { 109 | close(d.stopCh) 110 | } 111 | 112 | func (d *KVGetQuery) SetOptions(opts QueryOptions) { 113 | d.opts = opts 114 | } 115 | -------------------------------------------------------------------------------- /internal/dependency/kv_keys.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "regexp" 9 | "strings" 10 | 11 | "github.com/hashicorp/hcat/dep" 12 | "github.com/pkg/errors" 13 | ) 14 | 15 | var ( 16 | // Ensure implements 17 | _ isDependency = (*KVKeysQuery)(nil) 18 | 19 | // KVKeysQueryRe is the regular expression to use. 20 | KVKeysQueryRe = regexp.MustCompile(`\A` + prefixRe + dcRe + `\z`) 21 | ) 22 | 23 | // KVKeysQuery queries the KV store for a single key. 24 | type KVKeysQuery struct { 25 | isConsul 26 | stopCh chan struct{} 27 | 28 | dc string 29 | prefix string 30 | opts QueryOptions 31 | } 32 | 33 | // NewKVKeysQuery parses a string into a dependency. 34 | func NewKVKeysQuery(s string) (*KVKeysQuery, error) { 35 | if s != "" && !KVKeysQueryRe.MatchString(s) { 36 | return nil, fmt.Errorf("kv.keys: invalid format: %q", s) 37 | } 38 | 39 | m := regexpMatch(KVKeysQueryRe, s) 40 | return &KVKeysQuery{ 41 | stopCh: make(chan struct{}, 1), 42 | dc: m["dc"], 43 | prefix: m["prefix"], 44 | }, nil 45 | } 46 | 47 | // Fetch queries the Consul API defined by the given client. 48 | func (d *KVKeysQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 49 | select { 50 | case <-d.stopCh: 51 | return nil, nil, ErrStopped 52 | default: 53 | } 54 | 55 | opts := d.opts.Merge(&QueryOptions{ 56 | Datacenter: d.dc, 57 | }) 58 | 59 | list, qm, err := clients.Consul().KV().Keys(d.prefix, "", opts.ToConsulOpts()) 60 | if err != nil { 61 | return nil, nil, errors.Wrap(err, d.ID()) 62 | } 63 | 64 | keys := make([]string, len(list)) 65 | for i, v := range list { 66 | v = strings.TrimPrefix(v, d.prefix) 67 | v = strings.TrimLeft(v, "/") 68 | keys[i] = v 69 | } 70 | 71 | rm := &dep.ResponseMetadata{ 72 | LastIndex: qm.LastIndex, 73 | LastContact: qm.LastContact, 74 | } 75 | 76 | return keys, rm, nil 77 | } 78 | 79 | // CanShare returns a boolean if this dependency is shareable. 80 | func (d *KVKeysQuery) CanShare() bool { 81 | return true 82 | } 83 | 84 | // ID returns the human-friendly version of this dependency. 85 | func (d *KVKeysQuery) ID() string { 86 | prefix := d.prefix 87 | if d.dc != "" { 88 | prefix = prefix + "@" + d.dc 89 | } 90 | return fmt.Sprintf("kv.keys(%s)", prefix) 91 | } 92 | 93 | // Stringer interface reuses ID 94 | func (d *KVKeysQuery) String() string { 95 | return d.ID() 96 | } 97 | 98 | // Stop halts the dependency's fetch function. 99 | func (d *KVKeysQuery) Stop() { 100 | close(d.stopCh) 101 | } 102 | 103 | func (d *KVKeysQuery) SetOptions(opts QueryOptions) { 104 | d.opts = opts 105 | } 106 | -------------------------------------------------------------------------------- /internal/dependency/kv_list.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "encoding/gob" 8 | "fmt" 9 | "regexp" 10 | "strings" 11 | 12 | "github.com/hashicorp/hcat/dep" 13 | "github.com/pkg/errors" 14 | ) 15 | 16 | var ( 17 | // Ensure implements 18 | _ isDependency = (*KVListQuery)(nil) 19 | 20 | // KVListQueryRe is the regular expression to use. 21 | KVListQueryRe = regexp.MustCompile(`\A` + prefixRe + dcRe + `\z`) 22 | ) 23 | 24 | func init() { 25 | gob.Register([]*dep.KeyPair{}) 26 | } 27 | 28 | // KVListQuery queries the KV store for a single key. 29 | type KVListQuery struct { 30 | isConsul 31 | stopCh chan struct{} 32 | 33 | dc string 34 | prefix string 35 | ns string 36 | opts QueryOptions 37 | } 38 | 39 | // NewKVListQuery processes options in the format of "prefix key=value" 40 | // e.g. "key_prefix dc=dc1" 41 | func NewKVListQueryV1(prefix string, opts []string) (*KVListQuery, error) { 42 | if prefix == "" || prefix == "/" { 43 | return nil, fmt.Errorf("kv.list: prefix required") 44 | } 45 | 46 | q := KVListQuery{ 47 | stopCh: make(chan struct{}, 1), 48 | prefix: strings.TrimPrefix(prefix, "/"), 49 | } 50 | 51 | for _, opt := range opts { 52 | if strings.TrimSpace(opt) == "" { 53 | continue 54 | } 55 | 56 | queryParam := strings.Split(opt, "=") 57 | if len(queryParam) != 2 { 58 | return nil, fmt.Errorf( 59 | "kv.list: invalid query parameter format: %q", opt) 60 | } 61 | query := strings.TrimSpace(queryParam[0]) 62 | value := strings.TrimSpace(queryParam[1]) 63 | switch query { 64 | case "dc", "datacenter": 65 | q.dc = value 66 | case "ns", "namespace": 67 | q.ns = value 68 | default: 69 | return nil, fmt.Errorf( 70 | "kv.list: invalid query parameter: %q", opt) 71 | } 72 | } 73 | 74 | return &q, nil 75 | } 76 | 77 | // NewKVListQuery parses a string into a dependency. 78 | func NewKVListQuery(s string) (*KVListQuery, error) { 79 | if s != "" && !KVListQueryRe.MatchString(s) { 80 | return nil, fmt.Errorf("kv.list: invalid format: %q", s) 81 | } 82 | 83 | m := regexpMatch(KVListQueryRe, s) 84 | return &KVListQuery{ 85 | stopCh: make(chan struct{}, 1), 86 | dc: m["dc"], 87 | prefix: m["prefix"], 88 | ns: "", 89 | }, nil 90 | } 91 | 92 | // Fetch queries the Consul API defined by the given client. 93 | func (d *KVListQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 94 | select { 95 | case <-d.stopCh: 96 | return nil, nil, ErrStopped 97 | default: 98 | } 99 | 100 | opts := d.opts.Merge(&QueryOptions{ 101 | Datacenter: d.dc, 102 | Namespace: d.ns, 103 | }) 104 | 105 | list, qm, err := clients.Consul().KV().List(d.prefix, opts.ToConsulOpts()) 106 | if err != nil { 107 | return nil, nil, errors.Wrap(err, d.ID()) 108 | } 109 | 110 | pairs := make([]*dep.KeyPair, 0, len(list)) 111 | for _, pair := range list { 112 | key := strings.TrimPrefix(pair.Key, d.prefix) 113 | key = strings.TrimLeft(key, "/") 114 | 115 | pairs = append(pairs, &dep.KeyPair{ 116 | Path: pair.Key, 117 | Key: key, 118 | Value: string(pair.Value), 119 | Exists: true, 120 | CreateIndex: pair.CreateIndex, 121 | ModifyIndex: pair.ModifyIndex, 122 | LockIndex: pair.LockIndex, 123 | Flags: pair.Flags, 124 | Session: pair.Session, 125 | }) 126 | } 127 | 128 | rm := &dep.ResponseMetadata{ 129 | LastIndex: qm.LastIndex, 130 | LastContact: qm.LastContact, 131 | } 132 | 133 | return pairs, rm, nil 134 | } 135 | 136 | // CanShare returns a boolean if this dependency is shareable. 137 | func (d *KVListQuery) CanShare() bool { 138 | return true 139 | } 140 | 141 | // ID returns the human-friendly version of this dependency. 142 | func (d *KVListQuery) ID() string { 143 | prefix := d.prefix 144 | if d.dc != "" { 145 | prefix = prefix + "@" + d.dc 146 | } 147 | return fmt.Sprintf("kv.list(%s)", prefix) 148 | } 149 | 150 | // Stringer interface reuses ID 151 | func (d *KVListQuery) String() string { 152 | return d.ID() 153 | } 154 | 155 | // Stop halts the dependency's fetch function. 156 | func (d *KVListQuery) Stop() { 157 | close(d.stopCh) 158 | } 159 | 160 | func (d *KVListQuery) SetOptions(opts QueryOptions) { 161 | d.opts = opts 162 | } 163 | -------------------------------------------------------------------------------- /internal/dependency/testdata/cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIICNDCCAZ2gAwIBAgIQFwpF8sMHYeFwQKzjObM95jANBgkqhkiG9w0BAQsFADAS 3 | MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw 4 | MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB 5 | iQKBgQDc/BnhasDc0dI0lRyWpD39IyJ4bZfal2DWDoP5FoSfdr92OZzlDgvYBRPU 6 | nVURyV3eaJMIVtT+q1TICfOFGa5r7EEwWOF9R5rqoEJQjkfWYJ2lXmn2OZOuSLix 7 | Gon1KPe6yZqJGpqvVUiMQGDSHGpM0EKnJeE0Q3tNJ8w6sZj2owIDAQABo4GIMIGF 8 | MA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8E 9 | BTADAQH/MB0GA1UdDgQWBBTFCdtMBKtGiG5YxOLOe6fBOEUSqzAuBgNVHREEJzAl 10 | ggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0B 11 | AQsFAAOBgQCGWklb+40ucgXdosA+KxPBYnUiWkNlrr8nPshz9yCAlPMVW9EWQWI7 12 | npdaLV6sVGRx6bYhF/6D0r0l8TE6qN5I9zfsEytdSs8pxgiyOpNwu5fifPj04nN8 13 | YF4vm3hCitS6nyHxWWjgZYJaH+ZGFOaNLwYOeL7u3PkWpVfNCOumoA== 14 | -----END CERTIFICATE----- 15 | -------------------------------------------------------------------------------- /internal/dependency/testdata/key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBANz8GeFqwNzR0jSV 3 | HJakPf0jInhtl9qXYNYOg/kWhJ92v3Y5nOUOC9gFE9SdVRHJXd5okwhW1P6rVMgJ 4 | 84UZrmvsQTBY4X1HmuqgQlCOR9ZgnaVeafY5k65IuLEaifUo97rJmokamq9VSIxA 5 | YNIcakzQQqcl4TRDe00nzDqxmPajAgMBAAECgYAzpcQSuBWNRoi/e14sIwTN5elH 6 | hi2ojBq4zLmxfL7QWjuTURHHQwonmcAxv/fC6XJD6eL7Xvf28WomOpUstXzbExJD 7 | bouut4ApKc4C9HT3FGMBKfZVQRBwBhe/15eS7ZCzp445WriulWVbV8AbgPP6U6+x 8 | qZA9BekhYR85cMQBeQJBAPn99AimQV2nbbADOUwDAykonNte7Pdo3f5Y1NB4Y6T2 9 | 2KM2gArZtXo2iD6F1bqyqF5TUNOsqbXXd2jzf02nRUcCQQDiS68nnaNzXwctEFSj 10 | ufGppF+OXdS891WwPyuJ8MSW8iwAntFwqGbxc2A69TA1iR/HfaE1teB5g7Cru6hr 11 | uaHFAkEAiqpSsnmFyG0WaotfPMpu9mWQnB4LUzDX8j1Tzk749of1opKYc2xPPXsC 12 | F6wk4Wo3+ho8uy0K9dKOaaim9GvUAQJBAN7W2KySNxqtQUvHARIZUThUfSSckZlj 13 | liXwjtdPGMfrwhj6TBQ8QOMTUne8arTNS1YPCGjzqRD/9UGnkbpDGmECQCCOPqlL 14 | 9MvGWS+bJdM4M8zC2G6H/voHqRCAbRlQ11hELr74nwfKDgVS4Uq4ULDeLV/Sqlcb 15 | SeXVNtxw/l+Ru2E= 16 | -----END PRIVATE KEY----- 17 | -------------------------------------------------------------------------------- /internal/dependency/vault_agent_token.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "io/ioutil" 8 | "os" 9 | "strings" 10 | "time" 11 | 12 | "github.com/hashicorp/hcat/dep" 13 | "github.com/pkg/errors" 14 | ) 15 | 16 | // Ensure implements 17 | var _ isDependency = (*VaultAgentTokenQuery)(nil) 18 | 19 | const ( 20 | // VaultAgentTokenSleepTime is the amount of time to sleep between queries, since 21 | // the fsnotify library is not compatible with solaris and other OSes yet. 22 | VaultAgentTokenSleepTime = 15 * time.Second 23 | ) 24 | 25 | // VaultAgentTokenQuery is the dependency to Vault Agent token 26 | type VaultAgentTokenQuery struct { 27 | isVault 28 | stopCh chan struct{} 29 | 30 | path string 31 | stat os.FileInfo 32 | } 33 | 34 | // NewVaultAgentTokenQuery creates a new dependency. 35 | func NewVaultAgentTokenQuery(path string) (*VaultAgentTokenQuery, error) { 36 | return &VaultAgentTokenQuery{ 37 | stopCh: make(chan struct{}, 1), 38 | path: path, 39 | }, nil 40 | } 41 | 42 | // Fetch retrieves this dependency and returns the result or any errors that 43 | // occur in the process. 44 | func (d *VaultAgentTokenQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 45 | select { 46 | case <-d.stopCh: 47 | return "", nil, ErrStopped 48 | case r := <-d.watch(d.stat): 49 | if r.err != nil { 50 | return "", nil, errors.Wrap(r.err, d.ID()) 51 | } 52 | 53 | token, err := ioutil.ReadFile(d.path) 54 | if err != nil { 55 | return "", nil, errors.Wrap(err, d.ID()) 56 | } 57 | 58 | d.stat = r.stat 59 | clients.Vault().SetToken(strings.TrimSpace(string(token))) 60 | } 61 | 62 | return respWithMetadata("") 63 | } 64 | 65 | // CanShare returns if this dependency is sharable. 66 | func (d *VaultAgentTokenQuery) CanShare() bool { 67 | return false 68 | } 69 | 70 | // Stop halts the dependency's fetch function. 71 | func (d *VaultAgentTokenQuery) Stop() { 72 | close(d.stopCh) 73 | } 74 | 75 | // ID returns the human-friendly version of this dependency. 76 | func (d *VaultAgentTokenQuery) ID() string { 77 | return "vault-agent.token" 78 | } 79 | 80 | // Stringer interface reuses ID 81 | func (d *VaultAgentTokenQuery) String() string { 82 | return d.ID() 83 | } 84 | 85 | func (d *VaultAgentTokenQuery) SetOptions(opts QueryOptions) {} 86 | 87 | // watch watches the file for changes 88 | func (d *VaultAgentTokenQuery) watch(lastStat os.FileInfo) <-chan *watchResult { 89 | ch := make(chan *watchResult, 1) 90 | 91 | go func(lastStat os.FileInfo) { 92 | for { 93 | stat, err := os.Stat(d.path) 94 | if err != nil { 95 | select { 96 | case <-d.stopCh: 97 | return 98 | case ch <- &watchResult{err: err}: 99 | return 100 | } 101 | } 102 | 103 | changed := lastStat == nil || 104 | lastStat.Size() != stat.Size() || 105 | lastStat.ModTime() != stat.ModTime() 106 | 107 | if changed { 108 | select { 109 | case <-d.stopCh: 110 | return 111 | case ch <- &watchResult{stat: stat}: 112 | return 113 | } 114 | } 115 | 116 | select { 117 | case <-d.stopCh: 118 | return 119 | case <-time.After(VaultAgentTokenSleepTime): 120 | } 121 | } 122 | }(lastStat) 123 | 124 | return ch 125 | } 126 | -------------------------------------------------------------------------------- /internal/dependency/vault_agent_token_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "io/ioutil" 8 | "os" 9 | "path/filepath" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestVaultAgentTokenQuery_Fetch(t *testing.T) { 17 | // Don't use t.Parallel() here as the SetToken() calls are global and break 18 | // other tests if run in parallel 19 | 20 | // reset token back to original 21 | vc := testClients.Vault() 22 | token := vc.Token() 23 | defer vc.SetToken(token) 24 | 25 | // Set up the Vault token file. 26 | tokenFile, err := ioutil.TempFile("", "token1") 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | defer os.Remove(tokenFile.Name()) 31 | testWrite(tokenFile.Name(), []byte("token")) 32 | 33 | d, err := NewVaultAgentTokenQuery(tokenFile.Name()) 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | 38 | clientSet := testClients 39 | _, _, err = d.Fetch(clientSet) 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | 44 | assert.Equal(t, "token", clientSet.Vault().Token()) 45 | 46 | // Update the contents. 47 | testWrite(tokenFile.Name(), []byte("another_token")) 48 | _, _, err = d.Fetch(clientSet) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | assert.Equal(t, "another_token", clientSet.Vault().Token()) 54 | } 55 | 56 | func TestVaultAgentTokenQuery_Fetch_missingFile(t *testing.T) { 57 | t.Parallel() 58 | 59 | // Use a non-existant token file path. 60 | d, err := NewVaultAgentTokenQuery("/tmp/invalid-file") 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | clientSet := NewClientSet() 66 | clientSet.CreateVaultClient(&CreateClientInput{ 67 | Token: "foo", 68 | }) 69 | _, _, err = d.Fetch(clientSet) 70 | if err == nil || !strings.Contains(err.Error(), "no such file") { 71 | t.Fatal(err) 72 | } 73 | 74 | // Token should be unaffected. 75 | assert.Equal(t, "foo", clientSet.Vault().Token()) 76 | } 77 | 78 | // 79 | func testWrite(path string, contents []byte) error { 80 | if path == "" { 81 | panic("missing path") 82 | } 83 | 84 | parent := filepath.Dir(path) 85 | if _, err := os.Stat(parent); os.IsNotExist(err) { 86 | if err := os.MkdirAll(parent, 0755); err != nil { 87 | return err 88 | } 89 | } 90 | 91 | f, err := ioutil.TempFile(parent, "") 92 | if err != nil { 93 | return err 94 | } 95 | defer os.Remove(f.Name()) 96 | 97 | if _, err := f.Write(contents); err != nil { 98 | return err 99 | } 100 | 101 | for _, err := range []error{ 102 | f.Sync(), 103 | f.Close(), 104 | os.Chmod(f.Name(), 0644), 105 | os.Rename(f.Name(), path), 106 | } { 107 | if err != nil { 108 | return err 109 | } 110 | } 111 | 112 | return nil 113 | } 114 | -------------------------------------------------------------------------------- /internal/dependency/vault_list.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "path" 9 | "sort" 10 | "strings" 11 | "time" 12 | 13 | "github.com/hashicorp/hcat/dep" 14 | "github.com/pkg/errors" 15 | ) 16 | 17 | var ( 18 | // Ensure implements 19 | _ isDependency = (*VaultListQuery)(nil) 20 | ) 21 | 22 | // VaultListQuery is the dependency to Vault for a secret 23 | type VaultListQuery struct { 24 | isVault 25 | stopCh chan struct{} 26 | 27 | path string 28 | opts QueryOptions 29 | } 30 | 31 | // NewVaultListQuery creates a new datacenter dependency. 32 | func NewVaultListQuery(s string) (*VaultListQuery, error) { 33 | s = strings.TrimSpace(s) 34 | s = strings.Trim(s, "/") 35 | if s == "" { 36 | return nil, fmt.Errorf("vault.list: invalid format: %q", s) 37 | } 38 | 39 | return &VaultListQuery{ 40 | stopCh: make(chan struct{}, 1), 41 | path: s, 42 | }, nil 43 | } 44 | 45 | // Fetch queries the Vault API 46 | func (d *VaultListQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 47 | select { 48 | case <-d.stopCh: 49 | return nil, nil, ErrStopped 50 | default: 51 | } 52 | 53 | opts := d.opts.Merge(&QueryOptions{}) 54 | 55 | // If this is not the first query, poll to simulate blocking-queries. 56 | if opts.WaitIndex != 0 { 57 | dur := opts.DefaultLease 58 | select { 59 | case <-d.stopCh: 60 | return nil, nil, ErrStopped 61 | case <-time.After(dur): 62 | } 63 | } 64 | 65 | path := d.path 66 | // Checking secret engine version. If it's v2, we should shim /metadata/ 67 | // to secret path if necessary. 68 | mountPath, isV2, _ := isKVv2(clients.Vault(), path) 69 | if isV2 { 70 | path = shimKv2ListPath(path, mountPath) 71 | } 72 | // If we got this far, we either didn't have a secret to renew, the secret was 73 | // not renewable, or the renewal failed, so attempt a fresh list. 74 | secret, err := clients.Vault().Logical().List(path) 75 | if err != nil { 76 | return nil, nil, errors.Wrap(err, d.ID()) 77 | } 78 | 79 | var result []string 80 | 81 | // The secret could be nil if it does not exist. 82 | if secret == nil || secret.Data == nil { 83 | return respWithMetadata(result) 84 | } 85 | 86 | // This is a weird thing that happened once... 87 | keys, ok := secret.Data["keys"] 88 | if !ok { 89 | return respWithMetadata(result) 90 | } 91 | 92 | list, ok := keys.([]interface{}) 93 | if !ok { 94 | return nil, nil, fmt.Errorf("%s: unexpected response", d) 95 | } 96 | 97 | for _, v := range list { 98 | typed, ok := v.(string) 99 | if !ok { 100 | return nil, nil, fmt.Errorf("%s: non-string in list", d) 101 | } 102 | result = append(result, typed) 103 | } 104 | sort.Strings(result) 105 | 106 | return respWithMetadata(result) 107 | } 108 | 109 | // CanShare returns if this dependency is shareable. 110 | func (d *VaultListQuery) CanShare() bool { 111 | return false 112 | } 113 | 114 | // Stop halts the given dependency's fetch. 115 | func (d *VaultListQuery) Stop() { 116 | close(d.stopCh) 117 | } 118 | 119 | // ID returns the human-friendly version of this dependency. 120 | func (d *VaultListQuery) ID() string { 121 | return fmt.Sprintf("vault.list(%s)", d.path) 122 | } 123 | 124 | // Stringer interface reuses ID 125 | func (d *VaultListQuery) String() string { 126 | return d.ID() 127 | } 128 | 129 | func (d *VaultListQuery) SetOptions(opts QueryOptions) { 130 | d.opts = opts 131 | } 132 | 133 | // shimKvV2ListPath aligns the supported legacy path to KV v2 specs by inserting 134 | // /metadata/ into the path for listing secrets. Paths with /metadata/ are not modified. 135 | func shimKv2ListPath(rawPath, mountPath string) string { 136 | mountPath = strings.TrimSuffix(mountPath, "/") 137 | 138 | if strings.HasPrefix(rawPath, path.Join(mountPath, "metadata")) { 139 | // It doesn't need modifying. 140 | return rawPath 141 | } 142 | 143 | switch { 144 | case rawPath == mountPath: 145 | return path.Join(mountPath, "metadata") 146 | default: 147 | rawPath = strings.TrimPrefix(rawPath, mountPath) 148 | return path.Join(mountPath, "metadata", rawPath) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /internal/dependency/vault_list_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestNewVaultListQuery(t *testing.T) { 15 | t.Parallel() 16 | 17 | cases := []struct { 18 | name string 19 | i string 20 | exp *VaultListQuery 21 | err bool 22 | }{ 23 | { 24 | "empty", 25 | "", 26 | nil, 27 | true, 28 | }, 29 | { 30 | "path", 31 | "path", 32 | &VaultListQuery{ 33 | path: "path", 34 | }, 35 | false, 36 | }, 37 | { 38 | "leading_slash", 39 | "/leading/slash", 40 | &VaultListQuery{ 41 | path: "leading/slash", 42 | }, 43 | false, 44 | }, 45 | { 46 | "trailing_slash", 47 | "trailing/slash/", 48 | &VaultListQuery{ 49 | path: "trailing/slash", 50 | }, 51 | false, 52 | }, 53 | } 54 | 55 | for i, tc := range cases { 56 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 57 | act, err := NewVaultListQuery(tc.i) 58 | if (err != nil) != tc.err { 59 | t.Fatal(err) 60 | } 61 | 62 | if act != nil { 63 | act.stopCh = nil 64 | } 65 | 66 | assert.Equal(t, tc.exp, act) 67 | }) 68 | } 69 | } 70 | 71 | func TestVaultListQuery_Fetch(t *testing.T) { 72 | t.Parallel() 73 | 74 | clients, vault := testVaultServer(t, "listfetch", "1") 75 | secretsPath := vault.secretsPath 76 | 77 | clientsKv2, vaultKv2 := testVaultServer(t, "listfetchV2", "2") 78 | secretsPathKv2 := vaultKv2.secretsPath 79 | 80 | for _, v := range []*vaultServer{vault, vaultKv2} { 81 | err := v.CreateSecret("foo/bar", map[string]interface{}{ 82 | "ttl": "100ms", // explicitly make this a short duration for testing 83 | "zip": "zap", 84 | }) 85 | if err != nil { 86 | t.Fatal(err) 87 | } 88 | } 89 | cases := []struct { 90 | name string 91 | i string 92 | exp []string 93 | clients *ClientSet 94 | }{ 95 | { 96 | "exists", 97 | secretsPath, 98 | []string{"foo/"}, 99 | clients, 100 | }, 101 | { 102 | "no_exist", 103 | "not/a/real/path/like/ever", 104 | nil, 105 | clients, 106 | }, 107 | { 108 | "exists_v2", 109 | secretsPathKv2, 110 | []string{"foo/"}, 111 | clientsKv2, 112 | }, 113 | { 114 | "no_exist_kvv2", 115 | "not/a/real/path/like/ever", 116 | nil, 117 | clientsKv2, 118 | }, 119 | } 120 | 121 | for i, tc := range cases { 122 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 123 | d, err := NewVaultListQuery(tc.i) 124 | if err != nil { 125 | t.Fatal(err) 126 | } 127 | 128 | act, _, err := d.Fetch(tc.clients) 129 | if err != nil { 130 | t.Fatal(err) 131 | } 132 | 133 | assert.Equal(t, tc.exp, act) 134 | }) 135 | } 136 | 137 | t.Run("stops", func(t *testing.T) { 138 | d, err := NewVaultListQuery(secretsPath + "/foo/bar") 139 | if err != nil { 140 | t.Fatal(err) 141 | } 142 | 143 | dataCh := make(chan interface{}, 1) 144 | errCh := make(chan error, 1) 145 | go func() { 146 | for { 147 | data, _, err := d.Fetch(clients) 148 | if err != nil { 149 | errCh <- err 150 | return 151 | } 152 | dataCh <- data 153 | } 154 | }() 155 | 156 | select { 157 | case err := <-errCh: 158 | t.Fatal(err) 159 | case <-dataCh: 160 | } 161 | 162 | d.Stop() 163 | 164 | select { 165 | case err := <-errCh: 166 | if err != ErrStopped { 167 | t.Fatal(err) 168 | } 169 | case <-time.After(100 * time.Millisecond): 170 | t.Errorf("did not stop") 171 | } 172 | }) 173 | 174 | t.Run("fires_changes", func(t *testing.T) { 175 | d, err := NewVaultListQuery(secretsPath) 176 | if err != nil { 177 | t.Fatal(err) 178 | } 179 | 180 | //_, qm, err := d.Fetch(clients, nil) 181 | _, _, err = d.Fetch(clients) 182 | if err != nil { 183 | t.Fatal(err) 184 | } 185 | 186 | dataCh := make(chan interface{}, 1) 187 | errCh := make(chan error, 1) 188 | go func() { 189 | for { 190 | //data, _, err := d.Fetch(clients, &QueryOptions{WaitIndex: qm.LastIndex}) 191 | data, _, err := d.Fetch(clients) 192 | if err != nil { 193 | errCh <- err 194 | return 195 | } 196 | dataCh <- data 197 | return 198 | } 199 | }() 200 | 201 | select { 202 | case err := <-errCh: 203 | t.Fatal(err) 204 | case <-dataCh: 205 | } 206 | }) 207 | } 208 | 209 | func TestVaultListQuery_String(t *testing.T) { 210 | t.Parallel() 211 | 212 | cases := []struct { 213 | name string 214 | i string 215 | exp string 216 | }{ 217 | { 218 | "path", 219 | "path", 220 | "vault.list(path)", 221 | }, 222 | } 223 | 224 | for i, tc := range cases { 225 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 226 | d, err := NewVaultListQuery(tc.i) 227 | if err != nil { 228 | t.Fatal(err) 229 | } 230 | assert.Equal(t, tc.exp, d.ID()) 231 | }) 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /internal/dependency/vault_read.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "net/url" 9 | "strings" 10 | "time" 11 | 12 | "github.com/hashicorp/hcat/dep" 13 | "github.com/hashicorp/vault/api" 14 | "github.com/pkg/errors" 15 | ) 16 | 17 | // Ensure implements 18 | var _ isDependency = (*VaultReadQuery)(nil) 19 | 20 | // VaultReadQuery is the dependency to Vault for a secret 21 | type VaultReadQuery struct { 22 | isVault 23 | stopCh chan struct{} 24 | sleepCh chan time.Duration 25 | 26 | rawPath string 27 | queryValues url.Values 28 | secret *dep.Secret 29 | isKVv2 *bool 30 | secretPath string 31 | opts QueryOptions 32 | 33 | // vaultSecret is the actual Vault secret which we are renewing 34 | vaultSecret *api.Secret 35 | } 36 | 37 | // NewVaultReadQuery creates a new datacenter dependency. 38 | func NewVaultReadQuery(s string) (*VaultReadQuery, error) { 39 | s = strings.TrimSpace(s) 40 | s = strings.Trim(s, "/") 41 | if s == "" { 42 | return nil, fmt.Errorf("vault.read: invalid format: %q", s) 43 | } 44 | 45 | secretURL, err := url.Parse(s) 46 | if err != nil { 47 | return nil, err 48 | } 49 | 50 | return &VaultReadQuery{ 51 | stopCh: make(chan struct{}, 1), 52 | sleepCh: make(chan time.Duration, 1), 53 | rawPath: secretURL.Path, 54 | queryValues: secretURL.Query(), 55 | }, nil 56 | } 57 | 58 | // Fetch queries the Vault API 59 | func (d *VaultReadQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 60 | select { 61 | case <-d.stopCh: 62 | return nil, nil, ErrStopped 63 | case dur := <-d.sleepCh: 64 | select { 65 | case <-time.After(dur): 66 | break 67 | case <-d.stopCh: 68 | return nil, nil, ErrStopped 69 | } 70 | default: 71 | } 72 | 73 | firstRun := d.secret == nil 74 | 75 | if !firstRun && vaultSecretRenewable(d.secret) { 76 | err := renewSecret(clients, d) 77 | if err != nil { 78 | return nil, nil, errors.Wrap(err, d.ID()) 79 | } 80 | } 81 | 82 | err := d.fetchSecret(clients) 83 | if err != nil { 84 | return nil, nil, errors.Wrap(err, d.ID()) 85 | } 86 | 87 | if !vaultSecretRenewable(d.secret) { 88 | dur := leaseCheckWait(d.secret, nil) 89 | d.sleepCh <- dur 90 | } 91 | 92 | return respWithMetadata(d.secret) 93 | } 94 | 95 | func (d *VaultReadQuery) fetchSecret(clients dep.Clients) error { 96 | opts := d.opts.Merge(&QueryOptions{}) 97 | vaultSecret, err := d.readSecret(clients, opts) 98 | if err == nil { 99 | d.vaultSecret = vaultSecret 100 | // the cloned secret which will be exposed to the template 101 | d.secret = transformSecret(vaultSecret, opts.DefaultLease) 102 | } 103 | return err 104 | } 105 | 106 | func (d *VaultReadQuery) stopChan() chan struct{} { 107 | return d.stopCh 108 | } 109 | 110 | func (d *VaultReadQuery) secrets() (*dep.Secret, *api.Secret) { 111 | return d.secret, d.vaultSecret 112 | } 113 | 114 | // CanShare returns if this dependency is shareable. 115 | func (d *VaultReadQuery) CanShare() bool { 116 | return false 117 | } 118 | 119 | // Stop halts the given dependency's fetch. 120 | func (d *VaultReadQuery) Stop() { 121 | close(d.stopCh) 122 | } 123 | 124 | // ID returns the human-friendly version of this dependency. 125 | func (d *VaultReadQuery) ID() string { 126 | if v := d.queryValues["version"]; len(v) > 0 { 127 | return fmt.Sprintf("vault.read(%s.v%s)", d.rawPath, v[0]) 128 | } 129 | return fmt.Sprintf("vault.read(%s)", d.rawPath) 130 | } 131 | 132 | // Stringer interface reuses ID 133 | func (d *VaultReadQuery) String() string { 134 | return d.ID() 135 | } 136 | 137 | func (d *VaultReadQuery) readSecret(clients dep.Clients, opts *QueryOptions) (*api.Secret, error) { 138 | vaultClient := clients.Vault() 139 | 140 | // Check whether this secret refers to a KV v2 entry if we haven't yet. 141 | if d.isKVv2 == nil { 142 | mountPath, isKVv2, err := isKVv2(vaultClient, d.rawPath) 143 | if err != nil { 144 | isKVv2 = false 145 | d.secretPath = d.rawPath 146 | } else if isKVv2 { 147 | d.secretPath = shimKVv2Path(d.rawPath, mountPath) 148 | } else { 149 | d.secretPath = d.rawPath 150 | } 151 | d.isKVv2 = &isKVv2 152 | } 153 | 154 | vaultSecret, err := vaultClient.Logical().ReadWithData(d.secretPath, 155 | d.queryValues) 156 | if err != nil { 157 | return nil, errors.Wrap(err, d.ID()) 158 | } 159 | if vaultSecret == nil || deletedKVv2(vaultSecret) { 160 | return nil, fmt.Errorf("no secret exists at %s", d.secretPath) 161 | } 162 | return vaultSecret, nil 163 | } 164 | 165 | func (d *VaultReadQuery) SetOptions(opts QueryOptions) { 166 | d.opts = opts 167 | } 168 | 169 | func deletedKVv2(s *api.Secret) bool { 170 | switch md := s.Data["metadata"].(type) { 171 | case map[string]interface{}: 172 | return md["deletion_time"] != "" 173 | } 174 | return false 175 | } 176 | -------------------------------------------------------------------------------- /internal/dependency/vault_token.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "time" 8 | 9 | "github.com/hashicorp/hcat/dep" 10 | "github.com/hashicorp/vault/api" 11 | "github.com/pkg/errors" 12 | ) 13 | 14 | var ( 15 | // Ensure implements 16 | _ isDependency = (*VaultTokenQuery)(nil) 17 | ) 18 | 19 | // VaultTokenQuery is the dependency to Vault for a secret 20 | type VaultTokenQuery struct { 21 | isVault 22 | stopCh chan struct{} 23 | secret *dep.Secret 24 | vaultSecret *api.Secret 25 | } 26 | 27 | // NewVaultTokenQuery creates a new dependency. 28 | func NewVaultTokenQuery(token string) (*VaultTokenQuery, error) { 29 | vaultSecret := &api.Secret{ 30 | Auth: &api.SecretAuth{ 31 | ClientToken: token, 32 | Renewable: true, 33 | LeaseDuration: 1, 34 | }, 35 | } 36 | const tokenLeaseDuration = 5 * time.Minute 37 | return &VaultTokenQuery{ 38 | stopCh: make(chan struct{}, 1), 39 | vaultSecret: vaultSecret, 40 | secret: transformSecret(vaultSecret, tokenLeaseDuration), 41 | }, nil 42 | } 43 | 44 | // Fetch queries the Vault API 45 | func (d *VaultTokenQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 46 | select { 47 | case <-d.stopCh: 48 | return nil, nil, ErrStopped 49 | default: 50 | } 51 | 52 | if vaultSecretRenewable(d.secret) { 53 | err := renewSecret(clients, d) 54 | if err != nil { 55 | return nil, nil, errors.Wrap(err, d.ID()) 56 | } 57 | } 58 | 59 | return nil, nil, ErrLeaseExpired 60 | } 61 | 62 | func (d *VaultTokenQuery) stopChan() chan struct{} { 63 | return d.stopCh 64 | } 65 | 66 | func (d *VaultTokenQuery) secrets() (*dep.Secret, *api.Secret) { 67 | return d.secret, d.vaultSecret 68 | } 69 | 70 | // CanShare returns if this dependency is shareable. 71 | func (d *VaultTokenQuery) CanShare() bool { 72 | return false 73 | } 74 | 75 | // Stop halts the dependency's fetch function. 76 | func (d *VaultTokenQuery) Stop() { 77 | close(d.stopCh) 78 | } 79 | 80 | // ID returns the human-friendly version of this dependency. 81 | func (d *VaultTokenQuery) ID() string { 82 | return "vault.token" 83 | } 84 | 85 | // Stringer interface reuses ID 86 | func (d *VaultTokenQuery) String() string { 87 | return d.ID() 88 | } 89 | 90 | func (d *VaultTokenQuery) SetOptions(opts QueryOptions) {} 91 | -------------------------------------------------------------------------------- /internal/dependency/vault_token_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/hashicorp/hcat/dep" 11 | "github.com/hashicorp/vault/api" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestNewVaultTokenQuery(t *testing.T) { 16 | t.Parallel() 17 | 18 | cases := []struct { 19 | name string 20 | exp *VaultTokenQuery 21 | err bool 22 | }{ 23 | { 24 | "default", 25 | &VaultTokenQuery{ 26 | secret: &dep.Secret{ 27 | Auth: &dep.SecretAuth{ 28 | ClientToken: "my-token", 29 | Renewable: true, 30 | LeaseDuration: 1, 31 | }, 32 | LeaseDuration: 300, 33 | }, 34 | vaultSecret: &api.Secret{ 35 | Auth: &api.SecretAuth{ 36 | ClientToken: "my-token", 37 | Renewable: true, 38 | LeaseDuration: 1, 39 | }, 40 | }, 41 | }, 42 | false, 43 | }, 44 | } 45 | 46 | for i, tc := range cases { 47 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 48 | act, err := NewVaultTokenQuery("my-token") 49 | if (err != nil) != tc.err { 50 | t.Fatal(err) 51 | } 52 | 53 | if act != nil { 54 | act.stopCh = nil 55 | } 56 | 57 | assert.Equal(t, tc.exp, act) 58 | }) 59 | } 60 | } 61 | 62 | func TestVaultTokenQuery_String(t *testing.T) { 63 | t.Parallel() 64 | 65 | cases := []struct { 66 | name string 67 | exp string 68 | }{ 69 | { 70 | "default", 71 | "vault.token", 72 | }, 73 | } 74 | 75 | for i, tc := range cases { 76 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 77 | d, err := NewVaultTokenQuery("my-token") 78 | if err != nil { 79 | t.Fatal(err) 80 | } 81 | assert.Equal(t, tc.exp, d.ID()) 82 | }) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /internal/dependency/vault_write.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dependency 5 | 6 | import ( 7 | "crypto/sha1" 8 | "fmt" 9 | "io" 10 | "sort" 11 | "strings" 12 | "time" 13 | 14 | "github.com/hashicorp/hcat/dep" 15 | "github.com/hashicorp/vault/api" 16 | "github.com/pkg/errors" 17 | ) 18 | 19 | // Ensure implements 20 | var _ isDependency = (*VaultWriteQuery)(nil) 21 | 22 | // VaultWriteQuery is the dependency to Vault for a secret 23 | type VaultWriteQuery struct { 24 | isVault 25 | stopCh chan struct{} 26 | sleepCh chan time.Duration 27 | 28 | path string 29 | data map[string]interface{} 30 | dataHash string 31 | secret *dep.Secret 32 | opts QueryOptions 33 | 34 | // vaultSecret is the actual Vault secret which we are renewing 35 | vaultSecret *api.Secret 36 | } 37 | 38 | // NewVaultWriteQuery creates a new datacenter dependency. 39 | func NewVaultWriteQuery(s string, d map[string]interface{}) (*VaultWriteQuery, error) { 40 | s = strings.TrimSpace(s) 41 | s = strings.Trim(s, "/") 42 | if s == "" { 43 | return nil, fmt.Errorf("vault.write: invalid format: %q", s) 44 | } 45 | 46 | return &VaultWriteQuery{ 47 | stopCh: make(chan struct{}, 1), 48 | sleepCh: make(chan time.Duration, 1), 49 | path: s, 50 | data: d, 51 | dataHash: sha1Map(d), 52 | }, nil 53 | } 54 | 55 | // Fetch queries the Vault API 56 | func (d *VaultWriteQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 57 | select { 58 | case <-d.stopCh: 59 | return nil, nil, ErrStopped 60 | case dur := <-d.sleepCh: 61 | select { 62 | case <-time.After(dur): 63 | break 64 | case <-d.stopCh: 65 | return nil, nil, ErrStopped 66 | } 67 | default: 68 | } 69 | 70 | firstRun := d.secret == nil 71 | 72 | if !firstRun && vaultSecretRenewable(d.secret) { 73 | err := renewSecret(clients, d) 74 | if err != nil { 75 | return nil, nil, errors.Wrap(err, d.ID()) 76 | } 77 | } 78 | 79 | opts := d.opts.Merge(&QueryOptions{}) 80 | vaultSecret, err := d.writeSecret(clients, opts) 81 | if err != nil { 82 | return nil, nil, errors.Wrap(err, d.ID()) 83 | } 84 | 85 | // vaultSecret == nil when writing to KVv1 engines 86 | if vaultSecret == nil { 87 | return respWithMetadata(d.secret) 88 | } 89 | 90 | d.vaultSecret = vaultSecret 91 | // cloned secret which will be exposed to the template 92 | d.secret = transformSecret(vaultSecret, opts.DefaultLease) 93 | 94 | if !vaultSecretRenewable(d.secret) { 95 | dur := leaseCheckWait(d.secret, nil) 96 | d.sleepCh <- dur 97 | } 98 | 99 | return respWithMetadata(d.secret) 100 | } 101 | 102 | // meet renewer interface 103 | func (d *VaultWriteQuery) stopChan() chan struct{} { 104 | return d.stopCh 105 | } 106 | 107 | func (d *VaultWriteQuery) secrets() (*dep.Secret, *api.Secret) { 108 | return d.secret, d.vaultSecret 109 | } 110 | 111 | // CanShare returns if this dependency is shareable. 112 | func (d *VaultWriteQuery) CanShare() bool { 113 | return false 114 | } 115 | 116 | // Stop halts the given dependency's fetch. 117 | func (d *VaultWriteQuery) Stop() { 118 | close(d.stopCh) 119 | } 120 | 121 | // ID returns the human-friendly version of this dependency. 122 | func (d *VaultWriteQuery) ID() string { 123 | return fmt.Sprintf("vault.write(%s -> %s)", d.path, d.dataHash) 124 | } 125 | 126 | // Stringer interface reuses ID 127 | func (d *VaultWriteQuery) String() string { 128 | return d.ID() 129 | } 130 | 131 | // sha1Map returns the sha1 hash of the data in the map. The reason this data is 132 | // hashed is because it appears in the output and could contain sensitive 133 | // information. 134 | func sha1Map(m map[string]interface{}) string { 135 | keys := make([]string, 0, len(m)) 136 | for k := range m { 137 | keys = append(keys, k) 138 | } 139 | sort.Strings(keys) 140 | 141 | h := sha1.New() 142 | for _, k := range keys { 143 | io.WriteString(h, fmt.Sprintf("%s=%q", k, m[k])) 144 | } 145 | 146 | return fmt.Sprintf("%.4x", h.Sum(nil)) 147 | } 148 | 149 | func (d *VaultWriteQuery) writeSecret(clients dep.Clients, opts *QueryOptions) (*api.Secret, error) { 150 | path := d.path 151 | data := d.data 152 | 153 | mountPath, isv2, _ := isKVv2(clients.Vault(), path) 154 | if isv2 { 155 | path = shimKVv2Path(path, mountPath) 156 | data = map[string]interface{}{"data": d.data} 157 | } 158 | 159 | vaultSecret, err := clients.Vault().Logical().Write(path, data) 160 | if err != nil { 161 | return nil, errors.Wrap(err, d.ID()) 162 | } 163 | // vaultSecret is always nil when KVv1 engine (isv2==false) 164 | if isv2 && vaultSecret == nil { 165 | return nil, fmt.Errorf("no secret exists at %s", path) 166 | } 167 | 168 | return vaultSecret, nil 169 | } 170 | 171 | func (d *VaultWriteQuery) SetOptions(opts QueryOptions) { 172 | d.opts = opts 173 | } 174 | -------------------------------------------------------------------------------- /internal/test/helpers.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package test 5 | 6 | import ( 7 | "sync" 8 | 9 | "github.com/hashicorp/consul/sdk/testutil" 10 | ) 11 | 12 | // Meets consul/sdk/testutil/TestingTB interface 13 | var _ testutil.TestingTB = (*TestingTB)(nil) 14 | 15 | type TestingTB struct { 16 | sync.Mutex 17 | cleanup func() 18 | } 19 | 20 | func (t *TestingTB) DoCleanup() { 21 | t.Lock() 22 | defer t.Unlock() 23 | t.cleanup() 24 | } 25 | 26 | func (*TestingTB) Failed() bool { return false } 27 | func (*TestingTB) Logf(string, ...interface{}) {} 28 | func (*TestingTB) Fatalf(string, ...interface{}) {} 29 | func (*TestingTB) Name() string { return "TestingTB" } 30 | func (*TestingTB) Helper() {} 31 | func (t *TestingTB) Cleanup(f func()) { 32 | t.Lock() 33 | defer t.Unlock() 34 | prev := t.cleanup 35 | t.cleanup = func() { 36 | f() 37 | if prev != nil { 38 | prev() 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /looker.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package hcat 5 | 6 | import ( 7 | "net/http" 8 | "os" 9 | "sync" 10 | "time" 11 | 12 | "github.com/hashicorp/hcat/dep" 13 | idep "github.com/hashicorp/hcat/internal/dependency" 14 | ) 15 | 16 | // Looker is an interface for looking up data from Consul, Vault and the 17 | // Environment. 18 | type Looker interface { 19 | dep.Clients 20 | Env() []string 21 | Stop() 22 | } 23 | 24 | // ClientSet focuses only on external (consul/vault) dependencies 25 | // at this point so we extend it here to include environment variables to meet 26 | // the looker interface. 27 | type ClientSet struct { 28 | *idep.ClientSet 29 | // map of client-structs to retry functions 30 | *sync.RWMutex // locking for env and retry 31 | injectedEnv []string 32 | } 33 | 34 | // NewClientSet is used to create the clients used. 35 | // Fulfills the Looker interface. 36 | func NewClientSet() *ClientSet { 37 | return &ClientSet{ 38 | ClientSet: idep.NewClientSet(), 39 | 40 | RWMutex: &sync.RWMutex{}, 41 | injectedEnv: []string{}, 42 | } 43 | } 44 | 45 | // "Default" singleton ClientSet simplifies sharing of ClientSet's single-use 46 | // lookup fields, E.g. unwrapped vault tokens 47 | var oneClientSet sync.Once 48 | var defaultClientSet *ClientSet 49 | 50 | func DefaultClientSet() *ClientSet { 51 | oneClientSet.Do(func() { defaultClientSet = NewClientSet() }) 52 | return defaultClientSet 53 | } 54 | 55 | // AddConsul creates a Consul client and adds to the client set 56 | // HTTP/2 requires HTTPS, so if you need HTTP/2 be sure the local agent has 57 | // TLS setup and it's HTTPS port condigured and use with the Address here. 58 | func (cs *ClientSet) AddConsul(i ConsulInput) error { 59 | return cs.CreateConsulClient(i.toInternal()) 60 | } 61 | 62 | // AddVault creates a Vault client and adds to the client set 63 | func (cs *ClientSet) AddVault(i VaultInput) error { 64 | return cs.CreateVaultClient(i.toInternal()) 65 | } 66 | 67 | // Stop closes all idle connections for any attached clients and clears 68 | // the list of injected environment variables. 69 | func (cs *ClientSet) Stop() { 70 | if cs.ClientSet != nil { 71 | cs.ClientSet.Stop() 72 | } 73 | cs.injectedEnv = []string{} 74 | } 75 | 76 | // InjectEnv adds "key=value" pairs to the environment used for template 77 | // evaluations and child process runs. Note that this is in addition to the 78 | // environment running consul template and in the case of duplicates, the 79 | // last entry wins. 80 | func (cs *ClientSet) InjectEnv(env ...string) { 81 | cs.Lock() 82 | defer cs.Unlock() 83 | cs.injectedEnv = append(cs.injectedEnv, env...) 84 | } 85 | 86 | // You should do any messaging of the Environment variables during startup 87 | // As this will just use the raw Environment. 88 | func (cs *ClientSet) Env() []string { 89 | cs.RLock() 90 | defer cs.RUnlock() 91 | return append(os.Environ(), cs.injectedEnv...) 92 | } 93 | 94 | // Input wrappers around internal structure. Going to rework the internal 95 | // structure, so this abstracts that away to make that workable. 96 | 97 | // VaultInput defines the inputs needed to configure the Vault client. 98 | type VaultInput struct { 99 | HttpClient *http.Client // optional, principally for testing 100 | Address string 101 | Namespace string 102 | Token string 103 | UnwrapToken bool 104 | Transport TransportInput 105 | } 106 | 107 | func (i VaultInput) toInternal() *idep.CreateClientInput { 108 | cci := &idep.CreateClientInput{ 109 | Address: i.Address, 110 | Namespace: i.Namespace, 111 | Token: i.Token, 112 | UnwrapToken: i.UnwrapToken, 113 | } 114 | return i.Transport.toInternal(cci) 115 | } 116 | 117 | // ConsulInput defines the inputs needed to configure the Consul client. 118 | type ConsulInput struct { 119 | Address string 120 | Namespace string 121 | Token string 122 | AuthEnabled bool 123 | AuthUsername string 124 | AuthPassword string 125 | Transport TransportInput 126 | // optional, principally for testing 127 | HttpClient *http.Client 128 | } 129 | 130 | func (i ConsulInput) toInternal() *idep.CreateClientInput { 131 | cci := &idep.CreateClientInput{ 132 | Address: i.Address, 133 | Namespace: i.Namespace, 134 | Token: i.Token, 135 | AuthEnabled: i.AuthEnabled, 136 | AuthUsername: i.AuthUsername, 137 | AuthPassword: i.AuthPassword, 138 | } 139 | return i.Transport.toInternal(cci) 140 | } 141 | 142 | type TransportInput struct { 143 | // Transport/TLS 144 | SSLEnabled bool 145 | SSLVerify bool 146 | SSLCert string 147 | SSLKey string 148 | SSLCACert string 149 | SSLCAPath string 150 | ServerName string 151 | 152 | DialKeepAlive time.Duration 153 | DialTimeout time.Duration 154 | DisableKeepAlives bool 155 | IdleConnTimeout time.Duration 156 | MaxIdleConns int 157 | MaxIdleConnsPerHost int 158 | TLSHandshakeTimeout time.Duration 159 | } 160 | 161 | func (i TransportInput) toInternal(cci *idep.CreateClientInput) *idep.CreateClientInput { 162 | cci.SSLEnabled = i.SSLEnabled 163 | cci.SSLVerify = i.SSLVerify 164 | cci.SSLCert = i.SSLCert 165 | cci.SSLKey = i.SSLKey 166 | cci.SSLCACert = i.SSLCACert 167 | cci.SSLCAPath = i.SSLCAPath 168 | cci.ServerName = i.ServerName 169 | cci.TransportDialKeepAlive = i.DialKeepAlive 170 | cci.TransportDialTimeout = i.DialTimeout 171 | cci.TransportDisableKeepAlives = i.DisableKeepAlives 172 | cci.TransportIdleConnTimeout = i.IdleConnTimeout 173 | cci.TransportMaxIdleConns = i.MaxIdleConns 174 | cci.TransportMaxIdleConnsPerHost = i.MaxIdleConnsPerHost 175 | cci.TransportTLSHandshakeTimeout = i.TLSHandshakeTimeout 176 | return cci 177 | } 178 | -------------------------------------------------------------------------------- /looker_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package hcat 5 | 6 | import ( 7 | "fmt" 8 | "net" 9 | "net/http" 10 | "net/http/httptest" 11 | "os" 12 | "testing" 13 | ) 14 | 15 | func TestClientSet(t *testing.T) { 16 | t.Run("client-api-init", func(t *testing.T) { 17 | ts := httptest.NewUnstartedServer(http.HandlerFunc( 18 | func(w http.ResponseWriter, r *http.Request) { 19 | fmt.Fprint(w, `"test"`) 20 | })) 21 | ts.Listener, _ = net.Listen("tcp", "127.0.0.1:8500") 22 | ts.Start() 23 | defer ts.Close() 24 | // ^ fake consul 25 | cs := NewClientSet() 26 | err := cs.AddConsul(ConsulInput{}) 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | err = cs.AddVault(VaultInput{}) 31 | if err != nil { 32 | t.Fatal(err) 33 | } 34 | defer cs.Stop() 35 | if c := cs.Consul(); c == nil { 36 | t.Fatal("Consul Client failed to load.") 37 | } 38 | if v := cs.Vault(); v == nil { 39 | t.Fatal("Consul Client failed to load.") 40 | } 41 | }) 42 | 43 | t.Run("env", func(t *testing.T) { 44 | cs := NewClientSet() 45 | defer cs.Stop() 46 | // All os environment variables should be present 47 | parentEnv := make(map[string]bool) 48 | for _, e := range os.Environ() { 49 | parentEnv[e] = true 50 | } 51 | for _, e := range cs.Env() { 52 | if !parentEnv[e] { 53 | t.Fatal("Missing parent environment variable") 54 | } 55 | } 56 | // Check inject 57 | cs.InjectEnv("foo=bar") 58 | foundit := false 59 | for _, e := range cs.Env() { 60 | if e == "foo=bar" { 61 | foundit = true 62 | break 63 | } 64 | } 65 | if !foundit { 66 | t.Fatal("Injecting environment variable failed") 67 | } 68 | // check that it still pulls in os environ 69 | os.Setenv("key", "value") 70 | for _, e := range cs.Env() { 71 | if e == "key=value" { 72 | foundit = true 73 | break 74 | } 75 | } 76 | if !foundit { 77 | t.Fatal("System environment variable failed") 78 | } 79 | }) 80 | } 81 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package hcat 5 | 6 | import ( 7 | "flag" 8 | "io/ioutil" 9 | "log" 10 | "os" 11 | "testing" 12 | 13 | "github.com/hashicorp/consul/sdk/testutil" 14 | "github.com/hashicorp/hcat/internal/test" 15 | ) 16 | 17 | var ( 18 | RunExamples = flag.Bool("egs", false, "Run example tests") 19 | Consuladdr string 20 | ) 21 | 22 | func TestMain(m *testing.M) { 23 | flag.Parse() 24 | cleanup := func() {} 25 | if *RunExamples { 26 | Consuladdr, cleanup = testConsulSetup() 27 | } 28 | retCode := m.Run() 29 | cleanup() // can't defer w/ os.Exit 30 | os.Exit(retCode) 31 | } 32 | 33 | // support for running consul as part of integration testing 34 | func testConsulSetup() (string, func()) { 35 | var err error 36 | origStderr := os.Stderr 37 | os.Stderr, err = os.OpenFile(os.DevNull, os.O_WRONLY, 0o666) 38 | if err != nil { 39 | os.Stderr = origStderr 40 | } 41 | tb := &test.TestingTB{} 42 | consul, err := testutil.NewTestServerConfigT(tb, 43 | func(c *testutil.TestServerConfig) { 44 | c.LogLevel = "error" 45 | c.Stdout = ioutil.Discard 46 | c.Stderr = ioutil.Discard 47 | }) 48 | if err != nil { 49 | log.Fatalf("failed to start consul server: %v", err) 50 | } 51 | os.Stderr = origStderr 52 | return consul.HTTPAddr, func() { consul.Stop() } 53 | } 54 | -------------------------------------------------------------------------------- /resolver.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package hcat 5 | 6 | // Resolver is responsible rendering Templates and invoking Commands. 7 | // Empty but reserving the space for future use. 8 | type Resolver struct{} 9 | 10 | // ResolveEvent captures the whether the template dependencies have all been 11 | // resolved and rendered in memory. 12 | type ResolveEvent struct { 13 | // Complete is true if all dependencies have values and the template 14 | // is fully rendered (in memory). 15 | Complete bool 16 | 17 | // Contents is the rendered contents from the template. 18 | // Only returned when Complete is true. 19 | Contents []byte 20 | 21 | // NoChange is true if no dependencies have changes in values and therefore 22 | // templates were not re-rendered. 23 | NoChange bool 24 | } 25 | 26 | // Basic constructor, here for consistency and future flexibility. 27 | func NewResolver() *Resolver { 28 | return &Resolver{} 29 | } 30 | 31 | // Watcherer is the subset of the Watcher's API that the resolver needs. 32 | // The interface is used to make the used/required API explicit. 33 | type Watcherer interface { 34 | Buffering(Notifier) bool 35 | Recaller(Notifier) Recaller 36 | Complete(Notifier) bool 37 | Clients() Looker 38 | } 39 | 40 | // Templater the interface the Template provides. 41 | // The interface is used to make the used/required API explicit. 42 | type Templater interface { 43 | Notifier 44 | Execute(Recaller) ([]byte, error) 45 | } 46 | 47 | // Interface that indicates it implements Mark and Sweep "garbage" collection 48 | // to track and collect (stop/dereference) dependencies and views that are no 49 | // longer in use. This happens over longer runs with nested dependencies 50 | // (EG. loop over all services and lookup each service instance, instance 51 | // goes away) and results in goroutine leaks if not managed. 52 | type Collector interface { 53 | MarkForSweep(IDer) 54 | Sweep(IDer) 55 | } 56 | 57 | // Run the template Execute once. You should repeat calling this until 58 | // output returns Complete as true. It uses the watcher for dependency 59 | // lookup state. The content will be updated each pass until complete. 60 | func (r *Resolver) Run(tmpl Templater, w Watcherer) (ResolveEvent, error) { 61 | 62 | // If Watcherer supports it, wrap the template call with the Mark-n-Sweep 63 | // garbage collector to stop and dereference the old/unused views. 64 | gcViews := func(f func() ([]byte, error)) ([]byte, error) { return f() } 65 | if c, ok := w.(Collector); ok { 66 | gcViews = func(f func() ([]byte, error)) (data []byte, err error) { 67 | c.MarkForSweep(tmpl) 68 | if data, err = f(); err == nil { 69 | c.Sweep(tmpl) 70 | } 71 | return data, err 72 | } 73 | } 74 | 75 | // Attempt to render the template, returning any missing dependencies and 76 | // the rendered contents. If there are any missing dependencies, the 77 | // contents cannot be rendered or trusted! 78 | output, err := gcViews(func() ([]byte, error) { 79 | return tmpl.Execute(w.Recaller(tmpl)) 80 | }) 81 | switch { 82 | case err == ErrNoNewValues || err == nil: 83 | default: 84 | return ResolveEvent{}, err 85 | } 86 | 87 | return ResolveEvent{ 88 | Complete: w.Complete(tmpl), 89 | Contents: output, 90 | NoChange: err == ErrNoNewValues, 91 | }, nil 92 | } 93 | -------------------------------------------------------------------------------- /sets.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package hcat 5 | 6 | import ( 7 | "fmt" 8 | "sync" 9 | 10 | "github.com/hashicorp/hcat/dep" 11 | ) 12 | 13 | // stringSet is a simple string set implementation used 14 | type stringSet struct { 15 | *sync.RWMutex 16 | set map[string]struct{} 17 | } 18 | 19 | func newStringSet() stringSet { 20 | return stringSet{ 21 | RWMutex: &sync.RWMutex{}, 22 | set: make(map[string]struct{}), 23 | } 24 | } 25 | 26 | // Len(gth) or size of set 27 | func (s stringSet) Len() int { 28 | return len(s.set) 29 | } 30 | 31 | // Add and entry to the set 32 | func (s stringSet) add(k string) { 33 | s.set[k] = struct{}{} 34 | } 35 | func (s stringSet) Add(k string) { 36 | s.Lock() 37 | defer s.Unlock() 38 | s.add(k) 39 | } 40 | 41 | // Map returns a copy of the underlying map used by the set for membership. 42 | func (s stringSet) Map() map[string]struct{} { 43 | s.RLock() 44 | defer s.RUnlock() 45 | newmap := make(map[string]struct{}, len(s.set)) 46 | for k, v := range s.set { 47 | newmap[k] = v 48 | } 49 | return newmap 50 | } 51 | 52 | // Clear deletes all entries from set 53 | func (s stringSet) clear() { 54 | for k := range s.set { 55 | delete(s.set, k) 56 | } 57 | } 58 | func (s stringSet) Clear() { 59 | s.Lock() 60 | defer s.Unlock() 61 | s.clear() 62 | } 63 | 64 | // DepSet is a set (type) of Dependencies and is used with public template 65 | // rendering interface. Relative ordering is preserved. 66 | type DepSet struct { 67 | stringSet 68 | list []dep.Dependency 69 | } 70 | 71 | // NewDepSet returns an initialized DepSet (set of dependencies). 72 | func NewDepSet() *DepSet { 73 | return &DepSet{ 74 | list: make([]dep.Dependency, 0, 8), 75 | stringSet: newStringSet(), 76 | } 77 | } 78 | 79 | // Add adds a new element to the set if it does not already exist. 80 | func (s *DepSet) Add(d dep.Dependency) bool { 81 | s.Lock() 82 | defer s.Unlock() 83 | if _, ok := s.stringSet.set[d.ID()]; !ok { 84 | s.list = append(s.list, d) 85 | s.stringSet.add(d.ID()) 86 | return true 87 | } 88 | return false 89 | } 90 | 91 | // List returns the insertion-ordered list of dependencies. 92 | func (s *DepSet) List() []dep.Dependency { 93 | s.RLock() 94 | defer s.RUnlock() 95 | return s.list 96 | } 97 | 98 | // String is a string representation of the set. 99 | func (s *DepSet) String() string { 100 | s.RLock() 101 | defer s.RUnlock() 102 | return fmt.Sprint(s.list) 103 | } 104 | 105 | // Clear deletes all entries from set. 106 | func (s *DepSet) Clear() { 107 | s.Lock() 108 | defer s.Unlock() 109 | s.stringSet.clear() 110 | for i := range s.list { 111 | s.list[i] = nil 112 | } 113 | s.list = s.list[:0] 114 | } 115 | -------------------------------------------------------------------------------- /store.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package hcat 5 | 6 | import ( 7 | "sync" 8 | ) 9 | 10 | // Store is what Template uses to determine the values that are 11 | // available for template parsing. 12 | type Store struct { 13 | sync.RWMutex 14 | 15 | // data is the map of individual dependencies and the most recent data for 16 | // that dependency. 17 | data map[string]interface{} 18 | } 19 | 20 | // NewStore creates a new Store with empty values for each 21 | // of the key structs. 22 | func NewStore() *Store { 23 | return &Store{ 24 | data: make(map[string]interface{}), 25 | } 26 | } 27 | 28 | // Save accepts a dependency and the data to store associated with that 29 | // dep. This function converts the given data to a proper type and stores 30 | // it interally. 31 | func (s *Store) Save(id string, data interface{}) { 32 | s.Lock() 33 | defer s.Unlock() 34 | 35 | if _, ok := s.data[id]; ok { 36 | s.data[id] = data 37 | return 38 | } 39 | s.data[id] = data 40 | } 41 | 42 | // Recall gets the current value for the given dependency in the Store. 43 | func (s *Store) Recall(id string) (interface{}, bool) { 44 | s.RLock() 45 | defer s.RUnlock() 46 | 47 | data, ok := s.data[id] 48 | return data, ok 49 | } 50 | 51 | // Forget accepts a dependency and removes all associated data with this 52 | // dependency. 53 | func (s *Store) Delete(id string) { 54 | s.Lock() 55 | defer s.Unlock() 56 | 57 | delete(s.data, id) 58 | } 59 | 60 | // Reset clears all stored data. 61 | func (s *Store) Reset() { 62 | s.Lock() 63 | defer s.Unlock() 64 | 65 | for k := range s.data { 66 | delete(s.data, k) 67 | } 68 | } 69 | 70 | // forceSet is used to force set the value of a dependency for a given hash 71 | // code. Used in testing. 72 | func (s *Store) forceSet(hashCode string, data interface{}) { 73 | s.Lock() 74 | defer s.Unlock() 75 | 76 | s.data[hashCode] = data 77 | } 78 | -------------------------------------------------------------------------------- /store_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package hcat 5 | 6 | import ( 7 | "reflect" 8 | "testing" 9 | 10 | "github.com/hashicorp/hcat/dep" 11 | idep "github.com/hashicorp/hcat/internal/dependency" 12 | ) 13 | 14 | func TestNewStore(t *testing.T) { 15 | t.Parallel() 16 | st := NewStore() 17 | 18 | if st.data == nil { 19 | t.Errorf("expected data to not be nil") 20 | } 21 | } 22 | 23 | func TestRecall(t *testing.T) { 24 | t.Parallel() 25 | st := NewStore() 26 | 27 | d, err := idep.NewCatalogNodesQuery("") 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | 32 | nodes := []*dep.Node{ 33 | { 34 | Node: "node", 35 | Address: "address", 36 | }, 37 | } 38 | 39 | id := d.ID() 40 | st.Save(id, nodes) 41 | 42 | data, ok := st.Recall(id) 43 | if !ok { 44 | t.Fatal("expected data from Store") 45 | } 46 | 47 | result := data.([]*dep.Node) 48 | if !reflect.DeepEqual(result, nodes) { 49 | t.Errorf("expected %#v to be %#v", result, nodes) 50 | } 51 | } 52 | 53 | func TestForceSet(t *testing.T) { 54 | t.Parallel() 55 | st := NewStore() 56 | 57 | d, err := idep.NewCatalogNodesQuery("") 58 | if err != nil { 59 | t.Fatal(err) 60 | } 61 | 62 | nodes := []*dep.Node{ 63 | { 64 | Node: "node", 65 | Address: "address", 66 | }, 67 | } 68 | 69 | st.forceSet(d.ID(), nodes) 70 | 71 | data, ok := st.Recall(d.ID()) 72 | if !ok { 73 | t.Fatal("expected data from Store") 74 | } 75 | 76 | result := data.([]*dep.Node) 77 | if !reflect.DeepEqual(result, nodes) { 78 | t.Errorf("expected %#v to be %#v", result, nodes) 79 | } 80 | } 81 | 82 | func TestForget(t *testing.T) { 83 | t.Parallel() 84 | st := NewStore() 85 | 86 | d, err := idep.NewCatalogNodesQuery("") 87 | if err != nil { 88 | t.Fatal(err) 89 | } 90 | 91 | nodes := []*dep.Node{ 92 | { 93 | Node: "node", 94 | Address: "address", 95 | }, 96 | } 97 | 98 | id := d.ID() 99 | st.Save(id, nodes) 100 | st.Delete(id) 101 | 102 | if _, ok := st.Recall(id); ok { 103 | t.Errorf("expected %#v to not be forgotten", d) 104 | } 105 | } 106 | 107 | func TestReset(t *testing.T) { 108 | t.Parallel() 109 | st := NewStore() 110 | 111 | d, err := idep.NewCatalogNodesQuery("") 112 | if err != nil { 113 | t.Fatal(err) 114 | } 115 | 116 | nodes := []*dep.Node{ 117 | { 118 | Node: "node", 119 | Address: "address", 120 | }, 121 | } 122 | 123 | id := d.ID() 124 | st.Save(id, nodes) 125 | st.Reset() 126 | 127 | if _, ok := st.Recall(id); ok { 128 | t.Errorf("expected %#v to not be forgotten", d) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /testdata/sandbox/path/to/bad-symlink: -------------------------------------------------------------------------------- 1 | ../../../../template_funcs_test.go -------------------------------------------------------------------------------- /testdata/sandbox/path/to/file: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hashicorp/hcat/9b3254bdeff2a4b8265f71092c91db61df3b71f9/testdata/sandbox/path/to/file -------------------------------------------------------------------------------- /testdata/sandbox/path/to/ok-symlink: -------------------------------------------------------------------------------- 1 | file -------------------------------------------------------------------------------- /tfunc/consul_filter.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "fmt" 8 | "regexp" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/hashicorp/hcat/dep" 13 | "github.com/pkg/errors" 14 | ) 15 | 16 | // byMeta returns Services grouped by one or many ServiceMeta fields. 17 | func byMeta(meta string, services []*dep.HealthService) (groups map[string][]*dep.HealthService, err error) { 18 | re := regexp.MustCompile("[^a-zA-Z0-9_-]") 19 | normalize := func(x string) string { 20 | return re.ReplaceAllString(x, "_") 21 | } 22 | getOrDefault := func(m map[string]string, key string) string { 23 | realKey := strings.TrimSuffix(key, "|int") 24 | if val := m[realKey]; val != "" { 25 | return val 26 | } 27 | if strings.HasSuffix(key, "|int") { 28 | return "0" 29 | } 30 | return fmt.Sprintf("_no_%s_", realKey) 31 | } 32 | 33 | metas := strings.Split(meta, ",") 34 | 35 | groups = make(map[string][]*dep.HealthService) 36 | 37 | for _, s := range services { 38 | sm := s.ServiceMeta 39 | keyParts := []string{} 40 | for _, meta := range metas { 41 | value := getOrDefault(sm, meta) 42 | if strings.HasSuffix(meta, "|int") { 43 | value = getOrDefault(sm, meta) 44 | i, err := strconv.Atoi(value) 45 | if err != nil { 46 | return nil, errors.Wrap(err, fmt.Sprintf("cannot parse %v as number ", value)) 47 | } 48 | value = fmt.Sprintf("%05d", i) 49 | } 50 | keyParts = append(keyParts, normalize(value)) 51 | } 52 | key := strings.Join(keyParts, "_") 53 | groups[key] = append(groups[key], s) 54 | } 55 | 56 | return groups, nil 57 | } 58 | 59 | // byKey accepts a slice of KV pairs and returns a map of the top-level 60 | // key to all its subkeys. For example: 61 | // 62 | // elasticsearch/a //=> "1" 63 | // elasticsearch/b //=> "2" 64 | // redis/a/b //=> "3" 65 | // 66 | // Passing the result from Consul through byTag would yield: 67 | // 68 | // map[string]map[string]string{ 69 | // "elasticsearch": &dep.KeyPair{"a": "1"}, &dep.KeyPair{"b": "2"}, 70 | // "redis": &dep.KeyPair{"a/b": "3"} 71 | // } 72 | // 73 | // Note that the top-most key is stripped from the Key value. Keys that have no 74 | // prefix after stripping are removed from the list. 75 | func byKey(pairs []*dep.KeyPair) (map[string]map[string]*dep.KeyPair, error) { 76 | m := make(map[string]map[string]*dep.KeyPair) 77 | for _, pair := range pairs { 78 | parts := strings.Split(pair.Key, "/") 79 | top := parts[0] 80 | key := strings.Join(parts[1:], "/") 81 | 82 | if key == "" { 83 | // Do not add a key if it has no prefix after stripping. 84 | continue 85 | } 86 | 87 | if _, ok := m[top]; !ok { 88 | m[top] = make(map[string]*dep.KeyPair) 89 | } 90 | 91 | newPair := *pair 92 | newPair.Key = key 93 | m[top][key] = &newPair 94 | } 95 | 96 | return m, nil 97 | } 98 | 99 | // byTag is a template func that takes the provided services and 100 | // produces a map based on Service tags. 101 | // 102 | // The map key is a string representing the service tag. The map value is a 103 | // slice of Services which have the tag assigned. 104 | func byTag(in interface{}) (map[string][]interface{}, error) { 105 | m := make(map[string][]interface{}) 106 | 107 | switch typed := in.(type) { 108 | case nil: 109 | case []*dep.CatalogSnippet: 110 | for _, s := range typed { 111 | for _, t := range s.Tags { 112 | m[t] = append(m[t], s) 113 | } 114 | } 115 | case []*dep.HealthService: 116 | for _, s := range typed { 117 | for _, t := range s.Tags { 118 | m[t] = append(m[t], s) 119 | } 120 | } 121 | default: 122 | return nil, fmt.Errorf("byTag: wrong argument type %T", in) 123 | } 124 | 125 | return m, nil 126 | } 127 | -------------------------------------------------------------------------------- /tfunc/consul_filter_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "reflect" 10 | "testing" 11 | 12 | "github.com/hashicorp/hcat" 13 | "github.com/hashicorp/hcat/dep" 14 | ) 15 | 16 | func Test_byMeta(t *testing.T) { 17 | t.Parallel() 18 | svcA := &dep.HealthService{ 19 | ServiceMeta: map[string]string{ 20 | "version": "v2", 21 | "version_num": "2", 22 | "bad_version_num": "1zz", 23 | "env": "dev", 24 | }, 25 | ID: "svcA", 26 | } 27 | 28 | svcB := &dep.HealthService{ 29 | ServiceMeta: map[string]string{ 30 | "version": "v11", 31 | "version_num": "11", 32 | "bad_version_num": "1zz", 33 | "env": "prod", 34 | }, 35 | ID: "svcB", 36 | } 37 | 38 | svcC := &dep.HealthService{ 39 | ServiceMeta: map[string]string{ 40 | "version": "v11", 41 | "version_num": "11", 42 | "bad_version_num": "1zz", 43 | "env": "prod", 44 | }, 45 | ID: "svcC", 46 | } 47 | 48 | type args struct { 49 | meta string 50 | services []*dep.HealthService 51 | } 52 | 53 | tests := []struct { 54 | name string 55 | args args 56 | wantGroups map[string][]*dep.HealthService 57 | wantErr bool 58 | }{ 59 | { 60 | name: "version string", 61 | args: args{ 62 | meta: "version", 63 | services: []*dep.HealthService{svcA, svcB, svcC}, 64 | }, 65 | wantGroups: map[string][]*dep.HealthService{ 66 | "v11": {svcB, svcC}, 67 | "v2": {svcA}, 68 | }, 69 | wantErr: false, 70 | }, 71 | { 72 | name: "version number", 73 | args: args{ 74 | meta: "version_num|int", 75 | services: []*dep.HealthService{svcA, svcB, svcC}, 76 | }, 77 | wantGroups: map[string][]*dep.HealthService{ 78 | "00011": {svcB, svcC}, 79 | "00002": {svcA}, 80 | }, 81 | wantErr: false, 82 | }, 83 | { 84 | name: "bad version number", 85 | args: args{ 86 | meta: "bad_version_num|int", 87 | services: []*dep.HealthService{svcA, svcB, svcC}, 88 | }, 89 | wantGroups: nil, 90 | wantErr: true, 91 | }, 92 | { 93 | name: "multiple meta", 94 | args: args{ 95 | meta: "env,version_num|int,version", 96 | services: []*dep.HealthService{svcA, svcB, svcC}, 97 | }, 98 | wantGroups: map[string][]*dep.HealthService{ 99 | "dev_00002_v2": {svcA}, 100 | "prod_00011_v11": {svcB, svcC}, 101 | }, 102 | wantErr: false, 103 | }, 104 | } 105 | for _, tt := range tests { 106 | t.Run(tt.name, func(t *testing.T) { 107 | gotGroups, err := byMeta(tt.args.meta, tt.args.services) 108 | if (err != nil) != tt.wantErr { 109 | t.Errorf("byMeta() error = %v, wantErr %v", err, tt.wantErr) 110 | return 111 | } 112 | 113 | onlyIDs := func(groups map[string][]*dep.HealthService) (ids map[string]map[string]int) { 114 | ids = make(map[string]map[string]int) 115 | for group, svcs := range groups { 116 | ids[group] = make(map[string]int) 117 | for _, svc := range svcs { 118 | ids[group][svc.ID] = 1 119 | } 120 | } 121 | return 122 | } 123 | 124 | gotIDs := onlyIDs(gotGroups) 125 | wantIDs := onlyIDs(tt.wantGroups) 126 | if !reflect.DeepEqual(gotGroups, tt.wantGroups) { 127 | t.Errorf("byMeta() = %v, want %v", gotIDs, wantIDs) 128 | } 129 | }) 130 | } 131 | } 132 | 133 | func TestConsulFilterExecute(t *testing.T) { 134 | t.Parallel() 135 | 136 | type testCase struct { 137 | name string 138 | ti hcat.TemplateInput 139 | i hcat.Watcherer 140 | e string 141 | err bool 142 | } 143 | 144 | testFunc := func(tc testCase) func(*testing.T) { 145 | return func(t *testing.T) { 146 | tpl := newTemplate(tc.ti) 147 | 148 | a, err := tpl.Execute(tc.i.Recaller(tpl)) 149 | if (err != nil) != tc.err { 150 | t.Fatal(err) 151 | } 152 | if !bytes.Equal([]byte(tc.e), a) { 153 | t.Errorf("\nexp: %#v\nact: %#v", tc.e, string(a)) 154 | } 155 | } 156 | } 157 | 158 | cases := []testCase{ 159 | { 160 | "helper_by_key", 161 | hcat.TemplateInput{ 162 | Contents: `{{ range $key, $pairs := tree "list" | byKey }}{{ $key }}:{{ range $pairs }}{{ .Key }}={{ .Value }}{{ end }}{{ end }}`, 163 | }, 164 | func() hcat.Watcherer { 165 | st := hcat.NewStore() 166 | id := testKVListQueryID("list") 167 | st.Save(id, []*dep.KeyPair{ 168 | {Key: "", Value: ""}, 169 | {Key: "foo/bar", Value: "a"}, 170 | {Key: "zip/zap", Value: "b"}, 171 | }) 172 | return fakeWatcher{st} 173 | }(), 174 | "foo:bar=azip:zap=b", 175 | false, 176 | }, 177 | { 178 | "helper_by_tag", 179 | hcat.TemplateInput{ 180 | Contents: `{{ range $tag, $services := service "webapp" | byTag }}{{ $tag }}:{{ range $services }}{{ .Address }}{{ end }}{{ end }}`, 181 | }, 182 | func() hcat.Watcherer { 183 | st := hcat.NewStore() 184 | id := testHealthServiceQueryID("webapp") 185 | st.Save(id, []*dep.HealthService{ 186 | { 187 | Address: "1.2.3.4", 188 | Tags: []string{"prod", "staging"}, 189 | }, 190 | { 191 | Address: "5.6.7.8", 192 | Tags: []string{"staging"}, 193 | }, 194 | }) 195 | return fakeWatcher{st} 196 | }(), 197 | "prod:1.2.3.4staging:1.2.3.45.6.7.8", 198 | false, 199 | }, 200 | } 201 | 202 | for i, tc := range cases { 203 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), testFunc(tc)) 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /tfunc/contains.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "reflect" 8 | "strings" 9 | ) 10 | 11 | // contains is a function that have reverse arguments of "in" and is designed to 12 | // be used as a pipe instead of a function: 13 | // 14 | // {{ l | contains "thing" }} 15 | // 16 | func contains(v, l interface{}) (bool, error) { 17 | return in(l, v) 18 | } 19 | 20 | // containsSomeFunc returns functions to implement each of the following: 21 | // 22 | // 1. containsAll - true if (∀x ∈ v then x ∈ l); false otherwise 23 | // 2. containsAny - true if (∃x ∈ v such that x ∈ l); false otherwise 24 | // 3. containsNone - true if (∀x ∈ v then x ∉ l); false otherwise 25 | // 2. containsNotAll - true if (∃x ∈ v such that x ∉ l); false otherwise 26 | // 27 | // ret_true - return true at end of loop for none/all; false for any/notall 28 | // invert - invert block test for all/notall 29 | func containsSomeFunc(retTrue, invert bool) func([]interface{}, interface{}) (bool, error) { 30 | return func(v []interface{}, l interface{}) (bool, error) { 31 | for i := 0; i < len(v); i++ { 32 | if ok, _ := in(l, v[i]); ok != invert { 33 | return !retTrue, nil 34 | } 35 | } 36 | return retTrue, nil 37 | } 38 | } 39 | 40 | // in searches for a given value in a given interface. 41 | func in(l, v interface{}) (bool, error) { 42 | lv := reflect.ValueOf(l) 43 | vv := reflect.ValueOf(v) 44 | 45 | switch lv.Kind() { 46 | case reflect.Array, reflect.Slice: 47 | // if the slice contains 'interface' elements, then the element needs to be extracted directly to examine its type, 48 | // otherwise it will just resolve to 'interface'. 49 | var interfaceSlice []interface{} 50 | if reflect.TypeOf(l).Elem().Kind() == reflect.Interface { 51 | interfaceSlice = l.([]interface{}) 52 | } 53 | 54 | for i := 0; i < lv.Len(); i++ { 55 | var lvv reflect.Value 56 | if interfaceSlice != nil { 57 | lvv = reflect.ValueOf(interfaceSlice[i]) 58 | } else { 59 | lvv = lv.Index(i) 60 | } 61 | 62 | switch lvv.Kind() { 63 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 64 | switch vv.Kind() { 65 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 66 | if vv.Int() == lvv.Int() { 67 | return true, nil 68 | } 69 | } 70 | case reflect.Float32, reflect.Float64: 71 | switch vv.Kind() { 72 | case reflect.Float32, reflect.Float64: 73 | if vv.Float() == lvv.Float() { 74 | return true, nil 75 | } 76 | } 77 | case reflect.String: 78 | if vv.Type() == lvv.Type() && vv.String() == lvv.String() { 79 | return true, nil 80 | } 81 | } 82 | } 83 | case reflect.String: 84 | if vv.Type() == lv.Type() && strings.Contains(lv.String(), vv.String()) { 85 | return true, nil 86 | } 87 | } 88 | 89 | return false, nil 90 | } 91 | -------------------------------------------------------------------------------- /tfunc/deny.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import "errors" 7 | 8 | var disabledErr = errors.New("function disabled") 9 | 10 | // DenyFunc always returns an error, to be used in place of template functions 11 | // that you want denied. For use with the FuncMapMerge. 12 | func DenyFunc(...interface{}) (string, error) { 13 | return "", disabledErr 14 | } 15 | -------------------------------------------------------------------------------- /tfunc/deny_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import "testing" 7 | 8 | func TestDeny(t *testing.T) { 9 | v, err := DenyFunc() 10 | if v != "" { 11 | t.Errorf("bad return string: '%v'", v) 12 | } 13 | if err != disabledErr { 14 | t.Errorf("bad error: %v", err) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /tfunc/env.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "os" 8 | "strings" 9 | ) 10 | 11 | // envFunc returns a function which checks the value of an environment variable. 12 | // Invokers can specify their own environment, which takes precedences over any 13 | // real environment variables 14 | func envFunc(env []string) func(string) (string, error) { 15 | return func(s string) (string, error) { 16 | for _, e := range env { 17 | split := strings.SplitN(e, "=", 2) 18 | k, v := split[0], split[1] 19 | if k == s { 20 | return v, nil 21 | } 22 | } 23 | return os.Getenv(s), nil 24 | } 25 | } 26 | 27 | // envOrDefaultFunc returns a function which checks the value of an 28 | // environment variable. Invokers can specify their own environment, which 29 | // takes precedences over any real environment variables. If an environment 30 | // variable is found, the value of that variable will be used. This includes 31 | // empty values. Otherwise, the default will be used instead. 32 | func envOrDefaultFunc(env []string) func(string, string) (string, error) { 33 | return func(s string, def string) (string, error) { 34 | for _, e := range env { 35 | split := strings.SplitN(e, "=", 2) 36 | k, v := split[0], split[1] 37 | if k == s { 38 | return v, nil 39 | } 40 | } 41 | val, isPresent := os.LookupEnv(s) 42 | if isPresent { 43 | return val, nil 44 | } 45 | return def, nil 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /tfunc/env_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "os" 10 | "testing" 11 | 12 | "github.com/hashicorp/hcat" 13 | ) 14 | 15 | func TestEnvExecute(t *testing.T) { 16 | t.Parallel() 17 | 18 | // set an environment variable for the tests 19 | envVars := map[string]string{"HCAT_TEST": "foo", "EMPTY_VAR": ""} 20 | for k, v := range envVars { 21 | if err := os.Setenv(k, v); err != nil { 22 | t.Fatal(err) 23 | } 24 | defer func(e string) { os.Unsetenv(e) }(k) 25 | } 26 | 27 | cases := []struct { 28 | name string 29 | ti hcat.TemplateInput 30 | i hcat.Watcherer 31 | e string 32 | err bool 33 | }{ 34 | { 35 | "helper_env", 36 | hcat.TemplateInput{ 37 | // HCAT_TEST set above 38 | Contents: `{{ env "HCAT_TEST" }}`, 39 | }, 40 | fakeWatcher{hcat.NewStore()}, 41 | "foo", 42 | false, 43 | }, 44 | { 45 | "func_envOrDefault", 46 | hcat.TemplateInput{ 47 | Contents: `{{ envOrDefault "HCAT_TEST" "100" }} {{ envOrDefault "EMPTY_VAR" "200" }} {{ envOrDefault "UNSET_VAR" "300" }}`, 48 | }, 49 | fakeWatcher{hcat.NewStore()}, 50 | "foo 300", 51 | false, 52 | }, 53 | } 54 | 55 | for i, tc := range cases { 56 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 57 | tpl := newTemplate(tc.ti) 58 | 59 | a, err := tpl.Execute(tc.i.Recaller(tpl)) 60 | if (err != nil) != tc.err { 61 | t.Fatal(err) 62 | } 63 | if !bytes.Equal([]byte(tc.e), a) { 64 | t.Errorf("\nexp: %#v\nact: %#v", tc.e, string(a)) 65 | } 66 | }) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /tfunc/file.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "os" 8 | "os/user" 9 | "path/filepath" 10 | "strconv" 11 | "strings" 12 | 13 | "github.com/hashicorp/hcat" 14 | idep "github.com/hashicorp/hcat/internal/dependency" 15 | ) 16 | 17 | // fileFunc returns the contents of the file and monitors a file for changes 18 | func fileFunc(recall hcat.Recaller) interface{} { 19 | return func(s string) (string, error) { 20 | if len(s) == 0 { 21 | return "", nil 22 | } 23 | d, err := idep.NewFileQuery(s) 24 | if err != nil { 25 | return "", err 26 | } 27 | 28 | if value, ok := recall(d); ok { 29 | if value == nil { 30 | return "", nil 31 | } 32 | return value.(string), nil 33 | } 34 | 35 | return "", nil 36 | } 37 | } 38 | 39 | // writeToFile writes the content to a file with optional flags for 40 | // permissions, username (or UID), group name (or GID), and to select appending 41 | // mode or add a newline. 42 | // 43 | // The username and group name fields can be left blank to default to the 44 | // current user and group. 45 | // 46 | // For example: 47 | // key "my/key/path" | writeToFile "/my/file/path.txt" "" "" "0644" 48 | // key "my/key/path" | writeToFile "/my/file/path.txt" "100" "1000" "0644" 49 | // key "my/key/path" | writeToFile "/my/file/path.txt" "my-user" "my-group" "0644" 50 | // key "my/key/path" | writeToFile "/my/file/path.txt" "my-user" "my-group" "0644" "append" 51 | // key "my/key/path" | writeToFile "/my/file/path.txt" "my-user" "my-group" "0644" "append,newline" 52 | // 53 | func writeToFile(path, username, groupName, permissions string, args ...string) (string, error) { 54 | // Parse arguments 55 | flags := "" 56 | if len(args) == 2 { 57 | flags = args[0] 58 | } 59 | content := args[len(args)-1] 60 | 61 | p_u, err := strconv.ParseUint(permissions, 8, 32) 62 | if err != nil { 63 | return "", err 64 | } 65 | perm := os.FileMode(p_u) 66 | 67 | // Write to file 68 | var f *os.File 69 | shouldAppend := strings.Contains(flags, "append") 70 | if shouldAppend { 71 | f, err = os.OpenFile(path, os.O_APPEND|os.O_WRONLY|os.O_CREATE, perm) 72 | if err != nil { 73 | return "", err 74 | } 75 | } else { 76 | dirPath := filepath.Dir(path) 77 | 78 | if _, err := os.Stat(dirPath); err != nil { 79 | err := os.MkdirAll(dirPath, os.ModePerm) 80 | if err != nil { 81 | return "", err 82 | } 83 | } 84 | 85 | f, err = os.Create(path) 86 | if err != nil { 87 | return "", err 88 | } 89 | } 90 | defer f.Close() 91 | 92 | writingContent := []byte(content) 93 | shouldAddNewLine := strings.Contains(flags, "newline") 94 | if shouldAddNewLine { 95 | writingContent = append(writingContent, []byte("\n")...) 96 | } 97 | if _, err = f.Write(writingContent); err != nil { 98 | return "", err 99 | } 100 | 101 | // Change ownership and permissions 102 | var uid int 103 | var gid int 104 | if err != nil { 105 | return "", err 106 | } 107 | 108 | if username == "" { 109 | uid = os.Getuid() 110 | } else { 111 | var convErr error 112 | u, err := user.Lookup(username) 113 | if err != nil { 114 | // Check if username string is already a UID 115 | uid, convErr = strconv.Atoi(username) 116 | if convErr != nil { 117 | return "", err 118 | } 119 | } else { 120 | uid, _ = strconv.Atoi(u.Uid) 121 | } 122 | } 123 | 124 | if groupName == "" { 125 | gid = os.Getgid() 126 | } else { 127 | var convErr error 128 | g, err := user.LookupGroup(groupName) 129 | if err != nil { 130 | gid, convErr = strconv.Atoi(groupName) 131 | if convErr != nil { 132 | return "", err 133 | } 134 | } else { 135 | gid, _ = strconv.Atoi(g.Gid) 136 | } 137 | } 138 | 139 | // Avoid the chown call altogether if using current user and group. 140 | if username != "" || groupName != "" { 141 | err = os.Chown(path, uid, gid) 142 | if err != nil { 143 | return "", err 144 | } 145 | } 146 | 147 | err = os.Chmod(path, perm) 148 | if err != nil { 149 | return "", err 150 | } 151 | 152 | return "", nil 153 | } 154 | -------------------------------------------------------------------------------- /tfunc/loop.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "fmt" 8 | "reflect" 9 | ) 10 | 11 | // loop accepts varying parameters and differs its behavior. If given one 12 | // parameter, loop will return a goroutine that begins at 0 and loops until the 13 | // given int, increasing the index by 1 each iteration. If given two parameters, 14 | // loop will return a goroutine that begins at the first parameter and loops 15 | // up to but not including the second parameter. 16 | // 17 | // // Prints 0 1 2 3 4 18 | // for _, i := range loop(5) { 19 | // print(i) 20 | // } 21 | // 22 | // // Prints 5 6 7 23 | // for _, i := range loop(5, 8) { 24 | // print(i) 25 | // } 26 | // 27 | func loop(ifaces ...interface{}) (<-chan int64, error) { 28 | 29 | to64 := func(i interface{}) (int64, error) { 30 | v := reflect.ValueOf(i) 31 | switch v.Kind() { 32 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, 33 | reflect.Int64: 34 | return int64(v.Int()), nil 35 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, 36 | reflect.Uint64: 37 | return int64(v.Uint()), nil 38 | case reflect.String: 39 | return parseInt(v.String()) 40 | } 41 | return 0, fmt.Errorf("loop: bad argument type: %T", i) 42 | } 43 | 44 | var i1, i2 interface{} 45 | switch len(ifaces) { 46 | case 1: 47 | i1, i2 = 0, ifaces[0] 48 | case 2: 49 | i1, i2 = ifaces[0], ifaces[1] 50 | default: 51 | return nil, fmt.Errorf("loop: wrong number of arguments, expected "+ 52 | "1 or 2, but got %d", len(ifaces)) 53 | } 54 | 55 | start, err := to64(i1) 56 | if err != nil { 57 | return nil, err 58 | } 59 | stop, err := to64(i2) 60 | if err != nil { 61 | return nil, err 62 | } 63 | 64 | ch := make(chan int64) 65 | 66 | go func() { 67 | for i := start; i < stop; i++ { 68 | ch <- i 69 | } 70 | close(ch) 71 | }() 72 | 73 | return ch, nil 74 | } 75 | -------------------------------------------------------------------------------- /tfunc/loop_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/hashicorp/hcat" 12 | ) 13 | 14 | func TestLoopExecute(t *testing.T) { 15 | t.Parallel() 16 | 17 | cases := []struct { 18 | name string 19 | ti hcat.TemplateInput 20 | i hcat.Watcherer 21 | e string 22 | err bool 23 | }{ 24 | { 25 | "helper_loop", 26 | hcat.TemplateInput{ 27 | Contents: `{{ range loop 3 }}1{{ end }}`, 28 | }, 29 | fakeWatcher{hcat.NewStore()}, 30 | "111", 31 | false, 32 | }, 33 | { 34 | "helper_loop__i", 35 | hcat.TemplateInput{ 36 | Contents: `{{ range $i := loop 3 }}{{ $i }}{{ end }}`, 37 | }, 38 | fakeWatcher{hcat.NewStore()}, 39 | "012", 40 | false, 41 | }, 42 | { 43 | "helper_loop_start", 44 | hcat.TemplateInput{ 45 | Contents: `{{ range loop 1 3 }}1{{ end }}`, 46 | }, 47 | fakeWatcher{hcat.NewStore()}, 48 | "11", 49 | false, 50 | }, 51 | { 52 | "helper_loop_text", 53 | hcat.TemplateInput{ 54 | Contents: `{{ range loop 1 "3" }}1{{ end }}`, 55 | }, 56 | fakeWatcher{hcat.NewStore()}, 57 | "11", 58 | false, 59 | }, 60 | { 61 | "helper_loop_parseInt", 62 | hcat.TemplateInput{ 63 | Contents: `{{ $i := print "3" | parseInt }}{{ range loop 1 $i }}1{{ end }}`, 64 | }, 65 | fakeWatcher{hcat.NewStore()}, 66 | "11", 67 | false, 68 | }, 69 | { 70 | // GH-1143 71 | "helper_loop_var", 72 | hcat.TemplateInput{ 73 | Contents: `{{$n := 3 }}` + 74 | `{{ range $i := loop $n }}{{ $i }}{{ end }}`, 75 | }, 76 | fakeWatcher{hcat.NewStore()}, 77 | "012", 78 | false, 79 | }, 80 | } 81 | 82 | for i, tc := range cases { 83 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 84 | tpl := newTemplate(tc.ti) 85 | 86 | a, err := tpl.Execute(tc.i.Recaller(tpl)) 87 | if (err != nil) != tc.err { 88 | t.Fatal(err) 89 | } 90 | if !bytes.Equal([]byte(tc.e), a) { 91 | t.Errorf("\nexp: %#v\nact: %#v", tc.e, string(a)) 92 | } 93 | }) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /tfunc/maps.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "fmt" 8 | "sort" 9 | "strings" 10 | 11 | "github.com/hashicorp/hcat/dep" 12 | "github.com/imdario/mergo" 13 | "github.com/pkg/errors" 14 | ) 15 | 16 | // explode is used to expand a list of keypairs into a deeply-nested hash. 17 | func explode(pairs []*dep.KeyPair) (map[string]interface{}, error) { 18 | m := make(map[string]interface{}) 19 | for _, pair := range pairs { 20 | if err := explodeHelper(m, pair.Key, pair.Value, pair.Key); err != nil { 21 | return nil, errors.Wrap(err, "explode") 22 | } 23 | } 24 | return m, nil 25 | } 26 | 27 | // explodeHelper is a recursive helper for explode and explodeMap 28 | func explodeHelper(m map[string]interface{}, k string, v interface{}, p string) error { 29 | if strings.Contains(k, "/") { 30 | parts := strings.Split(k, "/") 31 | top := parts[0] 32 | key := strings.Join(parts[1:], "/") 33 | 34 | if _, ok := m[top]; !ok { 35 | m[top] = make(map[string]interface{}) 36 | } 37 | nest, ok := m[top].(map[string]interface{}) 38 | if !ok { 39 | return fmt.Errorf("not a map: %q: %q already has value %q", p, top, m[top]) 40 | } 41 | return explodeHelper(nest, key, v, k) 42 | } 43 | 44 | if k != "" { 45 | m[k] = v 46 | } 47 | 48 | return nil 49 | } 50 | 51 | // explodeMap turns a single-level map into a deeply-nested hash. 52 | func explodeMap(mapIn map[string]interface{}) (map[string]interface{}, error) { 53 | mapOut := make(map[string]interface{}) 54 | 55 | var keys []string 56 | for k := range mapIn { 57 | keys = append(keys, k) 58 | } 59 | sort.Strings(keys) 60 | 61 | for i := range keys { 62 | if err := explodeHelper(mapOut, keys[i], mapIn[keys[i]], keys[i]); err != nil { 63 | return nil, errors.Wrap(err, "explodeMap") 64 | } 65 | } 66 | return mapOut, nil 67 | } 68 | 69 | type _map = map[string]interface{} 70 | 71 | // mergeMap is used to merge two maps 72 | func mergeMap(dstMap _map, srcMap _map, args ...func(*mergo.Config)) (_map, error) { 73 | if err := mergo.Map(&dstMap, srcMap, args...); err != nil { 74 | return nil, err 75 | } 76 | return dstMap, nil 77 | } 78 | 79 | // mergeMapWithOverride is used to merge two maps with dstMap overriding vaules in srcMap 80 | func mergeMapWithOverride(dstMap _map, srcMap _map) (_map, error) { 81 | return mergeMap(dstMap, srcMap, mergo.WithOverride) 82 | } 83 | -------------------------------------------------------------------------------- /tfunc/maps_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/hashicorp/hcat" 12 | "github.com/hashicorp/hcat/dep" 13 | ) 14 | 15 | func TestMapExecute(t *testing.T) { 16 | t.Parallel() 17 | 18 | cases := []struct { 19 | name string 20 | ti hcat.TemplateInput 21 | i hcat.Watcherer 22 | e string 23 | err bool 24 | }{ 25 | { 26 | "helper_explode", 27 | hcat.TemplateInput{ 28 | Contents: `{{ range $k, $v := tree "list" | explode }}{{ $k }}{{ $v }}{{ end }}`, 29 | }, 30 | func() hcat.Watcherer { 31 | st := hcat.NewStore() 32 | id := testKVListQueryID("list") 33 | st.Save(id, []*dep.KeyPair{ 34 | {Key: "", Value: ""}, 35 | {Key: "foo/bar", Value: "a"}, 36 | {Key: "zip/zap", Value: "b"}, 37 | }) 38 | return fakeWatcher{st} 39 | }(), 40 | "foomap[bar:a]zipmap[zap:b]", 41 | false, 42 | }, 43 | { 44 | "helper_explodemap", 45 | hcat.TemplateInput{ 46 | Contents: `{{ testMap | explodeMap }}`, 47 | FuncMapMerge: map[string]interface{}{ 48 | "testMap": func() map[string]interface{} { 49 | m := make(map[string]interface{}) 50 | m["foo"] = map[string]string{"bar": "a"} 51 | m["qux"] = "c" 52 | m["zip"] = map[string]string{"zap": "d"} 53 | return m 54 | }, 55 | }, 56 | }, 57 | fakeWatcher{hcat.NewStore()}, 58 | "map[foo:map[bar:a] qux:c zip:map[zap:d]]", 59 | false, 60 | }, 61 | { 62 | "helper_mergeMap", 63 | hcat.TemplateInput{ 64 | Contents: `{{ $base := "{\"voo\":{\"bar\":\"v\"}}" | parseJSON}}{{ $role := tree "list" | explode | mergeMap $base}}{{ range $k, $v := $role }}{{ $k }}{{ $v }}{{ end }}`, 65 | }, 66 | func() hcat.Watcherer { 67 | st := hcat.NewStore() 68 | id := testKVListQueryID("list") 69 | st.Save(id, []*dep.KeyPair{ 70 | {Key: "", Value: ""}, 71 | {Key: "foo/bar", Value: "a"}, 72 | {Key: "zip/zap", Value: "b"}, 73 | }) 74 | return fakeWatcher{st} 75 | }(), 76 | "foomap[bar:a]voomap[bar:v]zipmap[zap:b]", 77 | false, 78 | }, 79 | { 80 | "helper_mergeMapWithOverride", 81 | hcat.TemplateInput{ 82 | Contents: `{{ $base := "{\"zip\":{\"zap\":\"t\"},\"voo\":{\"bar\":\"v\"}}" | parseJSON}}{{ $role := tree "list" | explode | mergeMapWithOverride $base}}{{ range $k, $v := $role }}{{ $k }}{{ $v }}{{ end }}`, 83 | }, 84 | func() hcat.Watcherer { 85 | st := hcat.NewStore() 86 | id := testKVListQueryID("list") 87 | st.Save(id, []*dep.KeyPair{ 88 | {Key: "", Value: ""}, 89 | {Key: "foo/bar", Value: "a"}, 90 | {Key: "zip/zap", Value: "b"}, 91 | }) 92 | return fakeWatcher{st} 93 | }(), 94 | "foomap[bar:a]voomap[bar:v]zipmap[zap:b]", 95 | false, 96 | }, 97 | } 98 | 99 | for i, tc := range cases { 100 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 101 | tpl := newTemplate(tc.ti) 102 | 103 | a, err := tpl.Execute(tc.i.Recaller(tpl)) 104 | if (err != nil) != tc.err { 105 | t.Fatal(err) 106 | } 107 | if !bytes.Equal([]byte(tc.e), a) { 108 | t.Errorf("\nexp: %#v\nact: %#v", tc.e, string(a)) 109 | } 110 | }) 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /tfunc/math_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/hashicorp/hcat" 12 | ) 13 | 14 | func TestMathExecute(t *testing.T) { 15 | t.Parallel() 16 | 17 | cases := []struct { 18 | name string 19 | ti hcat.TemplateInput 20 | i hcat.Watcherer 21 | e string 22 | err bool 23 | }{ 24 | { 25 | "math_add", 26 | hcat.TemplateInput{ 27 | Contents: `{{ 2 | add 2 }}`, 28 | }, 29 | fakeWatcher{hcat.NewStore()}, 30 | "4", 31 | false, 32 | }, 33 | { 34 | "math_subtract", 35 | hcat.TemplateInput{ 36 | Contents: `{{ 2 | subtract 2 }}`, 37 | }, 38 | fakeWatcher{hcat.NewStore()}, 39 | "0", 40 | false, 41 | }, 42 | { 43 | "math_multiply", 44 | hcat.TemplateInput{ 45 | Contents: `{{ 2 | multiply 2 }}`, 46 | }, 47 | fakeWatcher{hcat.NewStore()}, 48 | "4", 49 | false, 50 | }, 51 | { 52 | "math_divide", 53 | hcat.TemplateInput{ 54 | Contents: `{{ 2 | divide 2 }}`, 55 | }, 56 | fakeWatcher{hcat.NewStore()}, 57 | "1", 58 | false, 59 | }, 60 | { 61 | "math_modulo", 62 | hcat.TemplateInput{ 63 | Contents: `{{ 3 | modulo 2 }}`, 64 | }, 65 | fakeWatcher{hcat.NewStore()}, 66 | "1", 67 | false, 68 | }, 69 | { 70 | "math_minimum", 71 | hcat.TemplateInput{ 72 | Contents: `{{ 3 | minimum 2 }}`, 73 | }, 74 | fakeWatcher{hcat.NewStore()}, 75 | "2", 76 | false, 77 | }, 78 | { 79 | "math_maximum", 80 | hcat.TemplateInput{ 81 | Contents: `{{ 3 | maximum 2 }}`, 82 | }, 83 | fakeWatcher{hcat.NewStore()}, 84 | "3", 85 | false, 86 | }, 87 | } 88 | 89 | for i, tc := range cases { 90 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 91 | tpl := newTemplate(tc.ti) 92 | 93 | a, err := tpl.Execute(tc.i.Recaller(tpl)) 94 | if (err != nil) != tc.err { 95 | t.Fatal(err) 96 | } 97 | if !bytes.Equal([]byte(tc.e), a) { 98 | t.Errorf("\nexp: %#v\nact: %#v", tc.e, string(a)) 99 | } 100 | }) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /tfunc/parse.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "encoding/json" 8 | "strconv" 9 | 10 | "github.com/pkg/errors" 11 | yaml "gopkg.in/yaml.v2" 12 | ) 13 | 14 | // parseBool parses a string into a boolean 15 | func parseBool(s string) (bool, error) { 16 | if s == "" { 17 | return false, nil 18 | } 19 | 20 | result, err := strconv.ParseBool(s) 21 | if err != nil { 22 | return false, errors.Wrap(err, "parseBool") 23 | } 24 | return result, nil 25 | } 26 | 27 | // parseFloat parses a string into a base 10 float 28 | func parseFloat(s string) (float64, error) { 29 | if s == "" { 30 | return 0.0, nil 31 | } 32 | 33 | result, err := strconv.ParseFloat(s, 10) 34 | if err != nil { 35 | return 0, errors.Wrap(err, "parseFloat") 36 | } 37 | return result, nil 38 | } 39 | 40 | // parseInt parses a string into a base 10 int 41 | func parseInt(s string) (int64, error) { 42 | if s == "" { 43 | return 0, nil 44 | } 45 | 46 | result, err := strconv.ParseInt(s, 10, 64) 47 | if err != nil { 48 | return 0, errors.Wrap(err, "parseInt") 49 | } 50 | return result, nil 51 | } 52 | 53 | // parseJSON returns a structure for valid JSON 54 | func parseJSON(s string) (interface{}, error) { 55 | if s == "" { 56 | return map[string]interface{}{}, nil 57 | } 58 | 59 | var data interface{} 60 | if err := json.Unmarshal([]byte(s), &data); err != nil { 61 | return nil, err 62 | } 63 | return data, nil 64 | } 65 | 66 | // parseUint parses a string into a base 10 int 67 | func parseUint(s string) (uint64, error) { 68 | if s == "" { 69 | return 0, nil 70 | } 71 | 72 | result, err := strconv.ParseUint(s, 10, 64) 73 | if err != nil { 74 | return 0, errors.Wrap(err, "parseUint") 75 | } 76 | return result, nil 77 | } 78 | 79 | // parseYAML returns a structure for valid YAML 80 | func parseYAML(s string) (interface{}, error) { 81 | if s == "" { 82 | return map[string]interface{}{}, nil 83 | } 84 | 85 | var data interface{} 86 | if err := yaml.Unmarshal([]byte(s), &data); err != nil { 87 | return nil, err 88 | } 89 | return data, nil 90 | } 91 | -------------------------------------------------------------------------------- /tfunc/parse_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/hashicorp/hcat" 12 | ) 13 | 14 | func TestParseExecute(t *testing.T) { 15 | t.Parallel() 16 | 17 | cases := []struct { 18 | name string 19 | ti hcat.TemplateInput 20 | i hcat.Watcherer 21 | e string 22 | err bool 23 | }{ 24 | { 25 | "parseBool", 26 | hcat.TemplateInput{ 27 | Contents: `{{ "true" | parseBool }}`, 28 | }, 29 | fakeWatcher{hcat.NewStore()}, 30 | "true", 31 | false, 32 | }, 33 | { 34 | "parseFloat", 35 | hcat.TemplateInput{ 36 | Contents: `{{ "1.2" | parseFloat }}`, 37 | }, 38 | fakeWatcher{hcat.NewStore()}, 39 | "1.2", 40 | false, 41 | }, 42 | { 43 | "parseInt", 44 | hcat.TemplateInput{ 45 | Contents: `{{ "-1" | parseInt }}`, 46 | }, 47 | fakeWatcher{hcat.NewStore()}, 48 | "-1", 49 | false, 50 | }, 51 | { 52 | "parseJSON", 53 | hcat.TemplateInput{ 54 | Contents: `{{ "{\"foo\": \"bar\"}" | parseJSON }}`, 55 | }, 56 | fakeWatcher{hcat.NewStore()}, 57 | "map[foo:bar]", 58 | false, 59 | }, 60 | { 61 | "parseUint", 62 | hcat.TemplateInput{ 63 | Contents: `{{ "1" | parseUint }}`, 64 | }, 65 | fakeWatcher{hcat.NewStore()}, 66 | "1", 67 | false, 68 | }, 69 | { 70 | "parseYAML", 71 | hcat.TemplateInput{ 72 | Contents: `{{ "foo: bar" | parseYAML }}`, 73 | }, 74 | fakeWatcher{hcat.NewStore()}, 75 | "map[foo:bar]", 76 | false, 77 | }, 78 | { 79 | "parseYAMLv2", 80 | hcat.TemplateInput{ 81 | Contents: `{{ "foo: bar\nbaz: \"foo\"" | parseYAML }}`, 82 | }, 83 | fakeWatcher{hcat.NewStore()}, 84 | "map[baz:foo foo:bar]", 85 | false, 86 | }, 87 | { 88 | "parseYAMLnested", 89 | hcat.TemplateInput{ 90 | Contents: `{{ "foo:\n bar: \"baz\"\n baz: 7" | parseYAML }}`, 91 | }, 92 | fakeWatcher{hcat.NewStore()}, 93 | "map[foo:map[bar:baz baz:7]]", 94 | false, 95 | }, 96 | } 97 | 98 | for i, tc := range cases { 99 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 100 | tpl := newTemplate(tc.ti) 101 | 102 | a, err := tpl.Execute(tc.i.Recaller(tpl)) 103 | if (err != nil) != tc.err { 104 | t.Fatal(err) 105 | } 106 | if !bytes.Equal([]byte(tc.e), a) { 107 | t.Errorf("\nexp: %#v\nact: %#v", tc.e, string(a)) 108 | } 109 | }) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /tfunc/sockaddr.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "fmt" 8 | "strings" 9 | 10 | socktmpl "github.com/hashicorp/go-sockaddr/template" 11 | ) 12 | 13 | // sockaddr wraps go-sockaddr templating 14 | func sockaddr(args ...string) (string, error) { 15 | t := fmt.Sprintf("{{ %s }}", strings.Join(args, " ")) 16 | k, err := socktmpl.Parse(t) 17 | if err != nil { 18 | return "", err 19 | } 20 | return k, nil 21 | } 22 | -------------------------------------------------------------------------------- /tfunc/sockaddr_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/hashicorp/hcat" 12 | ) 13 | 14 | func TestSockAddrExecute(t *testing.T) { 15 | t.Parallel() 16 | 17 | cases := []struct { 18 | name string 19 | ti hcat.TemplateInput 20 | i hcat.Watcherer 21 | e string 22 | err bool 23 | }{} 24 | 25 | for i, tc := range cases { 26 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 27 | tpl := newTemplate(tc.ti) 28 | 29 | a, err := tpl.Execute(tc.i.Recaller(tpl)) 30 | if (err != nil) != tc.err { 31 | t.Fatal(err) 32 | } 33 | if !bytes.Equal([]byte(tc.e), a) { 34 | t.Errorf("\nexp: %#v\nact: %#v", tc.e, string(a)) 35 | } 36 | }) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /tfunc/string.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "fmt" 8 | "regexp" 9 | "strings" 10 | ) 11 | 12 | // Indent prefixes each line of a string with the specified number of spaces 13 | func indent(spaces int, s string) (string, error) { 14 | if spaces < 0 { 15 | return "", fmt.Errorf("indent value must be a positive integer") 16 | } 17 | var output, prefix []byte 18 | var sp bool 19 | var size int 20 | prefix = []byte(strings.Repeat(" ", spaces)) 21 | sp = true 22 | for _, c := range []byte(s) { 23 | if sp && c != '\n' { 24 | output = append(output, prefix...) 25 | size += spaces 26 | } 27 | output = append(output, c) 28 | sp = c == '\n' 29 | size++ 30 | } 31 | return string(output[:size]), nil 32 | } 33 | 34 | // join is a version of strings.Join that can be piped 35 | func join(sep string, a []string) (string, error) { 36 | return strings.Join(a, sep), nil 37 | } 38 | 39 | // split is a version of strings.Split that can be piped 40 | func split(sep, s string) ([]string, error) { 41 | s = strings.TrimSpace(s) 42 | if s == "" { 43 | return []string{}, nil 44 | } 45 | return strings.Split(s, sep), nil 46 | } 47 | 48 | // TrimSpace is a version of strings.TrimSpace that can be piped 49 | func trimSpace(s string) (string, error) { 50 | return strings.TrimSpace(s), nil 51 | } 52 | 53 | // replaceAll replaces all occurrences of a value in a string with the given 54 | // replacement value. 55 | func replaceAll(f, t, s string) (string, error) { 56 | return strings.Replace(s, f, t, -1), nil 57 | } 58 | 59 | // regexReplaceAll replaces all occurrences of a regular expression with 60 | // the given replacement value. 61 | func regexReplaceAll(re, pl, s string) (string, error) { 62 | compiled, err := regexp.Compile(re) 63 | if err != nil { 64 | return "", err 65 | } 66 | return compiled.ReplaceAllString(s, pl), nil 67 | } 68 | 69 | // regexMatch returns true or false if the string matches 70 | // the given regular expression 71 | func regexMatch(re, s string) (bool, error) { 72 | compiled, err := regexp.Compile(re) 73 | if err != nil { 74 | return false, err 75 | } 76 | return compiled.MatchString(s), nil 77 | } 78 | -------------------------------------------------------------------------------- /tfunc/string_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/hashicorp/hcat" 12 | ) 13 | 14 | func TestStringExecute(t *testing.T) { 15 | t.Parallel() 16 | 17 | cases := []struct { 18 | name string 19 | ti hcat.TemplateInput 20 | i hcat.Watcherer 21 | e string 22 | err bool 23 | }{ 24 | { 25 | "indent", 26 | hcat.TemplateInput{ 27 | Contents: `{{ "hello\nhello\r\nHELLO\r\nhello\nHELLO" | indent 4 }}`, 28 | }, 29 | fakeWatcher{hcat.NewStore()}, 30 | " hello\n hello\r\n HELLO\r\n hello\n HELLO", 31 | false, 32 | }, 33 | { 34 | "indent_negative", 35 | hcat.TemplateInput{ 36 | Contents: `{{ "hello\nhello\r\nHELLO\r\nhello\nHELLO" | indent -4 }}`, 37 | }, 38 | fakeWatcher{hcat.NewStore()}, 39 | "", 40 | true, 41 | }, 42 | { 43 | "indent_zero", 44 | hcat.TemplateInput{ 45 | Contents: `{{ "hello\nhello\r\nHELLO\r\nhello\nHELLO" | indent 0 }}`, 46 | }, 47 | fakeWatcher{hcat.NewStore()}, 48 | "hello\nhello\r\nHELLO\r\nhello\nHELLO", 49 | false, 50 | }, 51 | { 52 | "join", 53 | hcat.TemplateInput{ 54 | Contents: `{{ "a,b,c" | split "," | join ";" }}`, 55 | }, 56 | fakeWatcher{hcat.NewStore()}, 57 | "a;b;c", 58 | false, 59 | }, 60 | { 61 | "trimSpace", 62 | hcat.TemplateInput{ 63 | Contents: `{{ "\t hi\n " | trimSpace }}`, 64 | }, 65 | fakeWatcher{hcat.NewStore()}, 66 | "hi", 67 | false, 68 | }, 69 | { 70 | "split", 71 | hcat.TemplateInput{ 72 | Contents: `{{ "a,b,c" | split "," }}`, 73 | }, 74 | fakeWatcher{hcat.NewStore()}, 75 | "[a b c]", 76 | false, 77 | }, 78 | { 79 | "replaceAll", 80 | hcat.TemplateInput{ 81 | Contents: `{{ "hello my hello" | regexReplaceAll "hello" "bye" }}`, 82 | }, 83 | fakeWatcher{hcat.NewStore()}, 84 | "bye my bye", 85 | false, 86 | }, 87 | { 88 | "regexReplaceAll", 89 | hcat.TemplateInput{ 90 | Contents: `{{ "foo" | regexReplaceAll "\\w" "x" }}`, 91 | }, 92 | fakeWatcher{hcat.NewStore()}, 93 | "xxx", 94 | false, 95 | }, 96 | { 97 | "regexMatch", 98 | hcat.TemplateInput{ 99 | Contents: `{{ "foo" | regexMatch "[a-z]+" }}`, 100 | }, 101 | fakeWatcher{hcat.NewStore()}, 102 | "true", 103 | false, 104 | }, 105 | } 106 | 107 | for i, tc := range cases { 108 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 109 | tpl := newTemplate(tc.ti) 110 | 111 | a, err := tpl.Execute(tc.i.Recaller(tpl)) 112 | if (err != nil) != tc.err { 113 | t.Fatal(err) 114 | } 115 | if !bytes.Equal([]byte(tc.e), a) { 116 | t.Errorf("\nexp: %#v\nact: %#v", tc.e, string(a)) 117 | } 118 | }) 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /tfunc/tfunc.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "os" 8 | "text/template" 9 | ) 10 | 11 | // AllUnversioned available template functions 12 | func AllUnversioned() template.FuncMap { 13 | all := make(template.FuncMap) 14 | allfuncs := []func() template.FuncMap{ 15 | ConsulFilters, Env, Control, Helpers, Math, Files} 16 | for _, f := range allfuncs { 17 | for k, v := range f() { 18 | all[k] = v 19 | } 20 | } 21 | return all 22 | } 23 | 24 | // ConsulV0 is a set of template functions for querying Consul endpoints. 25 | func ConsulV0() template.FuncMap { 26 | return template.FuncMap{ 27 | "datacenters": datacentersFunc, 28 | "key": keyFunc, 29 | "keyExists": keyExistsFunc, 30 | "keyOrDefault": keyWithDefaultFunc, 31 | "ls": lsFunc(true), 32 | "safeLs": safeLsFunc, 33 | "node": nodeFunc, 34 | "nodes": nodesFunc, 35 | "service": serviceFunc, 36 | "connect": connectFunc, 37 | "services": servicesFunc, 38 | "tree": treeFunc(true), 39 | "safeTree": safeTreeFunc, 40 | "caRoots": connectCARootsFunc, 41 | "caLeaf": connectLeafFunc, 42 | } 43 | } 44 | 45 | // ConsulV1 is a set of template functions for querying Consul endpoints. 46 | // The functions support Consul v1 API filter expressions and Consul enterprise 47 | // namespaces. 48 | func ConsulV1() template.FuncMap { 49 | return template.FuncMap{ 50 | "service": v1ServiceFunc, 51 | "connect": v1ConnectFunc, 52 | "services": v1ServicesFunc, 53 | "keys": v1KVListFunc, 54 | "key": v1KVGetFunc, 55 | "keyExists": v1KVExistsFunc, 56 | "keyExistsGet": v1KVExistsGetFunc, 57 | 58 | // Set of Consul functions that are not yet implemented for v1. These 59 | // intentionally error instead of defaulting to the v0 implementations 60 | // to avoid introducing breaking changes when they are supported. 61 | "node": v1TODOFunc, 62 | "nodes": v1TODOFunc, 63 | } 64 | } 65 | 66 | // ConsulFilters provides functions to filter consul results 67 | func ConsulFilters() template.FuncMap { 68 | return template.FuncMap{ 69 | "byKey": byKey, 70 | "byTag": byTag, 71 | "byMeta": byMeta, 72 | } 73 | } 74 | 75 | // VaultV0 querying functions 76 | func VaultV0() template.FuncMap { 77 | return template.FuncMap{ 78 | "secret": secretFunc, 79 | "secrets": secretsFunc, 80 | } 81 | } 82 | 83 | // Environment variable querying functions 84 | func Env() template.FuncMap { 85 | return template.FuncMap{ 86 | "env": envFunc(os.Environ()), 87 | "envOrDefault": envOrDefaultFunc(os.Environ()), 88 | } 89 | } 90 | 91 | // Files provides functions for working with files 92 | func Files() template.FuncMap { 93 | return template.FuncMap{ 94 | "file": fileFunc, 95 | "writeToFile": writeToFile, 96 | } 97 | } 98 | 99 | // Control flow functions 100 | func Control() template.FuncMap { 101 | return template.FuncMap{ 102 | "contains": contains, 103 | "containsAll": containsSomeFunc(true, true), 104 | "containsAny": containsSomeFunc(false, false), 105 | "containsNone": containsSomeFunc(true, false), 106 | "containsNotAll": containsSomeFunc(false, true), 107 | "in": in, 108 | "loop": loop, 109 | } 110 | } 111 | 112 | // Mathimatical functions 113 | func Math() template.FuncMap { 114 | return template.FuncMap{ 115 | "add": add, 116 | "subtract": subtract, 117 | "multiply": multiply, 118 | "divide": divide, 119 | "modulo": modulo, 120 | "minimum": minimum, 121 | "maximum": maximum, 122 | } 123 | } 124 | 125 | // Helpers are all the rest... (maybe organize these more?) 126 | func Helpers() template.FuncMap { 127 | return template.FuncMap{ 128 | // Parsing 129 | "parseBool": parseBool, 130 | "parseFloat": parseFloat, 131 | "parseInt": parseInt, 132 | "parseJSON": parseJSON, 133 | "parseUint": parseUint, 134 | "parseYAML": parseYAML, 135 | // ToSomething 136 | "toLower": toLower, 137 | "toUpper": toUpper, 138 | "toTitle": toTitle, 139 | "toJSON": toJSON, 140 | "toJSONPretty": toJSONPretty, 141 | "toUnescapedJSON": toUnescapedJSON, 142 | "toUnescapedJSONPretty": toUnescapedJSONPretty, 143 | "toTOML": toTOML, 144 | "toYAML": toYAML, 145 | // (D)Encoding 146 | "base64Decode": base64Decode, 147 | "base64Encode": base64Encode, 148 | "base64URLDecode": base64URLDecode, 149 | "base64URLEncode": base64URLEncode, 150 | "sha256Hex": sha256Hex, 151 | "md5sum": md5sum, 152 | // String 153 | "join": join, 154 | "split": split, 155 | "trimSpace": trimSpace, 156 | "indent": indent, 157 | "replaceAll": replaceAll, 158 | "regexReplaceAll": regexReplaceAll, 159 | "regexMatch": regexMatch, 160 | // Data type (map, slice, etc) oriented 161 | "explode": explode, 162 | "explodeMap": explodeMap, 163 | "mergeMap": mergeMap, 164 | "mergeMapWithOverride": mergeMapWithOverride, 165 | // Misc/Other 166 | "timestamp": timestamp, 167 | "sockaddr": sockaddr, 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /tfunc/tfunc_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | "text/template" 10 | 11 | "github.com/hashicorp/hcat" 12 | "github.com/hashicorp/hcat/dep" 13 | ) 14 | 15 | func testHealthServiceQueryID(service string) string { 16 | return fmt.Sprintf("health.service(%s|passing)", service) 17 | } 18 | 19 | func testKVListQueryID(prefix string) string { 20 | return fmt.Sprintf("kv.list(%s)", prefix) 21 | } 22 | 23 | // simple check for duplicate names for template functions 24 | func TestAllForDups(t *testing.T) { 25 | all := make(template.FuncMap) 26 | allfuncs := []func() template.FuncMap{ 27 | ConsulFilters, Env, Control, Helpers, Math} 28 | for _, f := range allfuncs { 29 | for k, v := range f() { 30 | if _, ok := all[k]; ok { 31 | t.Fatal("duplicate entry") 32 | } 33 | all[k] = v 34 | } 35 | } 36 | } 37 | 38 | // Return a new template with all unversioned and V0 template functions. 39 | func newTemplate(ti hcat.TemplateInput) *hcat.Template { 40 | funcMap := AllUnversioned() 41 | // use vault v0 api as that is all that is currently supported 42 | for k, v := range VaultV0() { 43 | funcMap[k] = v 44 | } 45 | // use consul v0 api as default for now as most tests use it 46 | for k, v := range ConsulV0() { 47 | funcMap[k] = v 48 | } 49 | switch ti.FuncMapMerge { 50 | case nil: 51 | default: 52 | // allow passed in option to override defaults 53 | for k, v := range ti.FuncMapMerge { 54 | funcMap[k] = v 55 | } 56 | } 57 | ti.FuncMapMerge = funcMap 58 | return hcat.NewTemplate(ti) 59 | } 60 | 61 | // fake/stub Watcherer for tests 62 | type fakeWatcher struct { 63 | *hcat.Store 64 | } 65 | 66 | func (fakeWatcher) Buffering(hcat.Notifier) bool { return false } 67 | func (f fakeWatcher) Complete(hcat.Notifier) bool { return true } 68 | func (f fakeWatcher) Clients() hcat.Looker { return nil } 69 | func (f fakeWatcher) Recaller(hcat.Notifier) hcat.Recaller { 70 | return func(d dep.Dependency) (value interface{}, found bool) { 71 | return f.Store.Recall(d.ID()) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /tfunc/time.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "fmt" 8 | "strconv" 9 | "time" 10 | ) 11 | 12 | // now is function that represents the current time in UTC. This is here 13 | // primarily for the tests to override times. 14 | var now = func() time.Time { return time.Now().UTC() } 15 | 16 | // timestamp returns the current UNIX timestamp in UTC. If an argument is 17 | // specified, it will be used to format the timestamp. 18 | func timestamp(s ...string) (string, error) { 19 | switch len(s) { 20 | case 0: 21 | return now().Format(time.RFC3339), nil 22 | case 1: 23 | if s[0] == "unix" { 24 | return strconv.FormatInt(now().Unix(), 10), nil 25 | } 26 | return now().Format(s[0]), nil 27 | default: 28 | return "", fmt.Errorf("timestamp: wrong number of arguments, "+ 29 | "expected 0 or 1, but got %d", len(s)) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /tfunc/time_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "testing" 10 | "time" 11 | 12 | "github.com/hashicorp/hcat" 13 | ) 14 | 15 | func TestTimeExecute(t *testing.T) { 16 | t.Parallel() 17 | 18 | // overwrite now variable from ./time.go 19 | now = func() time.Time { return time.Unix(0, 0).UTC() } 20 | 21 | cases := []struct { 22 | name string 23 | ti hcat.TemplateInput 24 | i hcat.Watcherer 25 | e string 26 | err bool 27 | }{ 28 | { 29 | "timestamp", 30 | hcat.TemplateInput{ 31 | Contents: `{{ timestamp }}`, 32 | }, 33 | fakeWatcher{hcat.NewStore()}, 34 | "1970-01-01T00:00:00Z", 35 | false, 36 | }, 37 | { 38 | "helper_timestamp__formatted", 39 | hcat.TemplateInput{ 40 | Contents: `{{ timestamp "2006-01-02" }}`, 41 | }, 42 | fakeWatcher{hcat.NewStore()}, 43 | "1970-01-01", 44 | false, 45 | }, 46 | } 47 | 48 | for i, tc := range cases { 49 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 50 | tpl := newTemplate(tc.ti) 51 | 52 | a, err := tpl.Execute(tc.i.Recaller(tpl)) 53 | if (err != nil) != tc.err { 54 | t.Fatal(err) 55 | } 56 | if !bytes.Equal([]byte(tc.e), a) { 57 | t.Errorf("\nexp: %#v\nact: %#v", tc.e, string(a)) 58 | } 59 | }) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /tfunc/transform.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "bytes" 8 | "crypto/md5" 9 | "crypto/sha256" 10 | "encoding/base64" 11 | "encoding/hex" 12 | "encoding/json" 13 | "fmt" 14 | "io/ioutil" 15 | "strings" 16 | 17 | "github.com/BurntSushi/toml" 18 | "github.com/pkg/errors" 19 | yaml "gopkg.in/yaml.v2" 20 | ) 21 | 22 | // base64Decode decodes the given string as a base64 string, returning an error 23 | // if it fails. 24 | func base64Decode(s string) (string, error) { 25 | v, err := base64.StdEncoding.DecodeString(s) 26 | if err != nil { 27 | return "", errors.Wrap(err, "base64Decode") 28 | } 29 | return string(v), nil 30 | } 31 | 32 | // base64Encode encodes the given value into a string represented as base64. 33 | func base64Encode(s string) (string, error) { 34 | return base64.StdEncoding.EncodeToString([]byte(s)), nil 35 | } 36 | 37 | // base64URLDecode decodes the given string as a URL-safe base64 string. 38 | func base64URLDecode(s string) (string, error) { 39 | v, err := base64.URLEncoding.DecodeString(s) 40 | if err != nil { 41 | return "", errors.Wrap(err, "base64URLDecode") 42 | } 43 | return string(v), nil 44 | } 45 | 46 | // base64URLEncode encodes the given string to be URL-safe. 47 | func base64URLEncode(s string) (string, error) { 48 | return base64.URLEncoding.EncodeToString([]byte(s)), nil 49 | } 50 | 51 | // sha256Hex return the sha256 hex of a string 52 | func sha256Hex(item string) (string, error) { 53 | h := sha256.New() 54 | h.Write([]byte(item)) 55 | output := hex.EncodeToString(h.Sum(nil)) 56 | return output, nil 57 | } 58 | 59 | // md5sum returns the md5 hash of a string 60 | func md5sum(item string) string { 61 | return fmt.Sprintf("%x", md5.Sum([]byte(item))) 62 | } 63 | 64 | // toLower converts the given string (usually by a pipe) to lowercase. 65 | func toLower(s string) (string, error) { 66 | return strings.ToLower(s), nil 67 | } 68 | 69 | // toUpper converts the given string (usually by a pipe) to uppercase. 70 | func toUpper(s string) (string, error) { 71 | return strings.ToUpper(s), nil 72 | } 73 | 74 | // toTitle converts the given string (usually by a pipe) to titlecase. 75 | func toTitle(s string) (string, error) { 76 | return strings.Title(s), nil 77 | } 78 | 79 | // toJSON converts the given structure into a deeply nested JSON string. 80 | func toJSON(i interface{}) (string, error) { 81 | result, err := json.Marshal(i) 82 | if err != nil { 83 | return "", errors.Wrap(err, "toJSON") 84 | } 85 | return string(bytes.TrimSpace(result)), err 86 | } 87 | 88 | // toJSONPretty converts the given structure into a deeply nested pretty JSON 89 | // string. 90 | func toJSONPretty(i interface{}) (string, error) { 91 | result, err := json.MarshalIndent(i, "", " ") 92 | if err != nil { 93 | return "", errors.Wrap(err, "toJSONPretty") 94 | } 95 | return string(bytes.TrimSpace(result)), err 96 | } 97 | 98 | // toUnescapedJSON converts the given structure into a deeply nested JSON 99 | // string without HTML escaping. 100 | func toUnescapedJSON(i interface{}) (string, error) { 101 | buf := &bytes.Buffer{} 102 | encoder := json.NewEncoder(buf) 103 | encoder.SetEscapeHTML(false) 104 | if err := encoder.Encode(i); err != nil { 105 | return "", errors.Wrap(err, "toUnescapedJSON") 106 | } 107 | return strings.TrimRight(buf.String(), "\r\n"), nil 108 | } 109 | 110 | // toUnescapedJSONPretty converts the given structure into a deeply nested 111 | // pretty JSON string without HTML escaping. 112 | func toUnescapedJSONPretty(i interface{}) (string, error) { 113 | buf := &bytes.Buffer{} 114 | encoder := json.NewEncoder(buf) 115 | encoder.SetEscapeHTML(false) 116 | encoder.SetIndent("", " ") 117 | if err := encoder.Encode(i); err != nil { 118 | return "", errors.Wrap(err, "toUnescapedJSONPretty") 119 | } 120 | return strings.TrimRight(buf.String(), "\r\n"), nil 121 | } 122 | 123 | // toYAML converts the given structure into a deeply nested YAML string. 124 | func toYAML(m map[string]interface{}) (string, error) { 125 | result, err := yaml.Marshal(m) 126 | if err != nil { 127 | return "", errors.Wrap(err, "toYAML") 128 | } 129 | return string(bytes.TrimSpace(result)), nil 130 | } 131 | 132 | // toTOML converts the given structure into a deeply nested TOML string. 133 | func toTOML(m map[string]interface{}) (string, error) { 134 | buf := bytes.NewBuffer([]byte{}) 135 | enc := toml.NewEncoder(buf) 136 | if err := enc.Encode(m); err != nil { 137 | return "", errors.Wrap(err, "toTOML") 138 | } 139 | result, err := ioutil.ReadAll(buf) 140 | if err != nil { 141 | return "", errors.Wrap(err, "toTOML") 142 | } 143 | return string(bytes.TrimSpace(result)), nil 144 | } 145 | -------------------------------------------------------------------------------- /tfunc/transform_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/hashicorp/hcat" 12 | "github.com/hashicorp/hcat/dep" 13 | ) 14 | 15 | func TestTransformExecute(t *testing.T) { 16 | t.Parallel() 17 | 18 | cases := []struct { 19 | name string 20 | ti hcat.TemplateInput 21 | i hcat.Watcherer 22 | e string 23 | err bool 24 | }{ 25 | { 26 | "func_base64Decode", 27 | hcat.TemplateInput{ 28 | Contents: `{{ base64Decode "aGVsbG8=" }}`, 29 | }, 30 | fakeWatcher{hcat.NewStore()}, 31 | "hello", 32 | false, 33 | }, 34 | { 35 | "func_base64Decode_bad", 36 | hcat.TemplateInput{ 37 | Contents: `{{ base64Decode "aGVsxxbG8=" }}`, 38 | }, 39 | fakeWatcher{hcat.NewStore()}, 40 | "", 41 | true, 42 | }, 43 | { 44 | "func_base64Encode", 45 | hcat.TemplateInput{ 46 | Contents: `{{ base64Encode "hello" }}`, 47 | }, 48 | fakeWatcher{hcat.NewStore()}, 49 | "aGVsbG8=", 50 | false, 51 | }, 52 | { 53 | "func_base64URLDecode", 54 | hcat.TemplateInput{ 55 | Contents: `{{ base64URLDecode "dGVzdGluZzEyMw==" }}`, 56 | }, 57 | fakeWatcher{hcat.NewStore()}, 58 | "testing123", 59 | false, 60 | }, 61 | { 62 | "func_base64URLDecode_bad", 63 | hcat.TemplateInput{ 64 | Contents: `{{ base64URLDecode "aGVsxxbG8=" }}`, 65 | }, 66 | fakeWatcher{hcat.NewStore()}, 67 | "", 68 | true, 69 | }, 70 | { 71 | "func_base64URLEncode", 72 | hcat.TemplateInput{ 73 | Contents: `{{ base64URLEncode "testing123" }}`, 74 | }, 75 | fakeWatcher{hcat.NewStore()}, 76 | "dGVzdGluZzEyMw==", 77 | false, 78 | }, 79 | { 80 | "func_sha256", 81 | hcat.TemplateInput{ 82 | Contents: `{{ sha256Hex "hello" }}`, 83 | }, 84 | fakeWatcher{hcat.NewStore()}, 85 | "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", 86 | false, 87 | }, 88 | { 89 | "func_md5sum", 90 | hcat.TemplateInput{ 91 | Contents: `{{ "hello" | md5sum }}`, 92 | }, 93 | fakeWatcher{hcat.NewStore()}, 94 | "5d41402abc4b2a76b9719d911017c592", 95 | false, 96 | }, 97 | { 98 | "helper_toJSON", 99 | hcat.TemplateInput{ 100 | Contents: `{{ "a,b,c" | split "," | toJSON }}`, 101 | }, 102 | fakeWatcher{hcat.NewStore()}, 103 | "[\"a\",\"b\",\"c\"]", 104 | false, 105 | }, 106 | { 107 | "helper_toJSONPretty", 108 | hcat.TemplateInput{ 109 | Contents: `{{ "a,b,c" | split "," | toJSONPretty }}`, 110 | }, 111 | fakeWatcher{hcat.NewStore()}, 112 | "[\n \"a\",\n \"b\",\n \"c\"\n]", 113 | false, 114 | }, 115 | { 116 | "helper_toUnescapedJSON", 117 | hcat.TemplateInput{ 118 | Contents: `{{ "a?b&c,x?y&z" | split "," | toUnescapedJSON }}`, 119 | }, 120 | fakeWatcher{hcat.NewStore()}, 121 | "[\"a?b&c\",\"x?y&z\"]", 122 | false, 123 | }, 124 | { 125 | "helper_toUnescapedJSONPretty", 126 | hcat.TemplateInput{ 127 | Contents: `{{ tree "list" | explode | toUnescapedJSONPretty }}`, 128 | }, 129 | func() hcat.Watcherer { 130 | st := hcat.NewStore() 131 | id := testKVListQueryID("list") 132 | st.Save(id, []*dep.KeyPair{ 133 | {Key: "a", Value: "b&c"}, 134 | {Key: "x", Value: "y&z"}, 135 | {Key: "k", Value: "<>&&"}, 136 | }) 137 | return fakeWatcher{st} 138 | }(), 139 | "{\n \"a\": \"b&c\",\n \"k\": \"<>&&\",\n \"x\": \"y&z\"\n}", 140 | false, 141 | }, 142 | { 143 | "helper_toLower", 144 | hcat.TemplateInput{ 145 | Contents: `{{ "HI" | toLower }}`, 146 | }, 147 | fakeWatcher{hcat.NewStore()}, 148 | "hi", 149 | false, 150 | }, 151 | { 152 | "helper_toTitle", 153 | hcat.TemplateInput{ 154 | Contents: `{{ "this is a sentence" | toTitle }}`, 155 | }, 156 | fakeWatcher{hcat.NewStore()}, 157 | "This Is A Sentence", 158 | false, 159 | }, 160 | { 161 | "helper_toTOML", 162 | hcat.TemplateInput{ 163 | Contents: `{{ "{\"foo\":\"bar\"}" | parseJSON | toTOML }}`, 164 | }, 165 | fakeWatcher{hcat.NewStore()}, 166 | "foo = \"bar\"", 167 | false, 168 | }, 169 | { 170 | "helper_toUpper", 171 | hcat.TemplateInput{ 172 | Contents: `{{ "hi" | toUpper }}`, 173 | }, 174 | fakeWatcher{hcat.NewStore()}, 175 | "HI", 176 | false, 177 | }, 178 | { 179 | "helper_toYAML", 180 | hcat.TemplateInput{ 181 | Contents: `{{ "{\"foo\":\"bar\"}" | parseJSON | toYAML }}`, 182 | }, 183 | fakeWatcher{hcat.NewStore()}, 184 | "foo: bar", 185 | false, 186 | }, 187 | } 188 | 189 | for i, tc := range cases { 190 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 191 | tpl := newTemplate(tc.ti) 192 | 193 | a, err := tpl.Execute(tc.i.Recaller(tpl)) 194 | if (err != nil) != tc.err { 195 | t.Fatal(err) 196 | } 197 | if !bytes.Equal([]byte(tc.e), a) { 198 | t.Errorf("\nexp: %#v\nact: %#v", tc.e, string(a)) 199 | } 200 | }) 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /tfunc/vault.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package tfunc 5 | 6 | import ( 7 | "fmt" 8 | "strings" 9 | 10 | "github.com/hashicorp/hcat" 11 | "github.com/hashicorp/hcat/dep" 12 | idep "github.com/hashicorp/hcat/internal/dependency" 13 | ) 14 | 15 | // secretFunc returns or accumulates secret dependencies from Vault. 16 | func secretFunc(recall hcat.Recaller) interface{} { 17 | return func(s ...string) (interface{}, error) { 18 | if len(s) == 0 { 19 | return nil, nil 20 | } 21 | 22 | path, rest := s[0], s[1:] 23 | data := make(map[string]interface{}) 24 | for _, str := range rest { 25 | if len(str) == 0 { 26 | continue 27 | } 28 | parts := strings.SplitN(str, "=", 2) 29 | if len(parts) != 2 { 30 | return nil, fmt.Errorf("not k=v pair %q", str) 31 | } 32 | 33 | k, v := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) 34 | data[k] = v 35 | } 36 | 37 | var d dep.Dependency 38 | var err error 39 | 40 | isReadQuery := len(rest) == 0 41 | if isReadQuery { 42 | d, err = idep.NewVaultReadQuery(path) 43 | } else { 44 | d, err = idep.NewVaultWriteQuery(path, data) 45 | } 46 | 47 | if err != nil { 48 | return nil, err 49 | } 50 | 51 | if value, ok := recall(d); ok { 52 | return value.(*dep.Secret), nil 53 | } 54 | 55 | return nil, nil 56 | } 57 | } 58 | 59 | // secretsFunc returns or accumulates a list of secret dependencies from Vault. 60 | func secretsFunc(recall hcat.Recaller) interface{} { 61 | return func(s string) ([]string, error) { 62 | var result []string 63 | 64 | if len(s) == 0 { 65 | return result, nil 66 | } 67 | 68 | d, err := idep.NewVaultListQuery(s) 69 | if err != nil { 70 | return nil, err 71 | } 72 | 73 | if value, ok := recall(d); ok { 74 | result = value.([]string) 75 | return result, nil 76 | } 77 | 78 | return result, nil 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /vaulttoken/main_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package vaulttoken 5 | 6 | import ( 7 | "encoding/json" 8 | "io/ioutil" 9 | "log" 10 | "os" 11 | "os/exec" 12 | "path/filepath" 13 | "strings" 14 | "testing" 15 | 16 | "github.com/hashicorp/hcat" 17 | "github.com/hashicorp/vault/api" 18 | ) 19 | 20 | const ( 21 | vaultAddr = "http://127.0.0.1:8222" 22 | vaultToken = "a_token" 23 | ) 24 | 25 | var ( 26 | testVault *vaultServer 27 | testClients *hcat.ClientSet 28 | tokenRoleId string 29 | ) 30 | 31 | func TestMain(m *testing.M) { 32 | os.Exit(main(m)) 33 | } 34 | 35 | // sub-main so I can use defer 36 | func main(m *testing.M) int { 37 | log.SetOutput(ioutil.Discard) 38 | testVault = newTestVault() 39 | defer func() { testVault.Stop() }() 40 | 41 | clients := hcat.NewClientSet() 42 | if err := clients.AddVault(hcat.VaultInput{ 43 | Address: vaultAddr, 44 | Token: vaultToken, 45 | }); err != nil { 46 | panic(err) 47 | } 48 | 49 | testClients = clients 50 | tokenRoleId = vaultTokenSetup(clients) 51 | 52 | return m.Run() 53 | } 54 | 55 | type vaultServer struct { 56 | cmd *exec.Cmd 57 | } 58 | 59 | func (v vaultServer) Stop() error { 60 | if v.cmd != nil && v.cmd.Process != nil { 61 | return v.cmd.Process.Signal(os.Interrupt) 62 | } 63 | return nil 64 | } 65 | 66 | func newTestVault() *vaultServer { 67 | path, err := exec.LookPath("vault") 68 | if err != nil || path == "" { 69 | panic("vault not found on $PATH") 70 | } 71 | args := []string{ 72 | "server", "-dev", "-dev-root-token-id", vaultToken, 73 | "-dev-no-store-token", 74 | "-dev-listen-address", strings.TrimPrefix(vaultAddr, "http://"), 75 | } 76 | cmd := exec.Command("vault", args...) 77 | cmd.Stdout = ioutil.Discard 78 | cmd.Stderr = ioutil.Discard 79 | 80 | if err := cmd.Start(); err != nil { 81 | panic("vault failed to start: " + err.Error()) 82 | } 83 | return &vaultServer{ 84 | cmd: cmd, 85 | } 86 | } 87 | 88 | // Sets up approle auto-auth for token generation/testing 89 | func vaultTokenSetup(clients *hcat.ClientSet) string { 90 | vc := clients.Vault() 91 | 92 | // vault auth enable approle 93 | err := vc.Sys().EnableAuthWithOptions("approle", 94 | &api.MountInput{ 95 | Type: "approle", 96 | }) 97 | if err != nil && !strings.Contains(err.Error(), "path is already in use") { 98 | panic(err) 99 | } 100 | 101 | // vault policy write foo 'path ...' 102 | err = vc.Sys().PutPolicy("foo", 103 | `path "secret/data/foo" { capabilities = ["read"] }`) 104 | if err != nil { 105 | panic(err) 106 | } 107 | 108 | // vault write auth/approle/role/foo ... 109 | _, err = vc.Logical().Write("auth/approle/role/foo", 110 | map[string]interface{}{ 111 | "token_policies": "foo", 112 | "secret_id_num_uses": 100, 113 | "secret_id_ttl": "5m", 114 | "token_num_users": 10, 115 | "token_ttl": "7m", 116 | "token_max_ttl": "10m", 117 | }) 118 | if err != nil { 119 | panic(err) 120 | } 121 | 122 | var sec *api.Secret 123 | // vault read -field=role_id auth/approle/role/foo/role-id 124 | sec, err = vc.Logical().Read("auth/approle/role/foo/role-id") 125 | if err != nil { 126 | panic(err) 127 | } 128 | role_id := sec.Data["role_id"] 129 | return role_id.(string) 130 | } 131 | 132 | // returns path to token file (which is created by the agent run) 133 | // token file isn't cleaned, so use returned path to remove it when done 134 | func runVaultAgent(clients *hcat.ClientSet, role_id string) string { 135 | dir, err := os.MkdirTemp("", "consul-template-test") 136 | if err != nil { 137 | panic(err) 138 | } 139 | defer os.RemoveAll(dir) 140 | 141 | tokenFile := filepath.Join("", "vatoken.txt") 142 | 143 | role_idPath := filepath.Join(dir, "roleid") 144 | secret_idPath := filepath.Join(dir, "secretid") 145 | vaconf := filepath.Join(dir, "vault-agent-config.json") 146 | 147 | // Generate secret_id, need new one for each agent run 148 | // vault write -f -field secret_id auth/approle/role/foo/secret-id 149 | vc := clients.Vault() 150 | sec, err := vc.Logical().Write("auth/approle/role/foo/secret-id", nil) 151 | if err != nil { 152 | panic(err) 153 | } 154 | secret_id := sec.Data["secret_id"].(string) 155 | err = os.WriteFile(secret_idPath, []byte(secret_id), 0o444) 156 | if err != nil { 157 | panic(err) 158 | } 159 | err = os.WriteFile(role_idPath, []byte(role_id), 0o444) 160 | if err != nil { 161 | panic(err) 162 | } 163 | 164 | type obj map[string]interface{} 165 | type list []obj 166 | va := obj{ 167 | "vault": obj{"address": vaultAddr}, 168 | "auto_auth": obj{ 169 | "method": obj{ 170 | "type": "approle", 171 | "config": obj{ 172 | "role_id_file_path": role_idPath, 173 | "secret_id_file_path": secret_idPath, 174 | }, 175 | "wrap_ttl": "5m", 176 | }, 177 | "sinks": list{ 178 | {"sink": obj{"type": "file", "config": obj{"path": tokenFile}}}, 179 | }, 180 | }, 181 | } 182 | txt, err := json.Marshal(va) 183 | if err != nil { 184 | panic(err) 185 | } 186 | err = os.WriteFile(vaconf, txt, 0o644) 187 | if err != nil { 188 | panic(err) 189 | } 190 | 191 | args := []string{ 192 | "agent", "-exit-after-auth", "-config=" + vaconf, 193 | } 194 | cmd := exec.Command("vault", args...) 195 | cmd.Stdout = ioutil.Discard 196 | cmd.Stderr = ioutil.Discard 197 | 198 | if err := cmd.Run(); err != nil { 199 | panic("vault agent failed to run: " + err.Error()) 200 | } 201 | return tokenFile 202 | } 203 | -------------------------------------------------------------------------------- /vaulttoken/notifier.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package vaulttoken 5 | 6 | import ( 7 | "github.com/hashicorp/hcat/dep" 8 | ) 9 | 10 | type callback func(any) bool 11 | 12 | // dep Notifier for use by vault token above and in tests 13 | type callbackNotifier struct { 14 | dep dep.Dependency 15 | fun callback 16 | } 17 | 18 | // returned boolean controls watcher.Watch channel output 19 | // ie. returning false will skip sending it on that channel. 20 | func (n callbackNotifier) Notify(d any) (ok bool) { 21 | if n.fun != nil { 22 | return n.fun(d) 23 | } else { 24 | return true 25 | } 26 | } 27 | 28 | // unique ID for this notifier (amoung the pop of notifiers) 29 | func (n callbackNotifier) ID() string { 30 | return n.dep.ID() 31 | } 32 | -------------------------------------------------------------------------------- /vaulttoken/vault_agent_token.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package vaulttoken 5 | 6 | import ( 7 | "os" 8 | "time" 9 | 10 | "github.com/hashicorp/hcat/dep" 11 | "github.com/pkg/errors" 12 | ) 13 | 14 | // Ensure implements 15 | var _ dep.Dependency = (*VaultAgentTokenQuery)(nil) 16 | 17 | const ( 18 | // VaultAgentTokenSleepTime is the amount of time to sleep between queries, since 19 | // the fsnotify library is not compatible with solaris and other OSes yet. 20 | VaultAgentTokenSleepTime = 15 * time.Second 21 | ) 22 | 23 | // VaultAgentTokenQuery is the dependency to Vault Agent token 24 | type VaultAgentTokenQuery struct { 25 | stopCh chan struct{} 26 | stat os.FileInfo 27 | path string 28 | } 29 | 30 | // NewVaultAgentTokenQuery creates a new dependency. 31 | func NewVaultAgentTokenQuery(path string) (*VaultAgentTokenQuery, error) { 32 | return &VaultAgentTokenQuery{ 33 | stopCh: make(chan struct{}, 1), 34 | path: path, 35 | }, nil 36 | } 37 | 38 | // Fetch retrieves this dependency and returns the result or any errors that 39 | // occur in the process. 40 | func (d *VaultAgentTokenQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 41 | var token string 42 | select { 43 | case <-d.stopCh: 44 | return "", nil, dep.ErrStopped 45 | case r := <-d.watch(d.stat): 46 | 47 | if r.err != nil { 48 | return "", nil, errors.Wrap(r.err, d.ID()) 49 | } 50 | 51 | raw_token, err := os.ReadFile(d.path) 52 | if err != nil { 53 | return "", nil, errors.Wrap(err, d.ID()) 54 | } 55 | 56 | d.stat = r.stat 57 | token = string(raw_token) 58 | } 59 | 60 | return token, &dep.ResponseMetadata{ 61 | LastIndex: uint64(time.Now().Unix()), 62 | }, nil 63 | } 64 | 65 | // ID returns the human-friendly version of this dependency. 66 | func (d *VaultAgentTokenQuery) ID() string { 67 | return "vault-agent.token" 68 | } 69 | 70 | // Stop halts the dependency's fetch function. 71 | func (d *VaultAgentTokenQuery) Stop() { 72 | close(d.stopCh) 73 | } 74 | 75 | // Stringer interface reuses ID 76 | func (d *VaultAgentTokenQuery) String() string { 77 | return d.ID() 78 | } 79 | 80 | type watchResult struct { 81 | stat os.FileInfo 82 | err error 83 | } 84 | 85 | // watch watches the file for changes 86 | func (d *VaultAgentTokenQuery) watch(lastStat os.FileInfo) <-chan *watchResult { 87 | ch := make(chan *watchResult, 1) 88 | 89 | go func(lastStat os.FileInfo) { 90 | for { 91 | stat, err := os.Stat(d.path) 92 | if err != nil { 93 | select { 94 | case <-d.stopCh: 95 | return 96 | case ch <- &watchResult{err: err}: 97 | return 98 | } 99 | } 100 | 101 | changed := lastStat == nil || 102 | lastStat.Size() != stat.Size() || 103 | lastStat.ModTime() != stat.ModTime() 104 | 105 | if changed { 106 | select { 107 | case <-d.stopCh: 108 | return 109 | case ch <- &watchResult{stat: stat}: 110 | return 111 | } 112 | } 113 | 114 | time.Sleep(VaultAgentTokenSleepTime) 115 | } 116 | }(lastStat) 117 | 118 | return ch 119 | } 120 | -------------------------------------------------------------------------------- /vaulttoken/vault_agent_token_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package vaulttoken 5 | 6 | import ( 7 | "io/ioutil" 8 | "os" 9 | "path/filepath" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/hashicorp/hcat" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func TestVaultAgentTokenQuery_Fetch(t *testing.T) { 18 | // Don't use t.Parallel() here as the SetToken() calls are global and break 19 | // other tests if run in parallel 20 | 21 | // reset token back to original 22 | vc := testClients.Vault() 23 | defer vc.SetToken(vc.Token()) 24 | 25 | // Set up the Vault token file. 26 | tokenFile, err := ioutil.TempFile("", "token1") 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | defer os.Remove(tokenFile.Name()) 31 | testWrite(tokenFile.Name(), []byte("token")) 32 | 33 | d, err := NewVaultAgentTokenQuery(tokenFile.Name()) 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | 38 | clientSet := testClients 39 | token, _, err := d.Fetch(clientSet) 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | 44 | assert.Equal(t, "token", token) 45 | 46 | // Update the contents. 47 | testWrite(tokenFile.Name(), []byte("another_token")) 48 | token, _, err = d.Fetch(clientSet) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | assert.Equal(t, "another_token", token) 54 | } 55 | 56 | func TestVaultAgentTokenQuery_Fetch_missingFile(t *testing.T) { 57 | t.Parallel() 58 | 59 | // Use a non-existant token file path. 60 | d, err := NewVaultAgentTokenQuery("/tmp/invalid-file") 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | clientSet := hcat.NewClientSet() 66 | clientSet.AddVault(hcat.VaultInput{ 67 | Token: "foo", 68 | }) 69 | _, _, err = d.Fetch(clientSet) 70 | if err == nil || !strings.Contains(err.Error(), "no such file") { 71 | t.Fatal(err) 72 | } 73 | 74 | // Token should be unaffected. 75 | assert.Equal(t, "foo", clientSet.Vault().Token()) 76 | } 77 | 78 | func testWrite(path string, contents []byte) error { 79 | if path == "" { 80 | panic("missing path") 81 | } 82 | 83 | parent := filepath.Dir(path) 84 | if _, err := os.Stat(parent); os.IsNotExist(err) { 85 | if err := os.MkdirAll(parent, 0o755); err != nil { 86 | return err 87 | } 88 | } 89 | 90 | f, err := ioutil.TempFile(parent, "") 91 | if err != nil { 92 | return err 93 | } 94 | defer os.Remove(f.Name()) 95 | 96 | if _, err := f.Write(contents); err != nil { 97 | return err 98 | } 99 | 100 | for _, err := range []error{ 101 | f.Sync(), 102 | f.Close(), 103 | os.Chmod(f.Name(), 0o644), 104 | os.Rename(f.Name(), path), 105 | } { 106 | if err != nil { 107 | return err 108 | } 109 | } 110 | 111 | return nil 112 | } 113 | -------------------------------------------------------------------------------- /vaulttoken/vault_token.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package vaulttoken 5 | 6 | import ( 7 | "github.com/hashicorp/hcat/dep" 8 | "github.com/hashicorp/vault/api" 9 | "github.com/pkg/errors" 10 | ) 11 | 12 | // Ensure implements 13 | var _ dep.Dependency = (*VaultTokenQuery)(nil) 14 | 15 | // VaultTokenQuery is the dependency to Vault for a secret 16 | type VaultTokenQuery struct { 17 | stopCh chan struct{} 18 | secret *api.Secret 19 | } 20 | 21 | // NewVaultTokenQuery creates a new dependency. 22 | func NewVaultTokenQuery(token string) (*VaultTokenQuery, error) { 23 | secret := &api.Secret{ 24 | Auth: &api.SecretAuth{ 25 | ClientToken: token, 26 | Renewable: true, 27 | LeaseDuration: 1, 28 | }, 29 | } 30 | return &VaultTokenQuery{ 31 | stopCh: make(chan struct{}, 1), 32 | secret: secret, 33 | }, nil 34 | } 35 | 36 | // Fetch queries the Vault API 37 | func (d *VaultTokenQuery) Fetch(clients dep.Clients) (interface{}, *dep.ResponseMetadata, error) { 38 | select { 39 | case <-d.stopCh: 40 | return nil, nil, dep.ErrStopped 41 | default: 42 | } 43 | 44 | vaultSecretRenewable := d.secret.Renewable 45 | if d.secret.Auth != nil { 46 | vaultSecretRenewable = d.secret.Auth.Renewable 47 | } 48 | 49 | // ??? event/log if this runs and vaultSecretRenewable is false 50 | 51 | if vaultSecretRenewable { 52 | err := d.renewSecret(clients) 53 | if err != nil { 54 | return nil, nil, errors.Wrap(err, d.ID()) 55 | } 56 | } 57 | 58 | return nil, nil, dep.ErrLeaseExpired 59 | } 60 | 61 | func (d *VaultTokenQuery) renewSecret(clients dep.Clients) error { 62 | renewer, err := clients.Vault().NewRenewer(&api.RenewerInput{ 63 | Secret: d.secret, 64 | }) 65 | if err != nil { 66 | return err 67 | } 68 | go renewer.Renew() 69 | defer renewer.Stop() 70 | 71 | for { 72 | select { 73 | case err := <-renewer.DoneCh(): 74 | return err 75 | case renewal := <-renewer.RenewCh(): 76 | d.secret = renewal.Secret 77 | case <-d.stopCh: 78 | return dep.ErrStopped 79 | } 80 | } 81 | } 82 | 83 | // Stop halts the dependency's fetch function. 84 | func (d *VaultTokenQuery) Stop() { 85 | close(d.stopCh) 86 | } 87 | 88 | // ID returns the human-friendly version of this dependency. 89 | func (d *VaultTokenQuery) ID() string { 90 | return "vault.token" 91 | } 92 | 93 | // Stringer interface reuses ID 94 | func (d *VaultTokenQuery) String() string { 95 | return d.ID() 96 | } 97 | -------------------------------------------------------------------------------- /vaulttoken/vault_token_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package vaulttoken 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/hashicorp/vault/api" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestNewVaultTokenQuery(t *testing.T) { 15 | t.Parallel() 16 | 17 | cases := []struct { 18 | name string 19 | exp *VaultTokenQuery 20 | err bool 21 | }{ 22 | { 23 | "default", 24 | &VaultTokenQuery{ 25 | secret: &api.Secret{ 26 | Auth: &api.SecretAuth{ 27 | ClientToken: "my-token", 28 | Renewable: true, 29 | LeaseDuration: 1, 30 | }, 31 | }, 32 | }, 33 | false, 34 | }, 35 | } 36 | 37 | for i, tc := range cases { 38 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 39 | act, err := NewVaultTokenQuery("my-token") 40 | if (err != nil) != tc.err { 41 | t.Fatal(err) 42 | } 43 | 44 | if act != nil { 45 | act.stopCh = nil 46 | } 47 | 48 | assert.Equal(t, tc.exp, act) 49 | }) 50 | } 51 | } 52 | 53 | func TestVaultTokenQuery_String(t *testing.T) { 54 | t.Parallel() 55 | 56 | cases := []struct { 57 | name string 58 | exp string 59 | }{ 60 | { 61 | "default", 62 | "vault.token", 63 | }, 64 | } 65 | 66 | for i, tc := range cases { 67 | t.Run(fmt.Sprintf("%d_%s", i, tc.name), func(t *testing.T) { 68 | d, err := NewVaultTokenQuery("my-token") 69 | if err != nil { 70 | t.Fatal(err) 71 | } 72 | assert.Equal(t, tc.exp, d.ID()) 73 | }) 74 | } 75 | } 76 | --------------------------------------------------------------------------------