├── .github └── workflows │ ├── codeql-analysis.yml │ └── go.yml ├── LICENSE ├── README.md ├── client.go ├── client ├── add.go ├── command.go ├── create.go ├── daemon.go ├── daemon_test.go ├── deactivate.go ├── delete.go ├── flock_unix.go ├── flock_windows.go ├── get.go ├── getacl.go ├── getkeys.go ├── getversions.go ├── help.go ├── knox-upstart.conf ├── list_key_templates.go ├── login.go ├── promote.go ├── reactivate.go ├── register.go ├── register_test.go ├── restart_knox.sh ├── tink_keyset_helper.go ├── tink_keyset_helper_test.go ├── unregister.go ├── updateaccess.go └── version.go ├── client_test.go ├── cmd ├── dev_client │ └── main.go ├── dev_server │ └── main.go └── migrate_db │ └── main.go ├── go.mod ├── go.sum ├── knox.go ├── knox_test.go ├── log ├── log.go └── log_test.go └── server ├── api.go ├── api_test.go ├── auth ├── auth.go ├── auth_test.go └── spiffe.go ├── decorators.go ├── http_test.go ├── key_manager.go ├── key_manager_test.go ├── keydb ├── cryptor.go ├── cryptor_test.go ├── keydb.go └── keydb_test.go ├── routes.go └── routes_test.go /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ master ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ master ] 20 | schedule: 21 | - cron: '16 16 * * 4' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'go' ] 36 | 37 | steps: 38 | - name: Set up Go 1.x 39 | uses: actions/setup-go@v2 40 | with: 41 | go-version: ^1.14.6 42 | id: go 43 | 44 | - name: Checkout repository 45 | uses: actions/checkout@v2 46 | 47 | # Initializes the CodeQL tools for scanning. 48 | - name: Initialize CodeQL 49 | uses: github/codeql-action/init@v1 50 | with: 51 | languages: ${{ matrix.language }} 52 | 53 | - name: Build 54 | run: go build -v ./... 55 | 56 | - name: Perform CodeQL Analysis 57 | uses: github/codeql-action/analyze@v1 58 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | permissions: read-all 4 | 5 | on: 6 | push: 7 | branches: [ master ] 8 | pull_request: 9 | branches: [ master ] 10 | 11 | jobs: 12 | 13 | build: 14 | name: Build 15 | runs-on: ubuntu-latest 16 | steps: 17 | 18 | - name: Set up Go 1.x 19 | uses: actions/setup-go@v2 20 | with: 21 | go-version: ^1.14.6 22 | id: go 23 | 24 | - name: Checkout 25 | uses: actions/checkout@v2 26 | 27 | - name: Build 28 | run: go build -v ./... 29 | 30 | - name: Test 31 | run: go test -v ./... 32 | 33 | - name: Fmt 34 | run: diff -u <(echo -n) <(gofmt -d -s .) 35 | 36 | - name: vet 37 | run: go vet ./... 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 Pinterest 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Knox -- the high level overview 2 | Knox is a service for storing and rotation of secrets, keys, and passwords used by other services. 3 | 4 | ## The Problem Knox is Meant to Solve 5 | Pinterest has a plethora of keys or secrets doing things like signing cookies, encrypting data, protecting our network via TLS, accessing our AWS machines, communicating with our third parties, and many more. If these keys become compromised, rotating (or changing our keys) used to be a difficult process generally involving a deploy and likely a code change. Keys/secrets within Pinterest were stored in git repositories. This means they were copied all over our company's infrastructure and present on many of our employees laptops. There was no way to audit who accessed or who has access to the keys. Knox was built to solve these problems. 6 | 7 | The goals of Knox are: 8 | - Ease of use for developers to access/use confidential secrets, keys, and credentials 9 | - Confidentiality for secrets, keys, and credentials 10 | - Provide mechanisms for key rotation in case of compromise 11 | - Create audit log to keep track of what systems and users access confidential data 12 | 13 | Read more at https://github.com/pinterest/knox/wiki 14 | 15 | ## Getting knox set up 16 | The first step is to install Go (or use Docker, see below). We require Go >= 1.6 or Go 1.5 with the vendor flag enabled (`GO15VENDOREXPERIMENT=1`). For instructions on setting up Go, please visit https://golang.org/doc/install 17 | 18 | After Go is set up (including a `$GOPATH` directory that will store your workspace), please run `go get -d github.com/pinterest/knox` to get the latest version of the knox code. 19 | 20 | To compile the devserver and devclient binaries, run `go install github.com/pinterest/knox/cmd/dev_server` and `go install github.com/pinterest/knox/cmd/dev_client`. These can be directly executed, the dev_client expects the server to be running on a localhost. By default, the client uses mTLS with a hardcoded signed cert given for example.com for machine authentication and had github authentication enabled for users. 21 | 22 | To start your server run: 23 | ```sh 24 | $GOPATH/bin/dev_server 25 | ``` 26 | 27 | For using this client as a user, generate a token via these instructions https://help.github.com/articles/creating-an-access-token-for-command-line-use/ with read:org permissions. This token will be able to get your username and the organization you belong to. With the dev_server running you can now create your first knox key. 28 | 29 | ```sh 30 | export KNOX_USER_AUTH= 31 | echo -n "My first knox secret" | $GOPATH/bin/dev_client create test_service:first_secret 32 | ``` 33 | 34 | You can retrieve the secret using: 35 | ```sh 36 | $GOPATH/bin/dev_client get test_service:first_secret 37 | ``` 38 | 39 | You can see all key IDs using: 40 | ```sh 41 | $GOPATH/bin/dev_client keys 42 | ``` 43 | 44 | To see all available commands run: 45 | ```sh 46 | $GOPATH/bin/dev_client help 47 | ``` 48 | 49 | For production usage, I recommend making your own client, renaming it `knox`, and moving it into you $PATH for ease of use. 50 | 51 | For more information on interacting with knox, use `knox help` or go to https://github.com/pinterest/knox/wiki/Knox-Client 52 | 53 | ## Knox with Docker 54 | 55 | You can run a Docker container to get knox set up, instead of installing Go on your host. 56 | 57 | ```sh 58 | git clone https://github.com/pinterest/knox.git 59 | cd knox 60 | docker run --name knox --rm -v "$PWD":/go/src/github.com/pinterest/knox -it golang /bin/bash 61 | ``` 62 | 63 | This will run a bash shell into the container, mounting a local copy of knox in the go source path. 64 | 65 | You can refer back to the section "Getting knox set up" to set up knox. 66 | -------------------------------------------------------------------------------- /client/add.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pinterest/knox" 7 | ) 8 | 9 | func init() { 10 | cmdAdd.Run = runAdd // break init cycle 11 | } 12 | 13 | var cmdAdd = &Command{ 14 | UsageLine: "add [--key-template template_name] ", 15 | Short: "adds a new key version to knox", 16 | Long: ` 17 | Add will add a new key version to an existing key in knox. Key data of new version should be sent to stdin unless a key-template is specified. 18 | 19 | First way: key data of new version is sent to stdin. 20 | Please run "knox add ". 21 | 22 | Second way: the key-template option can be used to specify a template to generate the new key version, instead of stdin. For available key templates, run "knox key-templates". 23 | Please run "knox add --key-template ". 24 | 25 | This key version will be set to active upon creation. The version id will be sent to stdout on creation. 26 | 27 | This command uses user access and requires write access in the key's ACL. 28 | 29 | For more about knox, see https://github.com/pinterest/knox. 30 | 31 | See also: knox create, knox promote 32 | `, 33 | } 34 | var addTinkKeyset = cmdAdd.Flag.String("key-template", "", "name of a knox-supported Tink key template") 35 | 36 | func runAdd(cmd *Command, args []string) *ErrorStatus { 37 | if len(args) != 1 { 38 | return &ErrorStatus{fmt.Errorf("add takes only one argument. See 'knox help add'"), false} 39 | } 40 | keyID := args[0] 41 | var data []byte 42 | var err error 43 | if *addTinkKeyset != "" { 44 | data, err = getDataWithTemplate(*addTinkKeyset, keyID) 45 | } else { 46 | data, err = readDataFromStdin() 47 | } 48 | if err != nil { 49 | return &ErrorStatus{err, false} 50 | } 51 | versionID, err := cli.AddVersion(keyID, data) 52 | if err != nil { 53 | return &ErrorStatus{fmt.Errorf("Error adding version: %s", err.Error()), true} 54 | } 55 | fmt.Printf("Added key version %d\n", versionID) 56 | return nil 57 | } 58 | 59 | // getDataWithTemplate returns the data for a new version of a knox identifier that stores Tink keyset. 60 | func getDataWithTemplate(templateName string, keyID string) ([]byte, error) { 61 | err := obeyNamingRule(templateName, keyID) 62 | if err != nil { 63 | return nil, err 64 | } 65 | // get all versions (primary, active, inactive) of this knox identifier 66 | allVersions, err := cli.NetworkGetKeyWithStatus(keyID, knox.Inactive) 67 | if err != nil { 68 | return nil, fmt.Errorf("error getting key: %s", err.Error()) 69 | } 70 | return addNewTinkKeyset(tinkKeyTemplates[templateName].templateFunc, allVersions.VersionList) 71 | } 72 | -------------------------------------------------------------------------------- /client/command.go: -------------------------------------------------------------------------------- 1 | // This file uses code from http://golang.org/src/cmd/go/main.go 2 | // modified for use with Knox 3 | // 4 | // Copyright (c) 2012 The Go Authors. All rights reserved. 5 | 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions are 8 | // met: 9 | 10 | // * Redistributions of source code must retain the above copyright 11 | // notice, this list of conditions and the following disclaimer. 12 | // * Redistributions in binary form must reproduce the above 13 | // copyright notice, this list of conditions and the following disclaimer 14 | // in the documentation and/or other materials provided with the 15 | // distribution. 16 | // * Neither the name of Google Inc. nor the names of its 17 | // contributors may be used to endorse or promote products derived from 18 | // this software without specific prior written permission. 19 | 20 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | package client 33 | 34 | import ( 35 | "flag" 36 | "fmt" 37 | "io" 38 | "os" 39 | "strings" 40 | "sync" 41 | "text/template" 42 | "unicode" 43 | "unicode/utf8" 44 | 45 | "github.com/pinterest/knox" 46 | ) 47 | 48 | const defaultTokenFileLocation = ".knox_user_auth" 49 | 50 | var cli knox.APIClient 51 | 52 | // VisibilityParams exposes functions for the knox client to provide information 53 | type VisibilityParams struct { 54 | Logf func(string, ...interface{}) 55 | Errorf func(string, ...interface{}) 56 | SummaryMetrics func(map[string]uint64) 57 | InvokeMetrics func(map[string]string) 58 | GetKeyMetrics func(map[string]string) 59 | } 60 | 61 | var logf = func(string, ...interface{}) {} 62 | var errorf = func(string, ...interface{}) {} 63 | var daemonReportMetrics = func(map[string]uint64) {} 64 | var clientInvokeMetrics = func(map[string]string) {} 65 | var clientGetKeyMetrics = func(map[string]string) {} 66 | 67 | // Run is how to execute commands. It uses global variables and isn't safe to call in parallel. 68 | func Run( 69 | client knox.APIClient, 70 | p *VisibilityParams, 71 | loginCommand *Command, 72 | ) { 73 | 74 | cli = client 75 | if p != nil { 76 | if p.Logf != nil { 77 | logf = p.Logf 78 | } 79 | if p.Errorf != nil { 80 | errorf = p.Errorf 81 | } 82 | if p.SummaryMetrics != nil { 83 | daemonReportMetrics = p.SummaryMetrics 84 | } 85 | if p.InvokeMetrics != nil { 86 | clientInvokeMetrics = p.InvokeMetrics 87 | } 88 | if p.GetKeyMetrics != nil { 89 | clientGetKeyMetrics = p.GetKeyMetrics 90 | } 91 | if loginCommand == nil { 92 | fatalf("A login command was not supplied, you must supply a login command.") 93 | } 94 | } 95 | commands = append(commands, loginCommand) 96 | flag.Usage = usage 97 | flag.Parse() 98 | 99 | args := flag.Args() 100 | if len(args) < 1 { 101 | usage() 102 | } 103 | 104 | if args[0] == "help" { 105 | help(args[1:]) 106 | return 107 | } 108 | 109 | for _, cmd := range commands { 110 | if cmd.Name() == args[0] && cmd.Run != nil { 111 | cmd.Flag.Usage = func() { cmd.Usage() } 112 | if cmd.CustomFlags { 113 | args = args[1:] 114 | } else { 115 | cmd.Flag.Parse(args[1:]) 116 | args = cmd.Flag.Args() 117 | } 118 | errorStatus := cmd.Run(cmd, args) 119 | var metricsKey string 120 | if errorStatus != nil { 121 | if errorStatus.serverError { 122 | metricsKey = "failure" 123 | } else { 124 | metricsKey = "ignored_failure" 125 | } 126 | } else { 127 | metricsKey = "success" 128 | } 129 | clientInvokeMetrics(map[string]string{ 130 | "metrics_key": metricsKey, 131 | "method_name": fmt.Sprintf("client_%s", cmd.Name()), 132 | }) 133 | if errorStatus != nil { 134 | fatalf(errorStatus.Error()) 135 | } 136 | exit() 137 | return 138 | } 139 | } 140 | 141 | fmt.Fprintf(os.Stderr, "knox: unknown subcommand %q\nRun 'knox help' for usage.\n", args[0]) 142 | setExitStatus(2) 143 | exit() 144 | } 145 | 146 | // Commands lists the available commands and help topics. 147 | // The order here is the order in which they are printed by 'knox help'. 148 | var commands = []*Command{ 149 | // These commands are related to running knox as a daemon. 150 | cmdDaemon, 151 | cmdRegister, 152 | cmdUnregister, 153 | 154 | // These commands are related to key management by users. 155 | cmdGetKeys, 156 | cmdGet, 157 | cmdGetVersions, 158 | cmdGetACL, 159 | cmdPromote, 160 | cmdCreate, 161 | cmdAdd, 162 | cmdDeactivate, 163 | cmdReactivate, 164 | cmdUpdateAccess, 165 | cmdDelete, 166 | 167 | // These are additional help topics 168 | cmdListKeyTemplates, 169 | cmdVersion, 170 | helpAuth, 171 | } 172 | 173 | // A Command is an implementation of a go command 174 | // like go build or go fix. 175 | type Command struct { 176 | // Run runs the command. 177 | // The args are the arguments after the command name. 178 | Run func(cmd *Command, args []string) *ErrorStatus 179 | 180 | // UsageLine is the one-line usage message. 181 | // The first word in the line is taken to be the command name. 182 | UsageLine string 183 | 184 | // Short is the short description shown in the 'go help' output. 185 | Short string 186 | 187 | // Long is the long message shown in the 'go help ' output. 188 | Long string 189 | 190 | // Flag is a set of flags specific to this command. 191 | Flag flag.FlagSet 192 | 193 | // CustomFlags indicates that the command will do its own 194 | // flag parsing. 195 | CustomFlags bool 196 | } 197 | 198 | // ErrorStatus wraps the error status of knox client command execution 199 | type ErrorStatus struct { 200 | error 201 | serverError bool 202 | } 203 | 204 | // Name returns the command's name: the first word in the usage line. 205 | func (c *Command) Name() string { 206 | name := c.UsageLine 207 | i := strings.Index(name, " ") 208 | if i >= 0 { 209 | name = name[:i] 210 | } 211 | return name 212 | } 213 | 214 | // Usage prints the help string for a command. 215 | func (c *Command) Usage() { 216 | fmt.Fprintf(os.Stderr, "usage: %s\n\n", c.UsageLine) 217 | fmt.Fprintf(os.Stderr, "%s\n", strings.TrimSpace(c.Long)) 218 | os.Exit(2) 219 | } 220 | 221 | // Runnable reports whether the command can be run; otherwise 222 | // it is a documentation pseudo-command such as importpath. 223 | func (c *Command) Runnable() bool { 224 | return c.Run != nil 225 | } 226 | 227 | var exitStatus = 0 228 | var exitMu sync.Mutex 229 | 230 | func setExitStatus(n int) { 231 | exitMu.Lock() 232 | if exitStatus < n { 233 | exitStatus = n 234 | } 235 | exitMu.Unlock() 236 | } 237 | 238 | // tmpl executes the given template text on data, writing the result to w. 239 | func tmpl(w io.Writer, text string, data interface{}) { 240 | t := template.New("top") 241 | t.Funcs(template.FuncMap{"trim": strings.TrimSpace, "capitalize": capitalize}) 242 | template.Must(t.Parse(text)) 243 | if err := t.Execute(w, data); err != nil { 244 | panic(err) 245 | } 246 | } 247 | 248 | func capitalize(s string) string { 249 | if s == "" { 250 | return s 251 | } 252 | r, n := utf8.DecodeRuneInString(s) 253 | return string(unicode.ToTitle(r)) + s[n:] 254 | } 255 | 256 | func printUsage(w io.Writer) { 257 | tmpl(w, usageTemplate, commands) 258 | } 259 | 260 | func usage() { 261 | // special case "go test -h" 262 | if len(os.Args) > 1 && os.Args[1] == "test" { 263 | help([]string{"testflag"}) 264 | os.Exit(2) 265 | } 266 | printUsage(os.Stderr) 267 | os.Exit(2) 268 | } 269 | 270 | func exit() { 271 | os.Exit(exitStatus) 272 | } 273 | 274 | func fatalf(format string, args ...interface{}) { 275 | errorf(format, args...) 276 | setExitStatus(1) 277 | exit() 278 | } 279 | -------------------------------------------------------------------------------- /client/create.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | 8 | "github.com/pinterest/knox" 9 | ) 10 | 11 | func init() { 12 | cmdCreate.Run = runCreate // break init cycle 13 | } 14 | 15 | var cmdCreate = &Command{ 16 | UsageLine: "create [--key-template template_name] ", 17 | Short: "creates a new key", 18 | Long: ` 19 | Create will create a new key in knox with input as the primary key version. Key data should be sent to stdin unless a key-template is specified. 20 | 21 | First way: key data is sent to stdin. 22 | Please run "knox create ". 23 | 24 | Second way: the key-template option can be used to specify a template to generate the initial primary key version, instead of stdin. For available key templates, run "knox key-templates". 25 | Please run "knox create --key-template ". 26 | 27 | The original key version id will be print to stdout. 28 | 29 | To create a new key, user credentials are required. The default access list will include the creator of this key and a limited set of site reliablity and security engineers. 30 | 31 | For more about knox, see https://github.com/pinterest/knox. 32 | 33 | See also: knox add, knox get 34 | `, 35 | } 36 | var createTinkKeyset = cmdCreate.Flag.String("key-template", "", "name of a knox-supported Tink key template") 37 | 38 | func runCreate(cmd *Command, args []string) *ErrorStatus { 39 | if len(args) != 1 { 40 | return &ErrorStatus{fmt.Errorf("create takes exactly one argument. See 'knox help create'"), false} 41 | } 42 | keyID := args[0] 43 | var data []byte 44 | var err error 45 | if *createTinkKeyset != "" { 46 | templateName := *createTinkKeyset 47 | err = obeyNamingRule(templateName, keyID) 48 | if err != nil { 49 | return &ErrorStatus{err, false} 50 | } 51 | data, err = createNewTinkKeyset(tinkKeyTemplates[templateName].templateFunc) 52 | } else { 53 | data, err = readDataFromStdin() 54 | } 55 | if err != nil { 56 | return &ErrorStatus{err, false} 57 | } 58 | // TODO(devinlundberg): allow ACL to be entered as input 59 | acl := knox.ACL{} 60 | versionID, err := cli.CreateKey(keyID, data, acl) 61 | if err != nil { 62 | return &ErrorStatus{fmt.Errorf("Error adding version: %s", err.Error()), true} 63 | } 64 | fmt.Printf("Created key with initial version %d\n", versionID) 65 | return nil 66 | } 67 | 68 | func readDataFromStdin() ([]byte, error) { 69 | fmt.Println("Reading from stdin...") 70 | data, err := io.ReadAll(os.Stdin) 71 | if err != nil { 72 | return data, fmt.Errorf("problem reading key data: %s", err.Error()) 73 | } 74 | return data, nil 75 | } 76 | -------------------------------------------------------------------------------- /client/daemon.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "math/rand" 10 | "os" 11 | "os/exec" 12 | "path" 13 | "runtime" 14 | "strings" 15 | "time" 16 | 17 | "gopkg.in/fsnotify.v1" 18 | 19 | "github.com/pinterest/knox" 20 | ) 21 | 22 | var cmdDaemon = &Command{ 23 | Run: runDaemon, 24 | UsageLine: "daemon", 25 | Short: "runs a process to keep keys in sync with server", 26 | Long: ` 27 | daemon runs the knox process that will keep keys in sync. 28 | 29 | This process will keep running until sent a kill signal or it crashes. 30 | 31 | This maintains a file system cache of knox keys that is used for all other knox commands. 32 | 33 | For more about knox, see https://github.com/pinterest/knox. 34 | 35 | See also: knox register, knox unregister 36 | `, 37 | } 38 | 39 | var daemonFolder = "/var/lib/knox" 40 | var daemonToRegister = "/.registered" 41 | var daemonKeys = "/v0/keys/" 42 | 43 | var lockTimeout = 10 * time.Second 44 | var lockRetryTime = 50 * time.Millisecond 45 | 46 | var defaultFilePermission os.FileMode = 0666 47 | var defaultDirPermission os.FileMode = 0777 48 | 49 | var daemonRefreshTime = 10 * time.Minute 50 | 51 | const tinkPrefix = "tink:" 52 | 53 | func runDaemon(cmd *Command, args []string) *ErrorStatus { 54 | 55 | if os.Getenv("KNOX_MACHINE_AUTH") == "" { 56 | hostname, err := os.Hostname() 57 | if err != nil { 58 | return &ErrorStatus{fmt.Errorf("You're on a host with no name: %s", err.Error()), false} 59 | } 60 | os.Setenv("KNOX_MACHINE_AUTH", hostname) 61 | } 62 | 63 | d := daemon{ 64 | dir: daemonFolder, 65 | registerFile: daemonToRegister, 66 | keysDir: daemonKeys, 67 | cli: cli, 68 | } 69 | err := d.initialize() 70 | if err != nil { 71 | return &ErrorStatus{err, false} 72 | } 73 | d.loop(daemonRefreshTime) 74 | return nil 75 | } 76 | 77 | type daemon struct { 78 | dir string 79 | registerFile string 80 | registerKeyFile Keys 81 | keysDir string 82 | cli knox.APIClient 83 | updateErrCount uint64 84 | getKeyErrCount uint64 85 | successCount uint64 86 | } 87 | 88 | func (d *daemon) loop(refresh time.Duration) { 89 | t := time.NewTicker(refresh) 90 | 91 | watcher, err := fsnotify.NewWatcher() 92 | if err != nil { 93 | fatalf("Unable to watch files: %s", err.Error()) 94 | } 95 | watcher.Add(d.registerFilename()) 96 | 97 | for { 98 | logf("Daemon updating all registered keys") 99 | start := time.Now() 100 | err := d.update() 101 | if err != nil { 102 | d.updateErrCount++ 103 | logf("Failed to update keys: %s", err.Error()) 104 | } else { 105 | d.successCount++ 106 | } 107 | logf("Update of keys completed after %d ms", time.Since(start).Milliseconds()) 108 | 109 | select { 110 | case event := <-watcher.Events: 111 | // On any change to register file 112 | logf("Got file watcher event: %s on %s", event.Op.String(), event.Name) 113 | case <-t.C: 114 | // add random jitter to prevent a stampede 115 | <-time.After(time.Duration(rand.Intn(10)) * time.Millisecond) 116 | daemonReportMetrics(map[string]uint64{ 117 | "err": d.updateErrCount, 118 | "get_err": d.getKeyErrCount, 119 | "success": d.successCount, 120 | }) 121 | } 122 | } 123 | } 124 | 125 | func (d *daemon) initialize() error { 126 | err := os.MkdirAll(d.dir, defaultDirPermission) 127 | if err != nil { 128 | return fmt.Errorf("Failed to initialize /var/lib/knox (run 'sudo mkdir /var/lib/knox'?): %s", err.Error()) 129 | } 130 | 131 | // Need to chmod due to a umask set on masterless puppet machines 132 | err = os.Chmod(d.dir, defaultDirPermission) 133 | if err != nil { 134 | return fmt.Errorf("Failed to open up directory permissions: %s", err.Error()) 135 | } 136 | err = os.MkdirAll(d.keyDir(), defaultDirPermission) 137 | if err != nil { 138 | return fmt.Errorf("Failed to make key folders: %s", err.Error()) 139 | } 140 | 141 | // Need to chmod due to a umask set on masterless puppet machines 142 | err = os.Chmod(d.keyDir(), defaultDirPermission) 143 | if err != nil { 144 | return fmt.Errorf("Failed to open up directory permissions: %s", err.Error()) 145 | } 146 | _, err = os.Stat(d.registerFilename()) 147 | if os.IsNotExist(err) { 148 | err := os.WriteFile(d.registerFilename(), []byte{}, defaultFilePermission) 149 | if err != nil { 150 | return fmt.Errorf("Failed to initialize registered key file: %s", err.Error()) 151 | } 152 | } else if err != nil { 153 | return err 154 | } 155 | 156 | // Need to chmod due to a umask set on masterless puppet machines 157 | err = os.Chmod(d.registerFilename(), defaultFilePermission) 158 | if err != nil { 159 | return fmt.Errorf("Failed to open up register file permissions: %s", err.Error()) 160 | } 161 | d.registerKeyFile = NewKeysFile(d.registerFilename()) 162 | return nil 163 | } 164 | 165 | func (d *daemon) update() error { 166 | err := d.registerKeyFile.Lock() 167 | if err != nil { 168 | return err 169 | } 170 | // defer this so that functions can update the register file. 171 | defer d.registerKeyFile.Unlock() 172 | keyIDs, err := d.registerKeyFile.Get() 173 | if err != nil { 174 | return err 175 | } 176 | logf("Requested keys: %s", keyIDs) 177 | 178 | keyMap := map[string]string{} 179 | existingKeys := map[string]bool{} 180 | for _, k := range keyIDs { 181 | // set default value to empty string 182 | keyMap[k] = "" 183 | existingKeys[k] = false 184 | } 185 | 186 | currentKeyIDs, err := d.currentRegisteredKeys() 187 | if err != nil { 188 | return err 189 | } 190 | logf("Current keys on disk: %s", currentKeyIDs) 191 | 192 | for _, keyID := range currentKeyIDs { 193 | existingKeys[keyID] = true 194 | 195 | if _, present := keyMap[keyID]; present { 196 | key, err := d.cli.CacheGetKey(keyID) 197 | if err != nil { 198 | // Keep going in spite of failure 199 | logf("error getting cache key: %s", err) 200 | // Remove existing cached key with invalid format (saved with previous version clients) 201 | if _, err = os.Stat(d.keyFilename(keyID)); err == nil { 202 | d.deleteKey(keyID) 203 | } 204 | } else { 205 | keyMap[keyID] = key.VersionHash 206 | } 207 | } else { 208 | d.deleteKey(keyID) 209 | } 210 | } 211 | 212 | if len(keyMap) > 0 { 213 | updatedKeys, err := d.cli.GetKeys(keyMap) 214 | if err != nil { 215 | return err 216 | } 217 | logf("Updated keys received from server: %s", updatedKeys) 218 | for _, k := range updatedKeys { 219 | err = d.processKey(k) 220 | existingKeys[k] = true 221 | 222 | if err != nil { 223 | // Keep going in spite of failure 224 | d.getKeyErrCount++ 225 | logf("error processing key: %s", err) 226 | } 227 | } 228 | } 229 | // Find out if we missed anything (useful for humans reading the logs) 230 | // If key was not processed, and is also not current, then it didn't exist 231 | notFound := []string{} 232 | for id, exists := range existingKeys { 233 | if !exists { 234 | notFound = append(notFound, id) 235 | } 236 | } 237 | logf("Keys not found on server: %s", notFound) 238 | 239 | return nil 240 | } 241 | 242 | func (d daemon) deleteKey(keyID string) error { 243 | return os.Remove(d.keyFilename(keyID)) 244 | } 245 | 246 | func (d daemon) currentRegisteredKeys() ([]string, error) { 247 | files, err := os.ReadDir(d.keyDir()) 248 | if err != nil { 249 | return nil, err 250 | } 251 | var out []string 252 | for _, f := range files { 253 | out = append(out, f.Name()) 254 | } 255 | return out, nil 256 | } 257 | 258 | func (d daemon) keyDir() string { 259 | return path.Join(d.dir, d.keysDir) 260 | } 261 | 262 | func (d daemon) registerFilename() string { 263 | return path.Join(d.dir, d.registerFile) 264 | } 265 | 266 | func (d daemon) keyFilename(id string) string { 267 | return path.Join(d.dir, d.keysDir, id) 268 | } 269 | 270 | func (d daemon) processKey(keyID string) error { 271 | key, err := d.cli.NetworkGetKey(keyID) 272 | if err != nil { 273 | if err.Error() == "User or machine not authorized" || err.Error() == "Key identifer does not exist" { 274 | // This removes keys that do not exist or the machine is unauthorized to access 275 | d.registerKeyFile.Remove([]string{keyID}) 276 | } 277 | return fmt.Errorf("Error getting key %s: %s", keyID, err.Error()) 278 | } 279 | // Do not cache any new keys if they have invalid content 280 | if key.ID == "" || key.ACL == nil || key.VersionList == nil || key.VersionHash == "" { 281 | return fmt.Errorf("invalid key content returned") 282 | } 283 | 284 | if strings.HasPrefix(keyID, tinkPrefix) { 285 | keysetHandle, _, err := getTinkKeysetHandleFromKnoxVersionList(key.VersionList) 286 | if err != nil { 287 | return fmt.Errorf("Error fetching keyset handle for this tink key %s: %s", keyID, err.Error()) 288 | } 289 | tinkKeyset, err := convertTinkKeysetHandleToBytes(keysetHandle) 290 | if err != nil { 291 | return fmt.Errorf("Error converting tink keyset handle to bytes %s: %s", keyID, err.Error()) 292 | } 293 | key.TinkKeyset = base64.StdEncoding.EncodeToString(tinkKeyset) 294 | } 295 | 296 | b, err := json.Marshal(key) 297 | if err != nil { 298 | return fmt.Errorf("Error marshalling key %s: %s", keyID, err.Error()) 299 | } 300 | // Write to tmpfile, mv to normal location. Close + rm on failures 301 | tmpFile, err := os.CreateTemp(d.dir, fmt.Sprintf(".*.%s.tmp", keyID)) 302 | if err != nil { 303 | return fmt.Errorf("Error opening tmp file for key %s: %s", keyID, err.Error()) 304 | } 305 | _, err = tmpFile.Write(b) 306 | if err != nil { 307 | tmpFile.Close() 308 | os.Remove(tmpFile.Name()) 309 | return fmt.Errorf("Error writing key %s to file: %s", keyID, err.Error()) 310 | } 311 | // Done writing 312 | tmpFile.Close() 313 | 314 | err = os.Rename(tmpFile.Name(), d.keyFilename(keyID)) 315 | if err != nil { 316 | os.Remove(tmpFile.Name()) 317 | return fmt.Errorf("Error renaming key %s temporary file: %s", keyID, err.Error()) 318 | } 319 | 320 | err = os.Chmod(d.keyFilename(keyID), defaultFilePermission) 321 | if err != nil { 322 | return fmt.Errorf("Failed to open up key file permissions: %s", err.Error()) 323 | } 324 | return nil 325 | } 326 | 327 | // Keys are an interface for storing a list of key ids (for use with the register file to provide locks) 328 | type Keys interface { 329 | Get() ([]string, error) 330 | Add([]string) error 331 | Overwrite([]string) error 332 | Remove([]string) error 333 | Lock() error 334 | Unlock() error 335 | } 336 | 337 | // KeysFile is an implementation of Keys based on the file system for the register file. 338 | type KeysFile struct { 339 | fn string 340 | *flock 341 | } 342 | 343 | // NewKeysFile takes in a filename and outputs an implementation of the Keys interface 344 | func NewKeysFile(fn string) Keys { 345 | return &KeysFile{fn, newFlock()} 346 | } 347 | 348 | // Lock performs the nonblocking syscall lock and retries until the global timeout is met. 349 | func (k *KeysFile) Lock() error { 350 | err := k.lock(k, defaultFilePermission, true, lockTimeout) 351 | 352 | // Timeout means someone else is using our lock, which is unusual. 353 | // Let's collect some extra debugging information to find out why. 354 | if err == ErrTimeout && runtime.GOOS == "linux" { 355 | lockHolders, err := identifyLockHolders(k.fn) 356 | if err != nil { 357 | logf("hit timeout, found lock holder information:\n%s", lockHolders) 358 | } 359 | } 360 | 361 | // Annotate error with path to file to make debugging easier 362 | if err != nil { 363 | return fmt.Errorf("unable to obtain lock on file '%s': %s", k.fn, err.Error()) 364 | } 365 | return nil 366 | } 367 | 368 | // Unlock performs the nonblocking syscall unlock and retries until the global timeout is met. 369 | func (k *KeysFile) Unlock() error { 370 | err := k.unlock(k) 371 | 372 | // Annotate error with path to file to make debugging easier 373 | if err != nil { 374 | return fmt.Errorf("unable to release lock on file '%s': %s", k.fn, err.Error()) 375 | } 376 | return nil 377 | } 378 | 379 | // Get will get the list of key ids. It expects Lock to have been called. 380 | func (k *KeysFile) Get() ([]string, error) { 381 | b, err := os.ReadFile(k.fn) 382 | if err != nil { 383 | return nil, err 384 | } 385 | return strings.Fields(string(b)), nil 386 | } 387 | 388 | // Remove will remove the input key ids from the list. It expects Lock to have been called. 389 | func (k *KeysFile) Remove(ks []string) error { 390 | oldKeys, err := k.Get() 391 | if err != nil { 392 | if os.IsNotExist(err) { 393 | oldKeys = []string{} 394 | } else { 395 | return err 396 | } 397 | } 398 | // Use a map to remove any duplicates 399 | newKeys := make(map[string]bool) 400 | for _, oldK := range oldKeys { 401 | removeIt := false 402 | for _, k := range ks { 403 | if k == oldK { 404 | removeIt = true 405 | break 406 | } 407 | } 408 | if !removeIt { 409 | newKeys[oldK] = true 410 | } 411 | } 412 | 413 | var buffer bytes.Buffer 414 | for k := range newKeys { 415 | buffer.WriteString(k) 416 | buffer.WriteByte('\n') 417 | } 418 | return os.WriteFile(k.fn, buffer.Bytes(), 0666) 419 | } 420 | 421 | // Add will add the key IDs to the list. It expects Lock to have been called. 422 | func (k *KeysFile) Add(ks []string) error { 423 | oldKeys, err := k.Get() 424 | if err != nil { 425 | if os.IsNotExist(err) { 426 | oldKeys = []string{} 427 | } else { 428 | return err 429 | } 430 | } 431 | // Use a map to remove any duplicates 432 | newKeys := make(map[string]bool) 433 | for _, k := range oldKeys { 434 | newKeys[k] = true 435 | } 436 | for _, k := range ks { 437 | newKeys[k] = true 438 | } 439 | if len(newKeys) == len(oldKeys) { 440 | // Do not write if there are no changes 441 | return nil 442 | } 443 | 444 | var buffer bytes.Buffer 445 | for k := range newKeys { 446 | buffer.WriteString(k) 447 | buffer.WriteByte('\n') 448 | } 449 | return os.WriteFile(k.fn, buffer.Bytes(), 0666) 450 | } 451 | 452 | // Overwrite deletes all existing values in the key list and writes the input. 453 | // It expects Lock to have been called. 454 | func (k *KeysFile) Overwrite(ks []string) error { 455 | // Use a map to remove any duplicates 456 | newKeys := make(map[string]bool) 457 | for _, k := range ks { 458 | newKeys[k] = true 459 | } 460 | 461 | var buffer bytes.Buffer 462 | for k := range newKeys { 463 | buffer.WriteString(k) 464 | buffer.WriteByte('\n') 465 | } 466 | return os.WriteFile(k.fn, buffer.Bytes(), 0666) 467 | } 468 | 469 | func identifyLockHolders(filename string) (string, error) { 470 | if runtime.GOOS != "linux" { 471 | return "", errors.New("error identifying lock holder: works only on linux") 472 | } 473 | 474 | cmd := exec.Command("lsof", filename) 475 | out, err := cmd.CombinedOutput() 476 | if err != nil { 477 | return string(out), fmt.Errorf("error identifying lock holder: %s", err.Error()) 478 | } 479 | 480 | return string(out), nil 481 | } 482 | -------------------------------------------------------------------------------- /client/deactivate.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pinterest/knox" 7 | ) 8 | 9 | var cmdDeactivate = &Command{ 10 | Run: runDeactivate, 11 | UsageLine: "deactivate ", 12 | Short: "deactivates a key version", 13 | Long: ` 14 | Deactivate takes an active key version and makes it inactive. 15 | 16 | Inactive keys should not be used at all for any operation. 17 | 18 | Primary keys cannot be deactivated. Only active keys can be deactivated. 19 | 20 | This command requires write access to the key. 21 | 22 | For more about knox, see https://github.com/pinterest/knox. 23 | 24 | See also: knox reactivate, knox promote 25 | `, 26 | } 27 | 28 | func runDeactivate(cmd *Command, args []string) *ErrorStatus { 29 | if len(args) != 2 { 30 | return &ErrorStatus{fmt.Errorf("deactivate takes exactly two argument. See 'knox help deactivate'"), false} 31 | } 32 | keyID := args[0] 33 | keyVersion := args[1] 34 | 35 | err := cli.UpdateVersion(keyID, keyVersion, knox.Inactive) 36 | if err != nil { 37 | return &ErrorStatus{fmt.Errorf("Error updating version: %s", err.Error()), true} 38 | } 39 | fmt.Printf("Deactivated %s successfully.\n", keyVersion) 40 | return nil 41 | } 42 | -------------------------------------------------------------------------------- /client/delete.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | var cmdDelete = &Command{ 8 | Run: runDelete, 9 | UsageLine: "delete ", 10 | Short: "deletes an existing key", 11 | Long: ` 12 | This will delete your key and all data from the knox server. This operation is dangerous and requires admin permissions 13 | 14 | For more about knox, see https://github.com/pinterest/knox. 15 | 16 | See also: knox create 17 | `, 18 | } 19 | 20 | func runDelete(cmd *Command, args []string) *ErrorStatus { 21 | if len(args) != 1 { 22 | return &ErrorStatus{fmt.Errorf("create takes exactly one argument. See 'knox help delete'"), false} 23 | } 24 | 25 | err := cli.DeleteKey(args[0]) 26 | if err != nil { 27 | return &ErrorStatus{fmt.Errorf("Error deleting key: %s", err.Error()), true} 28 | } 29 | fmt.Printf("Successfully deleted key\n") 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /client/flock_unix.go: -------------------------------------------------------------------------------- 1 | // Based on https://github.com/boltdb/bolt/blob/master/bolt_unix.go 2 | // Copyright boltdb authors 3 | 4 | //go:build !windows && !plan9 && !solaris 5 | // +build !windows,!plan9,!solaris 6 | 7 | package client 8 | 9 | import ( 10 | "errors" 11 | "os" 12 | "syscall" 13 | "time" 14 | ) 15 | 16 | // ErrTimeout is returned when we cannot obtain an exclusive lock 17 | // on the key file. 18 | var ErrTimeout = errors.New("timeout waiting on lock to become available") 19 | 20 | type flock struct { 21 | fd int 22 | } 23 | 24 | func newFlock() *flock { 25 | return &flock{-1} 26 | } 27 | 28 | // lock acquires an advisory lock on a file descriptor. 29 | func (f *flock) lock(k *KeysFile, mode os.FileMode, exclusive bool, timeout time.Duration) error { 30 | var t time.Time 31 | for { 32 | // If we're beyond our timeout then return an error. 33 | // This can only occur after we've attempted a lock once. 34 | if t.IsZero() { 35 | t = time.Now() 36 | } else if timeout > 0 && time.Since(t) > timeout { 37 | return ErrTimeout 38 | } 39 | flag := syscall.LOCK_SH 40 | if exclusive { 41 | flag = syscall.LOCK_EX 42 | } 43 | 44 | // Otherwise attempt to obtain an exclusive lock. 45 | fd, err := f.getFD(k) 46 | if err != nil { 47 | return err 48 | } 49 | err = syscall.Flock(fd, flag|syscall.LOCK_NB) 50 | if err == nil { 51 | return nil 52 | } else if err != syscall.EWOULDBLOCK { 53 | return err 54 | } 55 | 56 | // Wait for a bit and try again. 57 | time.Sleep(50 * time.Millisecond) 58 | } 59 | } 60 | 61 | // unlock releases an advisory lock on a file descriptor. 62 | func (f *flock) unlock(k *KeysFile) error { 63 | return syscall.Flock(f.fd, syscall.LOCK_UN) 64 | } 65 | 66 | func (f *flock) getFD(k *KeysFile) (int, error) { 67 | if f.fd != -1 { 68 | return f.fd, nil 69 | } 70 | fd, err := syscall.Open(k.fn, syscall.O_RDWR, 0) 71 | if err != nil { 72 | return -1, err 73 | } 74 | f.fd = fd 75 | return f.fd, nil 76 | } 77 | -------------------------------------------------------------------------------- /client/flock_windows.go: -------------------------------------------------------------------------------- 1 | // Based on https://github.com/boltdb/bolt/blob/master/bolt_windows.go 2 | // Copyright boltdb authors 3 | 4 | package client 5 | 6 | import ( 7 | "errors" 8 | "os" 9 | "syscall" 10 | "time" 11 | "unsafe" 12 | ) 13 | 14 | // LockFileEx code derived from golang build filemutex_windows.go @ v1.5.1 15 | var ( 16 | modkernel32 = syscall.NewLazyDLL("kernel32.dll") 17 | procLockFileEx = modkernel32.NewProc("LockFileEx") 18 | procUnlockFileEx = modkernel32.NewProc("UnlockFileEx") 19 | 20 | // ErrTimeout is returned when we cannot obtain an exclusive lock 21 | // on the key file. 22 | ErrTimeout = errors.New("timeout waiting on lock to become available") 23 | ) 24 | 25 | const ( 26 | lockExt = ".lock" 27 | 28 | // see https://msdn.microsoft.com/en-us/library/windows/desktop/aa365203(v=vs.85).aspx 29 | flagLockExclusive = 2 30 | flagLockFailImmediately = 1 31 | 32 | // see https://msdn.microsoft.com/en-us/library/windows/desktop/ms681382(v=vs.85).aspx 33 | errLockViolation syscall.Errno = 0x21 34 | ) 35 | 36 | func lockFileEx(h syscall.Handle, flags, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) { 37 | r, _, err := procLockFileEx.Call(uintptr(h), uintptr(flags), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol))) 38 | if r == 0 { 39 | return err 40 | } 41 | return nil 42 | } 43 | 44 | func unlockFileEx(h syscall.Handle, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) { 45 | r, _, err := procUnlockFileEx.Call(uintptr(h), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol)), 0) 46 | if r == 0 { 47 | return err 48 | } 49 | return nil 50 | } 51 | 52 | type flock struct { 53 | lockfile *os.File 54 | } 55 | 56 | func newFlock() *flock { 57 | return &flock{&os.File{}} 58 | } 59 | 60 | // lock acquires an advisory lock on a file descriptor. 61 | func (f *flock) lock(k *KeysFile, mode os.FileMode, exclusive bool, timeout time.Duration) error { 62 | // Create a separate lock file on windows because a process 63 | // cannot share an exclusive lock on the same file. 64 | file, err := os.OpenFile(k.fn+lockExt, os.O_CREATE, mode) 65 | if err != nil { 66 | return err 67 | } 68 | f.lockfile = file 69 | var t time.Time 70 | for { 71 | // If we're beyond our timeout then return an error. 72 | // This can only occur after we've attempted a lock once. 73 | if t.IsZero() { 74 | t = time.Now() 75 | } else if timeout > 0 && time.Since(t) > timeout { 76 | return ErrTimeout 77 | } 78 | 79 | var flag uint32 = flagLockFailImmediately 80 | if exclusive { 81 | flag |= flagLockExclusive 82 | } 83 | err = lockFileEx(syscall.Handle(f.lockfile.Fd()), flag, 0, 1, 0, &syscall.Overlapped{}) 84 | if err == nil { 85 | return nil 86 | } else if err != errLockViolation { 87 | return err 88 | } 89 | 90 | // Wait for a bit and try again. 91 | time.Sleep(50 * time.Millisecond) 92 | } 93 | } 94 | 95 | // unlock releases an advisory lock on a file descriptor. 96 | func (f *flock) unlock(k *KeysFile) error { 97 | err := unlockFileEx(syscall.Handle(f.lockfile.Fd()), 0, 1, 0, &syscall.Overlapped{}) 98 | f.lockfile.Close() 99 | os.Remove(k.fn + lockExt) 100 | return err 101 | } 102 | -------------------------------------------------------------------------------- /client/get.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "strconv" 8 | 9 | "github.com/pinterest/knox" 10 | ) 11 | 12 | func init() { 13 | cmdGet.Run = runGet // break init cycle 14 | } 15 | 16 | var cmdGet = &Command{ 17 | UsageLine: "get [-v key_version] [-n] [-j] [-a] [--tink-keyset] [--tink-keyset-info] ", 18 | Short: "get a knox key", 19 | Long: ` 20 | Get gets the key data for a key. 21 | 22 | -v specifies the key_version to get. If it is not provided, this returns the primary version. 23 | -j returns the json version of the key as specified in the knox API. 24 | -n forces a network call. This will avoid cache issues where the ACL is out of date. 25 | -a returns all key versions (including inactive ones). Only works when -j is specified. 26 | --tink-keyset retrieve all the primary and active versions of this identifier in knox, combine them, and return one tink keyset. Force to retrieve tink keyset if -n is specified. 27 | --tink-keyset-info retrieves keyset metadata for primary and active versions without revealing the secret keys. Force to retrieve tink keyset metadata if -n is specified. 28 | 29 | This requires read access to the key. 30 | 31 | For more about knox, see https://github.com/pinterest/knox. 32 | 33 | See also: knox create, knox daemon, knox register, knox keys 34 | `, 35 | } 36 | var getVersion = cmdGet.Flag.String("v", "", "") 37 | var getJSON = cmdGet.Flag.Bool("j", false, "") 38 | var getNetwork = cmdGet.Flag.Bool("n", false, "") 39 | var getAll = cmdGet.Flag.Bool("a", false, "") 40 | var getTinkKeyset = cmdGet.Flag.Bool("tink-keyset", false, "get the stored tink keyset of the given knox identifier entirely") 41 | var getTinkKeysetInfo = cmdGet.Flag.Bool("tink-keyset-info", false, "get the metadata of the stored tink keyset of the given knox identifier") 42 | 43 | func successGetKeyMetric(keyID string) { 44 | clientGetKeyMetrics(map[string]string{ 45 | "key_id": keyID, 46 | "access_result": "success", 47 | "failure_reason": "", 48 | }) 49 | } 50 | 51 | func failureGetKeyMetric(keyID string, err error) { 52 | clientGetKeyMetrics(map[string]string{ 53 | "key_id": keyID, 54 | "access_result": "failure", 55 | "failure_reason": err.Error(), 56 | }) 57 | } 58 | 59 | func runGet(cmd *Command, args []string) *ErrorStatus { 60 | if len(args) != 1 { 61 | return &ErrorStatus{fmt.Errorf("get takes only one argument. See 'knox help get'"), false} 62 | } 63 | keyID := args[0] 64 | 65 | var err error 66 | var key *knox.Key 67 | if *getTinkKeyset { 68 | tinkKeysetInBytes, err := retrieveTinkKeyset(keyID, *getNetwork) 69 | if err != nil { 70 | failureGetKeyMetric(keyID, err) 71 | return err 72 | } 73 | fmt.Printf("%s", string(tinkKeysetInBytes)) 74 | successGetKeyMetric(keyID) 75 | return nil 76 | } 77 | if *getTinkKeysetInfo { 78 | tinkKeysetInfo, err := retrieveTinkKeysetInfo(keyID, *getNetwork) 79 | if err != nil { 80 | failureGetKeyMetric(keyID, err) 81 | return err 82 | } 83 | fmt.Println(tinkKeysetInfo) 84 | successGetKeyMetric(keyID) 85 | return nil 86 | } 87 | if *getAll { 88 | // By specifying status as inactive, we can get all key versions (active + inactive + primary) 89 | // from knox server 90 | if *getNetwork { 91 | key, err = cli.NetworkGetKeyWithStatus(keyID, knox.Inactive) 92 | } else { 93 | key, err = cli.GetKeyWithStatus(keyID, knox.Inactive) 94 | } 95 | } else { 96 | if *getNetwork { 97 | key, err = cli.NetworkGetKey(keyID) 98 | } else { 99 | key, err = cli.GetKey(keyID) 100 | } 101 | } 102 | if err != nil { 103 | failureGetKeyMetric(keyID, err) 104 | return &ErrorStatus{fmt.Errorf("Error getting key: %s", err.Error()), true} 105 | } 106 | if *getJSON { 107 | data, err := json.Marshal(key) 108 | if err != nil { 109 | failureGetKeyMetric(keyID, err) 110 | return &ErrorStatus{err, true} 111 | } 112 | fmt.Printf("%s", string(data)) 113 | successGetKeyMetric(keyID) 114 | return nil 115 | } 116 | if key.VersionList != nil { 117 | if *getVersion == "" { 118 | fmt.Printf("%s", string(key.VersionList.GetPrimary().Data)) 119 | successGetKeyMetric(keyID) 120 | return nil 121 | } 122 | for _, v := range key.VersionList { 123 | if strconv.FormatUint(v.ID, 10) == *getVersion { 124 | fmt.Printf("%s", string(v.Data)) 125 | successGetKeyMetric(keyID) 126 | return nil 127 | } 128 | } 129 | } 130 | failureGetKeyMetric(keyID, errors.New("key version not found")) 131 | return &ErrorStatus{fmt.Errorf("%s", "Key version not found."), false} 132 | } 133 | 134 | func retrieveTinkKeyset(keyID string, getFromNetwork bool) ([]byte, *ErrorStatus) { 135 | if !isIDforTinkKeyset(keyID) { 136 | return nil, &ErrorStatus{fmt.Errorf("this knox identifier is not for tink keyset"), false} 137 | } 138 | // get the primary and all active versions of this knox identifier. 139 | var primaryAndActiveVersions *knox.Key 140 | var err error 141 | if getFromNetwork { 142 | primaryAndActiveVersions, err = cli.NetworkGetKey(keyID) 143 | } else { 144 | primaryAndActiveVersions, err = cli.GetKey(keyID) 145 | } 146 | if err != nil { 147 | return nil, &ErrorStatus{fmt.Errorf("error getting key: %s", err.Error()), true} 148 | } 149 | keysetHandle, _, err := getTinkKeysetHandleFromKnoxVersionList(primaryAndActiveVersions.VersionList) 150 | if err != nil { 151 | return nil, &ErrorStatus{err, false} 152 | } 153 | tinkKeysetInBytes, err := convertTinkKeysetHandleToBytes(keysetHandle) 154 | if err != nil { 155 | return nil, &ErrorStatus{err, false} 156 | } 157 | return tinkKeysetInBytes, nil 158 | } 159 | 160 | func retrieveTinkKeysetInfo(keyID string, getFromNetwork bool) (string, *ErrorStatus) { 161 | if !isIDforTinkKeyset(keyID) { 162 | return "", &ErrorStatus{fmt.Errorf("this knox identifier is not for tink keyset"), false} 163 | } 164 | // get the primary and all active versions of this knox identifier. 165 | var primaryAndActiveVersions *knox.Key 166 | var err error 167 | if getFromNetwork { 168 | primaryAndActiveVersions, err = cli.NetworkGetKey(keyID) 169 | } else { 170 | primaryAndActiveVersions, err = cli.GetKey(keyID) 171 | } 172 | if err != nil { 173 | return "", &ErrorStatus{fmt.Errorf("error getting key: %s", err.Error()), true} 174 | } 175 | keysetHandle, tinkKeyIDToKnoxVersionID, err := getTinkKeysetHandleFromKnoxVersionList(primaryAndActiveVersions.VersionList) 176 | if err != nil { 177 | return "", &ErrorStatus{err, false} 178 | } 179 | tinkKeysetInfo, err := getKeysetInfoFromTinkKeysetHandle(keysetHandle, tinkKeyIDToKnoxVersionID) 180 | if err != nil { 181 | return "", &ErrorStatus{err, false} 182 | } 183 | return tinkKeysetInfo, nil 184 | } 185 | -------------------------------------------------------------------------------- /client/getacl.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | ) 7 | 8 | func init() { 9 | cmdGetACL.Run = runGetACL // break init cycle 10 | } 11 | 12 | var cmdGetACL = &Command{ 13 | UsageLine: "acl ", 14 | Short: "gets the ACL for a key", 15 | Long: ` 16 | Acl get the ACL for a key. 17 | 18 | -json: Returns the ACL as a JSON formatted list of access rules, useful for generating files to be used with knox access -acl. 19 | 20 | This doesn't require any access to the key and allows, e.g., to see who has admin access to ask for grants. 21 | 22 | For more about knox, see https://github.com/pinterest/knox. 23 | 24 | See also: knox keys, knox get 25 | `, 26 | } 27 | 28 | var getACLJSON = cmdGetACL.Flag.Bool("json", false, "") 29 | 30 | func runGetACL(cmd *Command, args []string) *ErrorStatus { 31 | if len(args) != 1 { 32 | return &ErrorStatus{fmt.Errorf("acl takes only one argument. See 'knox help acl'"), false} 33 | } 34 | 35 | keyID := args[0] 36 | acl, err := cli.GetACL(keyID) 37 | if err != nil { 38 | return &ErrorStatus{fmt.Errorf("Error getting key ACL: %s", err.Error()), true} 39 | } 40 | 41 | if *getACLJSON { 42 | aclEnc, err := json.Marshal(acl) 43 | if err != nil { 44 | // malformated ACL considered as knox server side error 45 | return &ErrorStatus{fmt.Errorf("Could not marshal ACL: %v", acl), true} 46 | } 47 | fmt.Println(string(aclEnc)) 48 | return nil 49 | } 50 | 51 | for _, a := range *acl { 52 | aEnc, err := json.Marshal(a) 53 | if err != nil { 54 | // malformated ACL entry considered as knox server side error 55 | return &ErrorStatus{fmt.Errorf("Could not marshal entry: %v", a), true} 56 | } 57 | fmt.Println(string(aEnc)) 58 | } 59 | return nil 60 | } 61 | -------------------------------------------------------------------------------- /client/getkeys.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | var cmdGetKeys = &Command{ 8 | Run: runGetKeys, 9 | UsageLine: "keys [ ...]", 10 | Short: "gets keys and associated version hash", 11 | Long: ` 12 | Get Keys takes version ids returns matching key ids if they exist. 13 | 14 | If no version ids are given, it returns all key ids. 15 | 16 | This requires valid user or machine authentication, but there are no authorization requirements. 17 | 18 | For more about knox, see https://github.com/pinterest/knox. 19 | 20 | See also: knox get, knox create, knox daemon 21 | `, 22 | } 23 | 24 | func runGetKeys(cmd *Command, args []string) *ErrorStatus { 25 | m := map[string]string{} 26 | for _, s := range args { 27 | m[s] = "NONE" 28 | } 29 | l, err := cli.GetKeys(m) 30 | if err != nil { 31 | return &ErrorStatus{fmt.Errorf("Error getting keys: %s", err.Error()), true} 32 | } 33 | for _, k := range l { 34 | fmt.Println(k) 35 | } 36 | return nil 37 | } 38 | -------------------------------------------------------------------------------- /client/getversions.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/pinterest/knox" 9 | ) 10 | 11 | func init() { 12 | cmdGetVersions.Run = runGetVersions // break init cycle 13 | } 14 | 15 | var cmdGetVersions = &Command{ 16 | UsageLine: "versions [-s state] [-v] ", 17 | Short: "gets the versions for a key", 18 | Long: ` 19 | versions get all of the version ids for a key. 20 | 21 | -s specifies the minimum state of key to return. By default this is set to active which means active and primary keys are returned. Accepted values include inactive, active, and primary. 22 | -v enables verbose output, which shows the state of each version alongside the version number. 23 | 24 | This requires read access to the key and can use user or machine authentication. 25 | 26 | For more about knox, see https://github.com/pinterest/knox. 27 | 28 | See also: knox keys, knox get 29 | `, 30 | } 31 | var getVersionsState = cmdGetVersions.Flag.String("s", "active", "") 32 | var verboseOutput = cmdGetVersions.Flag.Bool("v", false, "verbose") 33 | 34 | func runGetVersions(cmd *Command, args []string) *ErrorStatus { 35 | if len(args) != 1 { 36 | return &ErrorStatus{fmt.Errorf("get takes only one argument. See 'knox help versions'"), false} 37 | } 38 | 39 | var status knox.VersionStatus 40 | switch strings.ToLower(*getVersionsState) { 41 | case "active": 42 | status = knox.Active 43 | case "primary": 44 | status = knox.Primary 45 | case "inactive": 46 | status = knox.Inactive 47 | } 48 | 49 | keyID := args[0] 50 | key, err := cli.GetKeyWithStatus(keyID, status) 51 | if err != nil { 52 | return &ErrorStatus{fmt.Errorf("Error getting key: %s", err.Error()), true} 53 | } 54 | kvl := key.VersionList 55 | for _, v := range kvl { 56 | status, err := json.Marshal(v.Status) 57 | if err != nil { 58 | status = []byte("(unknown)") 59 | } 60 | if *verboseOutput { 61 | fmt.Printf("%d %s\n", v.ID, string(status)) 62 | } else { 63 | fmt.Printf("%d\n", v.ID) 64 | } 65 | } 66 | return nil 67 | } 68 | -------------------------------------------------------------------------------- /client/help.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | var helpAuth = &Command{ 9 | UsageLine: "auth", 10 | Short: "Explains authentication variables", 11 | Long: ` 12 | 13 | The authentication variables are how the knox client communicates who is performing an action. 14 | 15 | If the $KNOX_USER_AUTH env variable is set, the value will be used as an OAuth token for authenticating the user. 16 | 17 | If the $KNOX_MACHINE_AUTH env variable is set, the value will be used as the current client hostname. 18 | 19 | See also: knox login 20 | `, 21 | } 22 | 23 | var usageTemplate = `Knox is a tool for storing and rotating keys. 24 | 25 | Usage: 26 | 27 | knox command [arguments] 28 | 29 | The commands are: 30 | {{range .}}{{if .Runnable}} 31 | {{.Name | printf "%-14s"}} {{.Short}}{{end}}{{end}} 32 | 33 | Use "knox help [command]" for more information about a command. 34 | 35 | Additional help topics: 36 | {{range .}}{{if not .Runnable}} 37 | {{.Name | printf "%-14s"}} {{.Short}}{{end}}{{end}} 38 | 39 | Use "knox help [topic]" for more information about that topic. 40 | 41 | ` 42 | 43 | var helpTemplate = `{{if .Runnable}}usage: knox {{.UsageLine}} 44 | 45 | {{end}}{{.Long | trim}} 46 | ` 47 | 48 | // help implements the 'help' command. 49 | func help(args []string) { 50 | if len(args) == 0 { 51 | printUsage(os.Stdout) 52 | // not exit 2: succeeded at 'go help'. 53 | return 54 | } 55 | if len(args) != 1 { 56 | fmt.Fprintf(os.Stderr, "usage: knox help command\n\nToo many arguments given.\n") 57 | os.Exit(2) // failed at 'go help' 58 | } 59 | 60 | arg := args[0] 61 | 62 | for _, cmd := range commands { 63 | if cmd.Name() == arg { 64 | tmpl(os.Stdout, helpTemplate, cmd) 65 | // not exit 2: succeeded at 'knox help cmd'. 66 | return 67 | } 68 | } 69 | 70 | fmt.Fprintf(os.Stderr, "Unknown help topic %#q. Run 'knox help'.\n", arg) 71 | os.Exit(2) // failed at 'knox help cmd' 72 | } 73 | -------------------------------------------------------------------------------- /client/knox-upstart.conf: -------------------------------------------------------------------------------- 1 | # /etc/init/knox.conf 2 | 3 | description "knox daemon" 4 | author "Devin Lundberg" 5 | 6 | # Start when there is a file system and a non-loopback network interface 7 | start on (local-filesystems and net-device-up IFACE!=lo) 8 | 9 | stop on shutdown 10 | 11 | # Automatically Respawn if a panic occurs (i.e. segfault) 12 | respawn 13 | respawn limit 3 5 14 | 15 | script 16 | if [ -f /etc/default/knox ]; then 17 | . /etc/default/knox 18 | fi 19 | exec /usr/bin/knox daemon 20 | end script 21 | -------------------------------------------------------------------------------- /client/list_key_templates.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | var cmdListKeyTemplates = &Command{ 8 | Run: runListKeyTemplates, 9 | UsageLine: "key-templates", 10 | Short: "Lists the supported tink key templates", 11 | Long: ` 12 | Lists the supported tink key templates. 13 | `, 14 | } 15 | 16 | func runListKeyTemplates(cmd *Command, args []string) *ErrorStatus { 17 | fmt.Println("The following tink key templates are supported:") 18 | fmt.Println(nameOfSupportedTinkKeyTemplates()) 19 | return nil 20 | } 21 | -------------------------------------------------------------------------------- /client/login.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "net/url" 9 | "os" 10 | "os/user" 11 | "path" 12 | "path/filepath" 13 | 14 | "golang.org/x/crypto/ssh/terminal" 15 | ) 16 | 17 | const DefaultUsageLine = "login [username]" 18 | const DefaultShortDescription = "login as user and save authentication data" 19 | const DefaultLongDescriptionFormat = ` 20 | Will authenticate user via OAuth2 password grant flow if available. Requires user to enter username and password. The authentication data is saved in "%v". 21 | 22 | The optional username argument can specify the user that to log in as otherwise it uses the current os user. 23 | 24 | For more about knox, see https://github.com/pinterest/knox. 25 | 26 | See also: knox help auth 27 | ` 28 | const DefaultTokenFileLocation = ".knox_user_auth" 29 | 30 | func NewLoginCommand( 31 | oauthTokenEndpoint string, 32 | oauthClientID string, 33 | tokenFileLocation string, 34 | usageLine string, 35 | shortDescription string, 36 | longDescription string) *Command { 37 | 38 | runLoginAugmented := func(cmd *Command, args []string) *ErrorStatus { 39 | return runLogin(cmd, oauthClientID, tokenFileLocation, oauthTokenEndpoint, args) 40 | } 41 | 42 | if tokenFileLocation == "" { 43 | tokenFileLocation = DefaultTokenFileLocation 44 | } 45 | if !filepath.IsAbs(tokenFileLocation) { 46 | currentUser, err := user.Current() 47 | if err != nil { 48 | fatalf("Error getting OS user:" + err.Error()) 49 | } 50 | 51 | tokenFileLocation = path.Join(currentUser.HomeDir, tokenFileLocation) 52 | } 53 | 54 | if usageLine == "" { 55 | usageLine = DefaultUsageLine 56 | } 57 | if shortDescription == "" { 58 | shortDescription = DefaultShortDescription 59 | } 60 | if longDescription == "" { 61 | longDescription = fmt.Sprintf(DefaultLongDescriptionFormat, tokenFileLocation) 62 | } 63 | 64 | return &Command{ 65 | UsageLine: DefaultUsageLine, 66 | Short: DefaultShortDescription, 67 | Long: longDescription, 68 | Run: runLoginAugmented, 69 | } 70 | } 71 | 72 | type authTokenResp struct { 73 | AccessToken string `json:"access_token"` 74 | Error string `json:"error"` 75 | } 76 | 77 | func runLogin( 78 | cmd *Command, 79 | oauthClientID string, 80 | tokenFileLocation string, 81 | oauthTokenEndpoint string, 82 | args []string) *ErrorStatus { 83 | var username string 84 | u, err := user.Current() 85 | if err != nil { 86 | return &ErrorStatus{fmt.Errorf("Error getting OS user: %s", err.Error()), false} 87 | } 88 | switch len(args) { 89 | case 0: 90 | username = u.Username 91 | case 1: 92 | username = args[0] 93 | default: 94 | return &ErrorStatus{fmt.Errorf("Invalid arguments. See 'knox login -h'"), false} 95 | } 96 | 97 | fmt.Println("Please enter your password:") 98 | password, err := terminal.ReadPassword(int(os.Stdin.Fd())) 99 | if err != nil { 100 | return &ErrorStatus{fmt.Errorf("Problem getting password: %s", err.Error()), false} 101 | } 102 | 103 | resp, err := http.PostForm(oauthTokenEndpoint, 104 | url.Values{ 105 | "grant_type": {"password"}, 106 | "client_id": {oauthClientID}, 107 | "username": {username}, 108 | "password": {string(password)}, 109 | }) 110 | if err != nil { 111 | // this is not Knox server error, thus assigning serverError as false 112 | return &ErrorStatus{fmt.Errorf("Error connecting to auth: %s", err.Error()), false} 113 | } 114 | var authResp authTokenResp 115 | data, err := io.ReadAll(resp.Body) 116 | if err != nil { 117 | return &ErrorStatus{fmt.Errorf("Failed to read data: %s", err.Error()), false} 118 | } 119 | err = json.Unmarshal(data, &authResp) 120 | if err != nil { 121 | return &ErrorStatus{fmt.Errorf("Unexpected response from auth" + err.Error() + "data: " + string(data)), false} 122 | } 123 | if authResp.Error != "" { 124 | return &ErrorStatus{fmt.Errorf("Fail to authenticate: %q", authResp.Error), false} 125 | } 126 | 127 | err = os.WriteFile(tokenFileLocation, data, 0600) 128 | if err != nil { 129 | return &ErrorStatus{fmt.Errorf("Failed to write auth data to file: %s", err.Error()), false} 130 | } 131 | 132 | return nil 133 | } 134 | -------------------------------------------------------------------------------- /client/promote.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pinterest/knox" 7 | ) 8 | 9 | var cmdPromote = &Command{ 10 | Run: runPromote, 11 | UsageLine: "promote ", 12 | Short: "promotes a key to primary state", 13 | Long: ` 14 | Promote will take an active key version and make it the primary key version. This also makes the current primary key active. 15 | 16 | To use this command, you must have write permissions on the key. 17 | 18 | For more about knox, see https://github.com/pinterest/knox. 19 | 20 | See also: knox reactivate, knox deactivate 21 | `, 22 | } 23 | 24 | func runPromote(cmd *Command, args []string) *ErrorStatus { 25 | if len(args) != 2 { 26 | return &ErrorStatus{fmt.Errorf("promote takes exactly two argument. See 'knox help promote'"), false} 27 | } 28 | keyID := args[0] 29 | versionID := args[1] 30 | 31 | err := cli.UpdateVersion(keyID, versionID, knox.Primary) 32 | if err != nil { 33 | return &ErrorStatus{fmt.Errorf("Error promoting version: %s", err.Error()), true} 34 | } 35 | fmt.Printf("Promoted %s successfully.\n", versionID) 36 | return nil 37 | } 38 | -------------------------------------------------------------------------------- /client/reactivate.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pinterest/knox" 7 | ) 8 | 9 | var cmdReactivate = &Command{ 10 | Run: runReactivate, 11 | UsageLine: "reactivate ", 12 | Short: "Reactivates an inactive key version", 13 | Long: ` 14 | Reactivate makes an inactive key version active. 15 | 16 | Active keys are not used by default, but can still be used if the primary key fails. 17 | Inactive keys should not be used for any purpose. 18 | 19 | This command requires write access to the key. 20 | 21 | For more about knox, see https://github.com/pinterest/knox. 22 | 23 | See also: knox deactivate, knox promote 24 | `, 25 | } 26 | 27 | func runReactivate(cmd *Command, args []string) *ErrorStatus { 28 | if len(args) != 2 { 29 | return &ErrorStatus{fmt.Errorf("reactivate takes exactly two argument. See 'knox help reactivate'"), false} 30 | } 31 | keyID := args[0] 32 | versionID := args[1] 33 | 34 | err := cli.UpdateVersion(keyID, versionID, knox.Active) 35 | if err != nil { 36 | return &ErrorStatus{fmt.Errorf("Error reactivating version: %s", err.Error()), true} 37 | } 38 | fmt.Printf("Reactivated %s successfully.\n", versionID) 39 | return nil 40 | } 41 | -------------------------------------------------------------------------------- /client/register.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/pinterest/knox" 7 | "path" 8 | "strconv" 9 | "time" 10 | ) 11 | 12 | func init() { 13 | cmdRegister.Run = runRegister 14 | } 15 | 16 | var cmdRegister = &Command{ 17 | UsageLine: "register [-r] [-k identifier] [-f identifier_file] [-g]", 18 | Short: "register keys to cache locally using daemon", 19 | Long: ` 20 | Register will cache the key in the file system and keep it up to date using the file system. 21 | 22 | -r removes all existing registered keys. -k or -f will instead replace all registered keys with those specified 23 | -k specifies a specific key identifier to register 24 | -f specifies a file containing a new line separated list of key identifiers 25 | -t specifies a timeout for getting the key from the daemon (e.g. '5s', '500ms') 26 | -g gets the key as well 27 | 28 | For a machine to access a certain key, it needs permissions on that key. 29 | 30 | Note that knox register will only update the register file and will return successful 31 | even if the machine does not have access to the key. The daemon will actually retrieve 32 | the key. 33 | 34 | For more about knox, see https://github.com/pinterest/knox. 35 | 36 | See also: knox unregister, knox daemon 37 | `, 38 | } 39 | 40 | var registerRemove = cmdRegister.Flag.Bool("r", false, "") 41 | var registerKey = cmdRegister.Flag.String("k", "", "") 42 | var registerKeyFile = cmdRegister.Flag.String("f", "", "") 43 | var registerAndGet = cmdRegister.Flag.Bool("g", false, "") 44 | var registerTimeout = cmdRegister.Flag.String("t", "5s", "") 45 | 46 | const registerRecheckTime = 10 * time.Millisecond 47 | 48 | func parseTimeout(val string) (time.Duration, error) { 49 | // For backwards-compatibility, a timeout value that is a simple integer will 50 | // be treated as a number of seconds. This ensures that the historical usage 51 | // of the timeout flag like '-t5' retains the same meaning. 52 | if secs, err := strconv.Atoi(val); err == nil { 53 | return time.Duration(secs) * time.Second, nil 54 | } 55 | 56 | // For all other values, use time.ParseDuration. 57 | return time.ParseDuration(val) 58 | } 59 | 60 | func runRegister(cmd *Command, args []string) *ErrorStatus { 61 | if _, ok := cli.(*knox.UncachedHTTPClient); ok { 62 | fmt.Println("Cannot Register in No Cache mode") 63 | return nil 64 | } 65 | timeout, err := parseTimeout(*registerTimeout) 66 | if err != nil { 67 | return &ErrorStatus{fmt.Errorf("Invalid value for timeout flag: %s", err.Error()), false} 68 | } 69 | 70 | k := NewKeysFile(path.Join(daemonFolder, daemonToRegister)) 71 | if *registerRemove && *registerKey == "" && *registerKeyFile == "" { 72 | // Short circuit & handle `knox register -r`, which is expected to remove all keys 73 | err := k.Lock() 74 | if err != nil { 75 | return &ErrorStatus{fmt.Errorf("There was an error obtaining file lock: %s", err.Error()), false} 76 | } 77 | err = k.Overwrite([]string{}) 78 | if err != nil { 79 | k.Unlock() 80 | return &ErrorStatus{fmt.Errorf("Failed to unregister all keys: %s", err.Error()), false} 81 | } 82 | err = k.Unlock() 83 | if err != nil { 84 | return &ErrorStatus{fmt.Errorf("There was an error unlocking register file: %s", err.Error()), false} 85 | } 86 | logf("Successfully unregistered all keys.") 87 | return nil 88 | } else if *registerKey == "" && *registerKeyFile == "" { 89 | return &ErrorStatus{fmt.Errorf("You must include a key or key file to register. see 'knox help register'"), false} 90 | } 91 | // Get the list of keys to add 92 | var ks []string 93 | if *registerKey == "" { 94 | f := NewKeysFile(*registerKeyFile) 95 | ks, err = f.Get() 96 | if err != nil { 97 | return &ErrorStatus{fmt.Errorf("There was an error reading input key file %s", err.Error()), false} 98 | } 99 | } else { 100 | ks = []string{*registerKey} 101 | } 102 | // Handle adding new keys to the registered file 103 | err = k.Lock() 104 | if err != nil { 105 | return &ErrorStatus{fmt.Errorf("There was an error obtaining file lock: %s", err.Error()), false} 106 | } 107 | if *registerRemove { 108 | logf("Attempting to overwrite existing keys with %v.", ks) 109 | err = k.Overwrite(ks) 110 | } else { 111 | err = k.Add(ks) 112 | } 113 | if err != nil { 114 | k.Unlock() 115 | return &ErrorStatus{fmt.Errorf("There was an error registering keys %v: %s", ks, err.Error()), false} 116 | } 117 | err = k.Unlock() 118 | if err != nil { 119 | return &ErrorStatus{fmt.Errorf("There was an error unlocking register file: %s", err.Error()), false} 120 | } 121 | // If specified, force retrieval of keys 122 | if *registerAndGet { 123 | key, err := cli.CacheGetKey(*registerKey) 124 | c := time.After(timeout) 125 | for err != nil { 126 | select { 127 | case <-c: 128 | return &ErrorStatus{fmt.Errorf( 129 | "Error getting key from daemon (hit timeout after %s seconds); check knox logs for details (most recent error: %v)", 130 | timeout.String(), err), false} 131 | case <-time.After(registerRecheckTime): 132 | key, err = cli.CacheGetKey(*registerKey) 133 | } 134 | } 135 | // TODO: add json vs data option? 136 | data, err := json.Marshal(key) 137 | if err != nil { 138 | return &ErrorStatus{err, true} 139 | } 140 | fmt.Printf("%s", string(data)) 141 | return nil 142 | } 143 | logf("Successfully registered keys %v. Keys are updated by the daemon process every %.0f minutes. Check the log for the most recent run.", ks, daemonRefreshTime.Minutes()) 144 | return nil 145 | } 146 | -------------------------------------------------------------------------------- /client/register_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func TestParseTimeout(t *testing.T) { 9 | testCases := []struct { 10 | str string 11 | dur time.Duration 12 | }{ 13 | {"5", 5 * time.Second}, 14 | {"5s", 5 * time.Second}, 15 | {"0.5s", 500 * time.Millisecond}, 16 | {"500ms", 500 * time.Millisecond}, 17 | } 18 | 19 | for _, tc := range testCases { 20 | r, err := parseTimeout(tc.str) 21 | if err != nil { 22 | t.Errorf("error parsing value %s: %s", tc.str, err) 23 | continue 24 | } 25 | if r != tc.dur { 26 | t.Errorf("mismatch: %s should parse to %s", tc.str, tc.dur.String()) 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /client/restart_knox.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | service knox restart 3 | -------------------------------------------------------------------------------- /client/tink_keyset_helper.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "sort" 9 | "strings" 10 | 11 | "github.com/google/tink/go/aead" 12 | "github.com/google/tink/go/daead" 13 | "github.com/google/tink/go/hybrid" 14 | "github.com/google/tink/go/insecurecleartextkeyset" 15 | "github.com/google/tink/go/keyset" 16 | "github.com/google/tink/go/mac" 17 | "github.com/google/tink/go/signature" 18 | "github.com/google/tink/go/streamingaead" 19 | "github.com/pinterest/knox" 20 | 21 | tinkpb "github.com/google/tink/go/proto/tink_go_proto" 22 | ) 23 | 24 | // tinkKeyTemplateInfo represents the info for a supported tink keyset template. 25 | type tinkKeyTemplateInfo struct { 26 | knoxIDPrefix string 27 | templateFunc func() *tinkpb.KeyTemplate 28 | } 29 | 30 | // tinkKeyTemplates contains the supported tink key templates and the correcsponding naming rule for knox identifier 31 | var tinkKeyTemplates = map[string]tinkKeyTemplateInfo{ 32 | "TINK_AEAD_AES256_GCM": {"tink:aead:", aead.AES256GCMKeyTemplate}, 33 | "TINK_AEAD_AES128_GCM": {"tink:aead:", aead.AES128GCMKeyTemplate}, 34 | "TINK_MAC_HMAC_SHA512_256BITTAG": {"tink:mac:", mac.HMACSHA512Tag256KeyTemplate}, 35 | "TINK_DSIG_ECDSA_P256": {"tink:dsig:", signature.ECDSAP256KeyTemplate}, 36 | "TINK_DSIG_ED25519": {"tink:dsig:", signature.ED25519KeyTemplate}, 37 | "TINK_HYBRID_ECIES_P256_HKDF_HMAC_SHA256_AES128_GCM": {"tink:hybrid:", hybrid.ECIESHKDFAES128GCMKeyTemplate}, 38 | "TINK_DAEAD_AES256_SIV": {"tink:daead:", daead.AESSIVKeyTemplate}, 39 | "TINK_SAEAD_AES128_GCM_HKDF_1MB": {"tink:saead:", streamingaead.AES128GCMHKDF1MBKeyTemplate}, 40 | "TINK_SAEAD_AES128_GCM_HKDF_4KB": {"tink:saead:", streamingaead.AES128GCMHKDF4KBKeyTemplate}, 41 | } 42 | 43 | // nameOfSupportedTinkKeyTemplates returns the name of supported tink key templates in sorted order. 44 | func nameOfSupportedTinkKeyTemplates() string { 45 | supportedTemplates := make([]string, 0, len(tinkKeyTemplates)) 46 | for key := range tinkKeyTemplates { 47 | supportedTemplates = append(supportedTemplates, key) 48 | } 49 | sort.Strings(supportedTemplates) 50 | return strings.Join(supportedTemplates, "\n") 51 | } 52 | 53 | // obeyNamingRule checks whether knox identifier start with "tink::". 54 | func obeyNamingRule(templateName string, knoxIentifier string) error { 55 | templateInfo, ok := tinkKeyTemplates[templateName] 56 | if !ok { 57 | return errors.New("not supported Tink key template. See 'knox key-templates'") 58 | } else if !strings.HasPrefix(knoxIentifier, templateInfo.knoxIDPrefix) { 59 | return fmt.Errorf(" must have prefix '%s'", templateInfo.knoxIDPrefix) 60 | } 61 | return nil 62 | } 63 | 64 | // isIDforTinkKeyset checks whether knox identifier start with "tink::". 65 | func isIDforTinkKeyset(knoxIdentifier string) bool { 66 | for _, templateInfo := range tinkKeyTemplates { 67 | if strings.HasPrefix(knoxIdentifier, templateInfo.knoxIDPrefix) { 68 | return true 69 | } 70 | } 71 | return false 72 | } 73 | 74 | // createNewTinkKeyset creates a new tink keyset contains a single fresh key from the given tink key templateFunc. 75 | func createNewTinkKeyset(templateFunc func() *tinkpb.KeyTemplate) ([]byte, error) { 76 | // Creates a keyset handle that contains a single fresh key 77 | keysetHandle, err := keyset.NewHandle(templateFunc()) 78 | if keysetHandle == nil || err != nil { 79 | return nil, fmt.Errorf("cannot get tink keyset handle: %v", err) 80 | } 81 | return convertTinkKeysetHandleToBytes(keysetHandle) 82 | } 83 | 84 | // convertTinkKeysetHandleToBytes extracts keyset from tink keyset handle and converts it to bytes 85 | func convertTinkKeysetHandleToBytes(keysetHandle *keyset.Handle) ([]byte, error) { 86 | bytesBuffer := new(bytes.Buffer) 87 | writer := keyset.NewBinaryWriter(bytesBuffer) 88 | // To write cleartext keyset handle, must use package "insecurecleartextkeyset" 89 | err := insecurecleartextkeyset.Write(keysetHandle, writer) 90 | if err != nil { 91 | return nil, fmt.Errorf("cannot write tink keyset: %v", err) 92 | } 93 | return bytesBuffer.Bytes(), nil 94 | } 95 | 96 | // addNewTinkKeyset receives a knox version list and a tink key templateFunc, create a new tink keyset contains 97 | // a single fresh key from the given tink key templateFunc. Most importantly, the ID of this single fresh key is 98 | // different from the ID of all existing tink keys in the given knox version list (avoid Tink key ID duplications). 99 | func addNewTinkKeyset(templateFunc func() *tinkpb.KeyTemplate, knoxVersionList knox.KeyVersionList) ([]byte, error) { 100 | existingTinkKeysID := make(map[uint32]struct{}) 101 | for _, v := range knoxVersionList { 102 | tinkKeysetForAVersion, err := readTinkKeysetFromBytes(v.Data) 103 | if err != nil { 104 | return nil, err 105 | } 106 | existingTinkKeysID[tinkKeysetForAVersion.PrimaryKeyId] = struct{}{} 107 | } 108 | var keysetHandle *keyset.Handle 109 | var err error 110 | // This loop is for retrying until a non-duplicate key id is generated. 111 | isDuplicated := true 112 | for isDuplicated { 113 | keysetHandle, err = keyset.NewHandle(templateFunc()) 114 | if keysetHandle == nil || err != nil { 115 | return nil, fmt.Errorf("cannot get tink keyset handle: %v", err) 116 | } 117 | newTinkKeyID := keysetHandle.KeysetInfo().PrimaryKeyId 118 | _, isDuplicated = existingTinkKeysID[newTinkKeyID] 119 | } 120 | return convertTinkKeysetHandleToBytes(keysetHandle) 121 | } 122 | 123 | // readTinkKeysetFromBytes extracts tink keyset from bytes. 124 | func readTinkKeysetFromBytes(data []byte) (*tinkpb.Keyset, error) { 125 | bytesBuffer := new(bytes.Buffer) 126 | bytesBuffer.Write(data) 127 | tinkKeyset, err := keyset.NewBinaryReader(bytesBuffer).Read() 128 | if err != nil { 129 | return nil, fmt.Errorf("unexpected error reading tink keyset: %v", err) 130 | } 131 | return tinkKeyset, nil 132 | } 133 | 134 | // getTinkKeysetHandleFromKnoxVersionList returns a tink keyset handle that has all tink keys in the 135 | // received knox version list and a map from tink key IDs to knox version IDs. To be noticed, each 136 | // knox version contains a tink keyset that has a single tink key (tink key has a property, tink key id). 137 | // This func enumerates the given knox version list, put tink keys from different knox versions into 138 | // one tink keyset "fullTinkKeyset". Also, this func records which tink key is from which knox version 139 | // in a map "tinkKeyIDToKnoxVersionID". 140 | func getTinkKeysetHandleFromKnoxVersionList( 141 | knoxVersionList knox.KeyVersionList, 142 | ) (*keyset.Handle, map[uint32]uint64, error) { 143 | fullTinkKeyset := new(tinkpb.Keyset) 144 | tinkKeyIDToKnoxVersionID := make(map[uint32]uint64) 145 | for _, v := range knoxVersionList { 146 | // the data of each version is a tink keyset that contains a single tink key 147 | keyComponent, err := readTinkKeysetFromBytes(v.Data) 148 | if err != nil { 149 | return nil, nil, err 150 | } 151 | singleKey := keyComponent.Key[0] 152 | if v.Status == knox.Primary { 153 | fullTinkKeyset.PrimaryKeyId = singleKey.KeyId 154 | } 155 | fullTinkKeyset.Key = append(fullTinkKeyset.Key, singleKey) 156 | tinkKeyIDToKnoxVersionID[singleKey.KeyId] = v.ID 157 | } 158 | keysetHandle, err := convertCleartextTinkKeysetToHandle(fullTinkKeyset) 159 | if err != nil { 160 | return nil, nil, err 161 | } 162 | return keysetHandle, tinkKeyIDToKnoxVersionID, nil 163 | } 164 | 165 | // convertCleartextTinkKeysetToHandle converts cleartext tink keyset to tink keyset handle 166 | func convertCleartextTinkKeysetToHandle(cleartextTinkKeyset *tinkpb.Keyset) (*keyset.Handle, error) { 167 | bytesBuffer := new(bytes.Buffer) 168 | writer := keyset.NewBinaryWriter(bytesBuffer) 169 | writer.Write(cleartextTinkKeyset) 170 | reader := keyset.NewBinaryReader(bytesBuffer) 171 | // To get keyset handle from cleartext keyset, must use package "insecurecleartextkeyset" 172 | keysetHandle, err := insecurecleartextkeyset.Read(reader) 173 | if err != nil { 174 | return nil, fmt.Errorf("cannot get tink keyset handle: %v", err) 175 | } 176 | return keysetHandle, nil 177 | } 178 | 179 | // getKeysetInfoFromTinkKeysetHandle returns a string representation of the info of the given tink keyset 180 | // handle. The returned string which does not contain any sensitive key material. 181 | func getKeysetInfoFromTinkKeysetHandle( 182 | keysetHandle *keyset.Handle, 183 | tinkKeyIDToKnoxVersionID map[uint32]uint64, 184 | ) (string, error) { 185 | // translate the info from the tink build-in function to json format 186 | keysetInfo := newTinkKeysetInfo(keysetHandle.KeysetInfo(), tinkKeyIDToKnoxVersionID) 187 | keysetInfoForPrint, err := json.MarshalIndent(keysetInfo, "", " ") 188 | if err != nil { 189 | return "", err 190 | } 191 | return string(keysetInfoForPrint), nil 192 | } 193 | 194 | // tinkKeysetInfo translates tink keyset info to JSON format, doesn't contain any actual key material. 195 | type tinkKeysetInfo struct { 196 | PrimaryKeyId uint32 `json:"primary_key_id"` 197 | KeyInfo []*tinkKeyInfo `json:"key_info"` 198 | } 199 | 200 | // tinkKeyInfo translates tink key info to JSON format, doesn't contain any actual key material. 201 | type tinkKeyInfo struct { 202 | TypeUrl string `json:"type_url"` 203 | Status string `json:"status"` 204 | KeyId uint32 `json:"key_id"` 205 | OutputPrefixType string `json:"output_prefix_type"` 206 | KnoxVersionID uint64 `json:"knox_version_id"` 207 | } 208 | 209 | // newTinkKeysetInfo translates Tink keyset info to JSON format. 210 | func newTinkKeysetInfo( 211 | keysetInfo *tinkpb.KeysetInfo, 212 | tinkKeyIDToKnoxVersionID map[uint32]uint64, 213 | ) tinkKeysetInfo { 214 | return tinkKeysetInfo{ 215 | keysetInfo.PrimaryKeyId, 216 | newTinkKeysInfo(keysetInfo.KeyInfo, tinkKeyIDToKnoxVersionID), 217 | } 218 | } 219 | 220 | // newTinkKeyInfo translates Tink key info to JSON format. 221 | func newTinkKeysInfo( 222 | keysetInfo_KeyInfo []*tinkpb.KeysetInfo_KeyInfo, 223 | tinkKeyIDToKnoxVersionID map[uint32]uint64, 224 | ) []*tinkKeyInfo { 225 | var tinkKeysInfo []*tinkKeyInfo 226 | for _, v := range keysetInfo_KeyInfo { 227 | tinkKeysInfo = append(tinkKeysInfo, &tinkKeyInfo{ 228 | v.TypeUrl, 229 | v.Status.String(), 230 | v.KeyId, 231 | v.OutputPrefixType.String(), 232 | tinkKeyIDToKnoxVersionID[v.KeyId], 233 | }) 234 | } 235 | return tinkKeysInfo 236 | } 237 | -------------------------------------------------------------------------------- /client/tink_keyset_helper_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/golang/protobuf/proto" 10 | "github.com/google/tink/go/aead" 11 | "github.com/google/tink/go/insecurecleartextkeyset" 12 | "github.com/google/tink/go/keyset" 13 | "github.com/google/tink/go/mac" 14 | "github.com/google/tink/go/testkeyset" 15 | "github.com/pinterest/knox" 16 | 17 | tinkpb "github.com/google/tink/go/proto/tink_go_proto" 18 | ) 19 | 20 | func TestNameOfSupportedTinkKeyTemplates(t *testing.T) { 21 | names := []string{ 22 | "TINK_AEAD_AES128_GCM", 23 | "TINK_AEAD_AES256_GCM", 24 | "TINK_DAEAD_AES256_SIV", 25 | "TINK_DSIG_ECDSA_P256", 26 | "TINK_DSIG_ED25519", 27 | "TINK_HYBRID_ECIES_P256_HKDF_HMAC_SHA256_AES128_GCM", 28 | "TINK_MAC_HMAC_SHA512_256BITTAG", 29 | "TINK_SAEAD_AES128_GCM_HKDF_1MB", 30 | "TINK_SAEAD_AES128_GCM_HKDF_4KB", 31 | } 32 | expected := strings.Join(names, "\n") 33 | if expected != nameOfSupportedTinkKeyTemplates() { 34 | t.Fatalf("cannot list name of supported tink key templates correctly") 35 | } 36 | } 37 | 38 | func TestObeyNamingRule(t *testing.T) { 39 | if err := obeyNamingRule("invalid", "invalid"); err == nil { 40 | t.Fatalf("cannot identify invalid tink key template") 41 | } 42 | for k := range tinkKeyTemplates { 43 | illegalKnoxIdentifier := "wrongKnoxIdentifier" 44 | err := obeyNamingRule(k, illegalKnoxIdentifier) 45 | if err == nil { 46 | t.Fatalf("cannot identify illegal knox identifer for template '%s'", k) 47 | } 48 | } 49 | for k, v := range tinkKeyTemplates { 50 | legalKnoxIdentifier := v.knoxIDPrefix + "test" 51 | err := obeyNamingRule(k, legalKnoxIdentifier) 52 | if err != nil { 53 | t.Fatalf("cannot accept legal knox identifer for template '%s'", k) 54 | } 55 | } 56 | } 57 | 58 | func TestIsIDforTinkKeyset(t *testing.T) { 59 | if isIDforTinkKeyset("invalid") { 60 | t.Fatalf("cannot identify knox identifier that is not for tink keyset") 61 | } 62 | for _, templateInfo := range tinkKeyTemplates { 63 | knoxIdentifierForTinkKeyset := templateInfo.knoxIDPrefix + "test" 64 | if !isIDforTinkKeyset(knoxIdentifierForTinkKeyset) { 65 | t.Fatalf("cannot identify knox identifier that is for tink keyset") 66 | } 67 | } 68 | } 69 | 70 | func TestCreateNewTinkKeyset(t *testing.T) { 71 | keyTemplate := mac.HMACSHA512Tag256KeyTemplate 72 | keysetInBytes, err := createNewTinkKeyset(keyTemplate) 73 | if err != nil { 74 | t.Fatalf("cannot create a new tink keyset: %v", err) 75 | } 76 | bytesBuffer := new(bytes.Buffer) 77 | bytesBuffer.Write(keysetInBytes) 78 | tinkKeyset, err := keyset.NewBinaryReader(bytesBuffer).Read() 79 | if err != nil { 80 | t.Fatalf("unexpected error reading tink keyset data: %v", err) 81 | } 82 | if len(tinkKeyset.Key) != 1 { 83 | t.Fatalf("incorrect number of keys in the keyset: %d", len(tinkKeyset.Key)) 84 | } 85 | tinkKey := tinkKeyset.Key[0] 86 | if tinkKeyset.PrimaryKeyId != tinkKey.KeyId { 87 | t.Fatalf("incorrect primary key id, expect %d, got %d", tinkKey.KeyId, tinkKeyset.PrimaryKeyId) 88 | } 89 | if tinkKey.KeyData.TypeUrl != keyTemplate().TypeUrl { 90 | t.Fatalf("incorrect type url, expect %s, got %s", keyTemplate().TypeUrl, tinkKey.KeyData.TypeUrl) 91 | } 92 | keysetHandle, err := testkeyset.NewHandle(tinkKeyset) 93 | if err != nil { 94 | t.Fatalf("unexpected error creating new KeysetHandle: %v", err) 95 | } 96 | if _, err = mac.New(keysetHandle); err != nil { 97 | t.Fatalf("cannot get primitive from generated keyset: %s", err) 98 | } 99 | } 100 | 101 | func TestConvertTinkKeysetHandleToBytes(t *testing.T) { 102 | keyTemplate := mac.HMACSHA256Tag128KeyTemplate() 103 | keysetHandle, err := keyset.NewHandle(keyTemplate) 104 | if err != nil { 105 | t.Fatalf("unexpected error: %s", err) 106 | } 107 | keysetInBytes, err := convertTinkKeysetHandleToBytes(keysetHandle) 108 | if err != nil { 109 | t.Fatalf("cannot create convert tink keyset handle to bytes: %v", err) 110 | } 111 | bytesBuffer := new(bytes.Buffer) 112 | bytesBuffer.Write(keysetInBytes) 113 | tinkKeyset, err := keyset.NewBinaryReader(bytesBuffer).Read() 114 | if err != nil { 115 | t.Fatalf("unexpected error reading tink keyset data: %v", err) 116 | } 117 | if err := keyset.Validate(tinkKeyset); err != nil { 118 | t.Fatalf("when convert tink keyset handle to bytes, the keyset becomes invalid") 119 | } 120 | } 121 | 122 | // getDummyKnoxVersionList is a helper for test. It returns a dummy knox version list for testing and a map from 123 | // Tink key ID to knox version ID. The data of each version is a Tink keyset in bytes that contains a single Tink 124 | // key. Argument counts decides how many versions are in this dummy veriosn list. Argument templateFunc decides 125 | // the type of created Tink keyset in each version. 126 | func getDummyKnoxVersionList( 127 | counts int, 128 | templateFunc func() *tinkpb.KeyTemplate, 129 | ) (knox.KeyVersionList, map[uint32]uint64) { 130 | var dummyVersionList knox.KeyVersionList 131 | tinkKeyIDToKnoxVersionID := make(map[uint32]uint64) 132 | // counts decide how many versions this dummy version list will have 133 | for i := 0; i < counts; i++ { 134 | // get a tink keyset handle that contains a fresh single key and the keyID is not duplicated 135 | var keysetHandle *keyset.Handle 136 | var err error 137 | isDuplicated := true 138 | for isDuplicated { 139 | keysetHandle, err = keyset.NewHandle(templateFunc()) 140 | if keysetHandle == nil || err != nil { 141 | fatalf("cannot get tink keyset handle: %v", err) 142 | } 143 | _, isDuplicated = tinkKeyIDToKnoxVersionID[keysetHandle.KeysetInfo().PrimaryKeyId] 144 | } 145 | // Convert keyset handle to bytes, since the data in each version is bytes 146 | var keysetInBytes []byte 147 | keysetInBytes, err = convertTinkKeysetHandleToBytes(keysetHandle) 148 | if err != nil { 149 | fatalf(err.Error()) 150 | } 151 | // Add a new version to the dummy version list. Only one Primary version, all others are Active version. 152 | var status knox.VersionStatus 153 | if i == 0 { 154 | status = knox.Primary 155 | } else { 156 | status = knox.Active 157 | } 158 | // To be noticed, index i is used as dummy knox version ID and dummy creation time. 159 | dummyVersionList = append(dummyVersionList, knox.KeyVersion{ 160 | ID: uint64(i), 161 | Data: keysetInBytes, 162 | Status: status, 163 | CreationTime: int64(i), 164 | }) 165 | tinkKeyIDToKnoxVersionID[keysetHandle.KeysetInfo().PrimaryKeyId] = uint64(i) 166 | } 167 | return dummyVersionList, tinkKeyIDToKnoxVersionID 168 | } 169 | 170 | func TestAddNewTinkKeyset(t *testing.T) { 171 | keyTemplate := aead.AES256GCMKeyTemplate 172 | // create a dummy version list has one million Tink keys, this large number of Tink keys is used to 173 | // check whether func addNewTinkKeyset will add duplicated Key 174 | dummyVersionList, tinkKeyIDToKnoxVersionID := getDummyKnoxVersionList(1000000, keyTemplate) 175 | newKeysetInBytes, err := addNewTinkKeyset(keyTemplate, dummyVersionList) 176 | if err != nil { 177 | t.Fatalf("cannot add new Tink keyset: %v", err) 178 | } 179 | // convert bytes to a Tink keyset, and check whether it is a valid keyset 180 | tinkKeyset, err := readTinkKeysetFromBytes(newKeysetInBytes) 181 | if err != nil { 182 | t.Fatalf("unexpected error reading tink keyset data: %v", err) 183 | } 184 | if len(tinkKeyset.Key) != 1 { 185 | t.Fatalf("incorrect number of keys in the keyset: %d", len(tinkKeyset.Key)) 186 | } 187 | tinkKey := tinkKeyset.Key[0] 188 | _, isDuplicated := tinkKeyIDToKnoxVersionID[tinkKey.KeyId] 189 | if isDuplicated { 190 | t.Fatalf("the ID of new Tink key is duplicated") 191 | } 192 | if tinkKeyset.PrimaryKeyId != tinkKey.KeyId { 193 | t.Fatalf("incorrect primary key id, expect %d, got %d", tinkKey.KeyId, tinkKeyset.PrimaryKeyId) 194 | } 195 | if tinkKey.KeyData.TypeUrl != keyTemplate().TypeUrl { 196 | t.Fatalf("incorrect type url, expect %s, got %s", keyTemplate().TypeUrl, tinkKey.KeyData.TypeUrl) 197 | } 198 | keysetHandle, err := testkeyset.NewHandle(tinkKeyset) 199 | if err != nil { 200 | t.Fatalf("unexpected error creating new KeysetHandle: %v", err) 201 | } 202 | if _, err = aead.New(keysetHandle); err != nil { 203 | t.Fatalf("cannot get primitive from generated keyset: %s", err) 204 | } 205 | } 206 | 207 | func TestReadTinkKeysetFromBytes(t *testing.T) { 208 | keyTemplate := mac.HMACSHA256Tag128KeyTemplate() 209 | keysetHandle, err := keyset.NewHandle(keyTemplate) 210 | if err != nil { 211 | t.Fatalf("unexpected error: %s", err) 212 | } 213 | bytesBuffer := new(bytes.Buffer) 214 | writer := keyset.NewBinaryWriter(bytesBuffer) 215 | err = insecurecleartextkeyset.Write(keysetHandle, writer) 216 | if err != nil { 217 | t.Fatalf("unexpected error writing tink keyset handle") 218 | } 219 | tinkKeyset, err := readTinkKeysetFromBytes(bytesBuffer.Bytes()) 220 | if err != nil { 221 | t.Fatalf("cannot read tink keyset from bytes") 222 | } 223 | err = keyset.Validate(tinkKeyset) 224 | if err != nil { 225 | t.Fatalf("the result of readTinkKeysetFromBytes is not a valid Tink keyset") 226 | } 227 | } 228 | 229 | func TestGetTinkKeysetHandleFromKnoxVersionList(t *testing.T) { 230 | keyTemplate := aead.AES128GCMKeyTemplate 231 | dummyVersionList, tinkKeyIDtoKnoxVersionID := getDummyKnoxVersionList(1000, keyTemplate) 232 | keysetHandle, mapping, err := getTinkKeysetHandleFromKnoxVersionList(dummyVersionList) 233 | if err != nil { 234 | t.Fatalf(err.Error()) 235 | } 236 | if _, err := aead.New(keysetHandle); err != nil { 237 | t.Fatalf("cannot get primitive from generated keyset handle: %s", err) 238 | } 239 | for k, v := range tinkKeyIDtoKnoxVersionID { 240 | if v != mapping[k] { 241 | t.Fatalf("cannot map tink key id to knox version id correctly") 242 | } 243 | } 244 | } 245 | 246 | func TestConvertCleartextTinkKeysetToHandle(t *testing.T) { 247 | // Create a keyset that contains a single HmacKey. 248 | keyTemplate := mac.HMACSHA256Tag128KeyTemplate() 249 | keysetHandle, err := keyset.NewHandle(keyTemplate) 250 | if keysetHandle == nil || err != nil { 251 | t.Fatalf("cannot get keyset handle: %v", err) 252 | } 253 | tinkKeyset := insecurecleartextkeyset.KeysetMaterial(keysetHandle) 254 | parsedHandle, err := convertCleartextTinkKeysetToHandle(tinkKeyset) 255 | if err != nil { 256 | t.Fatalf("unexpected error reading keyset: %v", err) 257 | } 258 | parsedKeyset := insecurecleartextkeyset.KeysetMaterial(parsedHandle) 259 | if !proto.Equal(tinkKeyset, parsedKeyset) { 260 | t.Fatalf("parsed keyset (%s) doesn't match original keyset (%s)", parsedKeyset, tinkKeyset) 261 | } 262 | } 263 | 264 | func TestGetKeysetInfoFromTinkKeysetHandle(t *testing.T) { 265 | keyTemplate := aead.AES128GCMKeyTemplate 266 | keysetHandle, err := keyset.NewHandle(keyTemplate()) 267 | if err != nil { 268 | t.Fatalf("unexpected error: %s", err) 269 | } 270 | // 100 is the dummy knox version id 271 | tinkKeyIDToKnoxVersionID := map[uint32]uint64{keysetHandle.KeysetInfo().PrimaryKeyId: 100} 272 | var keysInfo []*tinkKeyInfo 273 | rawKeysetInfo := keysetHandle.KeysetInfo() 274 | keysInfo = append(keysInfo, &tinkKeyInfo{ 275 | rawKeysetInfo.KeyInfo[0].TypeUrl, 276 | rawKeysetInfo.KeyInfo[0].Status.String(), 277 | rawKeysetInfo.KeyInfo[0].KeyId, 278 | rawKeysetInfo.KeyInfo[0].OutputPrefixType.String(), 279 | 100, 280 | }) 281 | keysetInfo := tinkKeysetInfo{ 282 | rawKeysetInfo.PrimaryKeyId, 283 | keysInfo, 284 | } 285 | keysetInfoForPrint, err := json.MarshalIndent(keysetInfo, "", " ") 286 | if err != nil { 287 | t.Fatalf(err.Error()) 288 | } 289 | expected := string(keysetInfoForPrint) 290 | got, err := getKeysetInfoFromTinkKeysetHandle(keysetHandle, tinkKeyIDToKnoxVersionID) 291 | if err != nil || expected != got { 292 | t.Fatalf("cannot get keyset info in json format") 293 | } 294 | } 295 | 296 | func TestNewTinkKeysetInfo(t *testing.T) { 297 | keyTemplate := aead.AES128GCMKeyTemplate 298 | keysetHandle, err := keyset.NewHandle(keyTemplate()) 299 | if err != nil { 300 | t.Fatalf("unexpected error: %s", err) 301 | } 302 | // 123456 is the dummy knox version id 303 | tinkKeyIDToKnoxVersionID := map[uint32]uint64{keysetHandle.KeysetInfo().PrimaryKeyId: 123456} 304 | var keysInfo []*tinkKeyInfo 305 | rawKeysetInfo := keysetHandle.KeysetInfo() 306 | keysInfo = append(keysInfo, &tinkKeyInfo{ 307 | rawKeysetInfo.KeyInfo[0].TypeUrl, 308 | rawKeysetInfo.KeyInfo[0].Status.String(), 309 | rawKeysetInfo.KeyInfo[0].KeyId, 310 | rawKeysetInfo.KeyInfo[0].OutputPrefixType.String(), 311 | 123456, 312 | }) 313 | expected, _ := json.Marshal(tinkKeysetInfo{ 314 | rawKeysetInfo.PrimaryKeyId, 315 | keysInfo, 316 | }) 317 | got, _ := json.Marshal(newTinkKeysetInfo(keysetHandle.KeysetInfo(), tinkKeyIDToKnoxVersionID)) 318 | if string(got) != string(expected) { 319 | t.Fatalf("cannot create JSONTinkKeysetInfo correctly") 320 | } 321 | } 322 | 323 | func TestNewTinkKeysInfo(t *testing.T) { 324 | keyTemplate := aead.AES256GCMKeyTemplate 325 | keysetHandle, err := keyset.NewHandle(keyTemplate()) 326 | if err != nil { 327 | t.Fatalf("unexpected error: %s", err) 328 | } 329 | // 1234567890 is the dummy knox version id 330 | tinkKeyIDToKnoxVersionID := map[uint32]uint64{keysetHandle.KeysetInfo().PrimaryKeyId: 1234567890} 331 | var keysInfo []*tinkKeyInfo 332 | rawKeysetInfo := keysetHandle.KeysetInfo() 333 | keysInfo = append(keysInfo, &tinkKeyInfo{ 334 | rawKeysetInfo.KeyInfo[0].TypeUrl, 335 | rawKeysetInfo.KeyInfo[0].Status.String(), 336 | rawKeysetInfo.KeyInfo[0].KeyId, 337 | rawKeysetInfo.KeyInfo[0].OutputPrefixType.String(), 338 | 1234567890, 339 | }) 340 | expected, _ := json.Marshal(keysInfo) 341 | got, _ := json.Marshal(newTinkKeysInfo(keysetHandle.KeysetInfo().KeyInfo, tinkKeyIDToKnoxVersionID)) 342 | if string(got) != string(expected) { 343 | t.Fatalf("cannot create JSONTinkKeysetInfo_KeyInfo correctly") 344 | } 345 | } 346 | -------------------------------------------------------------------------------- /client/unregister.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | var cmdUnregister = &Command{ 8 | Run: runUnregister, 9 | UsageLine: "unregister ", 10 | Short: "unregister a key identifier from daemon", 11 | Long: ` 12 | Unregister stops cacheing and refreshing a specific key, deleting the associated files. 13 | 14 | For more about knox, see https://github.com/pinterest/knox. 15 | 16 | See also: knox register, knox daemon 17 | `, 18 | } 19 | 20 | func runUnregister(cmd *Command, args []string) *ErrorStatus { 21 | if len(args) != 1 { 22 | return &ErrorStatus{fmt.Errorf("You must include a key ID to deregister. See 'knox help unregister'"), false} 23 | } 24 | k := NewKeysFile(daemonFolder + daemonToRegister) 25 | err := k.Lock() 26 | if err != nil { 27 | return &ErrorStatus{fmt.Errorf("Error locking the register file: %s", err.Error()), false} 28 | } 29 | defer k.Unlock() 30 | 31 | err = k.Remove([]string{args[0]}) 32 | if err != nil { 33 | return &ErrorStatus{fmt.Errorf("Error removing the key: %s", err.Error()), false} 34 | } 35 | fmt.Println("Unregistered key successfully") 36 | return nil 37 | } 38 | -------------------------------------------------------------------------------- /client/updateaccess.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/pinterest/knox" 9 | ) 10 | 11 | func init() { 12 | cmdUpdateAccess.Run = runUpdateAccess 13 | } 14 | 15 | var cmdUpdateAccess = &Command{ 16 | UsageLine: "access (-acl | {-n|-r|-w|-a} {-M|-U|-G|-P|-S|-N} )", 17 | Short: "access modifies the acl of a key", 18 | Long: ` 19 | Access will add or change the acl on a key by adding a specific access control rule. 20 | 21 | -acl: Takes in a filename with a JSON formatted list of access rules 22 | 23 | -n: This will update the key so that the given principal has no access. Please note that if there is another rule that gives access that will take precedence. 24 | -r: This will grant the principal read access to the key. They will be able to read the keys data. 25 | -w: This will grant the principal write access to the key. They will be able to rotate keys in addition to all read permissions. 26 | -a: This will grant the principal admin access to the key. They will be able to update ACLs and delete keys in addition to all read and write permissions. 27 | 28 | -M: A specific machine. The principal should be set to the exact hostname. 29 | -U: A specific user. The principal should be set to the ldap username of the user. 30 | -G: A specific user group. The principal should be set to the group name. This takes the format of ou=Security,ou=Prod,ou=groups,dc=pinterest,dc=com in LDAP. 31 | -P: A machine hostname prefix. Prefix matching will be used to determine access. For example, if the principal is set to 'auth' then 'auth004' would match (and so would any hostname beginning with auth). 32 | -S: A specific service. The principal should be set to the exact SPIFFE ID. For example, 'spiffe://example.com/service'. 33 | -N: A service prefix (namespace). The principal should be set to a SPIFFE ID ending with a slash, such as 'spiffe://example.com/namespace/'. This will match all services under that prefix, so for example 'spiffe://example.com/namespace/service' would be allowed. 34 | 35 | This command requires admin access to the key. 36 | 37 | For more about knox, see https://github.com/pinterest/knox. 38 | 39 | See also: knox create, knox get 40 | `, 41 | } 42 | 43 | var updateAccessACL = cmdUpdateAccess.Flag.String("acl", "", "") 44 | 45 | var updateAccessNone = cmdUpdateAccess.Flag.Bool("n", false, "") 46 | var updateAccessRead = cmdUpdateAccess.Flag.Bool("r", false, "") 47 | var updateAccessWrite = cmdUpdateAccess.Flag.Bool("w", false, "") 48 | var updateAccessAdmin = cmdUpdateAccess.Flag.Bool("a", false, "") 49 | 50 | var updateAccessMachine = cmdUpdateAccess.Flag.Bool("M", false, "") 51 | var updateAccessUser = cmdUpdateAccess.Flag.Bool("U", false, "") 52 | var updateAccessGroup = cmdUpdateAccess.Flag.Bool("G", false, "") 53 | var updateAccessPrefix = cmdUpdateAccess.Flag.Bool("P", false, "") 54 | var updateAccessService = cmdUpdateAccess.Flag.Bool("S", false, "") 55 | var updateAccessServicePrefix = cmdUpdateAccess.Flag.Bool("N", false, "") 56 | 57 | func runUpdateAccess(cmd *Command, args []string) *ErrorStatus { 58 | if *updateAccessACL != "" { 59 | if len(args) != 1 { 60 | return &ErrorStatus{fmt.Errorf("access takes one argument when used with --acl. See 'knox help access'"), false} 61 | } 62 | keyID := args[0] 63 | b, err := os.ReadFile(*updateAccessACL) 64 | if err != nil { 65 | return &ErrorStatus{fmt.Errorf("Could not read acl file: %s", err.Error()), false} 66 | } 67 | acl := []knox.Access{} 68 | err = json.Unmarshal(b, &acl) 69 | if err != nil { 70 | return &ErrorStatus{fmt.Errorf("Could not decode access list properly: %s", err.Error()), false} 71 | } 72 | err = cli.PutAccess(keyID, acl...) 73 | if err != nil { 74 | return &ErrorStatus{fmt.Errorf("Failed to update access: %s", err.Error()), true} 75 | } 76 | fmt.Println("Successfully updated Access") 77 | return nil 78 | } 79 | if len(args) != 2 { 80 | return &ErrorStatus{fmt.Errorf("access takes exactly two arguments. See 'knox help access'"), false} 81 | } 82 | keyID := args[0] 83 | principal := args[1] 84 | var access knox.Access 85 | access.ID = principal 86 | switch { 87 | case *updateAccessNone: 88 | access.AccessType = knox.None 89 | case *updateAccessRead: 90 | access.AccessType = knox.Read 91 | case *updateAccessWrite: 92 | access.AccessType = knox.Write 93 | case *updateAccessAdmin: 94 | access.AccessType = knox.Admin 95 | default: 96 | return &ErrorStatus{fmt.Errorf("access requires {-n,-r,-w,-a}. See 'knox help access'"), false} 97 | } 98 | switch { 99 | case *updateAccessMachine: 100 | access.Type = knox.Machine 101 | case *updateAccessUser: 102 | access.Type = knox.User 103 | case *updateAccessGroup: 104 | access.Type = knox.UserGroup 105 | case *updateAccessPrefix: 106 | access.Type = knox.MachinePrefix 107 | case *updateAccessService: 108 | access.Type = knox.Service 109 | case *updateAccessServicePrefix: 110 | access.Type = knox.ServicePrefix 111 | default: 112 | return &ErrorStatus{fmt.Errorf("access requires {-M|-U|-G|-P|-S|-N}. See 'knox help access'"), false} 113 | } 114 | err := cli.PutAccess(keyID, access) 115 | if err != nil { 116 | return &ErrorStatus{fmt.Errorf("Failed to update access: %s", err.Error()), true} 117 | } 118 | fmt.Println("Successfully updated Access") 119 | return nil 120 | } 121 | -------------------------------------------------------------------------------- /client/version.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import "fmt" 4 | 5 | // Version represents the compiled version of the client binary. It can be overridden at compile time with: 6 | // `go build -ldflags "-X github.com/pinterest/knox/client.Version=1.2.3" github.com/pinterest/knox/cmd/dev_client` 7 | // In the above example, knox version would give you `1.2.3`. By default, the version is `devel`. 8 | var Version string = "devel" 9 | 10 | var cmdVersion = &Command{ 11 | Run: runVersion, 12 | UsageLine: "version", 13 | Short: "Prints the current version of the Knox client", 14 | Long: ` 15 | Prints the current version of the Knox client. 16 | `, 17 | } 18 | 19 | // GetVersion exposes the current client version 20 | func GetVersion() string { 21 | return Version 22 | } 23 | 24 | func runVersion(cmd *Command, args []string) *ErrorStatus { 25 | fmt.Printf("Knox CLI version %s\n", Version) 26 | return nil 27 | } 28 | -------------------------------------------------------------------------------- /cmd/dev_client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "encoding/json" 7 | "log" 8 | "math/rand" 9 | "net/http" 10 | "os" 11 | "os/user" 12 | "time" 13 | 14 | "github.com/pinterest/knox" 15 | "github.com/pinterest/knox/client" 16 | ) 17 | 18 | // certPEMBlock is the certificate signed by the CA to identify the machine using the client 19 | // (Should be pulled from a file or via another process) 20 | const certPEMBlock = `-----BEGIN CERTIFICATE----- 21 | MIIB7TCCAZOgAwIBAgIDEAAEMAoGCCqGSM49BAMCMFExCzAJBgNVBAYTAlVTMQsw 22 | CQYDVQQIEwJDQTEYMBYGA1UEChMPTXkgQ29tcGFueSBOYW1lMRswGQYDVQQDExJ1 23 | c2VPbmx5SW5EZXZPclRlc3QwHhcNMTgwMzAyMDI1NjEyWhcNMTkwMzAyMDI1NjEy 24 | WjBKMQswCQYDVQQGEwJVUzELMAkGA1UECAwCQ0ExGDAWBgNVBAoMD015IENvbXBh 25 | bnkgTmFtZTEUMBIGA1UEAwwLZXhhbXBsZS5jb20wWTATBgcqhkjOPQIBBggqhkjO 26 | PQMBBwNCAAQQTbdQNoE5/j6mgh4HAdbgPyGbuzjpHI/x34p6qPojduUK+ifUW6Mb 27 | bS5Zumjh31K5AmWYt4jWfU82Sb6sxPKXo2EwXzAJBgNVHRMEAjAAMAsGA1UdDwQE 28 | AwIF4DBFBgNVHREEPjA8hhxzcGlmZmU6Ly9leGFtcGxlLmNvbS9zZXJ2aWNlggtl 29 | eGFtcGxlLmNvbYIPd3d3LmV4YW1wbGUuY29tMAoGCCqGSM49BAMCA0gAMEUCIQDO 30 | TaI0ltMPlPDt4XSdWJawZ4euAGXJCyoxHFs8HQK8XwIgVokWyTcajFoP0/ZfzrM5 31 | SihfFJr39Ck4V5InJRHPPtY= 32 | -----END CERTIFICATE-----` 33 | 34 | // keyPEMBlock is the private key that should only be available on the machine running this client 35 | // (Should be pulled from a file or via another process) 36 | const keyPEMBlock = `-----BEGIN EC PRIVATE KEY----- 37 | MHcCAQEEIDHDjs9Ug8QvsuKRrtC6QUmz4u++oBJF2VtCZe9gYyzOoAoGCCqGSM49 38 | AwEHoUQDQgAEEE23UDaBOf4+poIeBwHW4D8hm7s46RyP8d+Keqj6I3blCvon1Fuj 39 | G20uWbpo4d9SuQJlmLeI1n1PNkm+rMTylw== 40 | -----END EC PRIVATE KEY-----` 41 | 42 | // hostname is the host running the knox server 43 | const hostname = "localhost:9000" 44 | 45 | // tokenEndpoint and clientID are used by "knox login" if your oauth client supports password flows. 46 | const tokenEndpoint = "https://oauth.token.endpoint.used.for/knox/login" 47 | const clientID = "" 48 | 49 | // keyFolder is the directory where keys are cached 50 | const keyFolder = "/var/lib/knox/v0/keys/" 51 | 52 | const ( 53 | authTypeUser = "user" 54 | authTypeMachine = "machine" 55 | ) 56 | 57 | // authTokenResp is the format of the OAuth response generated by "knox login" 58 | type authTokenResp struct { 59 | AccessToken string `json:"access_token"` 60 | Error string `json:"error"` 61 | } 62 | 63 | // getCert returns the cert in the tls.Certificate format. This should be a config option in prod. 64 | func getCert() (tls.Certificate, error) { 65 | return tls.X509KeyPair([]byte(certPEMBlock), []byte(keyPEMBlock)) 66 | } 67 | 68 | // authHandler is used to generate an authentication header. 69 | // The server expects VersionByte + TypeByte + IDToPassToAuthHandler. 70 | func authHandler() (string, string, knox.HTTP) { 71 | if s := os.Getenv("KNOX_USER_AUTH"); s != "" { 72 | return "0u" + s, authTypeUser, nil 73 | } 74 | if s := os.Getenv("KNOX_MACHINE_AUTH"); s != "" { 75 | c, _ := getCert() 76 | x509Cert, err := x509.ParseCertificate(c.Certificate[0]) 77 | if err != nil { 78 | return "0t" + s, authTypeMachine, nil 79 | } 80 | if len(x509Cert.Subject.CommonName) > 0 { 81 | return "0t" + x509Cert.Subject.CommonName, authTypeMachine, nil 82 | } else if len(x509Cert.DNSNames) > 0 { 83 | return "0t" + x509Cert.DNSNames[0], authTypeMachine, nil 84 | } else { 85 | return "0t" + s, authTypeMachine, nil 86 | } 87 | } 88 | if s := os.Getenv("KNOX_SERVICE_AUTH"); s != "" { 89 | return "0s" + s, authTypeMachine, nil 90 | } 91 | u, err := user.Current() 92 | if err != nil { 93 | return "", "", nil 94 | } 95 | 96 | d, err := os.ReadFile(u.HomeDir + "/.knox_user_auth") 97 | if err != nil { 98 | return "", "", nil 99 | } 100 | var authResp authTokenResp 101 | err = json.Unmarshal(d, &authResp) 102 | if err != nil { 103 | return "", "", nil 104 | } 105 | 106 | return "0u" + authResp.AccessToken, authTypeUser, nil 107 | } 108 | 109 | func main() { 110 | rand.Seed(time.Now().UTC().UnixNano()) 111 | 112 | tlsConfig := &tls.Config{ 113 | ServerName: "knox", 114 | InsecureSkipVerify: true, 115 | } 116 | 117 | cert, err := getCert() 118 | if err == nil { 119 | tlsConfig.Certificates = []tls.Certificate{cert} 120 | } 121 | 122 | authHandlers := []knox.AuthHandler{authHandler} 123 | cli := &knox.HTTPClient{ 124 | KeyFolder: keyFolder, 125 | UncachedClient: knox.NewUncachedClient(hostname, &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}, authHandlers, ""), 126 | } 127 | 128 | loginCommand := client.NewLoginCommand(clientID, tokenEndpoint, "", "", "", "") 129 | 130 | client.Run( 131 | cli, 132 | &client.VisibilityParams{ 133 | Logf: log.Printf, 134 | Errorf: log.Printf, 135 | SummaryMetrics: func(map[string]uint64) {}, 136 | InvokeMetrics: func(map[string]string) {}, 137 | GetKeyMetrics: func(map[string]string) {}, 138 | }, 139 | loginCommand, 140 | ) 141 | } 142 | -------------------------------------------------------------------------------- /cmd/dev_server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "crypto/elliptic" 6 | crypto_rand "crypto/rand" 7 | "crypto/tls" 8 | "crypto/x509" 9 | "crypto/x509/pkix" 10 | "encoding/pem" 11 | "expvar" 12 | "flag" 13 | "math/big" 14 | "math/rand" 15 | "net/http" 16 | "os" 17 | "time" 18 | 19 | "github.com/pinterest/knox" 20 | "github.com/pinterest/knox/log" 21 | "github.com/pinterest/knox/server" 22 | "github.com/pinterest/knox/server/auth" 23 | "github.com/pinterest/knox/server/keydb" 24 | ) 25 | 26 | const caCert = `-----BEGIN CERTIFICATE----- 27 | MIIB5jCCAYygAwIBAgIUD/1LTTQNvk3Rp9399flLlimbgngwCgYIKoZIzj0EAwIw 28 | UTELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNBMRgwFgYDVQQKEw9NeSBDb21wYW55 29 | IE5hbWUxGzAZBgNVBAMTEnVzZU9ubHlJbkRldk9yVGVzdDAeFw0xODAzMDIwMTU5 30 | MDBaFw0yMzAzMDEwMTU5MDBaMFExCzAJBgNVBAYTAlVTMQswCQYDVQQIEwJDQTEY 31 | MBYGA1UEChMPTXkgQ29tcGFueSBOYW1lMRswGQYDVQQDExJ1c2VPbmx5SW5EZXZP 32 | clRlc3QwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARbSovOAo4ZimGBOn+tyftX 33 | +GXShKsy2eFdvX9WfYx2NvYnw+RSM/JjRSBhUsCPXuEh/E5lhwRVfUxIlHry1CkS 34 | o0IwQDAOBgNVHQ8BAf8EBAMCAQYwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQU 35 | jjNCAZxA5kjDK1ogrwkdziFiDgkwCgYIKoZIzj0EAwIDSAAwRQIgLXo9amyNn1Y3 36 | qLpqrzVF7N7UQ3mxTl01MvnsqvahI08CIQCArwO8KmbPbN5XZrQ2h9zUgbsebwSG 37 | dfOY505yMqiXig== 38 | -----END CERTIFICATE-----` 39 | 40 | var gitSha = expvar.NewString("version") 41 | var service = expvar.NewString("service") 42 | 43 | var ( 44 | flagAddr = flag.String("http", ":9000", "HTTP port to listen on") 45 | ) 46 | 47 | const ( 48 | authTimeout = 10 * time.Second // Calls to auth timeout after 10 seconds 49 | serviceName = "knox_dev" 50 | ) 51 | 52 | func main() { 53 | rand.Seed(time.Now().UTC().UnixNano()) 54 | flag.Parse() 55 | accLogger, errLogger := setupLogging("dev", serviceName) 56 | 57 | dbEncryptionKey := []byte("testtesttesttest") 58 | cryptor := keydb.NewAESGCMCryptor(0, dbEncryptionKey) 59 | 60 | tlsCert, tlsKey, err := buildCert() 61 | if err != nil { 62 | errLogger.Fatal("Failed to make TLS key or cert: ", err) 63 | } 64 | 65 | db := keydb.NewTempDB() 66 | 67 | server.AddDefaultAccess(&knox.Access{ 68 | Type: knox.UserGroup, 69 | ID: "security-team", 70 | AccessType: knox.Admin, 71 | }) 72 | 73 | certPool := x509.NewCertPool() 74 | certPool.AppendCertsFromPEM([]byte(caCert)) 75 | 76 | decorators := [](func(http.HandlerFunc) http.HandlerFunc){ 77 | server.Logger(accLogger), 78 | server.AddHeader("Content-Type", "application/json"), 79 | server.AddHeader("X-Content-Type-Options", "nosniff"), 80 | server.Authentication( 81 | []auth.Provider{ 82 | auth.NewMTLSAuthProvider(certPool), 83 | auth.NewGitHubProvider(authTimeout), 84 | auth.NewSpiffeAuthProvider(certPool), 85 | auth.NewSpiffeAuthFallbackProvider(certPool), 86 | }, 87 | nil), 88 | } 89 | 90 | r, err := server.GetRouter(cryptor, db, decorators, make([]server.Route, 0)) 91 | if err != nil { 92 | errLogger.Fatal(err) 93 | } 94 | 95 | http.Handle("/", r) 96 | 97 | errLogger.Fatal(serveTLS(tlsCert, tlsKey, *flagAddr)) 98 | } 99 | 100 | func setupLogging(gitSha, service string) (*log.Logger, *log.Logger) { 101 | accLogger := log.New(os.Stderr, "", 0) 102 | accLogger.SetVersion(gitSha) 103 | accLogger.SetService(service) 104 | 105 | errLogger := log.New(os.Stderr, "", 0) 106 | errLogger.SetVersion(gitSha) 107 | errLogger.SetService(service) 108 | return accLogger, errLogger 109 | } 110 | 111 | func buildCert() (certPEMBlock, keyPEMBlock []byte, err error) { 112 | priv, err := ecdsa.GenerateKey(elliptic.P256(), crypto_rand.Reader) 113 | if err != nil { 114 | return nil, nil, err 115 | } 116 | 117 | notBefore := time.Now() 118 | notAfter := notBefore.Add(24 * time.Hour) 119 | serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) 120 | serialNumber, err := crypto_rand.Int(crypto_rand.Reader, serialNumberLimit) 121 | if err != nil { 122 | return nil, nil, err 123 | } 124 | 125 | template := x509.Certificate{ 126 | SerialNumber: serialNumber, 127 | Subject: pkix.Name{Organization: []string{"Acme Co"}}, 128 | NotBefore: notBefore, 129 | NotAfter: notAfter, 130 | 131 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, 132 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 133 | BasicConstraintsValid: true, 134 | } 135 | 136 | template.DNSNames = []string{"localhost"} 137 | 138 | derBytes, err := x509.CreateCertificate(crypto_rand.Reader, &template, &template, &priv.PublicKey, priv) 139 | if err != nil { 140 | return nil, nil, err 141 | } 142 | b, err := x509.MarshalECPrivateKey(priv) 143 | if err != nil { 144 | return nil, nil, err 145 | } 146 | 147 | return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b}), nil 148 | } 149 | 150 | // serveTLS sets up TLS using Mozilla reccommendations and then serves http 151 | func serveTLS(certPEMBlock, keyPEMBlock []byte, httpPort string) error { 152 | // This TLS config disables RC4 and SSLv3. 153 | tlsConfig := &tls.Config{ 154 | NextProtos: []string{"http/1.1"}, 155 | MinVersion: tls.VersionTLS12, 156 | PreferServerCipherSuites: true, 157 | ClientAuth: tls.RequestClientCert, 158 | CipherSuites: []uint16{ 159 | tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 160 | tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 161 | }, 162 | } 163 | 164 | tlsConfig.Certificates = make([]tls.Certificate, 1) 165 | var err error 166 | tlsConfig.Certificates[0], err = tls.X509KeyPair(certPEMBlock, keyPEMBlock) 167 | if err != nil { 168 | return err 169 | } 170 | server := &http.Server{Addr: httpPort, Handler: nil, TLSConfig: tlsConfig} 171 | 172 | return server.ListenAndServeTLS("", "") 173 | } 174 | -------------------------------------------------------------------------------- /cmd/migrate_db/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pinterest/knox" 7 | "github.com/pinterest/knox/server/keydb" 8 | ) 9 | 10 | func moveKeyData(sDB keydb.DB, sCrypt keydb.Cryptor, dDB keydb.DB, dCrypt keydb.Cryptor) error { 11 | dbKeys, err := sDB.GetAll() 12 | if err != nil { 13 | return err 14 | } 15 | newDBKeys := make([]*keydb.DBKey, 0, len(dbKeys)) 16 | for _, dbk := range dbKeys { 17 | k, err := sCrypt.Decrypt(&dbk) 18 | if err != nil { 19 | return err 20 | } 21 | newDBK, err := dCrypt.Encrypt(k) 22 | if err != nil { 23 | return err 24 | } 25 | newDBKeys = append(newDBKeys, newDBK) 26 | } 27 | 28 | err = dDB.Add(newDBKeys...) 29 | if err != nil { 30 | return err 31 | } 32 | return nil 33 | } 34 | 35 | func generateTestDBWithKeys(crypt keydb.Cryptor) keydb.DB { 36 | source := keydb.NewTempDB() 37 | d := []byte("test") 38 | v1 := knox.KeyVersion{ID: 1, Data: d, Status: knox.Primary, CreationTime: 10} 39 | v2 := knox.KeyVersion{ID: 2, Data: d, Status: knox.Active, CreationTime: 10} 40 | v3 := knox.KeyVersion{ID: 3, Data: d, Status: knox.Inactive, CreationTime: 10} 41 | validKVL := knox.KeyVersionList([]knox.KeyVersion{v1, v2, v3}) 42 | 43 | a1 := knox.Access{ID: "testmachine1", AccessType: knox.Admin, Type: knox.Machine} 44 | a2 := knox.Access{ID: "testuser", AccessType: knox.Write, Type: knox.User} 45 | a3 := knox.Access{ID: "testmachine", AccessType: knox.Read, Type: knox.MachinePrefix} 46 | validACL := knox.ACL([]knox.Access{a1, a2, a3}) 47 | 48 | key := knox.Key{ID: "test_key", ACL: validACL, VersionList: validKVL, VersionHash: validKVL.Hash()} 49 | key2 := knox.Key{ID: "test_key2", ACL: validACL, VersionList: validKVL, VersionHash: validKVL.Hash()} 50 | 51 | dbkey, err := crypt.Encrypt(&key) 52 | if err != nil { 53 | panic(err) 54 | } 55 | dbkey2, err := crypt.Encrypt(&key2) 56 | if err != nil { 57 | panic(err) 58 | } 59 | 60 | source.Add(dbkey, dbkey2) 61 | return source 62 | } 63 | 64 | func main() { 65 | crypt1 := keydb.NewAESGCMCryptor(0, make([]byte, 16)) 66 | crypt2 := keydb.NewAESGCMCryptor(1, make([]byte, 16)) 67 | 68 | source := generateTestDBWithKeys(crypt1) 69 | 70 | dest := keydb.NewTempDB() 71 | 72 | err := moveKeyData(source, crypt1, dest, crypt2) 73 | if err != nil { 74 | panic(err) 75 | } 76 | 77 | fmt.Printf("source: %v, dest: %v", source, dest) 78 | 79 | } 80 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/pinterest/knox 2 | 3 | go 1.21.5 4 | 5 | require ( 6 | github.com/golang/protobuf v1.5.2 7 | github.com/google/tink/go v1.6.1 8 | github.com/gorilla/context v1.1.1 9 | github.com/gorilla/mux v1.8.0 10 | golang.org/x/crypto v0.17.0 11 | gopkg.in/fsnotify.v1 v1.4.7 12 | ) 13 | 14 | require ( 15 | github.com/fsnotify/fsnotify v1.4.9 // indirect 16 | github.com/google/go-cmp v0.5.6 // indirect 17 | golang.org/x/sys v0.15.0 // indirect 18 | golang.org/x/term v0.15.0 // indirect 19 | google.golang.org/protobuf v1.33.0 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /log/log_test.go: -------------------------------------------------------------------------------- 1 | // This file uses code from http://golang.org/src/log/log.go 2 | // modified for JSON logging 3 | // 4 | // Copyright (c) 2012 The Go Authors. All rights reserved. 5 | 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions are 8 | // met: 9 | 10 | // * Redistributions of source code must retain the above copyright 11 | // notice, this list of conditions and the following disclaimer. 12 | // * Redistributions in binary form must reproduce the above 13 | // copyright notice, this list of conditions and the following disclaimer 14 | // in the documentation and/or other materials provided with the 15 | // distribution. 16 | // * Neither the name of Google Inc. nor the names of its 17 | // contributors may be used to endorse or promote products derived from 18 | // this software without specific prior written permission. 19 | 20 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | package log 33 | 34 | // These tests are too simple. 35 | 36 | import ( 37 | "bytes" 38 | "encoding/json" 39 | "os" 40 | "regexp" 41 | "testing" 42 | ) 43 | 44 | const ( 45 | Rdate = `[0-9][0-9][0-9][0-9]/[0-9][0-9]/[0-9][0-9]` 46 | Rtime = `[0-9][0-9]:[0-9][0-9]:[0-9][0-9]` 47 | Rmicroseconds = `\.[0-9][0-9][0-9][0-9][0-9][0-9]` 48 | Rline = `(83|85):` // must update if the calls to l.Printf / l.Print below move 49 | Rlongfile = `.*/[A-Za-z0-9_\-]+\.go:` + Rline 50 | Rshortfile = `[A-Za-z0-9_\-]+\.go:` + Rline 51 | ) 52 | 53 | type tester struct { 54 | flag int 55 | prefix string 56 | pattern string // regexp that log output must match; we add ^ and expected_text$ always 57 | } 58 | 59 | var tests = []tester{ 60 | // individual pieces: 61 | {0, "", ""}, 62 | {0, "XXX", "XXX"}, 63 | {Ldate, "", Rdate + " "}, 64 | {Ltime, "", Rtime + " "}, 65 | {Ltime | Lmicroseconds, "", Rtime + Rmicroseconds + " "}, 66 | {Lmicroseconds, "", Rtime + Rmicroseconds + " "}, // microsec implies time 67 | {Llongfile, "", Rlongfile + " "}, 68 | {Lshortfile, "", Rshortfile + " "}, 69 | {Llongfile | Lshortfile, "", Rshortfile + " "}, // shortfile overrides longfile 70 | // everything at once: 71 | {Ldate | Ltime | Lmicroseconds | Llongfile, "XXX", "XXX" + Rdate + " " + Rtime + Rmicroseconds + " " + Rlongfile + " "}, 72 | {Ldate | Ltime | Lmicroseconds | Lshortfile, "XXX", "XXX" + Rdate + " " + Rtime + Rmicroseconds + " " + Rshortfile + " "}, 73 | } 74 | 75 | // Test using Println("hello", 23, "world") or using Printf("hello %d world", 23) 76 | func testPrint(t *testing.T, flag int, prefix string, pattern string, useFormat bool) { 77 | var m LogMessage 78 | buf := new(bytes.Buffer) 79 | SetOutput(buf) 80 | SetFlags(flag) 81 | SetPrefix(prefix) 82 | if useFormat { 83 | Printf("hello %d world\n", 23) 84 | } else { 85 | Println("hello", 23, "world") 86 | } 87 | pattern = "^" + pattern + "hello 23 world\n$" 88 | err := json.NewDecoder(buf).Decode(&m) 89 | if err != nil { 90 | t.Errorf("Unexpected error decoding log JSON: %q", err.Error()) 91 | } 92 | if m.ID == "" { 93 | t.Errorf("ID should be set to random value and not empty") 94 | } 95 | payload, ok := m.Payload.(string) 96 | if !ok { 97 | t.Errorf("Payload is not of string type") 98 | } 99 | matched, err4 := regexp.MatchString(pattern, payload) 100 | if err4 != nil { 101 | t.Fatal("pattern did not compile:", err4) 102 | } 103 | if !matched { 104 | t.Errorf("log output should match %q is %q", pattern, payload) 105 | } 106 | SetOutput(os.Stderr) 107 | } 108 | 109 | func TestAll(t *testing.T) { 110 | for _, testcase := range tests { 111 | testPrint(t, testcase.flag, testcase.prefix, testcase.pattern, false) 112 | testPrint(t, testcase.flag, testcase.prefix, testcase.pattern, true) 113 | } 114 | } 115 | 116 | func TestOutput(t *testing.T) { 117 | const testString = "test" 118 | var b bytes.Buffer 119 | var m LogMessage 120 | l := New(&b, "", 0) 121 | l.Print(testString) 122 | err := json.NewDecoder(&b).Decode(&m) 123 | if err != nil { 124 | t.Errorf("Unexpected error decoding log JSON: %q", err.Error()) 125 | } 126 | if m.ID == "" { 127 | t.Errorf("ID should be set to random value and not empty") 128 | } 129 | payload, ok := m.Payload.(string) 130 | if !ok { 131 | t.Errorf("Payload is not of string type") 132 | } 133 | 134 | if expect := testString; payload != expect { 135 | t.Errorf("log output should match %q is %q", expect, payload) 136 | } 137 | } 138 | 139 | func TestIDUnique(t *testing.T) { 140 | const testString = "test" 141 | var b bytes.Buffer 142 | var m1 LogMessage 143 | var m2 LogMessage 144 | l := New(&b, "", 0) 145 | l.Print(testString) 146 | decoder := json.NewDecoder(&b) 147 | err := decoder.Decode(&m1) 148 | if err != nil { 149 | t.Errorf("Unexpected error decoding log JSON: %q", err.Error()) 150 | } 151 | 152 | l.Print(testString) 153 | err = decoder.Decode(&m2) 154 | if err != nil { 155 | t.Errorf("Unexpected error decoding log JSON: %q", err.Error()) 156 | } 157 | if m1.ID == m2.ID { 158 | t.Errorf("ID should be set to random value and not equal: %q == %q", m1.ID, m2.ID) 159 | } 160 | } 161 | 162 | func TestHostSet(t *testing.T) { 163 | const testString = "test" 164 | var b bytes.Buffer 165 | var m LogMessage 166 | l := New(&b, "", 0) 167 | l.Print(testString) 168 | err := json.NewDecoder(&b).Decode(&m) 169 | if err != nil { 170 | t.Errorf("Unexpected error decoding log JSON: %q", err.Error()) 171 | } 172 | } 173 | 174 | func TestVersionSet(t *testing.T) { 175 | const testString = "test" 176 | var b bytes.Buffer 177 | var m LogMessage 178 | l := New(&b, "", 0) 179 | l.SetVersion("version-secret_panda") 180 | l.Print(testString) 181 | err := json.NewDecoder(&b).Decode(&m) 182 | if err != nil { 183 | t.Errorf("Unexpected error decoding log JSON: %q", err.Error()) 184 | } 185 | if m.Version != "version-secret_panda" { 186 | t.Errorf("Expected service name to be version-secret_panda, got %q", m.Version) 187 | } 188 | } 189 | 190 | func TestServiceSet(t *testing.T) { 191 | const testString = "test" 192 | var b bytes.Buffer 193 | var m LogMessage 194 | l := New(&b, "", 0) 195 | l.SetService("codename-secret_panda") 196 | l.Print(testString) 197 | err := json.NewDecoder(&b).Decode(&m) 198 | if err != nil { 199 | t.Errorf("Unexpected error decoding log JSON: %q", err.Error()) 200 | } 201 | if m.Service != "codename-secret_panda" { 202 | t.Errorf("Expected service name to be codename-secret_panda, got %q", m.Service) 203 | } 204 | } 205 | 206 | func TestOutputJSON(t *testing.T) { 207 | data := struct { 208 | Title string 209 | Number int 210 | }{ 211 | "Whoa metrics", 212 | 1337, 213 | } 214 | var m LogMessage 215 | var b bytes.Buffer 216 | l := New(&b, "", LstdFlags) 217 | l.OutputJSON(data) 218 | d := struct { 219 | Title string 220 | Number int 221 | }{} 222 | m.Payload = &d 223 | err := json.NewDecoder(&b).Decode(&m) 224 | if err != nil { 225 | t.Errorf("Unexpected error decoding log JSON: %q", err.Error()) 226 | } 227 | if d.Title != "Whoa metrics" { 228 | t.Errorf("Expected Title to be Whoa metrics, got %q", d.Title) 229 | } 230 | if d.Number != 1337 { 231 | t.Errorf("Expected Number to be 1337, got %d", d.Number) 232 | } 233 | } 234 | 235 | func TestOutputBinary(t *testing.T) { 236 | var m LogMessage 237 | var b bytes.Buffer 238 | l := New(&b, "", LstdFlags) 239 | l.OutputBinary([]byte("")) 240 | err := json.NewDecoder(&b).Decode(&m) 241 | if err != nil { 242 | t.Errorf("Unexpected error decoding log JSON: %q", err.Error()) 243 | } 244 | } 245 | 246 | func TestFlagAndPrefixSetting(t *testing.T) { 247 | var m LogMessage 248 | var b bytes.Buffer 249 | l := New(&b, "Test:", LstdFlags) 250 | f := l.Flags() 251 | if f != LstdFlags { 252 | t.Errorf("Flags 1: expected %x got %x", LstdFlags, f) 253 | } 254 | l.SetFlags(f | Lmicroseconds) 255 | f = l.Flags() 256 | if f != LstdFlags|Lmicroseconds { 257 | t.Errorf("Flags 2: expected %x got %x", LstdFlags|Lmicroseconds, f) 258 | } 259 | p := l.Prefix() 260 | if p != "Test:" { 261 | t.Errorf(`Prefix: expected "Test:" got %q`, p) 262 | } 263 | l.SetPrefix("Reality:") 264 | p = l.Prefix() 265 | if p != "Reality:" { 266 | t.Errorf(`Prefix: expected "Reality:" got %q`, p) 267 | } 268 | // Verify a log message looks right, with our prefix and microseconds present. 269 | l.Print("hello") 270 | pattern := "^Reality:" + Rdate + " " + Rtime + Rmicroseconds + " hello" 271 | err := json.NewDecoder(&b).Decode(&m) 272 | if err != nil { 273 | t.Errorf("Unexpected error decoding log JSON: %q", err.Error()) 274 | } 275 | payload, ok := m.Payload.(string) 276 | if !ok { 277 | t.Errorf("Payload is not of string type") 278 | } 279 | matched, err := regexp.Match(pattern, []byte(payload)) 280 | if err != nil { 281 | t.Fatalf("pattern %q did not compile: %s", pattern, err) 282 | } 283 | if !matched { 284 | t.Errorf("message did not match pattern %q is %q", pattern, payload) 285 | } 286 | } 287 | -------------------------------------------------------------------------------- /server/api.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "math/rand" 7 | "net/http" 8 | "os" 9 | "time" 10 | 11 | "github.com/gorilla/mux" 12 | 13 | "github.com/pinterest/knox" 14 | "github.com/pinterest/knox/log" 15 | "github.com/pinterest/knox/server/keydb" 16 | ) 17 | 18 | // HTTPError is the error type with knox err subcode and message for logging purposes 19 | type HTTPError struct { 20 | Subcode int 21 | Message string 22 | } 23 | 24 | // errF is a convience method to make an httpError. 25 | func errF(c int, m string) *HTTPError { 26 | return &HTTPError{c, m} 27 | } 28 | 29 | // httpErrResp contain the http codes and messages to be returned back to clients. 30 | type httpErrResp struct { 31 | Code int 32 | Message string 33 | } 34 | 35 | // HTTPErrMap is a mapping from err subcodes to the http err response that will be returned. 36 | var HTTPErrMap = map[int]*httpErrResp{ 37 | knox.NoKeyIDCode: {http.StatusBadRequest, "Missing Key ID"}, 38 | knox.InternalServerErrorCode: {http.StatusInternalServerError, "Internal Server Error"}, 39 | knox.KeyIdentifierExistsCode: {http.StatusBadRequest, "Key identifer exists"}, 40 | knox.KeyVersionDoesNotExistCode: {http.StatusNotFound, "Key version does not exist"}, 41 | knox.KeyIdentifierDoesNotExistCode: {http.StatusNotFound, "Key identifer does not exist"}, 42 | knox.UnauthenticatedCode: {http.StatusUnauthorized, "User or machine is not authenticated"}, 43 | knox.UnauthorizedCode: {http.StatusForbidden, "User or machine not authorized"}, 44 | knox.NotYetImplementedCode: {http.StatusNotImplemented, "Not yet implemented"}, 45 | knox.NotFoundCode: {http.StatusNotFound, "Route not found"}, 46 | knox.NoKeyDataCode: {http.StatusBadRequest, "Missing Key Data"}, 47 | knox.BadRequestDataCode: {http.StatusBadRequest, "Bad request format"}, 48 | knox.BadKeyFormatCode: {http.StatusBadRequest, "Key ID contains unsupported characters"}, 49 | knox.BadPrincipalIdentifier: {http.StatusBadRequest, "Invalid principal identifier"}, 50 | } 51 | 52 | func combine(f, g func(http.HandlerFunc) http.HandlerFunc) func(http.HandlerFunc) http.HandlerFunc { 53 | return func(h http.HandlerFunc) http.HandlerFunc { 54 | return f(g(h)) 55 | } 56 | } 57 | 58 | // GetRouterFromKeyManager creates the mux router that serves knox routes from a key manager 59 | func GetRouterFromKeyManager( 60 | cryptor keydb.Cryptor, 61 | keyManager KeyManager, 62 | decorators [](func(http.HandlerFunc) http.HandlerFunc), 63 | additionalRoutes []Route) (*mux.Router, error) { 64 | existingRouteIds := map[string]Route{} 65 | existingRouteMethodAndPaths := map[string]map[string]Route{} 66 | allRoutes := append(routes[:], additionalRoutes[:]...) 67 | 68 | for _, route := range allRoutes { 69 | if _, routeExists := existingRouteIds[route.Id]; routeExists { 70 | return nil, fmt.Errorf( 71 | "There are ID conflicts for the route with ID: '%v'", 72 | route.Id, 73 | ) 74 | } 75 | childMap, methodExists := existingRouteMethodAndPaths[route.Method] 76 | if !methodExists { 77 | childMap := map[string]Route{ 78 | route.Path: route, 79 | } 80 | existingRouteMethodAndPaths[route.Method] = childMap 81 | } else { 82 | if conflictingRoute, pathExists := childMap[route.Path]; pathExists { 83 | return nil, fmt.Errorf( 84 | "There are Method/Path conflicts for the following Route IDs: ('%v' and '%v')", 85 | conflictingRoute.Id, route.Id, 86 | ) 87 | } 88 | } 89 | 90 | existingRouteMethodAndPaths[route.Method][route.Path] = route 91 | existingRouteIds[route.Id] = route 92 | } 93 | 94 | r := mux.NewRouter() 95 | 96 | decorator := func(f http.HandlerFunc) http.HandlerFunc { return f } 97 | for i := range decorators { 98 | j := len(decorators) - i - 1 99 | decorator = combine(decorators[j], decorator) 100 | } 101 | 102 | r.NotFoundHandler = setupRoute("404", keyManager)(decorator(WriteErr(errF(knox.NotFoundCode, "")))) 103 | 104 | for _, route := range allRoutes { 105 | addRoute(r, route, decorator, keyManager) 106 | } 107 | return r, nil 108 | } 109 | 110 | // GetRouter creates the mux router that serves knox routes. 111 | // All routes are declared in this file. Each handler itself takes in the db and 112 | // auth provider interfaces and returns a handler that the is processed through 113 | // the API Middleware. 114 | func GetRouter( 115 | cryptor keydb.Cryptor, 116 | db keydb.DB, 117 | decorators [](func(http.HandlerFunc) http.HandlerFunc), 118 | additionalRoutes []Route) (*mux.Router, error) { 119 | m := NewKeyManager(cryptor, db) 120 | 121 | return GetRouterFromKeyManager(cryptor, m, decorators, additionalRoutes) 122 | } 123 | 124 | func addRoute( 125 | router *mux.Router, 126 | route Route, 127 | routeDecorator func(f http.HandlerFunc) http.HandlerFunc, 128 | keyManager KeyManager) { 129 | handler := setupRoute(route.Id, keyManager)(parseParams(route.Parameters)(routeDecorator(route.ServeHTTP))) 130 | router.Handle(route.Path, handler).Methods(route.Method) 131 | } 132 | 133 | // Parameter is an interface through which route-specific Knox API Parameters 134 | // can be specified 135 | type Parameter interface { 136 | Name() string 137 | Get(r *http.Request) (string, bool) 138 | } 139 | 140 | // UrlParameter is an implementation of the Parameter interface that extracts 141 | // parameter values from the URL as referenced in section 3.3 of RFC2396. 142 | type UrlParameter string 143 | 144 | // Get returns the value of the URL parameter 145 | func (p UrlParameter) Get(r *http.Request) (string, bool) { 146 | s, ok := mux.Vars(r)[string(p)] 147 | return s, ok 148 | } 149 | 150 | // Name defines the URL-embedded key that this parameter maps to 151 | func (p UrlParameter) Name() string { 152 | return string(p) 153 | } 154 | 155 | // RawQueryParameter is an implementation of the Parameter interface that 156 | // extracts the complete query string from the request URL 157 | // as referenced in section 3.4 of RFC2396. 158 | type RawQueryParameter string 159 | 160 | // Get returns the value of the entire query string 161 | func (p RawQueryParameter) Get(r *http.Request) (string, bool) { 162 | return r.URL.RawQuery, true 163 | } 164 | 165 | // Name represents the key-name that will be set for the raw query string 166 | // in the `parameters` map of the route handler function. 167 | func (p RawQueryParameter) Name() string { 168 | return string(p) 169 | } 170 | 171 | // QueryParameter is an implementation of the Parameter interface that extracts 172 | // specific parameter values from the query string of the request URL 173 | // as referenced in section 3.4 of RFC2396. 174 | type QueryParameter string 175 | 176 | // Get returns the value of the query string parameter 177 | func (p QueryParameter) Get(r *http.Request) (string, bool) { 178 | val, ok := r.URL.Query()[string(p)] 179 | if !ok { 180 | return "", false 181 | } 182 | return val[0], true 183 | } 184 | 185 | // Name defines the URL-embedded key that this parameter maps to 186 | func (p QueryParameter) Name() string { 187 | return string(p) 188 | } 189 | 190 | // PostParameter is an implementation of the Parameter interface that 191 | // extracts values embedded in the web form transmitted in the 192 | // request body 193 | type PostParameter string 194 | 195 | // Get returns the value of the appropriate parameter from the request body 196 | func (p PostParameter) Get(r *http.Request) (string, bool) { 197 | err := r.ParseForm() 198 | if err != nil { 199 | return "", false 200 | } 201 | k, ok := r.PostForm[string(p)] 202 | if !ok { 203 | return "", ok 204 | } 205 | return k[0], ok 206 | } 207 | 208 | // Name represents the key corresponding to this parameter in the request form 209 | func (p PostParameter) Name() string { 210 | return string(p) 211 | } 212 | 213 | // Route is a struct that defines a path and method-specific 214 | // HTTP route on the Knox server 215 | type Route struct { 216 | // Handler represents the handler function that is responsible for serving 217 | // this route 218 | Handler func(db KeyManager, principal knox.Principal, parameters map[string]string) (interface{}, *HTTPError) 219 | 220 | // Id represents A unique string identifier that represents this specific 221 | // route 222 | Id string 223 | 224 | // Path represents the relative HTTP path (or prefix) that must be specified 225 | // in order to invoke this route 226 | Path string 227 | 228 | // Method represents the HTTP method that must be specified in order to 229 | // invoke this route 230 | Method string 231 | 232 | // Parameters is an array that represents the route-specific parameters 233 | // that will be passed to the handler function 234 | Parameters []Parameter 235 | } 236 | 237 | // WriteErr returns a function that can encode error information and set an 238 | // HTTP error response code in the specified HTTP response writer 239 | func WriteErr(apiErr *HTTPError) http.HandlerFunc { 240 | return func(w http.ResponseWriter, r *http.Request) { 241 | resp := new(knox.Response) 242 | hostname, err := os.Hostname() 243 | if err != nil { 244 | panic("Hostname is required:" + err.Error()) 245 | } 246 | resp.Host = hostname 247 | resp.Timestamp = time.Now().UnixNano() 248 | resp.Status = "error" 249 | resp.Code = apiErr.Subcode 250 | resp.Message = apiErr.Message 251 | code := HTTPErrMap[apiErr.Subcode].Code 252 | w.WriteHeader(code) 253 | setAPIError(r, apiErr) 254 | 255 | if err := json.NewEncoder(w).Encode(resp); err != nil { 256 | // It is unclear what to do here since the server failed to write the response. 257 | log.Println(err.Error()) 258 | } 259 | } 260 | } 261 | 262 | // WriteData returns a function that can write arbitrary data to the specified 263 | // HTTP response writer 264 | func WriteData(w http.ResponseWriter, data interface{}) { 265 | r := new(knox.Response) 266 | r.Message = "" 267 | r.Code = knox.OKCode 268 | r.Status = "ok" 269 | hostname, err := os.Hostname() 270 | if err != nil { 271 | panic("Hostname is required:" + err.Error()) 272 | } 273 | r.Host = hostname 274 | r.Timestamp = time.Now().UnixNano() 275 | r.Data = data 276 | if err := json.NewEncoder(w).Encode(r); err != nil { 277 | // It is unclear what to do here since the server failed to write the response. 278 | log.Println(err.Error()) 279 | } 280 | } 281 | 282 | // ServeHTTP runs API middleware and calls the underlying handler function. 283 | func (r Route) ServeHTTP(w http.ResponseWriter, req *http.Request) { 284 | db := getDB(req) 285 | principal := GetPrincipal(req) 286 | ps := GetParams(req) 287 | data, err := r.Handler(db, principal, ps) 288 | 289 | if err != nil { 290 | WriteErr(err)(w, req) 291 | } else { 292 | WriteData(w, data) 293 | } 294 | } 295 | 296 | // Users besides creator who have default access to all keys. 297 | // This is by default empty and should be expanded by the main function. 298 | var defaultAccess []knox.Access 299 | 300 | // AddDefaultAccess adds an access to every created key. 301 | func AddDefaultAccess(a *knox.Access) { 302 | defaultAccess = append(defaultAccess, *a) 303 | } 304 | 305 | var accessCallback func(knox.AccessCallbackInput) (bool, error) 306 | 307 | // SetAccessCallback adds a callback. 308 | func SetAccessCallback(callback func(knox.AccessCallbackInput) (bool, error)) { 309 | accessCallback = callback 310 | } 311 | 312 | // Extra validators to apply on principals submitted to Knox. 313 | var extraPrincipalValidators []knox.PrincipalValidator 314 | 315 | // AddPrincipalValidator applies additional, custom validation on principals 316 | // submitted to Knox for adding into ACLs. Can be used to set custom business 317 | // logic for e.g. what kind of machine or service prefixes are acceptable. 318 | func AddPrincipalValidator(validator knox.PrincipalValidator) { 319 | extraPrincipalValidators = append(extraPrincipalValidators, validator) 320 | } 321 | 322 | // newKeyVersion creates a new KeyVersion with correctly set defaults. 323 | func newKeyVersion(d []byte, s knox.VersionStatus) knox.KeyVersion { 324 | version := knox.KeyVersion{} 325 | version.Data = d 326 | version.Status = s 327 | version.CreationTime = time.Now().UnixNano() 328 | // This is only 63 bits of randomness, but it appears to be the fastest way. 329 | version.ID = uint64(rand.Int63()) 330 | return version 331 | } 332 | 333 | // NewKey creates a new Key with correctly set defaults. 334 | func newKey(id string, acl knox.ACL, d []byte, u knox.Principal) knox.Key { 335 | key := knox.Key{} 336 | key.ID = id 337 | 338 | creatorAccess := knox.Access{ID: u.GetID(), AccessType: knox.Admin, Type: knox.User} 339 | key.ACL = acl.Add(creatorAccess) 340 | for _, a := range defaultAccess { 341 | key.ACL = key.ACL.Add(a) 342 | } 343 | 344 | key.VersionList = []knox.KeyVersion{newKeyVersion(d, knox.Primary)} 345 | key.VersionHash = key.VersionList.Hash() 346 | return key 347 | } 348 | -------------------------------------------------------------------------------- /server/api_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "crypto/tls" 6 | "encoding/json" 7 | "fmt" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "strings" 12 | "testing" 13 | "time" 14 | 15 | "github.com/pinterest/knox" 16 | "github.com/pinterest/knox/server/auth" 17 | "github.com/pinterest/knox/server/keydb" 18 | ) 19 | 20 | const TESTVAL int = 1 21 | 22 | type mockAuthFail struct{} 23 | 24 | func (a mockAuthFail) Authenticate(r *http.Request) (knox.Principal, error) { 25 | return nil, fmt.Errorf("Error!") 26 | } 27 | func (a mockAuthFail) IsUser(p knox.Principal) bool { 28 | return false 29 | } 30 | 31 | type mockAuthTrue struct{} 32 | 33 | func (a mockAuthTrue) Authenticate(r *http.Request) (knox.Principal, error) { 34 | return nil, nil 35 | } 36 | func (a mockAuthTrue) IsUser(p knox.Principal) bool { 37 | return true 38 | } 39 | 40 | func mockFailureHandler(m KeyManager, principal knox.Principal, parameters map[string]string) (interface{}, *HTTPError) { 41 | return nil, errF(knox.InternalServerErrorCode, "") 42 | } 43 | 44 | func mockHandler(m KeyManager, principal knox.Principal, parameters map[string]string) (interface{}, *HTTPError) { 45 | return TESTVAL, nil 46 | } 47 | 48 | func additionalMockHandler(m KeyManager, principal knox.Principal, parameters map[string]string) (interface{}, *HTTPError) { 49 | return "The meaning of life is 42", nil 50 | } 51 | 52 | func mockAccessCallback(input knox.AccessCallbackInput) (bool, error) { 53 | return true, nil 54 | } 55 | 56 | func mockRoute() Route { 57 | return Route{ 58 | Method: "GET", 59 | Path: "/v0/keys/", 60 | Handler: mockHandler, 61 | Id: "test1", 62 | Parameters: []Parameter{}, 63 | } 64 | } 65 | 66 | func additionalMockRoute() Route { 67 | return Route{ 68 | Method: "GET", 69 | Path: "/v0/custom/", 70 | Handler: additionalMockHandler, 71 | Id: "a-custom-route", 72 | Parameters: []Parameter{}, 73 | } 74 | } 75 | 76 | func mockFailureRoute() Route { 77 | return Route{ 78 | Method: "GET", 79 | Path: "/v0/keys/", 80 | Handler: mockFailureHandler, 81 | Id: "test2", 82 | Parameters: []Parameter{}, 83 | } 84 | } 85 | 86 | func TestAddDefaultAccess(t *testing.T) { 87 | dUID := "testuser2" 88 | u2 := auth.NewUser(dUID, []string{}) 89 | a := &knox.Access{ID: dUID, AccessType: knox.Read, Type: knox.User} 90 | 91 | AddDefaultAccess(a) 92 | id := "testkeyid" 93 | uid := "testuser" 94 | acl := knox.ACL([]knox.Access{}) 95 | data := []byte("testdata") 96 | u := auth.NewUser(uid, []string{}) 97 | key := newKey(id, acl, data, u) 98 | if !u.CanAccess(key.ACL, knox.Admin) { 99 | t.Fatal("creator does not have access to his key") 100 | } 101 | if !u2.CanAccess(key.ACL, knox.Read) { 102 | t.Fatal("default access does not have access to his key") 103 | } 104 | if len(key.ACL) != 2 { 105 | text, _ := json.Marshal(key.ACL) 106 | t.Fatal("The Key's ACL is too big: " + string(text)) 107 | } 108 | defaultAccess = []knox.Access{} 109 | 110 | } 111 | 112 | func TestSetAccessCallback(t *testing.T) { 113 | defer SetAccessCallback(nil) 114 | 115 | SetAccessCallback(mockAccessCallback) 116 | 117 | input := knox.AccessCallbackInput{} 118 | 119 | if accessCallback == nil { 120 | t.Fatal("accessCallback should not be nil") 121 | } 122 | 123 | canAccess, err := accessCallback(input) 124 | if err != nil { 125 | t.Fatal("accessCallback should not return an error") 126 | } 127 | 128 | if !canAccess { 129 | t.Fatal("accessCallback should return true") 130 | } 131 | } 132 | 133 | func TestParseFormParameter(t *testing.T) { 134 | p := PostParameter("key") 135 | 136 | r, err := http.NewRequest("POST", "http://www.com/?key=nope", strings.NewReader("nokey=yup")) 137 | if err != nil { 138 | t.Fatal(err.Error()) 139 | } 140 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") 141 | s, ok := p.Get(r) 142 | if ok { 143 | t.Fatal("Key parameter should not be present in post form") 144 | } 145 | 146 | r, err = http.NewRequest("POST", "http://www.com/?key=nope", strings.NewReader("key=yup")) 147 | if err != nil { 148 | t.Fatal(err.Error()) 149 | } 150 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") 151 | s, ok = p.Get(r) 152 | if !ok { 153 | t.Fatal("Key parameter should be present in post form") 154 | } 155 | if s != "yup" { 156 | t.Fatal("Key should be yup") 157 | } 158 | 159 | // This should cause some problems 160 | r, err = http.NewRequest("POST", "http://www.com/?key=nope", nil) 161 | if err != nil { 162 | t.Fatal(err.Error()) 163 | } 164 | s, ok = p.Get(r) 165 | if ok { 166 | t.Fatal("Key parameter should not be present in nil request body") 167 | } 168 | 169 | } 170 | 171 | func checkinternalServerErrorResponse(t *testing.T, w *httptest.ResponseRecorder) { 172 | if w.Code != HTTPErrMap[knox.InternalServerErrorCode].Code { 173 | t.Fatal("unexpected response code") 174 | } 175 | var resp knox.Response 176 | err := json.Unmarshal([]byte(w.Body.String()), &resp) 177 | if err != nil { 178 | t.Fatal("Test returned invalid JSON data") 179 | } 180 | if resp.Data != nil { 181 | t.Fatal("Test returned invalid data") 182 | } 183 | if resp.Status != "error" { 184 | t.Fatal("unexpected status") 185 | } 186 | if resp.Code != knox.InternalServerErrorCode { 187 | t.Fatal("unexpected error code") 188 | } 189 | if resp.Message != "" { 190 | t.Fatal("Wrong message") 191 | } 192 | if resp.Host == "" { 193 | t.Fatal("no hostname present") 194 | } 195 | if resp.Timestamp > time.Now().UnixNano() { 196 | t.Fatal("time is in the future") 197 | } 198 | } 199 | 200 | func TestErrorHandler(t *testing.T) { 201 | testErr := errF(knox.InternalServerErrorCode, "") 202 | handler := WriteErr(testErr) 203 | 204 | w := httptest.NewRecorder() 205 | handler(w, nil) 206 | checkinternalServerErrorResponse(t, w) 207 | } 208 | 209 | func TestNewKeyVersion(t *testing.T) { 210 | data := []byte("testdata") 211 | status := knox.Active 212 | beforeTime := time.Now().UnixNano() 213 | version := newKeyVersion(data, status) 214 | afterTime := time.Now().UnixNano() 215 | if !bytes.Equal(version.Data, data) { 216 | t.Fatal("version data mismatch") 217 | } 218 | if version.CreationTime < beforeTime || version.CreationTime > afterTime { 219 | t.Fatal("Creation time does not fit in time bounds") 220 | } 221 | if version.Status != status { 222 | t.Fatal("version status doesn't match") 223 | } 224 | version2 := newKeyVersion(data, status) 225 | if version.ID == version2.ID { 226 | t.Fatal("version ids are deterministic") 227 | } 228 | } 229 | 230 | func TestNewKey(t *testing.T) { 231 | id := "testkeyid" 232 | uid := "testuser" 233 | acl := knox.ACL([]knox.Access{{ID: "testmachine", AccessType: knox.Admin, Type: knox.Machine}}) 234 | data := []byte("testdata") 235 | u := auth.NewUser(uid, []string{}) 236 | key := newKey(id, acl, data, u) 237 | if key.ID != id { 238 | t.Fatal("ID does not match: " + key.ID + "!=" + id) 239 | } 240 | if len(key.VersionList) != 1 || !bytes.Equal(key.VersionList[0].Data, data) { 241 | t.Fatal("data does not match: " + string(key.VersionList[0].Data) + "!=" + string(data)) 242 | } 243 | if !u.CanAccess(key.ACL, knox.Admin) { 244 | t.Fatal("creator does not have access to his key") 245 | } 246 | if len(key.ACL) != len(defaultAccess)+2 { 247 | text, _ := json.Marshal(key.ACL) 248 | t.Fatal("The Key's ACL is too big: " + string(text)) 249 | } 250 | 251 | } 252 | 253 | func TestBuildRequest(t *testing.T) { 254 | u := auth.NewUser("Man", []string{}) 255 | m := auth.NewMachine("Robot") 256 | params := map[string]string{"data": "secret", "keyID": "not_secret"} 257 | req := &http.Request{ 258 | URL: &url.URL{Path: "path"}, 259 | TLS: &tls.ConnectionState{CipherSuite: 1}, 260 | Method: "GET", 261 | RemoteAddr: "10.1.1.1", 262 | } 263 | 264 | r := buildRequest(req, m, params) 265 | if r.Method != "GET" { 266 | t.Errorf("Method Should be GET not %q", r.Method) 267 | } 268 | if r.RemoteAddr != "10.1.1.1" { 269 | t.Errorf("Method Should be 10.1.1.1 not %q", r.RemoteAddr) 270 | } 271 | if r.Path != "path" { 272 | t.Errorf("Path Should be path not %q", r.Path) 273 | } 274 | if r.Principal != "Robot" { 275 | t.Errorf("Principal Should be Robot not %q", r.Principal) 276 | } 277 | if r.AuthType != "machine" { 278 | t.Errorf("AuthType Should be machine not %q", r.AuthType) 279 | } 280 | if r.TLSCipher != 1 { 281 | t.Errorf("TLSCipher Should be 1 not %d", r.TLSCipher) 282 | } 283 | if v, ok := r.Parameters["data"]; !ok || v == "secret_sauce" { 284 | t.Fatal("data should be scrubbed, but still present.") 285 | } 286 | r = buildRequest(req, u, params) 287 | if r.AuthType != "user" { 288 | t.Errorf("AuthType Should be user not %q", r.AuthType) 289 | } 290 | r = buildRequest(req, nil, params) 291 | if r.AuthType != "" { 292 | t.Errorf("AuthType Should be \"\" not %q", r.AuthType) 293 | } 294 | if r.Principal != "" { 295 | t.Errorf("Principal Should be \"\" not %q", r.Principal) 296 | } 297 | } 298 | 299 | func TestScrub(t *testing.T) { 300 | r := scrub(map[string]string{"keyID": "not_secret", "data": "secret_sauce"}) 301 | if v, ok := r["keyID"]; !ok || v != "not_secret" { 302 | t.Fatal("KeyID not expected to be scrubbed.") 303 | } 304 | if v, ok := r["data"]; !ok || v == "secret_sauce" { 305 | t.Fatal("data should be scrubbed, but still present.") 306 | } 307 | } 308 | 309 | func TestDuplicateRouteId(t *testing.T) { 310 | cryptor := keydb.NewAESGCMCryptor(0, []byte("testtesttesttest")) 311 | db := keydb.NewTempDB() 312 | decorators := [](func(http.HandlerFunc) http.HandlerFunc){} 313 | additionalRoutes := []Route{ 314 | { 315 | Method: "POST", 316 | Id: "getkeys", 317 | Path: "/v3/foobar/", 318 | Handler: getKeysHandler, 319 | Parameters: []Parameter{ 320 | RawQueryParameter("queryString"), 321 | }, 322 | }, 323 | } 324 | 325 | _, err := GetRouter(cryptor, db, decorators, additionalRoutes) 326 | if err == nil { 327 | t.Fatal("Expected an error when two routes were provided with duplicate IDs") 328 | } 329 | expectedErrorMessage := fmt.Sprintf("There are ID conflicts for the route with ID: '%v'", "getkeys") 330 | 331 | if err.Error() != expectedErrorMessage { 332 | t.Fatalf( 333 | "The incorrect error message was returned for a duplicate ID. "+ 334 | "Expected: '%v'. Actual: '%v'", 335 | expectedErrorMessage, err, 336 | ) 337 | } 338 | } 339 | 340 | func TestDuplicateMethodAndPath(t *testing.T) { 341 | cryptor := keydb.NewAESGCMCryptor(0, []byte("testtesttesttest")) 342 | db := keydb.NewTempDB() 343 | decorators := [](func(http.HandlerFunc) http.HandlerFunc){} 344 | additionalRoutes := []Route{ 345 | { 346 | Method: "GET", 347 | Id: "a-unique-id", 348 | Path: "/v0/keys/", 349 | Handler: getKeysHandler, 350 | Parameters: []Parameter{ 351 | RawQueryParameter("queryString"), 352 | }, 353 | }, 354 | } 355 | 356 | _, err := GetRouter(cryptor, db, decorators, additionalRoutes) 357 | if err == nil { 358 | t.Fatal("Expected an error when two routes were provided with duplicate IDs") 359 | } 360 | expectedErrorMessage := fmt.Sprintf( 361 | "There are Method/Path conflicts for the following Route IDs: ('%v' and '%v')", 362 | "getkeys", "a-unique-id") 363 | 364 | if err.Error() != expectedErrorMessage { 365 | t.Fatalf( 366 | "The incorrect error message was returned for a duplicate ID. "+ 367 | "Expected: '%v'. Actual: '%v'", 368 | expectedErrorMessage, err, 369 | ) 370 | } 371 | } 372 | 373 | func TestAdditionalRouteFunctionality(t *testing.T) { 374 | cryptor := keydb.NewAESGCMCryptor(0, []byte("testtesttesttest")) 375 | db := keydb.NewTempDB() 376 | decorators := [](func(http.HandlerFunc) http.HandlerFunc){} 377 | additionalRoutes := []Route{ 378 | additionalMockRoute(), 379 | } 380 | router, err := GetRouter(cryptor, db, decorators, additionalRoutes) 381 | if err != nil { 382 | t.Fatalf("Did not expect an error while creating router. Details: %v", err) 383 | } 384 | 385 | r, reqErr := http.NewRequest("GET", "/v0/custom/", bytes.NewBufferString("")) 386 | if reqErr != nil { 387 | t.Fatalf("Error while setting up test. Details: %v", err) 388 | } 389 | 390 | w := httptest.NewRecorder() 391 | router.ServeHTTP(w, r) 392 | resp := &knox.Response{} 393 | decoder := json.NewDecoder(w.Body) 394 | err = decoder.Decode(resp) 395 | if err != nil { 396 | t.Fatalf("Error while getting data from additional route. Details: %v", err) 397 | } 398 | 399 | expectedResponse := "The meaning of life is 42" 400 | if resp.Data != expectedResponse { 401 | t.Fatalf("Error while getting data from additional route. Expected: '%v'. Actual: '%v'", 402 | expectedResponse, resp.Data, 403 | ) 404 | } 405 | } 406 | -------------------------------------------------------------------------------- /server/auth/spiffe.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "crypto/x509/pkix" 5 | "encoding/asn1" 6 | "errors" 7 | ) 8 | 9 | // This code is from https://github.com/spiffe/go-spiffe 10 | 11 | var oidExtensionSubjectAltName = asn1.ObjectIdentifier{2, 5, 29, 17} 12 | 13 | func getURINamesFromSANExtension(sanExtension []byte) (uris []string, err error) { 14 | // RFC 5280, 4.2.1.6 15 | 16 | // SubjectAltName ::= GeneralNames 17 | // 18 | // GeneralNames ::= SEQUENCE SIZE (1..MAX) OF GeneralName 19 | // 20 | // GeneralName ::= CHOICE { 21 | // otherName [0] OtherName, 22 | // rfc822Name [1] IA5String, 23 | // dNSName [2] IA5String, 24 | // x400Address [3] ORAddress, 25 | // directoryName [4] Name, 26 | // ediPartyName [5] EDIPartyName, 27 | // uniformResourceIdentifier [6] IA5String, 28 | // iPAddress [7] OCTET STRING, 29 | // registeredID [8] OBJECT IDENTIFIER } 30 | var seq asn1.RawValue 31 | var rest []byte 32 | if rest, err = asn1.Unmarshal(sanExtension, &seq); err != nil { 33 | return uris, err 34 | } else if len(rest) != 0 { 35 | err = errors.New("x509: trailing data after X.509 extension") 36 | return uris, err 37 | } 38 | if !seq.IsCompound || seq.Tag != 16 || seq.Class != 0 { 39 | err = asn1.StructuralError{Msg: "bad SAN sequence"} 40 | return uris, err 41 | } 42 | 43 | rest = seq.Bytes 44 | for len(rest) > 0 { 45 | var v asn1.RawValue 46 | rest, err = asn1.Unmarshal(rest, &v) 47 | if err != nil { 48 | return uris, err 49 | } 50 | if v.Tag == 6 { 51 | uris = append(uris, string(v.Bytes)) 52 | } 53 | } 54 | 55 | return uris, err 56 | } 57 | 58 | // GetURINamesFromExtensions retrieves URIs from the SAN extension of a slice of extensions 59 | func GetURINamesFromExtensions(extensions *[]pkix.Extension) (uris []string, err error) { 60 | for _, ext := range *extensions { 61 | if ext.Id.Equal(oidExtensionSubjectAltName) { 62 | uris, err = getURINamesFromSANExtension(ext.Value) 63 | if err != nil { 64 | return uris, err 65 | } 66 | } 67 | } 68 | 69 | return uris, nil 70 | } 71 | -------------------------------------------------------------------------------- /server/decorators.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/url" 7 | 8 | "github.com/gorilla/context" 9 | "github.com/pinterest/knox" 10 | "github.com/pinterest/knox/log" 11 | "github.com/pinterest/knox/server/auth" 12 | ) 13 | 14 | type contextKey int 15 | 16 | const ( 17 | apiErrorContext contextKey = iota 18 | principalContext 19 | paramsContext 20 | dbContext 21 | idContext 22 | ) 23 | 24 | // GetAPIError gets the HTTP error that will be returned from the server. 25 | func GetAPIError(r *http.Request) *HTTPError { 26 | if rv := context.Get(r, apiErrorContext); rv != nil { 27 | return rv.(*HTTPError) 28 | } 29 | return nil 30 | } 31 | 32 | func setAPIError(r *http.Request, val *HTTPError) { 33 | context.Set(r, apiErrorContext, val) 34 | } 35 | 36 | // GetPrincipal gets the principal authenticated through the authentication decorator 37 | func GetPrincipal(r *http.Request) knox.Principal { 38 | ctx := getOrInitializePrincipalContext(r) 39 | return ctx.GetCurrentPrincipal() 40 | } 41 | 42 | // SetPrincipal sets the principal authenticated through the authentication decorator. 43 | // For security reasons, this method will only set the Principal in the context for 44 | // the first invocation. Subsequent invocations WILL cause a panic. 45 | func SetPrincipal(r *http.Request, val knox.Principal) { 46 | ctx := getOrInitializePrincipalContext(r) 47 | ctx.SetCurrentPrincipal(val) 48 | } 49 | 50 | // GetParams gets the parameters for the request through the parameters context. 51 | func GetParams(r *http.Request) map[string]string { 52 | if rv := context.Get(r, paramsContext); rv != nil { 53 | return rv.(map[string]string) 54 | } 55 | return nil 56 | } 57 | 58 | func setParams(r *http.Request, val map[string]string) { 59 | context.Set(r, paramsContext, val) 60 | } 61 | 62 | func getDB(r *http.Request) KeyManager { 63 | if rv := context.Get(r, dbContext); rv != nil { 64 | return rv.(KeyManager) 65 | } 66 | return nil 67 | } 68 | 69 | func setDB(r *http.Request, val KeyManager) { 70 | context.Set(r, dbContext, val) 71 | } 72 | 73 | func getOrInitializePrincipalContext(r *http.Request) auth.PrincipalContext { 74 | if ctx := context.Get(r, principalContext); ctx != nil { 75 | return ctx.(auth.PrincipalContext) 76 | } 77 | ctx := auth.NewPrincipalContext(r) 78 | context.Set(r, principalContext, ctx) 79 | return ctx 80 | } 81 | 82 | // GetRouteID gets the short form function name for the route being called. Used for logging/metrics. 83 | func GetRouteID(r *http.Request) string { 84 | if rv := context.Get(r, idContext); rv != nil { 85 | return rv.(string) 86 | } 87 | return "" 88 | } 89 | 90 | func setRouteID(r *http.Request, val string) { 91 | context.Set(r, idContext, val) 92 | } 93 | 94 | // AddHeader adds a HTTP header to the response 95 | func AddHeader(k, v string) func(http.HandlerFunc) http.HandlerFunc { 96 | return func(f http.HandlerFunc) http.HandlerFunc { 97 | return func(w http.ResponseWriter, r *http.Request) { 98 | w.Header().Set(k, v) 99 | f(w, r) 100 | } 101 | } 102 | } 103 | 104 | // Logger logs the request and response information in json format to the logger given. 105 | func Logger(logger *log.Logger) func(http.HandlerFunc) http.HandlerFunc { 106 | return func(f http.HandlerFunc) http.HandlerFunc { 107 | return func(w http.ResponseWriter, r *http.Request) { 108 | f(w, r) 109 | p := GetPrincipal(r) 110 | params := GetParams(r) 111 | apiError := GetAPIError(r) 112 | agent := r.Header.Get("User-Agent") 113 | if agent == "" { 114 | agent = "unknown" 115 | } 116 | e := &reqLog{ 117 | Type: "access", 118 | StatusCode: 200, 119 | Request: buildRequest(r, p, params), 120 | UserAgent: agent, 121 | } 122 | if apiError != nil { 123 | e.Code = apiError.Subcode 124 | e.StatusCode = HTTPErrMap[apiError.Subcode].Code 125 | e.Msg = apiError.Message 126 | } 127 | logger.OutputJSON(e) 128 | } 129 | } 130 | } 131 | 132 | type reqLog struct { 133 | Type string `json:"type"` 134 | Code int `json:"code"` 135 | StatusCode int `json:"status_code"` 136 | Request request `json:"request"` 137 | Msg string `json:"msg"` 138 | UserAgent string `json:"userAgent"` 139 | } 140 | 141 | type request struct { 142 | Method string `json:"method"` 143 | Path string `json:"path"` 144 | Parameters map[string]string `json:"parameters"` 145 | ParsedQuery map[string]string `json:"parsed_query_string"` 146 | Principal string `json:"principal"` 147 | FallbackPrincipals []string `json:"fallback_principals"` 148 | AuthType string `json:"auth_type"` 149 | RequestURI string `json:"request_uri"` 150 | RemoteAddr string `json:"remote_addr"` 151 | TLSServer string `json:"tls_server"` 152 | TLSCipher uint16 `json:"tls_cipher"` 153 | TLSVersion uint16 `json:"tls_version"` 154 | TLSResumed bool `json:"tls_resumed"` 155 | TLSUnique []byte `json:"tls_session_id"` 156 | } 157 | 158 | func scrub(params map[string]string) map[string]string { 159 | // Don't log any secret information (cause its secret) 160 | if _, ok := params["data"]; ok { 161 | params["data"] = "" 162 | } 163 | return params 164 | } 165 | 166 | func buildRequest(req *http.Request, p knox.Principal, params map[string]string) request { 167 | params = scrub(params) 168 | 169 | r := request{ 170 | Method: req.Method, 171 | Parameters: params, 172 | RemoteAddr: req.RemoteAddr, 173 | } 174 | if qs, ok := params["queryString"]; ok { 175 | keyMap, _ := url.ParseQuery(qs) 176 | m := map[string]string{} 177 | for k := range keyMap { 178 | for _, v := range keyMap[k] { 179 | m[k] = v 180 | } 181 | } 182 | r.ParsedQuery = m 183 | } 184 | if req.URL != nil { 185 | r.Path = req.URL.Path 186 | } 187 | if p != nil { 188 | r.Principal = p.GetID() 189 | r.AuthType = p.Type() 190 | if mux, ok := p.(knox.PrincipalMux); ok { 191 | r.FallbackPrincipals = mux.GetIDs() 192 | } 193 | } else { 194 | r.Principal = "" 195 | r.AuthType = "" 196 | } 197 | if req.TLS != nil { 198 | r.TLSServer = req.TLS.ServerName 199 | r.TLSCipher = req.TLS.CipherSuite 200 | r.TLSVersion = req.TLS.Version 201 | r.TLSResumed = req.TLS.DidResume 202 | r.TLSUnique = req.TLS.TLSUnique 203 | } 204 | return r 205 | } 206 | 207 | // ProviderMatcher is a function that determines whether or not the specified 208 | // authentication provider is suitable for the specified HTTP request. It is 209 | // expected to return a boolean value detailing whether or not the specified 210 | // provider is a match and is also expected to return any applicable 211 | // authentication payload that would then be passed to the provider. 212 | type ProviderMatcher func(provider auth.Provider, request *http.Request) (providerSupportsRequest bool, authenticationPayload string) 213 | 214 | // Authentication sets the principal or returns an error if the principal cannot be authenticated. 215 | func Authentication(providers []auth.Provider, matcher ProviderMatcher) func(http.HandlerFunc) http.HandlerFunc { 216 | if matcher == nil { 217 | matcher = providerMatch 218 | } 219 | 220 | return func(f http.HandlerFunc) http.HandlerFunc { 221 | return func(w http.ResponseWriter, r *http.Request) { 222 | var defaultPrincipal knox.Principal 223 | allPrincipals := map[string]knox.Principal{} 224 | errReturned := fmt.Errorf("No matching authentication providers found") 225 | 226 | for _, p := range providers { 227 | if match, payload := matcher(p, r); match { 228 | principal, errAuthenticate := p.Authenticate(payload, r) 229 | if errAuthenticate != nil { 230 | errReturned = errAuthenticate 231 | continue 232 | } 233 | if defaultPrincipal == nil { 234 | // First match is considered the default principal to use. 235 | defaultPrincipal = principal 236 | } 237 | 238 | // We record the name of the provider to be used in logging, so we can record 239 | // information about which provider authenticated which principal later on. 240 | allPrincipals[p.Name()] = principal 241 | } 242 | } 243 | if defaultPrincipal == nil { 244 | WriteErr(errF(knox.UnauthenticatedCode, errReturned.Error()))(w, r) 245 | return 246 | } 247 | 248 | SetPrincipal(r, knox.NewPrincipalMux(defaultPrincipal, allPrincipals)) 249 | f(w, r) 250 | return 251 | } 252 | } 253 | } 254 | 255 | func providerMatch(provider auth.Provider, request *http.Request) (providerSupportsRequest bool, payload string) { 256 | authorizationHeaderValue := request.Header.Get("Authorization") 257 | 258 | if len(authorizationHeaderValue) > 2 && authorizationHeaderValue[0] == provider.Version() && authorizationHeaderValue[1] == provider.Type() { 259 | return true, authorizationHeaderValue[2:] 260 | } 261 | return false, "" 262 | } 263 | 264 | func parseParams(parameters []Parameter) func(http.HandlerFunc) http.HandlerFunc { 265 | return func(f http.HandlerFunc) http.HandlerFunc { 266 | return func(w http.ResponseWriter, r *http.Request) { 267 | var ps = make(map[string]string) 268 | for _, p := range parameters { 269 | if s, ok := p.Get(r); ok { 270 | ps[p.Name()] = s 271 | } 272 | } 273 | setParams(r, ps) 274 | f(w, r) 275 | } 276 | } 277 | } 278 | 279 | func setupRoute(id string, m KeyManager) func(http.HandlerFunc) http.HandlerFunc { 280 | return func(f http.HandlerFunc) http.HandlerFunc { 281 | return func(w http.ResponseWriter, r *http.Request) { 282 | setDB(r, m) 283 | setRouteID(r, id) 284 | f(w, r) 285 | } 286 | } 287 | } 288 | -------------------------------------------------------------------------------- /server/key_manager.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pinterest/knox" 7 | "github.com/pinterest/knox/server/keydb" 8 | ) 9 | 10 | // KeyManager is the interface for logic related to managing keys. 11 | type KeyManager interface { 12 | GetAllKeyIDs() ([]string, error) 13 | GetUpdatedKeyIDs(map[string]string) ([]string, error) 14 | GetKey(id string, status knox.VersionStatus) (*knox.Key, error) 15 | AddNewKey(*knox.Key) error 16 | DeleteKey(id string) error 17 | UpdateAccess(string, ...knox.Access) error 18 | AddVersion(string, *knox.KeyVersion) error 19 | UpdateVersion(keyID string, versionID uint64, s knox.VersionStatus) error 20 | } 21 | 22 | // NewKeyManager builds a struct for interfacing with the keydb. 23 | func NewKeyManager(c keydb.Cryptor, db keydb.DB) KeyManager { 24 | return &keyManager{c, db} 25 | } 26 | 27 | type keyManager struct { 28 | cryptor keydb.Cryptor 29 | db keydb.DB 30 | } 31 | 32 | func (m *keyManager) GetAllKeyIDs() ([]string, error) { 33 | keys, err := m.db.GetAll() 34 | if err != nil { 35 | return nil, err 36 | } 37 | output := []string{} 38 | for _, k := range keys { 39 | output = append(output, k.ID) 40 | } 41 | return output, nil 42 | } 43 | 44 | func (m *keyManager) GetUpdatedKeyIDs(versions map[string]string) ([]string, error) { 45 | keys, err := m.db.GetAll() 46 | if err != nil { 47 | return nil, err 48 | } 49 | output := []string{} 50 | for _, k := range keys { 51 | if v, ok := versions[k.ID]; ok && k.VersionHash != v { 52 | output = append(output, k.ID) 53 | } 54 | } 55 | return output, nil 56 | } 57 | 58 | func (m *keyManager) GetKey(id string, status knox.VersionStatus) (*knox.Key, error) { 59 | encK, err := m.db.Get(id) 60 | if err != nil { 61 | return nil, err 62 | } 63 | k, err := m.cryptor.Decrypt(encK) 64 | if err != nil { 65 | return nil, fmt.Errorf("Error decrypting key: %s", err.Error()) 66 | } 67 | switch status { 68 | case knox.Inactive: 69 | return k, nil 70 | case knox.Active: 71 | k.VersionList = k.VersionList.GetActive() 72 | return k, nil 73 | case knox.Primary: 74 | k.VersionList = knox.KeyVersionList{*k.VersionList.GetPrimary()} 75 | return k, nil 76 | default: 77 | return nil, knox.ErrInvalidStatus 78 | } 79 | } 80 | 81 | func (m *keyManager) AddNewKey(k *knox.Key) error { 82 | if err := k.Validate(); err != nil { 83 | return err 84 | } 85 | dbk, err := m.cryptor.Encrypt(k) 86 | if err != nil { 87 | return err 88 | } 89 | return m.db.Add(dbk) 90 | } 91 | 92 | func (m *keyManager) DeleteKey(id string) error { 93 | return m.db.Remove(id) 94 | } 95 | 96 | func (m *keyManager) UpdateAccess(id string, acl ...knox.Access) error { 97 | encK, err := m.db.Get(id) 98 | if err != nil { 99 | return err 100 | } 101 | newEncK := encK.Copy() 102 | for _, a := range acl { 103 | newEncK.ACL = newEncK.ACL.Add(a) 104 | } 105 | err = newEncK.ACL.Validate() 106 | if err != nil { 107 | return err 108 | } 109 | return m.db.Update(newEncK) 110 | } 111 | 112 | func (m *keyManager) AddVersion(id string, v *knox.KeyVersion) error { 113 | encK, err := m.db.Get(id) 114 | if err != nil { 115 | return err 116 | } 117 | 118 | k, err := m.cryptor.Decrypt(encK) 119 | if err != nil { 120 | return fmt.Errorf("Error decrypting key: %s", err.Error()) 121 | } 122 | 123 | k.VersionList = append(k.VersionList, *v) 124 | k.VersionHash = k.VersionList.Hash() 125 | err = k.Validate() 126 | if err != nil { 127 | return err 128 | } 129 | encV, err := m.cryptor.EncryptVersion(k, v) 130 | if err != nil { 131 | return err 132 | } 133 | 134 | newEncK := encK.Copy() 135 | newEncK.VersionList = append(newEncK.VersionList, *encV) 136 | newEncK.VersionHash = k.VersionList.Hash() 137 | 138 | return m.db.Update(newEncK) 139 | } 140 | 141 | func (m *keyManager) UpdateVersion(keyID string, versionID uint64, s knox.VersionStatus) error { 142 | encK, err := m.db.Get(keyID) 143 | if err != nil { 144 | return err 145 | } 146 | k, err := m.cryptor.Decrypt(encK) 147 | if err != nil { 148 | return fmt.Errorf("Error decrypting key: %s", err.Error()) 149 | } 150 | // Validate the change makes sense 151 | kvl, err := k.VersionList.Update(versionID, s) 152 | if err != nil { 153 | return err 154 | } 155 | k.VersionHash = kvl.Hash() 156 | err = k.Validate() 157 | if err != nil { 158 | return err 159 | } 160 | newEncK := encK.Copy() 161 | for j, v := range newEncK.VersionList { 162 | for _, nv := range kvl { 163 | if v.ID == nv.ID { 164 | newEncK.VersionList[j].Status = nv.Status 165 | } 166 | } 167 | } 168 | newEncK.VersionHash = k.VersionHash 169 | return m.db.Update(newEncK) 170 | } 171 | -------------------------------------------------------------------------------- /server/key_manager_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "reflect" 5 | "sort" 6 | "testing" 7 | 8 | "github.com/pinterest/knox" 9 | "github.com/pinterest/knox/server/auth" 10 | "github.com/pinterest/knox/server/keydb" 11 | ) 12 | 13 | func GetMocks() (KeyManager, knox.Principal, knox.ACL) { 14 | db := keydb.NewTempDB() 15 | cryptor := keydb.NewAESGCMCryptor(10, []byte("testtesttesttest")) 16 | m := NewKeyManager(cryptor, db) 17 | acl := knox.ACL([]knox.Access{}) 18 | u := auth.NewUser("test", []string{}) 19 | return m, u, acl 20 | } 21 | 22 | type mockPrincipal struct { 23 | ID string 24 | } 25 | 26 | func (p mockPrincipal) CanAccess(a knox.ACL, t knox.AccessType) bool { 27 | return true 28 | } 29 | 30 | func (p mockPrincipal) GetID() string { 31 | return p.ID 32 | } 33 | 34 | func TestGetAllKeyIDs(t *testing.T) { 35 | m, u, acl := GetMocks() 36 | keys, err := m.GetAllKeyIDs() 37 | if err != nil { 38 | t.Fatalf("%s is not nil", err) 39 | } 40 | if len(keys) != 0 { 41 | t.Fatal("database should have no keys in it") 42 | } 43 | 44 | key1 := newKey("id1", acl, []byte("data"), u) 45 | m.AddNewKey(&key1) 46 | if err != nil { 47 | t.Fatalf("%s is not nil", err) 48 | } 49 | 50 | keys, err = m.GetAllKeyIDs() 51 | if err != nil { 52 | t.Fatal(err.Error()) 53 | } 54 | if len(keys) == 1 { 55 | if keys[0] != key1.ID { 56 | t.Fatalf("%s does not match %s", keys[0], key1.ID) 57 | } 58 | } else if len(keys) != 0 { 59 | t.Fatal("Unexpected # of keys in get all keys response") 60 | } 61 | 62 | key2 := newKey("id2", acl, []byte("data"), u) 63 | m.AddNewKey(&key2) 64 | if err != nil { 65 | t.Fatalf("%s is not nil", err) 66 | } 67 | 68 | keys, err = m.GetAllKeyIDs() 69 | if err != nil { 70 | t.Fatal(err.Error()) 71 | } 72 | if len(keys) == 2 { 73 | if keys[0] == key1.ID { 74 | if keys[1] != key2.ID { 75 | t.Fatalf("%s does not match %s", keys[1], key2.ID) 76 | } 77 | } else if keys[0] == key2.ID { 78 | if keys[1] != key2.ID { 79 | t.Fatalf("%s does not match %s", keys[1], key1.ID) 80 | } 81 | } else { 82 | t.Fatal("Unexpected key ID returned") 83 | } 84 | } else if len(keys) != 1 { 85 | t.Fatal("Unexpected # of keys in get all keys response") 86 | } 87 | 88 | err = m.DeleteKey(key1.ID) 89 | if err != nil { 90 | t.Fatalf("%s is not nil", err) 91 | } 92 | keys, err = m.GetAllKeyIDs() 93 | if err != nil { 94 | t.Fatal(err.Error()) 95 | } 96 | if len(keys) == 1 { 97 | if keys[0] != key2.ID { 98 | t.Fatalf("%s does not match %s", keys[0], key2.ID) 99 | } 100 | } else if len(keys) != 2 { 101 | t.Fatal("Unexpected # of keys in get all keys response") 102 | } 103 | } 104 | 105 | func TestGetUpdatedKeyIDs(t *testing.T) { 106 | m, u, acl := GetMocks() 107 | keys, err := m.GetUpdatedKeyIDs(map[string]string{}) 108 | if err != nil { 109 | t.Fatalf("%s is not nil", err) 110 | } 111 | if len(keys) != 0 { 112 | t.Fatal("database should have no keys in it") 113 | } 114 | 115 | key1 := newKey("id1", acl, []byte("data"), u) 116 | m.AddNewKey(&key1) 117 | if err != nil { 118 | t.Fatalf("%s is not nil", err) 119 | } 120 | 121 | keys, err = m.GetUpdatedKeyIDs(map[string]string{key1.ID: "NOT_THE_HASH"}) 122 | if err != nil { 123 | t.Fatal(err.Error()) 124 | } 125 | if len(keys) == 1 { 126 | if keys[0] != key1.ID { 127 | t.Fatalf("%s does not match %s", keys[0], key1.ID) 128 | } 129 | } else if len(keys) != 0 { 130 | t.Fatal("Unexpected # of keys in get all keys response") 131 | } 132 | 133 | keys, err = m.GetUpdatedKeyIDs(map[string]string{key1.ID: key1.VersionHash}) 134 | if len(keys) != 0 { 135 | t.Fatal("database should have no keys in it") 136 | } 137 | 138 | key2 := newKey("id2", acl, []byte("data"), u) 139 | m.AddNewKey(&key2) 140 | if err != nil { 141 | t.Fatalf("%s is not nil", err) 142 | } 143 | 144 | keys, err = m.GetUpdatedKeyIDs(map[string]string{key2.ID: "NOT_THE_HASH"}) 145 | if err != nil { 146 | t.Fatal(err.Error()) 147 | } 148 | if len(keys) == 1 { 149 | if keys[0] != key2.ID { 150 | t.Fatalf("%s does not match %s", keys[0], key2.ID) 151 | } 152 | } else if len(keys) != 0 { 153 | t.Fatal("Unexpected # of keys in get all keys response") 154 | } 155 | 156 | keys, err = m.GetUpdatedKeyIDs(map[string]string{key2.ID: "NOT_THE_HASH", key1.ID: "NOT_THE_HASH"}) 157 | if len(keys) != 2 { 158 | t.Fatalf("Expect 2 keys not %d", len(keys)) 159 | } 160 | if keys[0] == key1.ID { 161 | if keys[1] != key2.ID { 162 | t.Fatalf("%s does not match %s", keys[1], key2.ID) 163 | } 164 | } else if keys[0] == key2.ID { 165 | if keys[1] != key1.ID { 166 | t.Fatalf("%s does not match %s", keys[1], key1.ID) 167 | } 168 | } else { 169 | t.Fatal("Unexpected key ID returned") 170 | } 171 | 172 | keys, err = m.GetUpdatedKeyIDs(map[string]string{key2.ID: key2.VersionHash, key1.ID: "NOT_THE_HASH"}) 173 | if len(keys) != 1 { 174 | t.Fatalf("Expect 1 key not %d", len(keys)) 175 | } 176 | if keys[0] != key1.ID { 177 | t.Fatalf("%s does not match %s", keys[0], key1.ID) 178 | } 179 | 180 | keys, err = m.GetUpdatedKeyIDs(map[string]string{key2.ID: key2.VersionHash, key1.ID: key1.VersionHash}) 181 | if len(keys) != 0 { 182 | t.Fatal("expected no keys") 183 | } 184 | 185 | } 186 | 187 | func TestAddNewKey(t *testing.T) { 188 | m, u, acl := GetMocks() 189 | key1 := newKey("id1", acl, []byte("data"), u) 190 | 191 | key, err := m.GetKey(key1.ID, knox.Active) 192 | if err == nil { 193 | t.Fatal("Should be an error") 194 | } 195 | 196 | err = m.AddNewKey(&key1) 197 | if err != nil { 198 | t.Fatalf("%s is not nil", err) 199 | } 200 | 201 | key, err = m.GetKey(key1.ID, knox.Active) 202 | if err != nil { 203 | t.Fatalf("%s is not nil", err) 204 | } 205 | if !reflect.DeepEqual(key, &key1) { 206 | t.Fatal("keys are not equal") 207 | } 208 | 209 | pKey, err := m.GetKey(key1.ID, knox.Primary) 210 | if err != nil { 211 | t.Fatalf("%s is not nil", err) 212 | } 213 | 214 | aKey, err := m.GetKey(key1.ID, knox.Active) 215 | if err != nil { 216 | t.Fatalf("%s is not nil", err) 217 | } 218 | if !reflect.DeepEqual(pKey, aKey) { 219 | t.Fatal("keys are not equal") 220 | } 221 | 222 | iKey, err := m.GetKey(key1.ID, knox.Inactive) 223 | if err != nil { 224 | t.Fatalf("%s is not nil", err) 225 | } 226 | if !reflect.DeepEqual(pKey, iKey) { 227 | t.Fatal("keys are not equal") 228 | } 229 | if !reflect.DeepEqual(iKey, aKey) { 230 | t.Fatal("keys are not equal") 231 | } 232 | 233 | err = m.DeleteKey(key1.ID) 234 | if err != nil { 235 | t.Fatalf("%s is not nil", err) 236 | } 237 | 238 | key, err = m.GetKey(key1.ID, knox.Active) 239 | if err == nil { 240 | t.Fatal("Should be an error") 241 | } 242 | } 243 | 244 | func TestUpdateAccess(t *testing.T) { 245 | m, u, acl := GetMocks() 246 | key1 := newKey("id1", acl, []byte("data"), u) 247 | access := knox.Access{Type: knox.User, ID: "grootan", AccessType: knox.Read} 248 | access2 := knox.Access{Type: knox.UserGroup, ID: "group", AccessType: knox.Write} 249 | access3 := knox.Access{Type: knox.Machine, ID: "machine", AccessType: knox.Read} 250 | err := m.UpdateAccess(key1.ID, access) 251 | if err == nil { 252 | t.Fatal("Should be an error") 253 | } 254 | 255 | err = m.AddNewKey(&key1) 256 | if err != nil { 257 | t.Fatalf("%s is not nil", err) 258 | } 259 | 260 | err = m.UpdateAccess(key1.ID, access) 261 | if err != nil { 262 | t.Fatalf("%s is not nil", err) 263 | } 264 | err = m.UpdateAccess(key1.ID, access2, access3) 265 | if err != nil { 266 | t.Fatalf("%s is not nil", err) 267 | } 268 | 269 | key, err := m.GetKey(key1.ID, knox.Active) 270 | if err != nil { 271 | t.Fatalf("%s is not nil", err) 272 | } 273 | if len(key.ACL) != 4 { 274 | t.Fatalf("%d acl rules instead of expected 4", len(key.ACL)) 275 | } 276 | for _, a := range key.ACL { 277 | switch a.ID { 278 | case access.ID: 279 | if access.Type != a.Type { 280 | t.Fatalf("%d does not equal %d", access.Type, a.Type) 281 | } 282 | if access.AccessType != a.AccessType { 283 | t.Fatalf("%d does not equal %d", access.AccessType, a.AccessType) 284 | } 285 | case access2.ID: 286 | if access2.Type != a.Type { 287 | t.Fatalf("%d does not equal %d", access2.Type, a.Type) 288 | } 289 | if access2.AccessType != a.AccessType { 290 | t.Fatalf("%d does not equal %d", access2.AccessType, a.AccessType) 291 | } 292 | case access3.ID: 293 | if access3.Type != a.Type { 294 | t.Fatalf("%d does not equal %d", access3.Type, a.Type) 295 | } 296 | if access3.AccessType != a.AccessType { 297 | t.Fatalf("%d does not equal %d", access3.AccessType, a.AccessType) 298 | } 299 | case u.GetID(): 300 | continue 301 | default: 302 | t.Fatalf("unknown acl value for key %v", a) 303 | } 304 | } 305 | } 306 | 307 | func TestAddUpdateVersion(t *testing.T) { 308 | m, u, acl := GetMocks() 309 | var key *knox.Key 310 | key1 := newKey("id1", acl, []byte("data"), u) 311 | kv := newKeyVersion([]byte("data2"), knox.Active) 312 | access := knox.Access{Type: knox.User, ID: "grootan", AccessType: knox.Read} 313 | err := m.UpdateAccess(key1.ID, access) 314 | if err == nil { 315 | t.Fatal("Should be an error") 316 | } 317 | 318 | err = m.AddNewKey(&key1) 319 | if err != nil { 320 | t.Fatalf("%s is not nil", err) 321 | } 322 | 323 | key, err = m.GetKey(key1.ID, knox.Active) 324 | if err != nil { 325 | t.Fatalf("%s is not nil", err) 326 | } 327 | if !reflect.DeepEqual(key, &key1) { 328 | t.Fatal("keys are not equal") 329 | } 330 | 331 | err = m.AddVersion(key1.ID, &kv) 332 | if err != nil { 333 | t.Fatalf("%s is not nil", err) 334 | } 335 | 336 | key, err = m.GetKey(key1.ID, knox.Active) 337 | if err != nil { 338 | t.Fatalf("%s is not nil", err) 339 | } 340 | if key.ID != key1.ID { 341 | t.Fatalf("%s does not equal %s", key.ID, key1.ID) 342 | } 343 | if len(key.VersionList) != 2 { 344 | t.Fatalf("%d does not equal %d", len(key.VersionList), 2) 345 | } 346 | if key.VersionHash == key1.VersionHash { 347 | t.Fatalf("%s does equal %s", key.VersionHash, key1.VersionHash) 348 | } 349 | sort.Sort(key.VersionList) 350 | sort.Sort(key1.VersionList) 351 | for _, kv1 := range key.VersionList { 352 | if kv1.Status == knox.Primary { 353 | if !reflect.DeepEqual(kv1, key1.VersionList[0]) { 354 | t.Fatal("primary versions are not equal") 355 | } 356 | } 357 | if kv1.Status == knox.Active { 358 | if !reflect.DeepEqual(kv1, kv) { 359 | t.Fatal("active versions are not equal") 360 | } 361 | } 362 | if kv1.Status == knox.Inactive { 363 | t.Fatal("No key versions should be inactive") 364 | } 365 | } 366 | 367 | err = m.UpdateVersion(key1.ID, kv.ID, knox.Primary) 368 | if err != nil { 369 | t.Fatalf("%s is not nil", err) 370 | } 371 | 372 | key, err = m.GetKey(key1.ID, knox.Active) 373 | if err != nil { 374 | t.Fatalf("%s is not nil", err) 375 | } 376 | if key.ID != key1.ID { 377 | t.Fatalf("%s does not equal %s", key.ID, key1.ID) 378 | } 379 | if key.VersionHash == key1.VersionHash { 380 | t.Fatalf("%s does equal %s", key.VersionHash, key1.VersionHash) 381 | } 382 | if len(key.VersionList) != 2 { 383 | t.Fatalf("%d does not equal %d", len(key.VersionList), 2) 384 | } 385 | sort.Sort(key.VersionList) 386 | kv1 := key.VersionList[0] 387 | if kv1.Status != knox.Primary { 388 | t.Fatalf("%d does equal %d", kv1.Status, knox.Primary) 389 | } 390 | if kv1.ID != kv.ID { 391 | t.Fatalf("%d does equal %d", kv1.ID, kv.ID) 392 | } 393 | if string(kv1.Data) != string(kv.Data) { 394 | t.Fatalf("%s does equal %s", string(kv1.Data), string(kv.Data)) 395 | } 396 | if kv1.CreationTime != kv.CreationTime { 397 | t.Fatalf("%d does equal %d", kv1.CreationTime, kv.CreationTime) 398 | } 399 | 400 | kv1 = key.VersionList[1] 401 | if kv1.Status != knox.Active { 402 | t.Fatalf("%d does equal %d", kv1.Status, knox.Primary) 403 | } 404 | if kv1.ID != key1.VersionList[0].ID { 405 | t.Fatalf("%d does equal %d", kv1.ID, key1.VersionList[0].ID) 406 | } 407 | if string(kv1.Data) != string(key1.VersionList[0].Data) { 408 | t.Fatalf("%s does equal %s", string(kv1.Data), string(key1.VersionList[0].Data)) 409 | } 410 | if kv1.CreationTime != key1.VersionList[0].CreationTime { 411 | t.Fatalf("%d does equal %d", kv1.CreationTime, key1.VersionList[0].CreationTime) 412 | } 413 | 414 | err = m.UpdateVersion(key1.ID, key1.VersionList[0].ID, knox.Inactive) 415 | if err != nil { 416 | t.Fatalf("%s is not nil", err) 417 | } 418 | 419 | key, err = m.GetKey(key1.ID, knox.Active) 420 | if err != nil { 421 | t.Fatalf("%s is not nil", err) 422 | } 423 | if key.ID != key1.ID { 424 | t.Fatalf("%s does not equal %s", key.ID, key1.ID) 425 | } 426 | if key.VersionHash == key1.VersionHash { 427 | t.Fatalf("%s does equal %s", key.VersionHash, key1.VersionHash) 428 | } 429 | if len(key.VersionList) != 1 { 430 | t.Fatalf("%d does not equal %d", len(key.VersionList), 1) 431 | } 432 | kv1 = key.VersionList[0] 433 | if kv1.Status != knox.Primary { 434 | t.Fatalf("%d does equal %d", kv1.Status, knox.Primary) 435 | } 436 | if kv1.ID != kv.ID { 437 | t.Fatalf("%d does equal %d", kv1.ID, kv.ID) 438 | } 439 | if string(kv1.Data) != string(kv.Data) { 440 | t.Fatalf("%s does equal %s", string(kv1.Data), string(kv.Data)) 441 | } 442 | if kv1.CreationTime != kv.CreationTime { 443 | t.Fatalf("%d does equal %d", kv1.CreationTime, kv.CreationTime) 444 | } 445 | } 446 | 447 | func TestGetInactiveKeyVersions(t *testing.T) { 448 | m, u, acl := GetMocks() 449 | 450 | keyOrig := newKey("id1", acl, []byte("data"), u) 451 | kv := newKeyVersion([]byte("data2"), knox.Active) 452 | 453 | // Create key and add version so we have two versions 454 | err := m.AddNewKey(&keyOrig) 455 | if err != nil { 456 | t.Fatalf("%s is not nil", err) 457 | } 458 | 459 | err = m.AddVersion(keyOrig.ID, &kv) 460 | if err != nil { 461 | t.Fatalf("%s is not nil", err) 462 | } 463 | 464 | // Get active versions and deactivate one of them 465 | key, err := m.GetKey(keyOrig.ID, knox.Active) 466 | if err != nil { 467 | t.Fatalf("%s is not nil", err) 468 | } 469 | 470 | kvID0 := key.VersionList[0].ID 471 | kvID1 := key.VersionList[1].ID 472 | 473 | // Deactivate one of these versions 474 | err = m.UpdateVersion(keyOrig.ID, kvID1, knox.Inactive) 475 | if err != nil { 476 | t.Fatalf("%s is not nil", err) 477 | } 478 | 479 | // Reading active key versions should now list only one version 480 | key, err = m.GetKey(keyOrig.ID, knox.Active) 481 | if err != nil { 482 | t.Fatalf("%s is not nil", err) 483 | } 484 | 485 | if len(key.VersionList) != 1 { 486 | t.Fatalf("Wanted one key version, got: %d", len(key.VersionList)) 487 | } 488 | if key.VersionList[0].ID != kvID0 { 489 | t.Fatalf("Inactive key id was listed as ctive") 490 | } 491 | 492 | // Reading active/inactive key versions should now list both 493 | key, err = m.GetKey(keyOrig.ID, knox.Inactive) 494 | if err != nil { 495 | t.Fatalf("%s is not nil", err) 496 | } 497 | 498 | if len(key.VersionList) != 2 { 499 | t.Fatalf("Wanted two key versions, got: %d", len(key.VersionList)) 500 | } 501 | } 502 | -------------------------------------------------------------------------------- /server/keydb/cryptor.go: -------------------------------------------------------------------------------- 1 | package keydb 2 | 3 | import ( 4 | "bytes" 5 | "crypto/aes" 6 | "crypto/cipher" 7 | "crypto/rand" 8 | "encoding/binary" 9 | "fmt" 10 | 11 | "github.com/pinterest/knox" 12 | ) 13 | 14 | var ErrCryptorVersion = fmt.Errorf("Cryptor version does not match") 15 | 16 | // Cryptor is an interface for converting a knox Key to a DB Key 17 | type Cryptor interface { 18 | Decrypt(*DBKey) (*knox.Key, error) 19 | Encrypt(*knox.Key) (*DBKey, error) 20 | EncryptVersion(*knox.Key, *knox.KeyVersion) (*EncKeyVersion, error) 21 | } 22 | 23 | // NewAESGCMCryptor creates a Cryptor that performs AES GCM AEAD encryption on key data. 24 | func NewAESGCMCryptor(version byte, keyData []byte) Cryptor { 25 | return &aesGCMCryptor{keyData, version} 26 | } 27 | 28 | // aesGCMCryptor does AES encryption, but does not include correct associated data. 29 | type aesGCMCryptor struct { 30 | keyData []byte 31 | version byte 32 | } 33 | 34 | func (c *aesGCMCryptor) EncryptVersion(k *knox.Key, v *knox.KeyVersion) (*EncKeyVersion, error) { 35 | b, err := aes.NewCipher(c.keyData) 36 | if err != nil { 37 | return nil, err 38 | } 39 | gcm, err := cipher.NewGCM(b) 40 | if err != nil { 41 | return nil, err 42 | } 43 | nonce := make([]byte, gcm.NonceSize()) 44 | if _, err := rand.Read(nonce); err != nil { 45 | return nil, err 46 | } 47 | 48 | ciphertext := gcm.Seal(nil, nonce, v.Data, c.generateAD(k.ID, v.ID, v.CreationTime)) 49 | 50 | return &EncKeyVersion{ 51 | ID: v.ID, 52 | EncData: ciphertext, 53 | Status: v.Status, 54 | CreationTime: v.CreationTime, 55 | CryptoMetadata: buildMetadata(c.version, nonce), 56 | }, nil 57 | } 58 | 59 | // generateAD generates the data to be signed with key version versionid|creationtime|keyid 60 | func (c *aesGCMCryptor) generateAD(kid string, vid uint64, creation int64) []byte { 61 | idBytes := make([]byte, binary.MaxVarintLen64) 62 | binary.PutUvarint(idBytes, vid) 63 | creationBytes := make([]byte, binary.MaxVarintLen64) 64 | binary.PutVarint(creationBytes, creation) 65 | 66 | b := bytes.NewBuffer(idBytes) 67 | b.Write(creationBytes) 68 | b.WriteString(kid) 69 | return b.Bytes() 70 | } 71 | 72 | func (c *aesGCMCryptor) decryptVersion(k *DBKey, v *EncKeyVersion) (*knox.KeyVersion, error) { 73 | md := aesCryptoMetadata(v.CryptoMetadata) 74 | if md.Version() != c.version { 75 | return nil, ErrCryptorVersion 76 | } 77 | b, err := aes.NewCipher(c.keyData) 78 | if err != nil { 79 | return nil, err 80 | } 81 | gcm, err := cipher.NewGCM(b) 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | plaintext, err := gcm.Open(nil, md.Nonce(), v.EncData, c.generateAD(k.ID, v.ID, v.CreationTime)) 87 | if err != nil { 88 | return nil, err 89 | } 90 | 91 | return &knox.KeyVersion{ 92 | ID: v.ID, 93 | Data: plaintext, 94 | Status: v.Status, 95 | CreationTime: v.CreationTime, 96 | }, nil 97 | } 98 | 99 | func (c *aesGCMCryptor) Encrypt(k *knox.Key) (*DBKey, error) { 100 | dbVersions := make([]EncKeyVersion, len(k.VersionList)) 101 | for i, v := range k.VersionList { 102 | dbv, err := c.EncryptVersion(k, &v) 103 | if err != nil { 104 | return nil, err 105 | } 106 | dbVersions[i] = *dbv 107 | } 108 | 109 | newKey := DBKey{ 110 | ID: k.ID, 111 | ACL: k.ACL, 112 | VersionList: dbVersions, 113 | VersionHash: k.VersionHash, 114 | } 115 | return &newKey, nil 116 | } 117 | 118 | func (c *aesGCMCryptor) Decrypt(k *DBKey) (*knox.Key, error) { 119 | versions := make([]knox.KeyVersion, len(k.VersionList)) 120 | for i, v := range k.VersionList { 121 | dbv, err := c.decryptVersion(k, &v) 122 | if err != nil { 123 | return nil, err 124 | } 125 | versions[i] = *dbv 126 | } 127 | 128 | newKey := knox.Key{ 129 | ID: k.ID, 130 | ACL: k.ACL, 131 | VersionList: versions, 132 | VersionHash: k.VersionHash, 133 | } 134 | return &newKey, nil 135 | } 136 | 137 | type aesCryptoMetadata []byte 138 | 139 | func (c aesCryptoMetadata) Version() byte { 140 | return c[0] 141 | } 142 | 143 | func (c aesCryptoMetadata) Nonce() []byte { 144 | return c[1:] 145 | } 146 | 147 | func buildMetadata(version byte, nonce []byte) aesCryptoMetadata { 148 | c := make([]byte, len(nonce)+1) 149 | c[0] = version 150 | copy(c[1:], nonce) 151 | return c 152 | } 153 | -------------------------------------------------------------------------------- /server/keydb/cryptor_test.go: -------------------------------------------------------------------------------- 1 | package keydb 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/pinterest/knox" 8 | ) 9 | 10 | var testSecret = []byte("testtesttesttest") 11 | 12 | func makeTestKey() *knox.Key { 13 | return &knox.Key{ 14 | ID: "testID", 15 | ACL: knox.ACL([]knox.Access{{Type: knox.User, ID: "testUser", AccessType: knox.Read}}), 16 | VersionList: knox.KeyVersionList([]knox.KeyVersion{makeTestVersion()}), 17 | VersionHash: "testHash", 18 | } 19 | } 20 | 21 | func makeTestVersion() knox.KeyVersion { 22 | return knox.KeyVersion{ 23 | ID: 12345, 24 | Data: []byte("data"), 25 | Status: knox.Primary, 26 | CreationTime: 1, 27 | } 28 | } 29 | 30 | func TestEncryptDecryptVersion(t *testing.T) { 31 | k := makeTestKey() 32 | dbKey := &DBKey{ 33 | ID: k.ID, 34 | ACL: k.ACL, 35 | VersionHash: k.VersionHash, 36 | } 37 | v := k.VersionList.GetPrimary() 38 | crypt := &aesGCMCryptor{testSecret, 10} 39 | encV, err := crypt.EncryptVersion(k, v) 40 | if err != nil { 41 | t.Fatalf("%s is not nil", err) 42 | } 43 | decV, err := crypt.decryptVersion(dbKey, encV) 44 | if err != nil { 45 | t.Fatalf("%s is not nil", err) 46 | } 47 | if !reflect.DeepEqual(decV, v) { 48 | t.Fatal("decrypted key does not equal key") 49 | } 50 | } 51 | 52 | func TestEncryptDecryptKey(t *testing.T) { 53 | k := makeTestKey() 54 | crypt := NewAESGCMCryptor(10, testSecret) 55 | encK, err := crypt.Encrypt(k) 56 | if err != nil { 57 | t.Fatalf("%s is not nil", err) 58 | } 59 | decK, err := crypt.Decrypt(encK) 60 | if err != nil { 61 | t.Fatalf("%s is not nil", err) 62 | } 63 | if !reflect.DeepEqual(decK, k) { 64 | t.Fatal("decrypted key does not equal key") 65 | } 66 | } 67 | 68 | func TestBadKeyData(t *testing.T) { 69 | k := makeTestKey() 70 | crypt := NewAESGCMCryptor(0, []byte("notAESlen")) 71 | _, err := crypt.Encrypt(k) 72 | if err == nil { 73 | t.Fatal("error is nil for a bad key") 74 | } 75 | 76 | _, err = crypt.Decrypt(&DBKey{VersionList: []EncKeyVersion{{CryptoMetadata: []byte{0}}}}) 77 | if err == nil { 78 | t.Fatal("error is nil for bad data") 79 | } 80 | } 81 | 82 | func TestBadCryptorVersion(t *testing.T) { 83 | k := makeTestKey() 84 | crypt := NewAESGCMCryptor(10, testSecret) 85 | encK, err := crypt.Encrypt(k) 86 | if err != nil { 87 | t.Fatalf("%s is not nil", err) 88 | } 89 | 90 | crypt2 := NewAESGCMCryptor(1, testSecret) 91 | _, err = crypt2.Decrypt(encK) 92 | if err == nil { 93 | t.Fatalf("err is nil on bad crypter version") 94 | } 95 | } 96 | 97 | func TestBadCiphertext(t *testing.T) { 98 | k := makeTestKey() 99 | crypt := NewAESGCMCryptor(10, testSecret) 100 | encK, err := crypt.Encrypt(k) 101 | if err != nil { 102 | t.Fatalf("%s is not nil", err) 103 | } 104 | encK.VersionList[0].EncData = []byte("invalidciphertext") 105 | _, err = crypt.Decrypt(encK) 106 | if err == nil { 107 | t.Fatal("error is nil for bad ciphertext") 108 | } 109 | } 110 | 111 | func TestAESMetadata(t *testing.T) { 112 | version := byte(1) 113 | nonce := []byte("abcd") 114 | cm := buildMetadata(version, nonce) 115 | if string(cm.Nonce()) != string(nonce) { 116 | t.Fatalf("nonces are not equal: %s expected: %s", string(cm.Nonce()), string(nonce)) 117 | } 118 | if cm.Version() != version { 119 | t.Fatalf("%d does not equal %d", cm.Version(), version) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /server/keydb/keydb.go: -------------------------------------------------------------------------------- 1 | package keydb 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/json" 6 | "fmt" 7 | "sync" 8 | "time" 9 | 10 | "github.com/pinterest/knox" 11 | ) 12 | 13 | var ErrDBVersion = fmt.Errorf("DB version does not match") 14 | 15 | // DBKey is a struct for the json serialization of keys in the database. 16 | type DBKey struct { 17 | ID string `json:"id"` 18 | ACL knox.ACL `json:"acl"` 19 | VersionList []EncKeyVersion `json:"versions"` 20 | VersionHash string `json:"hash"` 21 | // The version should be set by the db provider and is not part of the data. 22 | DBVersion int64 `json:"-"` 23 | } 24 | 25 | // Copy provides a deep copy of database keys so that VersionLists can be edited in a copy. 26 | func (k *DBKey) Copy() *DBKey { 27 | versionList := make([]EncKeyVersion, len(k.VersionList)) 28 | copy(versionList, k.VersionList) 29 | acl := make([]knox.Access, len(k.ACL)) 30 | copy(acl, k.ACL) 31 | return &DBKey{ 32 | ID: k.ID, 33 | ACL: acl, 34 | VersionList: versionList, 35 | VersionHash: k.VersionHash, 36 | DBVersion: k.DBVersion, 37 | } 38 | } 39 | 40 | // EncKeyVersion is a struct for encrypting key data 41 | type EncKeyVersion struct { 42 | ID uint64 `json:"id"` 43 | EncData []byte `json:"data"` 44 | Status knox.VersionStatus `json:"status"` 45 | CreationTime int64 `json:"ts"` 46 | CryptoMetadata []byte `json:"crypt"` 47 | } 48 | 49 | // DB is the underlying database connection that KeyDB uses for all of its operations. 50 | // 51 | // This interface should not contain any business logic and should only deal with formatting 52 | // and database specific logic. 53 | type DB interface { 54 | // Get returns the key specified by the ID. 55 | Get(id string) (*DBKey, error) 56 | // GetAll returns all of the keys in the database. 57 | GetAll() ([]DBKey, error) 58 | 59 | // Update makes an update to DBKey indexed by its ID. 60 | // It will fail if the key has been changed since the specified version. 61 | Update(key *DBKey) error 62 | // Add adds the key(s) to the DB (it will fail if the key id exists). 63 | Add(keys ...*DBKey) error 64 | // Remove permanently removes the key specified by the ID. 65 | Remove(id string) error 66 | } 67 | 68 | // NewTempDB creates a new TempDB with no data. 69 | func NewTempDB() DB { 70 | return &TempDB{} 71 | } 72 | 73 | // TempDB is an in memory DB that does no replication across servers and starts 74 | // out fresh everytime. It is written for testing and simple dev work. 75 | type TempDB struct { 76 | sync.RWMutex 77 | keys []DBKey 78 | err error 79 | } 80 | 81 | // SetError is used to set the error the TempDB for testing purposes. 82 | func (db *TempDB) SetError(err error) { 83 | db.Lock() 84 | defer db.Unlock() 85 | db.err = err 86 | } 87 | 88 | // Get gets stored db key from TempDB. 89 | func (db *TempDB) Get(id string) (*DBKey, error) { 90 | db.RLock() 91 | defer db.RUnlock() 92 | if db.err != nil { 93 | return nil, db.err 94 | } 95 | for _, k := range db.keys { 96 | if k.ID == id { 97 | return &k, nil 98 | } 99 | } 100 | return nil, knox.ErrKeyIDNotFound 101 | } 102 | 103 | // GetAll gets all keys from TempDB. 104 | func (db *TempDB) GetAll() ([]DBKey, error) { 105 | db.RLock() 106 | defer db.RUnlock() 107 | if db.err != nil { 108 | return nil, db.err 109 | } 110 | return db.keys, nil 111 | } 112 | 113 | // Update looks for an existing key and updates the key in the database. 114 | func (db *TempDB) Update(key *DBKey) error { 115 | db.Lock() 116 | defer db.Unlock() 117 | if db.err != nil { 118 | return db.err 119 | } 120 | for i, dbk := range db.keys { 121 | if dbk.ID == key.ID { 122 | if dbk.DBVersion != key.DBVersion { 123 | return ErrDBVersion 124 | } 125 | k := key.Copy() 126 | k.DBVersion = time.Now().UnixNano() 127 | db.keys[i] = *k 128 | return nil 129 | } 130 | } 131 | return knox.ErrKeyIDNotFound 132 | } 133 | 134 | // Add adds the key(s) to the DB (it will fail if the key id exists). 135 | func (db *TempDB) Add(keys ...*DBKey) error { 136 | db.Lock() 137 | defer db.Unlock() 138 | if db.err != nil { 139 | return db.err 140 | } 141 | for _, key := range keys { 142 | for _, oldK := range db.keys { 143 | if oldK.ID == key.ID { 144 | return knox.ErrKeyExists 145 | } 146 | } 147 | } 148 | for _, key := range keys { 149 | k := key.Copy() 150 | k.DBVersion = time.Now().UnixNano() 151 | 152 | db.keys = append(db.keys, *k) 153 | } 154 | return nil 155 | 156 | } 157 | 158 | // Remove will remove the key id from the database. 159 | func (db *TempDB) Remove(id string) error { 160 | db.Lock() 161 | defer db.Unlock() 162 | if db.err != nil { 163 | return db.err 164 | } 165 | for i, k := range db.keys { 166 | if k.ID == id { 167 | db.keys = append(db.keys[:i], db.keys[i+1:]...) 168 | return nil 169 | } 170 | } 171 | return knox.ErrKeyIDNotFound 172 | } 173 | 174 | // SQLDB provides a generic way to use SQL providers as Knox DBs. 175 | type SQLDB struct { 176 | getStmt *sql.Stmt 177 | getAllStmt *sql.Stmt 178 | UpdateStmt *sql.Stmt 179 | AddStmt *sql.Stmt 180 | RemoveStmt *sql.Stmt 181 | db sql.DB 182 | } 183 | 184 | var sqlCreateKeys = `CREATE TABLE IF NOT EXISTS secrets ( 185 | id VARCHAR(512) PRIMARY KEY, 186 | acl TEXT NOT NULL, 187 | version_hash TEXT NOT NULL, 188 | versions TEXT NOT NULL, 189 | last_updated BIGINT NOT NULL 190 | );` 191 | 192 | // NewPostgreSQLDB will create a SQLDB with the necessary statements for using postgres. 193 | func NewPostgreSQLDB(sqlDB *sql.DB) (DB, error) { 194 | db := &SQLDB{} 195 | var err error 196 | _, err = sqlDB.Exec(sqlCreateKeys) 197 | if err != nil { 198 | return nil, err 199 | } 200 | db.getStmt, err = sqlDB.Prepare("SELECT id, acl, version_hash, versions, last_updated FROM secrets WHERE id=$1") 201 | if err != nil { 202 | return nil, err 203 | } 204 | db.getAllStmt, err = sqlDB.Prepare("SELECT id, acl, version_hash, versions, last_updated FROM secrets") 205 | if err != nil { 206 | return nil, err 207 | } 208 | db.UpdateStmt, err = sqlDB.Prepare("UPDATE secrets SET versions=$1, version_hash=$2,last_updated=$3,acl=$4 WHERE id=$5 AND last_updated=$6") 209 | if err != nil { 210 | return nil, err 211 | } 212 | db.AddStmt, err = sqlDB.Prepare("INSERT INTO secrets (id, acl, versions, version_hash, last_updated) VALUES ($1,$2,$3,$4,$5)") 213 | if err != nil { 214 | return nil, err 215 | } 216 | db.RemoveStmt, err = sqlDB.Prepare("DELETE FROM secrets WHERE id=$1") 217 | if err != nil { 218 | return nil, err 219 | } 220 | return db, nil 221 | } 222 | 223 | // NewSQLDB creates a table and prepared statements suitable for mysql and sqlite databases. 224 | func NewSQLDB(sqlDB *sql.DB) (DB, error) { 225 | db := &SQLDB{} 226 | var err error 227 | _, err = sqlDB.Exec(sqlCreateKeys) 228 | if err != nil { 229 | return nil, err 230 | } 231 | db.getStmt, err = sqlDB.Prepare("SELECT id, acl, version_hash, versions, last_updated FROM secrets WHERE id=?") 232 | if err != nil { 233 | return nil, err 234 | } 235 | db.getAllStmt, err = sqlDB.Prepare("SELECT id, acl, version_hash, versions, last_updated FROM secrets") 236 | if err != nil { 237 | return nil, err 238 | } 239 | db.UpdateStmt, err = sqlDB.Prepare("UPDATE secrets SET versions=?, version_hash=?,last_updated=?,acl=? WHERE id=? AND last_updated=?") 240 | if err != nil { 241 | return nil, err 242 | } 243 | db.AddStmt, err = sqlDB.Prepare("INSERT INTO secrets (id, acl, versions, version_hash, last_updated) VALUES (?,?,?,?,?)") 244 | if err != nil { 245 | return nil, err 246 | } 247 | db.RemoveStmt, err = sqlDB.Prepare("DELETE FROM secrets WHERE id=?") 248 | if err != nil { 249 | return nil, err 250 | } 251 | return db, nil 252 | } 253 | 254 | // Get will return the key given its key ID. 255 | func (db *SQLDB) Get(id string) (*DBKey, error) { 256 | var key DBKey 257 | var acl, versions []byte 258 | err := db.getStmt.QueryRow(id).Scan(&key.ID, &acl, &key.VersionHash, &versions, &key.DBVersion) 259 | if err != nil { 260 | return nil, knox.ErrKeyIDNotFound 261 | } 262 | err = json.Unmarshal(acl, &key.ACL) 263 | if err != nil { 264 | return nil, err 265 | } 266 | err = json.Unmarshal(versions, &key.VersionList) 267 | if err != nil { 268 | return nil, err 269 | } 270 | return &key, nil 271 | } 272 | 273 | // GetAll returns all of the keys in the database. 274 | func (db *SQLDB) GetAll() ([]DBKey, error) { 275 | var keys []DBKey 276 | rows, err := db.getAllStmt.Query() 277 | if err != nil { 278 | return nil, err 279 | } 280 | for rows.Next() { 281 | var key DBKey 282 | var acl, versions []byte 283 | err := rows.Scan(&key.ID, &acl, &key.VersionHash, &versions, &key.DBVersion) 284 | if err != nil { 285 | return nil, err 286 | } 287 | err = json.Unmarshal(acl, &key.ACL) 288 | if err != nil { 289 | return nil, err 290 | } 291 | err = json.Unmarshal(versions, &key.VersionList) 292 | if err != nil { 293 | return nil, err 294 | } 295 | keys = append(keys, key) 296 | } 297 | err = rows.Err() 298 | if err != nil { 299 | return nil, err 300 | } 301 | return keys, nil 302 | } 303 | 304 | // Update makes an update to DBKey indexed by its ID. 305 | // It will fail if the key has been changed since the specified version. 306 | func (db *SQLDB) Update(key *DBKey) error { 307 | versions, err := json.Marshal(key.VersionList) 308 | if err != nil { 309 | return err 310 | } 311 | acl, err := json.Marshal(key.ACL) 312 | if err != nil { 313 | return err 314 | } 315 | updateTime := time.Now().UnixNano() 316 | r, err := db.UpdateStmt.Exec(versions, key.VersionHash, updateTime, acl, key.ID, key.DBVersion) 317 | if err != nil { 318 | return err 319 | } 320 | affected, err := r.RowsAffected() 321 | if err != nil { 322 | // This likely shouldn't return an error if rows affected is not implemented. 323 | return err 324 | } 325 | if affected == 0 { 326 | rs, err := db.getStmt.Query(key.ID) 327 | defer rs.Close() 328 | if err != nil { 329 | return err 330 | } 331 | if !rs.Next() { 332 | return knox.ErrKeyIDNotFound 333 | } 334 | return ErrDBVersion 335 | } 336 | return nil 337 | } 338 | 339 | // Add adds the key version (it will fail if the key id exists). 340 | func (db *SQLDB) Add(keys ...*DBKey) error { 341 | // For loop is the dumbest way to this; should refactor into one query/transaction. 342 | for _, key := range keys { 343 | versions, err := json.Marshal(key.VersionList) 344 | if err != nil { 345 | return err 346 | } 347 | acl, err := json.Marshal(key.ACL) 348 | if err != nil { 349 | return err 350 | } 351 | updateTime := time.Now().UnixNano() 352 | _, err = db.AddStmt.Exec(key.ID, acl, versions, key.VersionHash, updateTime) 353 | if err != nil { 354 | // Not sure how to properly differentiate here... 355 | return knox.ErrKeyExists 356 | } 357 | // Not checking rows affected because I assume the db will return an error on primary key collision. 358 | } 359 | return nil 360 | } 361 | 362 | // Remove permanently removes the key specified by the ID. 363 | func (db *SQLDB) Remove(id string) error { 364 | r, err := db.RemoveStmt.Exec(id) 365 | if err != nil { 366 | return err 367 | } 368 | affected, err := r.RowsAffected() 369 | if err != nil { 370 | // This likely shouldn't return an error if rows affected is not implemented. 371 | return err 372 | } 373 | if affected == 0 { 374 | return knox.ErrKeyIDNotFound 375 | } 376 | return nil 377 | } 378 | -------------------------------------------------------------------------------- /server/keydb/keydb_test.go: -------------------------------------------------------------------------------- 1 | package keydb 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | "time" 8 | 9 | "github.com/pinterest/knox" 10 | ) 11 | 12 | func newEncKeyVersion(d []byte, s knox.VersionStatus) EncKeyVersion { 13 | version := EncKeyVersion{} 14 | version.EncData = d 15 | version.Status = s 16 | version.CreationTime = time.Now().UnixNano() 17 | // This is only 63 bits of randomness, but it appears to be the fastest way. 18 | version.ID = uint64(rand.Int63()) 19 | return version 20 | } 21 | 22 | func newDBKey(id string, d []byte, version int64) DBKey { 23 | key := DBKey{} 24 | key.ID = id 25 | 26 | key.ACL = knox.ACL{} 27 | key.DBVersion = version 28 | 29 | key.VersionList = []EncKeyVersion{newEncKeyVersion(d, knox.Primary)} 30 | return key 31 | } 32 | 33 | func TestTemp(t *testing.T) { 34 | db := NewTempDB() 35 | timeout := 100 * time.Millisecond 36 | TesterAddGet(t, db, timeout) 37 | TesterAddUpdate(t, db, timeout) 38 | TesterAddRemove(t, db, timeout) 39 | } 40 | 41 | func TestDBCopy(t *testing.T) { 42 | a := knox.Access{} 43 | v := EncKeyVersion{} 44 | r := DBKey{ 45 | ID: "id1", 46 | ACL: []knox.Access{a}, 47 | VersionList: []EncKeyVersion{v}, 48 | VersionHash: "hash1", 49 | DBVersion: 1, 50 | } 51 | b := r.Copy() 52 | b.ID = "id2" 53 | if r.ID == b.ID { 54 | t.Error("Ids are equal after copy") 55 | } 56 | b.DBVersion = 2 57 | if r.DBVersion == b.DBVersion { 58 | t.Error("DBVersion are equal after copy") 59 | } 60 | b.VersionHash = "hash2" 61 | if r.VersionHash == b.VersionHash { 62 | t.Error("VersionHash are equal after copy") 63 | } 64 | b.ACL[0].ID = "pi" 65 | if r.ACL[0].ID == b.ACL[0].ID { 66 | t.Error("ACL[0].ID are equal after copy") 67 | } 68 | b.VersionList[0].ID = 17 69 | if r.VersionList[0].ID == b.VersionList[0].ID { 70 | t.Error("VersionList[0].ID are equal after copy") 71 | } 72 | 73 | } 74 | 75 | func TestTempErrs(t *testing.T) { 76 | db := &TempDB{} 77 | err := fmt.Errorf("Does not compute... EXTERMINATE! EXTERMINATE!") 78 | db.SetError(err) 79 | TesterErrs(t, db, err) 80 | } 81 | 82 | func TesterErrs(t *testing.T, db DB, expErr error) { 83 | k := newDBKey("TesterErrs1", []byte("ab"), 0) 84 | go func() { 85 | _, err := db.GetAll() 86 | if err != expErr { 87 | t.Errorf("%s does not equal %s", err, expErr) 88 | } 89 | }() 90 | go func() { 91 | err := db.Add(&k) 92 | if err != expErr { 93 | t.Errorf("%s does not equal %s", err, expErr) 94 | } 95 | }() 96 | go func() { 97 | err := db.Remove(k.ID) 98 | if err != expErr { 99 | t.Errorf("%s does not equal %s", err, expErr) 100 | } 101 | }() 102 | go func() { 103 | err := db.Update(&k) 104 | if err != expErr { 105 | t.Errorf("%s does not equal %s", err, expErr) 106 | } 107 | }() 108 | go func() { 109 | _, err := db.Get(k.ID) 110 | if err != expErr { 111 | t.Errorf("%s does not equal %s", err, expErr) 112 | } 113 | }() 114 | } 115 | 116 | func TesterAddGet(t *testing.T, db DB, timeout time.Duration) { 117 | origKeys, err := db.GetAll() 118 | if err != nil { 119 | t.Fatalf("%s not nil", err) 120 | } 121 | k := newDBKey("TestAddGet1", []byte("a"), 0) 122 | err = db.Add(&k) 123 | if err != nil { 124 | t.Fatalf("%s not nil", err) 125 | } 126 | complete := false 127 | timer := time.Tick(timeout) 128 | for !complete { 129 | select { 130 | case <-timer: 131 | t.Fatal("Timed out waiting TestAddGet1 to get added") 132 | case <-time.Tick(1 * time.Millisecond): 133 | newK, err := db.Get(k.ID) 134 | if err == nil { 135 | if newK.ID != k.ID { 136 | t.Fatalf("%s does not equal %s", newK.ID, k.ID) 137 | } 138 | if len(newK.VersionList) != 1 { 139 | t.Fatalf("%d does not equal 1", len(newK.VersionList)) 140 | } 141 | if newK.VersionList[0].EncData[0] != k.VersionList[0].EncData[0] { 142 | t.Fatalf("%c does not equal %c", newK.VersionList[0].EncData[0], k.VersionList[0].EncData[0]) 143 | } 144 | complete = true 145 | } else if err != knox.ErrKeyIDNotFound { 146 | t.Fatal(err) 147 | } 148 | } 149 | } 150 | keys, err := db.GetAll() 151 | if err != nil { 152 | t.Fatalf("%s not nil", err) 153 | } 154 | if len(keys) != len(origKeys)+1 { 155 | t.Fatal("key list length did not grow by 1") 156 | } 157 | 158 | err = db.Add(&k) 159 | if err != knox.ErrKeyExists { 160 | t.Fatalf("%s does not equal %s", err, knox.ErrKeyExists) 161 | } 162 | } 163 | 164 | func TesterAddUpdate(t *testing.T, db DB, timeout time.Duration) { 165 | _, err := db.GetAll() 166 | if err != nil { 167 | t.Fatalf("%s not nil", err) 168 | } 169 | k := newDBKey("TesterAddUpdate1", []byte("a"), 0) 170 | err = db.Update(&k) 171 | if err != knox.ErrKeyIDNotFound { 172 | t.Fatalf("%s does not equal %s", err, knox.ErrKeyIDNotFound) 173 | } 174 | err = db.Add(&k) 175 | if err != nil { 176 | t.Fatalf("%s not nil", err) 177 | } 178 | complete := false 179 | timer := time.Tick(timeout) 180 | var version int64 181 | for !complete { 182 | select { 183 | case <-timer: 184 | t.Fatal("Timed out waiting TesterAddUpdate1 to get added") 185 | case <-time.Tick(1 * time.Millisecond): 186 | newK, err := db.Get(k.ID) 187 | if err == nil { 188 | version = newK.DBVersion 189 | complete = true 190 | } else if err != knox.ErrKeyIDNotFound { 191 | t.Fatal(err) 192 | } 193 | } 194 | } 195 | if version == 0 { 196 | t.Fatal("version number did not initialize to non zero value") 197 | } 198 | err = db.Update(&k) 199 | if err != ErrDBVersion { 200 | t.Fatalf("%s does not equal %s", err, ErrDBVersion) 201 | } 202 | 203 | k.VersionList = append(k.VersionList, newEncKeyVersion([]byte("b"), knox.Active)) 204 | k.DBVersion = version 205 | err = db.Update(&k) 206 | if err != nil { 207 | t.Fatalf("%s not nil", err) 208 | } 209 | complete = false 210 | timer = time.Tick(timeout) 211 | for !complete { 212 | select { 213 | case <-timer: 214 | t.Fatal("Timed out waiting TesterAddUpdate1 to get added") 215 | case <-time.Tick(1 * time.Millisecond): 216 | newK, err := db.Get(k.ID) 217 | if err == nil && len(newK.VersionList) != 1 { 218 | if len(newK.VersionList) != 2 { 219 | t.Fatalf("%d does not equal 2", len(newK.VersionList)) 220 | } 221 | var pk, ak EncKeyVersion 222 | if newK.VersionList[0].Status == knox.Primary { 223 | pk = newK.VersionList[0] 224 | ak = newK.VersionList[1] 225 | } else { 226 | pk = newK.VersionList[1] 227 | ak = newK.VersionList[0] 228 | } 229 | if string(pk.EncData) != "a" { 230 | t.Fatalf("%s does not equal a", string(pk.EncData)) 231 | } 232 | if string(ak.EncData) != "b" { 233 | t.Fatalf("%s does not equal b", string(ak.EncData)) 234 | } 235 | version = newK.DBVersion 236 | complete = true 237 | } else if err != nil { 238 | t.Fatal(err) 239 | } 240 | } 241 | } 242 | } 243 | 244 | func TesterAddRemove(t *testing.T, db DB, timeout time.Duration) { 245 | _, err := db.GetAll() 246 | if err != nil { 247 | t.Fatalf("%s not nil", err) 248 | } 249 | k := newDBKey("TesterAddRemove1", []byte("a"), 0) 250 | err = db.Remove(k.ID) 251 | if err != knox.ErrKeyIDNotFound { 252 | t.Fatalf("%s does not equal %s", err, knox.ErrKeyIDNotFound) 253 | } 254 | err = db.Add(&k) 255 | if err != nil { 256 | t.Fatalf("%s not nil", err) 257 | } 258 | complete := false 259 | timer := time.Tick(timeout) 260 | for !complete { 261 | select { 262 | case <-timer: 263 | t.Fatal("Timed out waiting TestAddGet1 to get added") 264 | case <-time.Tick(1 * time.Millisecond): 265 | _, err := db.Get(k.ID) 266 | if err == nil { 267 | complete = true 268 | } else if err != knox.ErrKeyIDNotFound { 269 | t.Fatal(err) 270 | } 271 | } 272 | } 273 | err = db.Remove(k.ID) 274 | if err != nil { 275 | t.Fatalf("%s not nil", err) 276 | } 277 | complete = false 278 | timer = time.Tick(timeout) 279 | for !complete { 280 | select { 281 | case <-timer: 282 | t.Fatal("Timed out waiting TestAddGet1 to get added") 283 | case <-time.Tick(1 * time.Millisecond): 284 | _, err := db.Get(k.ID) 285 | if err == knox.ErrKeyIDNotFound { 286 | complete = true 287 | } else if err != knox.ErrKeyIDNotFound && err != nil { 288 | t.Fatal(err) 289 | } 290 | } 291 | } 292 | } 293 | --------------------------------------------------------------------------------