├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── container ├── asis.go ├── asis_test.go ├── basic.go ├── basic_test.go ├── container.go ├── container_test.go ├── errors.go ├── plain1.go ├── plain1_test.go ├── v1.go └── v1_test.go ├── engine ├── aes.go ├── benchmark_test.go ├── engine.go ├── engine_test.go ├── json.go ├── json_test.go └── main_test.go ├── keys ├── key.go ├── key_test.go ├── keychain.go ├── keychain_test.go └── main_test.go ├── main.go ├── proxy ├── backend.go ├── discovery.go ├── discovery_test.go ├── proxy.go ├── proxy_test.go ├── readonly.go ├── readonly_test.go ├── router.go └── router_test.go ├── proxystarter.go └── test.sh /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | /etcvault 3 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: false 3 | go: 4 | - 1.4 5 | 6 | install: 7 | - go get golang.org/x/tools/cmd/cover 8 | 9 | script: 10 | - ./test.sh 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Shota Fukumori (sora_h) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # etcvault - proxy for etcd, adding transparent encryption 2 | 3 | ## Features 4 | 5 | - Works as reverse proxy to etcd 6 | - Can discover other etcd members 7 | - Support etcd 2.0.x 8 | - Transparent value decryption for GET 9 | - Transparent value encryption for POST, PUT, PATCH 10 | - Multiple keys 11 | 12 | ## Motivation 13 | 14 | Maintaining multiple etcd clusters is hard. We wanted to use same etcd cluster for across services, entire our infrastructure. 15 | 16 | But currently etcd has no ACL like feature. All server can read any values even if it's not required for that server (e.g. credentials for different service). That's the reason why I developed Etcvault. 17 | 18 | And I know there's ongoing RFC for etcd, about ACL: https://github.com/coreos/etcd/blob/master/Documentation/rfc/api_security.md 19 | 20 | ## Example 21 | 22 | Generate key first. 23 | 24 | ``` 25 | $ mkdir /tmp/keychain 26 | $ etcvault keygen -save /tmp/keychain my-key 27 | ``` 28 | 29 | Start etcd and etcvault. 30 | 31 | ``` 32 | $ etcd -listen-client-urls http://127.0.0.1:2380 & 33 | $ etcvault start -listen http://127.0.0.1:2381 -initial-backends http://127.0.0.1:2379 -keychain /tmp/keychain & 34 | ``` 35 | 36 | Set plain text 37 | 38 | ``` 39 | $ etcdctl --peers http://127.0.0.1:2381 set greeting hello 40 | hello 41 | $ etcdctl get greeting 42 | hello 43 | ``` 44 | 45 | Try encryption/decryption 46 | 47 | ``` 48 | (this means encrypt "hello" with "my-key") 49 | $ etcdctl --peers http://127.0.0.1:2381 set greeting 'ETCVAULT::plain:my-key:hello::ETCVAULT' 50 | hello 51 | 52 | $ etcdctl --peers http://127.0.0.1:2381 get greeting 53 | hello 54 | 55 | (cannot read directly) 56 | $ etcdctl --peers http://127.0.0.1:2379 get greeting 57 | ETCVAULT::1:my-key::CMOAuEHp/gcbUFvRuQDDMtpIEl/MQ/2OeYT8sluZs8Fc+YjEalDGHzYSn5MM9FafD9fGMHg9ODPYKNk83i1xXZ9zRhKWeuvG8VrU0DlIQ0hdV3px2hDgJppQBYGfr7QVs/0CKaDFUpkMPuhp6dGkzJ+73ZllL3BTb5UjdW3yizYUB82Qs3fwEUZJnLTCvuejxzMF64weInQXnTBkVrt1Mq/QjBWVJvZty8vvAeEHDKo6n5NpgVlZrn48yVHdKWBzO2z5mQO4VK3MPfLUMPQgUsOBqqbUd4N/NjfxCmPL3cO+Y3FD4WiPvbKGGz6IjFnPr7MoWs8etV+vIC/33gOGSQ==::ETCVAULT 58 | ``` 59 | 60 | You can _transform_ `ETCVAULT::...::ETCVAULT` string to proper format using command 61 | 62 | ``` 63 | $ etcvault transform -keychain /tmp/keychain 'ETCVAULT::1:my-key::CMOAuEHp/gcbUFvRuQDDMtpIEl/MQ/2OeYT8sluZs8Fc+YjEalDGHzYSn5MM9FafD9fGMHg9ODPYKNk83i1xXZ9zRhKWeuvG8VrU0DlIQ0hdV3px2hDgJppQBYGfr7QVs/0CKaDFUpkMPuhp6dGkzJ+73ZllL3BTb5UjdW3yizYUB82Qs3fwEUZJnLTCvuejxzMF64weInQXnTBkVrt1Mq/QjBWVJvZty8vvAeEHDKo6n5NpgVlZrn48yVHdKWBzO2z5mQO4VK3MPfLUMPQgUsOBqqbUd4N/NjfxCmPL3cO+Y3FD4WiPvbKGGz6IjFnPr7MoWs8etV+vIC/33gOGSQ==::ETCVAULT' 64 | hello 65 | ``` 66 | 67 | ## Detailed Usage 68 | 69 | ### Generate keys 70 | 71 | ``` 72 | $ etcvault keygen NAME 73 | $ etcvault keygen -save /path/to/keychain/directory NAME 74 | ``` 75 | 76 | for more options, see help. 77 | 78 | ### Start proxy 79 | 80 | ``` 81 | $ etcvault start -keychain /path/to/keychain/directory -listen http://localhost:2381 -initial-backends http://etcd:2379 82 | ``` 83 | 84 | ## Options 85 | 86 | - `-listen`: URL to listen to. 87 | - `-advertise-url`: URL to advertise. Used for `/v2/members` and `/v2/machines` response. 88 | - `-keychain`: Path to directory contains key files 89 | 90 | ### Discovery options 91 | 92 | Must be present `-initial-backends` or `-discovery-srv`. Backends are discovered using etcd's API. 93 | 94 | - `-initial-backends`: etcd client URLs separated by comma. (e.g. `http://etcd-1:2379,http://etcd-2:2379,...`) 95 | - `-discovery-srv`: FQDN to look up `_etcd-server._tcp` and `_etcd-server-ssl._tcp` SRV records. 96 | 97 | ### TLS support 98 | 99 | etcvault supports HTTPS for both, transport with etcd and listening. 100 | 101 | #### Listen https 102 | 103 | just specify HTTPS url to `-listen` (e.g. `https://localhost:2381`). Valid certificate options are required. 104 | 105 | #### CA and key files 106 | 107 | - client: 108 | - `-client-ca-file` 109 | - Used to validate etcd client port's server certificate. 110 | - Also, when etcvault is listening HTTPS, and both `-listen-key-file` `-listen-cert-file` aren't present, this CA certificate will be used to validate etcvault's client certificate. 111 | - `-client-key-file`, `client-cert-file` 112 | - Used as client certificate to send to etcd client port. 113 | - Also, when etcvault is listening HTTPS, and both `-listen-key-file` `-listen-cert-file` aren't present, this certificate will be used as etcvault's server certificate. 114 | 115 | - listen: 116 | - `-listen-ca-file` 117 | - When present with `-listen-key-file` and `-listen-cert-file`, etcvault will validate its client's certificate using this CA file. 118 | - (only valid when `-listen-key-file` and `-listen-cert-file` are present) 119 | - `-listen-key-file`, `listen-cert-file` 120 | - When present, etcvault won't use `-client-*` for etcvault's TLS server. 121 | - This certificate is used for etcvault's server certificate 122 | 123 | - peer: 124 | - `-peer-ca-file` 125 | - Used to validate etcd peer port's server sertificate. 126 | - `-peer-key-file`, `peer-cert-file` 127 | - Used as client certificate to send to etcd peer port. 128 | - __Note:__ etcvault communicates with etcd peer ports when using `-discovery-srv` option. If you're not using it, you can omit `-peer-*`. 129 | 130 | ## Key distribution 131 | 132 | There's no best way to distribute keys. Try to do with your using server provisioning tools. 133 | 134 | Here's what file's required for encryption/decryption: 135 | 136 | - Hosts that only encryption 137 | - Place `${KEYCHAIND_DIR}/${KEY_NAME}.pub` 138 | - Hosts that can do decryption 139 | - Place `${KEYCHAIND_DIR}/${KEY_NAME}.pem` 140 | - `${KEY_NAME}.pub` is not necessary. 141 | 142 | 143 | ## FAQ 144 | 145 | ### Why etcvault communicate with etcd *peer* port? 146 | 147 | etcvault communicates with etcd peer port when you're using `-discovery-srv` option. Because SRV records are points to peer port. 148 | 149 | ## License 150 | 151 | MIT License 152 | -------------------------------------------------------------------------------- /container/asis.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | type Asis struct { 8 | Content string 9 | } 10 | 11 | func ParseAsis(str string) (*Asis, error) { 12 | basic, err := ParseBasic(str) 13 | if err != nil { 14 | return nil, err 15 | } 16 | 17 | if !(basic.Version == "asis") { 18 | return nil, ErrDifferentVersion 19 | } 20 | 21 | return &Asis{ 22 | Content: basic.Content, 23 | }, nil 24 | } 25 | 26 | func (container *Asis) Version() string { 27 | return "asis" 28 | } 29 | 30 | func (container *Asis) String() string { 31 | return fmt.Sprintf("ETCVAULT::asis:%s::ETCVAULT", container.Content) 32 | } 33 | -------------------------------------------------------------------------------- /container/asis_test.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestParseAsis(t *testing.T) { 8 | container, err := ParseAsis("ETCVAULT::asis:content::ETCVAULT") 9 | 10 | if err != nil { 11 | t.Errorf("unexpected err: %#v", err) 12 | } 13 | 14 | if container.Version() != "asis" { 15 | t.Errorf("unexpected container.Version: %#v", container.Version) 16 | } 17 | 18 | if container.Content != "content" { 19 | t.Errorf("unexpected container.Content: %#v", container.Content) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /container/basic.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type Basic struct { 9 | Version string 10 | Content string `json:"-"` 11 | } 12 | 13 | func ParseBasic(str string) (*Basic, error) { 14 | // ETCVAULT:::::ETCVAULT (at least 21 chars) 15 | if len(str) < 21 { 16 | return nil, ErrInvalid 17 | } 18 | if strings.Index(str, "ETCVAULT::") == -1 || strings.Index(str, "::ETCVAULT") != (len(str)-10) { 19 | return nil, ErrInvalid 20 | } 21 | inner := str[10 : len(str)-10] 22 | versionAndContent := strings.SplitN(inner, ":", 2) 23 | 24 | if len(versionAndContent) < 2 { 25 | return nil, ErrParse 26 | } 27 | 28 | return &Basic{ 29 | Version: versionAndContent[0], 30 | Content: versionAndContent[1], 31 | }, nil 32 | } 33 | 34 | func (container *Basic) String() string { 35 | return fmt.Sprintf("ETCVAULT::%s:%s::ETCVAULT", container.Version, container.Content) 36 | } 37 | -------------------------------------------------------------------------------- /container/basic_test.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestParseBasic(t *testing.T) { 8 | container, err := ParseBasic("ETCVAULT::42:foo::ETCVAULT") 9 | 10 | if err != nil { 11 | t.Errorf("unexpected err: %#v", err) 12 | } 13 | 14 | if container.Version != "42" { 15 | t.Errorf("unexpected container.Version: %#v", container.Version) 16 | } 17 | 18 | if container.Content != "foo" { 19 | t.Errorf("unexpected container.Content: %#v", container.Content) 20 | } 21 | } 22 | 23 | func TestParseBasicInvalid(t *testing.T) { 24 | _, err := ParseBasic("foo") 25 | 26 | if err != ErrInvalid { 27 | t.Errorf("unexpected err: %#v", err) 28 | } 29 | } 30 | 31 | func TestParseBasicNoTrail(t *testing.T) { 32 | _, err := ParseBasic("ETCVAULT::foo") 33 | 34 | if err != ErrInvalid { 35 | t.Errorf("unexpected err: %#v", err) 36 | } 37 | } 38 | 39 | func TestParseBasicNoHead(t *testing.T) { 40 | _, err := ParseBasic("foo::ETCVAULT") 41 | 42 | if err != ErrInvalid { 43 | t.Errorf("unexpected err: %#v", err) 44 | } 45 | } 46 | 47 | func TestParseBasicVersionOrContentMissing(t *testing.T) { 48 | _, err := ParseBasic("ETCVAULT::foo::ETCVAULT") 49 | 50 | if err != ErrParse { 51 | t.Errorf("unexpected err: %#v", err) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /container/container.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | type Container interface { 4 | Version() string 5 | String() string 6 | } 7 | 8 | func Parse(str string) (Container, error) { 9 | basic, err := ParseBasic(str) 10 | 11 | if err != nil { 12 | return nil, err 13 | } 14 | 15 | switch basic.Version { 16 | case "asis": 17 | return ParseAsis(str) 18 | case "1": 19 | return ParseV1(str) 20 | case "plain1", "plain": 21 | return ParsePlain1(str) 22 | default: 23 | return nil, ErrUnknownVersion 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /container/container_test.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestParseForV1(t *testing.T) { 9 | rawResult, err := Parse("ETCVAULT::1:key::aGVsbG8=::ETCVAULT") 10 | 11 | if err != nil { 12 | t.Errorf("unexpected error %#v", err) 13 | } 14 | 15 | result, ok := rawResult.(*V1) 16 | if !ok { 17 | t.Errorf("V1 has not returned") 18 | } 19 | 20 | if result.Version() != "1" { 21 | t.Errorf("unexpected version %#v", result.Version()) 22 | } 23 | 24 | if result.KeyName != "key" { 25 | t.Errorf("unexpected KeyName %#v", result.KeyName) 26 | } 27 | 28 | if result.ContentKey != nil { 29 | t.Errorf("unexpected ContentKey %#v", result.ContentKey) 30 | } 31 | 32 | if !bytes.Equal(result.Content, []byte("hello")) { 33 | t.Errorf("unexpected Content %#v == %#v", result.Content) 34 | } 35 | } 36 | 37 | func TestParseForPlain1(t *testing.T) { 38 | rawResult, err := Parse("ETCVAULT::plain1:key:helo::ETCVAULT") 39 | 40 | if err != nil { 41 | t.Errorf("unexpected error %#v", err) 42 | } 43 | 44 | result, ok := rawResult.(*Plain1) 45 | if !ok { 46 | t.Errorf("Plain1 has not returned") 47 | } 48 | 49 | if result.Version() != "plain1" { 50 | t.Errorf("unexpected version %#v", result.Version()) 51 | } 52 | 53 | if result.KeyName != "key" { 54 | t.Errorf("unexpected KeyName %#v", result.KeyName) 55 | } 56 | 57 | if result.Content != "helo" { 58 | t.Errorf("unexpected Content %#v", result.Content) 59 | } 60 | } 61 | 62 | func TestParseForAsis(t *testing.T) { 63 | rawResult, err := Parse("ETCVAULT::asis:helo::ETCVAULT") 64 | 65 | if err != nil { 66 | t.Errorf("unexpected error %#v", err) 67 | } 68 | 69 | result, ok := rawResult.(*Asis) 70 | if !ok { 71 | t.Errorf("Asis has not returned") 72 | } 73 | 74 | if result.Version() != "asis" { 75 | t.Errorf("unexpected version %#v", result.Version()) 76 | } 77 | 78 | if result.Content != "helo" { 79 | t.Errorf("unexpected Content %#v", result.Content) 80 | } 81 | } 82 | 83 | func TestParseForUnknown(t *testing.T) { 84 | result, err := Parse("ETCVAULT::unknown:XXX::ETCVAULT") 85 | 86 | if result != nil { 87 | t.Errorf("unexpected result %#v", result) 88 | } 89 | 90 | if err != ErrUnknownVersion { 91 | t.Errorf("unexpected error %#v", err) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /container/errors.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | var ErrParse = errors.New("couldn't parse") 8 | var ErrInvalid = errors.New("it's not in container form (invalid)") 9 | var ErrDifferentVersion = errors.New("it's in different version") 10 | var ErrUnknownVersion = errors.New("Unknown version") 11 | -------------------------------------------------------------------------------- /container/plain1.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type Plain1 struct { 9 | KeyName string 10 | Content string `json:"-"` 11 | } 12 | 13 | func ParsePlain1(str string) (*Plain1, error) { 14 | basic, err := ParseBasic(str) 15 | if err != nil { 16 | return nil, err 17 | } 18 | 19 | if !(basic.Version == "plain" || basic.Version == "plain1") { 20 | return nil, ErrDifferentVersion 21 | } 22 | 23 | keyAndContent := strings.SplitN(basic.Content, ":", 2) 24 | 25 | if len(keyAndContent) < 2 { 26 | return nil, ErrParse 27 | } 28 | 29 | return &Plain1{ 30 | KeyName: keyAndContent[0], 31 | Content: keyAndContent[1], 32 | }, nil 33 | } 34 | 35 | func (container *Plain1) Version() string { 36 | return "plain1" 37 | } 38 | 39 | func (container *Plain1) String() string { 40 | return fmt.Sprintf("ETCVAULT::plain:%s:%s::ETCVAULT", container.KeyName, container.Content) 41 | } 42 | -------------------------------------------------------------------------------- /container/plain1_test.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestParsePlain(t *testing.T) { 8 | container, err := ParsePlain1("ETCVAULT::plain:foo:content::ETCVAULT") 9 | 10 | if err != nil { 11 | t.Errorf("unexpected err: %#v", err) 12 | } 13 | 14 | if container.Version() != "plain1" { 15 | t.Errorf("unexpected container.Version: %#v", container.Version) 16 | } 17 | 18 | if container.KeyName != "foo" { 19 | t.Errorf("unexpected container.KeyName: %#v", container.KeyName) 20 | } 21 | 22 | if container.Content != "content" { 23 | t.Errorf("unexpected container.Content: %#v", container.Content) 24 | } 25 | } 26 | 27 | func TestParsePlain1(t *testing.T) { 28 | container, err := ParsePlain1("ETCVAULT::plain1:foo:content::ETCVAULT") 29 | 30 | if err != nil { 31 | t.Errorf("unexpected err: %#v", err) 32 | } 33 | 34 | if container.Version() != "plain1" { 35 | t.Errorf("unexpected container.Version: %#v", container.Version) 36 | } 37 | 38 | if container.KeyName != "foo" { 39 | t.Errorf("unexpected container.KeyName: %#v", container.KeyName) 40 | } 41 | 42 | if container.Content != "content" { 43 | t.Errorf("unexpected container.Content: %#v", container.Content) 44 | } 45 | } 46 | 47 | func TestParsePlain1Invalid(t *testing.T) { 48 | _, err := ParseBasic("foo") 49 | 50 | if err != ErrInvalid { 51 | t.Errorf("unexpected err: %#v", err) 52 | } 53 | } 54 | 55 | func TestParsePlain1NoTrail(t *testing.T) { 56 | _, err := ParseBasic("ETCVAULT::foo") 57 | 58 | if err != ErrInvalid { 59 | t.Errorf("unexpected err: %#v", err) 60 | } 61 | } 62 | 63 | func TestParsePlain1NoHead(t *testing.T) { 64 | _, err := ParseBasic("foo::ETCVAULT") 65 | 66 | if err != ErrInvalid { 67 | t.Errorf("unexpected err: %#v", err) 68 | } 69 | } 70 | 71 | func TestParsePlain1NoKeyOrContent(t *testing.T) { 72 | _, err := ParseBasic("ETCVAULT::foo::ETCVAULT") 73 | 74 | if err != ErrParse { 75 | t.Errorf("unexpected err: %#v", err) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /container/v1.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | type V1 struct { 10 | KeyName string 11 | ContentKey []byte `json:"-"` 12 | Content []byte `json:"-"` 13 | } 14 | 15 | func ParseV1(str string) (*V1, error) { 16 | basic, err := ParseBasic(str) 17 | if err != nil { 18 | return nil, err 19 | } 20 | 21 | if basic.Version != "1" { 22 | return nil, ErrDifferentVersion 23 | } 24 | 25 | parts := strings.SplitN(basic.Content, ":", 3) // key name, format, content 26 | 27 | if len(parts) < 3 { 28 | return nil, ErrParse 29 | } 30 | 31 | keyName := parts[0] 32 | format := parts[1] 33 | contentPart := parts[2] 34 | 35 | var contentKey []byte 36 | var content []byte 37 | if format == "long" { 38 | contentKeyAndContent := strings.SplitN(contentPart, ",", 2) 39 | if len(parts) < 2 { 40 | return nil, ErrParse 41 | } 42 | 43 | contentKey, err = base64.StdEncoding.DecodeString(contentKeyAndContent[0]) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | content, err = base64.StdEncoding.DecodeString(contentKeyAndContent[1]) 49 | if err != nil { 50 | return nil, err 51 | } 52 | } else { 53 | content, err = base64.StdEncoding.DecodeString(contentPart) 54 | if err != nil { 55 | return nil, err 56 | } 57 | } 58 | 59 | return &V1{ 60 | KeyName: keyName, 61 | ContentKey: contentKey, 62 | Content: content, 63 | }, nil 64 | } 65 | 66 | func (container *V1) Version() string { 67 | return "1" 68 | } 69 | 70 | func (container *V1) String() string { 71 | encodedContent := base64.StdEncoding.EncodeToString(container.Content) 72 | if container.ContentKey == nil { 73 | return fmt.Sprintf("ETCVAULT::1:%s::%s::ETCVAULT", container.KeyName, encodedContent) 74 | } else { 75 | encodedContentKey := base64.StdEncoding.EncodeToString(container.ContentKey) 76 | return fmt.Sprintf("ETCVAULT::1:%s:long:%s,%s::ETCVAULT", container.KeyName, encodedContentKey, encodedContent) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /container/v1_test.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestV1ParseShort(t *testing.T) { 9 | result, err := ParseV1("ETCVAULT::1:key::aGVsbG8=::ETCVAULT") 10 | 11 | if err != nil { 12 | t.Errorf("unexpected error %#v", err) 13 | } 14 | 15 | if result.Version() != "1" { 16 | t.Errorf("unexpected version %#v", result.Version()) 17 | } 18 | 19 | if result.KeyName != "key" { 20 | t.Errorf("unexpected KeyName %#v", result.KeyName) 21 | } 22 | 23 | if result.ContentKey != nil { 24 | t.Errorf("unexpected ContentKey %#v", result.ContentKey) 25 | } 26 | 27 | if !bytes.Equal(result.Content, []byte("hello")) { 28 | t.Errorf("unexpected Content %#v == %#v", result.Content) 29 | } 30 | } 31 | 32 | func TestV1ParseLong(t *testing.T) { 33 | result, err := ParseV1("ETCVAULT::1:key:long:aG9sYQ==,aGVsbG8=::ETCVAULT") 34 | 35 | if err != nil { 36 | t.Errorf("unexpected error %#v", err) 37 | } 38 | 39 | if result.Version() != "1" { 40 | t.Errorf("unexpected version %#v", result.Version()) 41 | } 42 | 43 | if result.KeyName != "key" { 44 | t.Errorf("unexpected KeyName %#v", result.KeyName) 45 | } 46 | 47 | if !bytes.Equal(result.ContentKey, []byte(`hola`)) { 48 | t.Errorf("unexpected ContentKey %#v", result.ContentKey) 49 | } 50 | 51 | if !bytes.Equal(result.Content, []byte(`hello`)) { 52 | t.Errorf("unexpected Content %#v", result.Content) 53 | } 54 | } 55 | 56 | func TestV1ParseInvalid(t *testing.T) { 57 | result, err := ParseV1("hello") 58 | 59 | if result != nil { 60 | t.Errorf("unexpected result %#v", result) 61 | } 62 | if err != ErrInvalid { 63 | t.Errorf("unexpected error %#v", err) 64 | } 65 | } 66 | 67 | func TestV1ParseError(t *testing.T) { 68 | result, err := ParseV1("ETCVAULT::1::ETCVAULT") 69 | 70 | if result != nil { 71 | t.Errorf("unexpected result %#v", result) 72 | } 73 | if err != ErrParse { 74 | t.Errorf("unexpected error %#v", err) 75 | } 76 | } 77 | 78 | func TestV1ParseNotV1(t *testing.T) { 79 | result, err := ParseV1("ETCVAULT::42:foo::ETCVAULT") 80 | 81 | if result != nil { 82 | t.Errorf("unexpected result %#v", result) 83 | } 84 | if err != ErrDifferentVersion { 85 | t.Errorf("unexpected error %#v", err) 86 | } 87 | } 88 | 89 | func TestV1StringShort(t *testing.T) { 90 | container := &V1{ 91 | KeyName: "key", 92 | Content: []byte("hello"), 93 | } 94 | 95 | result := container.String() 96 | 97 | if result != "ETCVAULT::1:key::aGVsbG8=::ETCVAULT" { 98 | t.Errorf("unexpected string %#v", result) 99 | } 100 | } 101 | 102 | func TestV1StringLong(t *testing.T) { 103 | container := &V1{ 104 | KeyName: "key", 105 | ContentKey: []byte("hola"), 106 | Content: []byte("hello"), 107 | } 108 | 109 | result := container.String() 110 | 111 | if result != "ETCVAULT::1:key:long:aG9sYQ==,aGVsbG8=::ETCVAULT" { 112 | t.Errorf("unexpected string %#v", result) 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /engine/aes.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | ciphers "crypto/cipher" 5 | "errors" 6 | ) 7 | 8 | var ErrInvalidPadding = errors.New("pkcs7 padding invalid") 9 | var ErrInvalidLength = errors.New("invalid length; it should be multiple of aes block size") 10 | 11 | func encryptAesWithPkcs7Padding(cipherPtr *ciphers.Block, origMsgPtr *[]byte) *[]byte { 12 | cipher := *cipherPtr 13 | blockSize := cipher.BlockSize() 14 | 15 | msg := *(addPkcs7Padding(cipher.BlockSize(), origMsgPtr)) 16 | encryptedMsg := make([]byte, len(msg)) 17 | buf := make([]byte, blockSize) 18 | for i := 0; i < len(msg); i += blockSize { 19 | beg, end := i, i+blockSize 20 | cipher.Encrypt(buf, msg[beg:end]) 21 | copy(encryptedMsg[beg:end], buf) 22 | } 23 | return &encryptedMsg 24 | } 25 | 26 | func decryptAesWithPkcs7Padding(cipherPtr *ciphers.Block, encryptedMsgPtr *[]byte) (*[]byte, error) { 27 | cipher := *cipherPtr 28 | encryptedMsg := *encryptedMsgPtr 29 | blockSize := cipher.BlockSize() 30 | 31 | if len(encryptedMsg)%blockSize != 0 { 32 | return nil, ErrInvalidLength 33 | } 34 | 35 | msg := make([]byte, len(encryptedMsg)) 36 | buf := make([]byte, blockSize) 37 | for i := 0; i < len(encryptedMsg); i += blockSize { 38 | beg, end := i, i+blockSize 39 | cipher.Decrypt(buf, encryptedMsg[beg:end]) 40 | copy(msg[beg:end], buf) 41 | } 42 | return removePkcs7Padding(cipher.BlockSize(), &msg) 43 | } 44 | 45 | func addPkcs7Padding(blockSize int, origMsgPtr *[]byte) *[]byte { 46 | origMsg := *origMsgPtr 47 | 48 | paddingLength := blockSize - (len(origMsg) % blockSize) 49 | 50 | msg := make([]byte, len(origMsg), len(origMsg)+paddingLength) 51 | copy(msg, origMsg) 52 | 53 | for i := 0; i < paddingLength; i++ { 54 | msg = append(msg, byte(paddingLength)) 55 | } 56 | 57 | return &msg 58 | } 59 | 60 | func removePkcs7Padding(blockSize int, paddedMsgPtr *[]byte) (*[]byte, error) { 61 | paddedMsg := *paddedMsgPtr 62 | // validate padding 63 | paddingLength := int(paddedMsg[len(paddedMsg)-1]) 64 | for _, padding := range paddedMsg[len(paddedMsg)-paddingLength : len(paddedMsg)] { 65 | if int(padding) != paddingLength { 66 | return nil, ErrInvalidPadding 67 | } 68 | } 69 | msg := paddedMsg[0 : len(paddedMsg)-paddingLength] 70 | 71 | return &msg, nil 72 | } 73 | -------------------------------------------------------------------------------- /engine/benchmark_test.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func BenchmarkEncryptV1Short(b *testing.B) { 8 | engine := NewEngine(testKeychain) 9 | b.ResetTimer() 10 | for i := 0; i < b.N; i++ { 11 | _, _ = engine.Transform("ETCVAULT::plain1:the-key:this text should be encrypted::ETCVAULT") 12 | } 13 | } 14 | 15 | func BenchmarkEncryptV1Long(b *testing.B) { 16 | engine := NewEngine(testKeychain) 17 | b.ResetTimer() 18 | for i := 0; i < b.N; i++ { 19 | _, _ = engine.Transform("ETCVAULT::plain1:the-key:this text should be encrypted aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa::ETCVAULT") 20 | } 21 | } 22 | 23 | func BenchmarkDecryptV1Short(b *testing.B) { 24 | engine := NewEngine(testKeychain) 25 | b.ResetTimer() 26 | for i := 0; i < b.N; i++ { 27 | _, _ = engine.Transform("ETCVAULT::1:the-key::oXKv3edU7AjUXK1+7+Ng7y5tjByLzMe8MRL2lCxlsE03pHS2AXnd3mvar5dkbgeTU4dY8lcMPYAqRGXi2y9YJ7MD+8vKpkORczLYOBTiSXY8cuttvWY+ffjeJMSsLiHn0tDdtjvCtshSBTe9vLz75yyW8J91DUm9CriHWtQhaXw=::ETCVAULT") 28 | } 29 | } 30 | 31 | func BenchmarkDecryptV1Long(b *testing.B) { 32 | engine := NewEngine(testKeychain) 33 | b.ResetTimer() 34 | for i := 0; i < b.N; i++ { 35 | _, _ = engine.Transform("ETCVAULT::1:the-key:long:JRrn3XxO/HJEu/xYblTkxooOGvFkvnHz4AyinTceZMI2ybRbS2TyoOS+fTGZTTdUMnQ0gKhqH/KsCBjtvW/lw+CXEXVooCmpRCRyVYJIu/FH+oarHIGkpDTeJruEVaL1Jlvo0gb9Ea4zeZuKSiabY+puoTHVCEm1sEN8pHE48xA=,6LaTIBRfKOMBfHq/2JaF/ooeVe97GLGe5gJB8DBYMI30q8mynk9DoMgDKX4ROoiUXatFhSS20hvIIZEUwt62qN7ksivXSb9OybZwU22h6Kw=::ETCVAULT") 36 | } 37 | } 38 | 39 | func BenchmarkPlain(b *testing.B) { 40 | engine := NewEngine(testKeychain) 41 | b.ResetTimer() 42 | for i := 0; i < b.N; i++ { 43 | _, _ = engine.Transform("i'm plain text.") 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /engine/engine.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/rand" 6 | "crypto/rsa" 7 | "crypto/sha256" 8 | "errors" 9 | "fmt" 10 | "github.com/sorah/etcvault/container" 11 | "github.com/sorah/etcvault/keys" 12 | ) 13 | 14 | var ErrNoPrivateKey = errors.New("no private key provided") 15 | var ErrTooShortKey = errors.New("key too short; couldn't generate 16, 24, and 32 bytes aes key") 16 | 17 | type Transformable interface { 18 | Transform(text string) (string, error) 19 | TransformEtcdJsonResponse(jsonData []byte) ([]byte, error) 20 | GetKeychain() *keys.Keychain 21 | } 22 | 23 | type Engine struct { 24 | Keychain *keys.Keychain 25 | } 26 | 27 | func NewEngine(keychain *keys.Keychain) *Engine { 28 | return &Engine{ 29 | Keychain: keychain, 30 | } 31 | } 32 | 33 | func (engine *Engine) GetKeychain() *keys.Keychain { 34 | return engine.Keychain 35 | } 36 | 37 | func (engine *Engine) Transform(text string) (string, error) { 38 | s, _, e := engine.TransformAndParse(text) 39 | return s, e 40 | } 41 | 42 | func (engine *Engine) TransformAndParse(text string) (string, container.Container, error) { 43 | // FIXME: test for this 44 | rawContainer, err := container.Parse(text) 45 | 46 | if err != nil { 47 | if err == container.ErrInvalid { 48 | return text, nil, nil 49 | } else { 50 | return "", nil, err 51 | } 52 | } 53 | 54 | switch c := rawContainer.(type) { 55 | case *container.Plain1: 56 | result, err := engine.TransformPlain1(c) 57 | return result, c, err 58 | case *container.Asis: 59 | result, err := engine.TransformAsis(c) 60 | return result, c, err 61 | case *container.V1: 62 | result, err := engine.TransformV1(c) 63 | return result, c, err 64 | } 65 | // shouldnt reach 66 | panic(fmt.Errorf("BUG: unsupported container type %#v", rawContainer)) 67 | } 68 | 69 | func (engine *Engine) TransformPlain1(c *container.Plain1) (string, error) { 70 | key, err := engine.Keychain.Find(c.KeyName) 71 | if err != nil { 72 | return "", err 73 | } 74 | 75 | encryptedContent, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, key.Public, []byte(c.Content), []byte{}) 76 | if err == rsa.ErrMessageTooLong { 77 | return engine.transformPlain1Long(key, c) 78 | } 79 | if err != nil { 80 | return "", err 81 | } 82 | 83 | result := &container.V1{ 84 | KeyName: key.Name, 85 | Content: encryptedContent, 86 | } 87 | 88 | return result.String(), nil 89 | } 90 | 91 | func (engine *Engine) TransformAsis(c *container.Asis) (string, error) { 92 | return c.Content, nil 93 | } 94 | 95 | func (engine *Engine) transformPlain1Long(key *keys.Key, c *container.Plain1) (string, error) { 96 | hash := sha256.New() 97 | 98 | rsaMaxLength := ((key.Public.N.BitLen() + 7) / 8) - (2 * hash.Size()) - 2 99 | contentKeyLength := 32 100 | if rsaMaxLength < contentKeyLength { 101 | contentKeyLength = 24 102 | } 103 | if rsaMaxLength < contentKeyLength { 104 | contentKeyLength = 16 105 | } 106 | if rsaMaxLength < contentKeyLength { 107 | return "", ErrTooShortKey 108 | } 109 | 110 | contentKey := make([]byte, contentKeyLength) 111 | if _, err := rand.Read(contentKey); err != nil { 112 | return "", err 113 | } 114 | 115 | encryptedContentKey, err := rsa.EncryptOAEP(hash, rand.Reader, key.Public, contentKey, []byte{}) 116 | if err != nil { 117 | return "", err 118 | } 119 | 120 | cipher, err := aes.NewCipher(contentKey) 121 | if err != nil { 122 | return "", err 123 | } 124 | 125 | content := []byte(c.Content) 126 | encryptedContent := *(encryptAesWithPkcs7Padding(&cipher, &content)) 127 | 128 | result := &container.V1{ 129 | KeyName: key.Name, 130 | ContentKey: encryptedContentKey, 131 | Content: encryptedContent, 132 | } 133 | return result.String(), nil 134 | } 135 | 136 | func (engine *Engine) TransformV1(c *container.V1) (string, error) { 137 | if c.ContentKey == nil { 138 | return engine.transformV1Short(c) 139 | } else { 140 | return engine.transformV1Long(c) 141 | } 142 | } 143 | 144 | func (engine *Engine) transformV1Short(c *container.V1) (string, error) { 145 | key, err := engine.Keychain.Find(c.KeyName) 146 | if err != nil { 147 | return "", err 148 | } 149 | if key.Private == nil { 150 | return "", ErrNoPrivateKey 151 | } 152 | 153 | hash := sha256.New() 154 | decryptedContent, err := rsa.DecryptOAEP(hash, rand.Reader, key.Private, c.Content, []byte{}) 155 | if err != nil { 156 | return "", err 157 | } 158 | return string(decryptedContent), nil 159 | } 160 | 161 | func (engine *Engine) transformV1Long(c *container.V1) (string, error) { 162 | key, err := engine.Keychain.Find(c.KeyName) 163 | if err != nil { 164 | return "", err 165 | } 166 | if key.Private == nil { 167 | return "", ErrNoPrivateKey 168 | } 169 | 170 | hash := sha256.New() 171 | decryptedContentKey, err := rsa.DecryptOAEP(hash, rand.Reader, key.Private, c.ContentKey, []byte{}) 172 | if err != nil { 173 | return "", err 174 | } 175 | 176 | aes, err := aes.NewCipher(decryptedContentKey) 177 | if err != nil { 178 | return "", err 179 | } 180 | 181 | decryptedContent, err := decryptAesWithPkcs7Padding(&aes, &c.Content) 182 | if err != nil { 183 | return "", nil 184 | } 185 | 186 | return string(*decryptedContent), nil 187 | } 188 | -------------------------------------------------------------------------------- /engine/engine_test.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestTransformPlainToPlain(t *testing.T) { 9 | engine := NewEngine(testKeychain) 10 | 11 | result, err := engine.Transform("plain text") 12 | 13 | if err != nil { 14 | t.Errorf("unexpected err: %#v", err) 15 | } 16 | if result != "plain text" { 17 | t.Errorf("unexpected result: %#v", result) 18 | } 19 | } 20 | 21 | func TestTransformPlainRoundtrip(t *testing.T) { 22 | engine := NewEngine(testKeychain) 23 | 24 | encryptedText, err := engine.Transform("ETCVAULT::plain:the-key:this text should be encrypted::ETCVAULT") 25 | if err != nil { 26 | t.Errorf("unexpected err: %#v", err) 27 | } 28 | if strings.Index(encryptedText, "this text should be encrypted") != -1 { 29 | t.Errorf("encrypted text contains original text: %#v", encryptedText) 30 | } 31 | if strings.Index(encryptedText, "ETCVAULT::1:the-key::") != 0 { 32 | t.Errorf("encrypted text unexpected: %#v", encryptedText) 33 | } 34 | 35 | plainText, err := engine.Transform(encryptedText) 36 | if err != nil { 37 | t.Errorf("2 unexpected err: %#v", err) 38 | } 39 | if plainText != "this text should be encrypted" { 40 | t.Errorf("unexpected result: %#v", plainText) 41 | } 42 | } 43 | 44 | func TestTransformV1RoundtripShort(t *testing.T) { 45 | engine := NewEngine(testKeychain) 46 | 47 | encryptedText, err := engine.Transform("ETCVAULT::plain:the-key:this text should be encrypted::ETCVAULT") 48 | if err != nil { 49 | t.Errorf("1 unexpected err: %#v", err) 50 | } 51 | if strings.Index(encryptedText, "this text should be encrypted") != -1 { 52 | t.Errorf("encrypted text contains original text: %#v", encryptedText) 53 | } 54 | if strings.Index(encryptedText, "ETCVAULT::1:the-key::") != 0 { 55 | t.Errorf("encrypted text unexpected: %#v", encryptedText) 56 | } 57 | 58 | plainText, err := engine.Transform(encryptedText) 59 | if err != nil { 60 | t.Errorf("2 unexpected err: %#v", err) 61 | } 62 | if plainText != "this text should be encrypted" { 63 | t.Errorf("unexpected result: %#v", plainText) 64 | } 65 | } 66 | 67 | func TestTransformV1DecryptionShort(t *testing.T) { 68 | engine := NewEngine(testKeychain) 69 | decryptedText, err := engine.Transform("ETCVAULT::1:the-key::oXKv3edU7AjUXK1+7+Ng7y5tjByLzMe8MRL2lCxlsE03pHS2AXnd3mvar5dkbgeTU4dY8lcMPYAqRGXi2y9YJ7MD+8vKpkORczLYOBTiSXY8cuttvWY+ffjeJMSsLiHn0tDdtjvCtshSBTe9vLz75yyW8J91DUm9CriHWtQhaXw=::ETCVAULT") 70 | 71 | if err != nil { 72 | t.Errorf("1 unexpected err: %#v", err) 73 | } 74 | if decryptedText != "this text should be encrypted" { 75 | t.Errorf("unexpected text %#v", decryptedText) 76 | } 77 | } 78 | 79 | func TestTransformV1RoundtripLong(t *testing.T) { 80 | engine := NewEngine(testKeychain) 81 | 82 | encryptedText, err := engine.Transform("ETCVAULT::plain:the-key:this text is too long so this should be long format aaaaaaaaaaaaaaaaaaaaaaaaaa::ETCVAULT") 83 | if err != nil { 84 | t.Errorf("1 unexpected err: %#v", err.Error()) 85 | } 86 | if strings.Index(encryptedText, "this text is too long so this should be long format aaaaaaaaaaaaaaaaaaaaaaaaaa") != -1 { 87 | t.Errorf("encrypted text contains original text: %#v", encryptedText) 88 | } 89 | if strings.Index(encryptedText, "ETCVAULT::1:the-key:long:") != 0 { 90 | t.Errorf("encrypted text unexpected: %#v", encryptedText) 91 | } 92 | 93 | plainText, err := engine.Transform(encryptedText) 94 | if err != nil { 95 | t.Errorf("2 unexpected err: %#v", err) 96 | } 97 | if plainText != "this text is too long so this should be long format aaaaaaaaaaaaaaaaaaaaaaaaaa" { 98 | t.Errorf("unexpected result: %#v", plainText) 99 | } 100 | } 101 | 102 | func TestTransformV1DecryptionLong(t *testing.T) { 103 | engine := NewEngine(testKeychain) 104 | decryptedText, err := engine.Transform("ETCVAULT::1:the-key:long:JRrn3XxO/HJEu/xYblTkxooOGvFkvnHz4AyinTceZMI2ybRbS2TyoOS+fTGZTTdUMnQ0gKhqH/KsCBjtvW/lw+CXEXVooCmpRCRyVYJIu/FH+oarHIGkpDTeJruEVaL1Jlvo0gb9Ea4zeZuKSiabY+puoTHVCEm1sEN8pHE48xA=,6LaTIBRfKOMBfHq/2JaF/ooeVe97GLGe5gJB8DBYMI30q8mynk9DoMgDKX4ROoiUXatFhSS20hvIIZEUwt62qN7ksivXSb9OybZwU22h6Kw=::ETCVAULT") 105 | if err != nil { 106 | t.Errorf("1 unexpected err: %#v", err) 107 | } 108 | if decryptedText != "this text is too long so this should be long format aaaaaaaaaaaaaaaaaaaaaaaaaa" { 109 | t.Errorf("unexpected text %#v", decryptedText) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /engine/json.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | // transform node.value, node.**.nodes[].value, prevNode.value, prevNode.**.nodes[].value. 8 | func (engine *Engine) TransformEtcdJsonResponse(jsonData []byte) ([]byte, error) { 9 | var data interface{} 10 | json.Unmarshal(jsonData, &data) 11 | 12 | root, ok := data.(map[string]interface{}) 13 | if !ok { 14 | return jsonData, nil 15 | } 16 | 17 | if nodeRaw, ok := root["node"]; ok { 18 | if node, ok := nodeRaw.(map[string]interface{}); ok { 19 | engine.transformEtcdJsonResponse0(&node, 0) 20 | } 21 | } 22 | 23 | if nodeRaw, ok := root["prevNode"]; ok { 24 | if node, ok := nodeRaw.(map[string]interface{}); ok { 25 | engine.transformEtcdJsonResponse0(&node, 0) 26 | } 27 | } 28 | 29 | return json.Marshal(data) 30 | } 31 | 32 | func (engine *Engine) transformEtcdJsonResponse0(nodePtr *map[string]interface{}, depth int) { 33 | if depth > 100 { 34 | return 35 | } 36 | 37 | node := *nodePtr 38 | 39 | if value, ok := node["value"]; ok { 40 | if str, ok := value.(string); ok { 41 | newValue, container, err := engine.TransformAndParse(str) 42 | if err == nil { 43 | node["value"] = newValue 44 | if container != nil { 45 | node["_etcvault"] = map[string]interface{}{ 46 | "version": container.Version(), 47 | "container": container, 48 | } 49 | } 50 | } else { 51 | node["_etcvault_error"] = err.Error() 52 | } 53 | } 54 | } 55 | 56 | if nodesRaw, ok := node["nodes"]; ok { 57 | if nodes, ok := nodesRaw.([]interface{}); ok { 58 | for _, subNodeRaw := range nodes { 59 | subNode, ok := subNodeRaw.(map[string]interface{}) 60 | if !ok { 61 | continue 62 | } 63 | 64 | engine.transformEtcdJsonResponse0(&subNode, depth+1) 65 | } 66 | } 67 | } 68 | 69 | return 70 | } 71 | -------------------------------------------------------------------------------- /engine/json_test.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestTransformEtcdJsonResponse(t *testing.T) { 9 | tests := []struct { 10 | Name string 11 | Case []byte 12 | Expect []byte 13 | }{ 14 | { 15 | Name: "non-container (node.value)", 16 | Case: []byte(`{"node": {"value": "non-container"}}`), 17 | Expect: []byte(`{"node":{"value":"non-container"}}`), 18 | }, 19 | { 20 | Name: "plain (node.value)", 21 | Case: []byte(`{"node": {"value": "ETCVAULT::asis:plain::ETCVAULT"}}`), 22 | Expect: []byte(`{"node":{"_etcvault":{"container":{"Content":"plain"},"version":"asis"},"value":"plain"}}`), 23 | }, 24 | { 25 | Name: "plain (prevNode.value)", 26 | Case: []byte(`{"prevNode": {"value": "ETCVAULT::asis:plain::ETCVAULT"}}`), 27 | Expect: []byte(`{"prevNode":{"_etcvault":{"container":{"Content":"plain"},"version":"asis"},"value":"plain"}}`), 28 | }, 29 | { 30 | Name: "both (node.value, prevNode.value)", 31 | Case: []byte(`{"node": {"value": "ETCVAULT::asis:plain::ETCVAULT"}, "prevNode": {"value": "ETCVAULT::asis:plain::ETCVAULT"}}`), 32 | Expect: []byte(`{"node":{"_etcvault":{"container":{"Content":"plain"},"version":"asis"},"value":"plain"},"prevNode":{"_etcvault":{"container":{"Content":"plain"},"version":"asis"},"value":"plain"}}`), 33 | }, 34 | { 35 | Name: "inside directory (node.nodes[0].value)", 36 | Case: []byte(`{"node": {"nodes": [{"value": "ETCVAULT::asis:plain::ETCVAULT"}]}}`), 37 | Expect: []byte(`{"node":{"nodes":[{"_etcvault":{"container":{"Content":"plain"},"version":"asis"},"value":"plain"}]}}`), 38 | }, 39 | { 40 | Name: "inside directory, multiple (node.nodes[0].value, node.nodes[1].value)", 41 | Case: []byte(`{"node": {"nodes": [{"value": "ETCVAULT::asis:plain::ETCVAULT"}, {"value": "ETCVAULT::asis:plain::ETCVAULT"}]}}`), 42 | Expect: []byte(`{"node":{"nodes":[{"_etcvault":{"container":{"Content":"plain"},"version":"asis"},"value":"plain"},{"_etcvault":{"container":{"Content":"plain"},"version":"asis"},"value":"plain"}]}}`), 43 | }, 44 | { 45 | Name: "nested, inside directory (node.nodes[0].nodes[0].value)", 46 | Case: []byte(`{"node": {"nodes": [{"nodes": [{"value": "ETCVAULT::asis:plain::ETCVAULT"}]}]}}`), 47 | Expect: []byte(`{"node":{"nodes":[{"nodes":[{"_etcvault":{"container":{"Content":"plain"},"version":"asis"},"value":"plain"}]}]}}`), 48 | }, 49 | } 50 | 51 | engine := NewEngine(testKeychain) 52 | 53 | for _, test := range tests { 54 | transformedJson, err := engine.TransformEtcdJsonResponse(test.Case) 55 | 56 | if err != nil { 57 | t.Errorf("%s:\n\tunexpected err: %s", test.Name, err.Error()) 58 | } 59 | 60 | if !reflect.DeepEqual(transformedJson, test.Expect) { 61 | t.Errorf("%s:\n\tunexpected result: %s", test.Name, transformedJson) 62 | } 63 | } 64 | } 65 | 66 | func TestTransformEtcdJsonResponseFailures(t *testing.T) { 67 | tests := []struct { 68 | Name string 69 | Case []byte 70 | Expect []byte 71 | }{ 72 | { 73 | Name: "plain (node.value)", 74 | Case: []byte(`{"node": {"value": "ETCVAULT::plain1::ETCVAULT"}}`), 75 | Expect: []byte(`{"node":{"_etcvault_error":"couldn't parse","value":"ETCVAULT::plain1::ETCVAULT"}}`), 76 | }, 77 | { 78 | Name: "plain (prevNode.value)", 79 | Case: []byte(`{"prevNode": {"value": "ETCVAULT::plain1::ETCVAULT"}}`), 80 | Expect: []byte(`{"prevNode":{"_etcvault_error":"couldn't parse","value":"ETCVAULT::plain1::ETCVAULT"}}`), 81 | }, 82 | { 83 | Name: "both (node.value, prevNode.value)", 84 | Case: []byte(`{"node": {"value": "ETCVAULT::plain1::ETCVAULT"}, "prevNode": {"value": "ETCVAULT::plain1::ETCVAULT"}}`), 85 | Expect: []byte(`{"node":{"_etcvault_error":"couldn't parse","value":"ETCVAULT::plain1::ETCVAULT"},"prevNode":{"_etcvault_error":"couldn't parse","value":"ETCVAULT::plain1::ETCVAULT"}}`), 86 | }, 87 | { 88 | Name: "inside directory (node.nodes[0].value)", 89 | Case: []byte(`{"node": {"nodes": [{"value": "ETCVAULT::plain1::ETCVAULT"}]}}`), 90 | Expect: []byte(`{"node":{"nodes":[{"_etcvault_error":"couldn't parse","value":"ETCVAULT::plain1::ETCVAULT"}]}}`), 91 | }, 92 | { 93 | Name: "inside directory, multiple (node.nodes[0].value, node.nodes[1].value)", 94 | Case: []byte(`{"node": {"nodes": [{"value": "ETCVAULT::plain1::ETCVAULT"}, {"value": "ETCVAULT::plain1::ETCVAULT"}]}}`), 95 | Expect: []byte(`{"node":{"nodes":[{"_etcvault_error":"couldn't parse","value":"ETCVAULT::plain1::ETCVAULT"},{"_etcvault_error":"couldn't parse","value":"ETCVAULT::plain1::ETCVAULT"}]}}`), 96 | }, 97 | { 98 | Name: "nested, inside directory (node.nodes[0].nodes[0].value)", 99 | Case: []byte(`{"node": {"nodes": [{"nodes": [{"value": "ETCVAULT::plain1::ETCVAULT"}]}]}}`), 100 | Expect: []byte(`{"node":{"nodes":[{"nodes":[{"_etcvault_error":"couldn't parse","value":"ETCVAULT::plain1::ETCVAULT"}]}]}}`), 101 | }, 102 | } 103 | 104 | engine := NewEngine(testKeychain) 105 | 106 | for _, test := range tests { 107 | transformedJson, err := engine.TransformEtcdJsonResponse(test.Case) 108 | 109 | if err != nil { 110 | t.Errorf("%s:\n\tunexpected err: %s", test.Name, err.Error()) 111 | } 112 | 113 | if !reflect.DeepEqual(transformedJson, test.Expect) { 114 | t.Errorf("%s:\n\t expected result: %s\n\tunexpected result: %s", test.Name, test.Expect, transformedJson) 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /engine/main_test.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "github.com/sorah/etcvault/keys" 5 | "io/ioutil" 6 | "os" 7 | "path" 8 | "testing" 9 | ) 10 | 11 | var testRsaPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- 12 | Name: the-key 13 | 14 | MIICXAIBAAKBgQDE0H3AjeUvlOA5ueZ1q6hukF+aRFbW2h8qW2OIw88+EN4qLani 15 | lTvTUO3V91hGhHe2CnnUOey1iAHnSPGx66XW3oN/Wuk+wK1tg1ivcCLHIOlRu22g 16 | 8DuS8TC92jhjkFVCgGasXNFGECiyF6J9WsYrF6F/OKvUVpEjWgyRMPMMuQIDAQAB 17 | AoGAMOlbhyH8ZhHKk64GfxHU/v00NSNsrWJxwlYJ63A2LceFXtgQUzYhMwf2w2j/ 18 | 8C51jbEWy85FbGvLhU4UetIEWW0OK5Y+J2juGD0ez1FX+EzmiO+khpGtYQ6OY56a 19 | 3g4FPsUuCj1gw2oBDDQ2e38RyqY9Nj3PWo4H5Y7ZbSWwSQ0CQQDSNABnC7AiM2K3 20 | 5uXqZiXx68RoLrYtGkXhgyZBIUZ+g6nbhBqpPEI9pql55yCjmx/zeY6VVipOffO2 21 | EEUpdnG/AkEA77G9SK8lqxMeH+GRL70jYNXBqdxYhKrWlFzom+VrHIyo//limocH 22 | dPJiEEIyPJQXeru2r2mWxVg98q+j3CUvhwJAIzebKaiHpfM+Atmog5EBonqBuYK5 23 | +ux/8LxsWFUe3mtoteJ4JQp3fqTBmC7lBQQkYkJnZRW+mM/5WPN44u15OQJBAJPO 24 | Wbehcav9vPzR3vK+QjurdKHnI5qjsnCInlPL8/IF9wzp3tkFXR7LfJckCtB6TcQ8 25 | Ttn6VaPZ11F456WQNK8CQETVQARcp/v4bWtVHfJKyBcx92FkclVNXae5aHpmvIjI 26 | LUu9LpYOrkcaL1d7SFPhWZUsI+crYKuLAb9tXG/AnJY= 27 | -----END RSA PRIVATE KEY-----`) 28 | 29 | var testRsaPublicKey = []byte(`-----BEGIN RSA PUBLIC KEY----- 30 | Name: the-key 31 | 32 | MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDE0H3AjeUvlOA5ueZ1q6hukF+a 33 | RFbW2h8qW2OIw88+EN4qLanilTvTUO3V91hGhHe2CnnUOey1iAHnSPGx66XW3oN/ 34 | Wuk+wK1tg1ivcCLHIOlRu22g8DuS8TC92jhjkFVCgGasXNFGECiyF6J9WsYrF6F/ 35 | OKvUVpEjWgyRMPMMuQIDAQAB 36 | -----END RSA PUBLIC KEY-----`) 37 | 38 | var testKeychain *keys.Keychain 39 | 40 | func TestMain(m *testing.M) { 41 | tmpDir, err := ioutil.TempDir("", "engine_test") 42 | if err != nil { 43 | panic(err) 44 | } 45 | 46 | testKeychain = keys.NewKeychain(tmpDir) 47 | 48 | if err := ioutil.WriteFile(path.Join(testKeychain.Path, "the-key.pem"), testRsaPrivateKey, 0600); err != nil { 49 | panic(err) 50 | } 51 | if err := ioutil.WriteFile(path.Join(testKeychain.Path, "pubkey.pub"), testRsaPublicKey, 0644); err != nil { 52 | panic(err) 53 | } 54 | 55 | defer func() { 56 | if err := os.RemoveAll(testKeychain.Path); err != nil { 57 | panic(err) 58 | } 59 | }() 60 | 61 | os.Exit(m.Run()) 62 | } 63 | -------------------------------------------------------------------------------- /keys/key.go: -------------------------------------------------------------------------------- 1 | package keys 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "crypto/x509" 7 | "encoding/pem" 8 | "errors" 9 | "io/ioutil" 10 | ) 11 | 12 | var ErrMissingPem = errors.New("invalid pem (couldn't decode)") 13 | var ErrInvalidPem = errors.New("invalid pem (type should be RSA PUBLIC KEY, PUBLIC KEY, RSA PRIVATE KEY, or PRIVATE KEY)") 14 | var ErrNotRsaKey = errors.New("invalid pem (key is not RSA public key or private key)") 15 | 16 | type Key struct { 17 | Name string 18 | Public *rsa.PublicKey 19 | Private *rsa.PrivateKey 20 | } 21 | 22 | func NewPrivateKey(name string, rsaPrivateKey *rsa.PrivateKey) *Key { 23 | pubKey := rsaPrivateKey.Public().(*rsa.PublicKey) 24 | rsaPrivateKey.Precompute() 25 | 26 | return &Key{ 27 | Name: name, 28 | Public: pubKey, 29 | Private: rsaPrivateKey, 30 | } 31 | } 32 | 33 | func NewPublicKey(name string, rsaPublicKey *rsa.PublicKey) *Key { 34 | return &Key{ 35 | Name: name, 36 | Public: rsaPublicKey, 37 | } 38 | } 39 | 40 | func LoadKey(pemBytes []byte) (*Key, error) { 41 | pem, _ := pem.Decode(pemBytes) 42 | if pem == nil { 43 | return nil, ErrMissingPem 44 | } 45 | 46 | name := pem.Headers["Name"] 47 | 48 | switch pem.Type { 49 | case "PUBLIC KEY", "RSA PUBLIC KEY": 50 | parsedKey, err := x509.ParsePKIXPublicKey(pem.Bytes) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | var pubKey *rsa.PublicKey 56 | var ok bool 57 | if pubKey, ok = parsedKey.(*rsa.PublicKey); !ok { 58 | return nil, ErrNotRsaKey 59 | } 60 | 61 | return NewPublicKey(name, pubKey), nil 62 | 63 | case "PRIVATE KEY", "RSA PRIVATE KEY": 64 | privateKey, err := x509.ParsePKCS1PrivateKey(pem.Bytes) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | return NewPrivateKey(name, privateKey), nil 70 | 71 | default: 72 | return nil, ErrInvalidPem 73 | } 74 | } 75 | 76 | func LoadKeyFromFile(path string) (*Key, error) { 77 | bytes, err := ioutil.ReadFile(path) 78 | if err != nil { 79 | return nil, err 80 | } 81 | 82 | return LoadKey(bytes) 83 | } 84 | 85 | func GenerateKey(name string, bits int) (*Key, error) { 86 | rsaKey, err := rsa.GenerateKey(rand.Reader, bits) 87 | if err != nil { 88 | return nil, err 89 | } 90 | 91 | return NewPrivateKey(name, rsaKey), nil 92 | } 93 | 94 | func (key *Key) PublicPem() []byte { 95 | der, err := x509.MarshalPKIXPublicKey(key.Public) 96 | 97 | // normally MarshalPKIXPublicKey doesn't say error for rsa.PublicKey 98 | if err != nil { 99 | panic(err) 100 | } 101 | 102 | block := &pem.Block{ 103 | Type: "PUBLIC KEY", 104 | Headers: map[string]string{"Name": key.Name}, 105 | Bytes: der, 106 | } 107 | 108 | return pem.EncodeToMemory(block) 109 | } 110 | 111 | func (key *Key) PrivatePem() []byte { 112 | if key.Private == nil { 113 | return nil 114 | } 115 | 116 | der := x509.MarshalPKCS1PrivateKey(key.Private) 117 | block := &pem.Block{ 118 | Type: "PRIVATE KEY", 119 | Headers: map[string]string{"Name": key.Name}, 120 | Bytes: der, 121 | } 122 | 123 | return pem.EncodeToMemory(block) 124 | } 125 | -------------------------------------------------------------------------------- /keys/key_test.go: -------------------------------------------------------------------------------- 1 | package keys 2 | 3 | import ( 4 | "crypto/rsa" 5 | "crypto/x509" 6 | "encoding/pem" 7 | "reflect" 8 | "testing" 9 | ) 10 | 11 | func TestNewPrivateKey(t *testing.T) { 12 | key := NewPrivateKey("foo", &rsaKey) 13 | 14 | if key.Name != "foo" { 15 | t.Errorf("unexpected key.Name %#v", key.Name) 16 | } 17 | 18 | if key.Private.E != rsaKey.E { 19 | t.Errorf("unexpected key.Private %#v", key.Private) 20 | } 21 | 22 | if key.Public.E != rsaKey.Public().(*rsa.PublicKey).E { 23 | t.Errorf("unexpected key.Public", key.Public) 24 | } 25 | } 26 | 27 | func TestNewPublicKey(t *testing.T) { 28 | pubKey := rsaKey.Public().(*rsa.PublicKey) 29 | key := NewPublicKey("foo", pubKey) 30 | 31 | if key.Name != "foo" { 32 | t.Errorf("unexpected key.Name %#v", key.Name) 33 | } 34 | 35 | if key.Private != nil { 36 | t.Errorf("unexpected key.Private %#v", key.Private) 37 | } 38 | 39 | if key.Public.E != rsaKey.Public().(*rsa.PublicKey).E { 40 | t.Errorf("unexpected key.Public", key.Public) 41 | } 42 | } 43 | 44 | func TestLoadKeyPublic(t *testing.T) { 45 | key, err := LoadKey(testRsaPublicKey) 46 | 47 | if err != nil { 48 | t.Errorf("error %#v", err) 49 | } 50 | 51 | if key.Name != "the-key" { 52 | t.Errorf("unexpected key.Name %#v", key.Name) 53 | } 54 | 55 | if key.Private != nil { 56 | t.Errorf("unexpected key.Private %#v", key.Private) 57 | } 58 | 59 | if key.Public.E != rsaKey.Public().(*rsa.PublicKey).E { 60 | t.Errorf("unexpected key.Public", key.Public) 61 | } 62 | } 63 | 64 | func TestLoadKeyPublicNoHeader(t *testing.T) { 65 | key, err := LoadKey(testRsaPublicKeyNoHeader) 66 | 67 | if err != nil { 68 | t.Errorf("error %#v", err) 69 | } 70 | 71 | if key.Name != "" { 72 | t.Errorf("unexpected key.Name %#v", key.Name) 73 | } 74 | 75 | if key.Private != nil { 76 | t.Errorf("unexpected key.Private %#v", key.Private) 77 | } 78 | 79 | if key.Public.E != rsaKey.Public().(*rsa.PublicKey).E { 80 | t.Errorf("unexpected key.Public", key.Public) 81 | } 82 | } 83 | 84 | func TestLoadKeyPublicButNotRsa(t *testing.T) { 85 | key, err := LoadKey(testEcdsaPublicKey) 86 | 87 | if err != ErrNotRsaKey { 88 | t.Errorf("unexpected error %#v", err) 89 | } 90 | 91 | if key != nil { 92 | t.Errorf("unexpected key %#v", err) 93 | } 94 | } 95 | 96 | func TestLoadKeyPublicButInvalid(t *testing.T) { 97 | // broken 98 | key, err := LoadKey([]byte(`-----BEGIN RSA PUBLIC KEY----- 99 | Wbehcav9vPzR3vK+QjurdKHnI5qjsnCInlPL8/IF9wzp3tkFXR7LfJckCtB6TcQ8 100 | Ttn6VaPZ11F456WQNK8CQETVQARcp/v4bWtVHfJKyBcx92FkclVNXae5aHpmvIjI 101 | LUu9LpYOrkcaL1d7SFPhWZUsI+crYKuLAb9tXG/AnJY= 102 | -----END RSA PRIVATE KEY-----`)) 103 | 104 | if err == nil { 105 | t.Errorf("unexpected error %#v", err) 106 | } 107 | 108 | if key != nil { 109 | t.Errorf("unexpected key %#v", err) 110 | } 111 | } 112 | 113 | func TestLoadKeyPrivate(t *testing.T) { 114 | key, err := LoadKey(testRsaPrivateKey) 115 | 116 | if err != nil { 117 | t.Errorf("error %#v", err) 118 | } 119 | 120 | if key.Name != "the-key" { 121 | t.Errorf("unexpected key.Name %#v", key.Name) 122 | } 123 | 124 | if key.Private == nil { 125 | t.Errorf("unexpected key.Private %#v", key.Private) 126 | } 127 | 128 | if key.Private.E != rsaKey.E { 129 | t.Errorf("unexpected key.Private %#v", key.Private) 130 | } 131 | } 132 | 133 | func TestLoadKeyPrivateNoHeader(t *testing.T) { 134 | key, err := LoadKey(testRsaPrivateKeyNoHeader) 135 | 136 | if err != nil { 137 | t.Errorf("error %#v", err) 138 | } 139 | 140 | if key.Name != "" { 141 | t.Errorf("unexpected key.Name %#v", key.Name) 142 | } 143 | 144 | if key.Private == nil { 145 | t.Errorf("unexpected key.Private %#v", key.Private) 146 | } 147 | 148 | if key.Private.E != rsaKey.E { 149 | t.Errorf("unexpected key.Private %#v", key.Private) 150 | } 151 | } 152 | 153 | func TestLoadKeyPrivateButInvalid(t *testing.T) { 154 | // broken 155 | key, err := LoadKey([]byte(`-----BEGIN RSA PRIVATE KEY----- 156 | GSIb3DQEBAQUAA4GNADCBiQKBgQDE0H3AjeUvlOA5ueZ1q6hukF+aRFbW2h8qW2O 157 | Iw88+EN4qLanilTvTUO3V91hGhHe2CnnUOey1iAHnSPGx66XW3oNWuk+wK1tg1iv 158 | cCLHIOlRu22g8DuS8TC92jhjkFVCgGasXNFGECiyF6J9WsYrF6FOKvUVpEjWgyRM 159 | PMMuQIDAQAB 160 | -----END RSA PUBLIC KEY-----`)) 161 | 162 | if err == nil { 163 | t.Errorf("unexpected error %#v", err) 164 | } 165 | 166 | if key != nil { 167 | t.Errorf("unexpected key %#v", err) 168 | } 169 | } 170 | 171 | func TestLoadKeyMissing(t *testing.T) { 172 | key, err := LoadKey([]byte{}) 173 | 174 | if err != ErrMissingPem { 175 | t.Errorf("unexpected error %#v", err) 176 | } 177 | 178 | if key != nil { 179 | t.Errorf("unexpected key %#v", err) 180 | } 181 | } 182 | 183 | func TestLoadKeyInvalid(t *testing.T) { 184 | key, err := LoadKey([]byte(`-----BEGIN SOMETHING KEY----- 185 | PMMuQIDAQAB 186 | -----END SOMETHING KEY-----`)) 187 | if err != ErrMissingPem { 188 | t.Errorf("unexpected error %#v", err) 189 | } 190 | 191 | if key != nil { 192 | t.Errorf("unexpected key %#v", err) 193 | } 194 | } 195 | 196 | func TestGenerateKey(t *testing.T) { 197 | key, err := GenerateKey("foo", 1024) 198 | 199 | if err != nil { 200 | t.Errorf("error %#v", err) 201 | } 202 | 203 | if key.Name != "foo" { 204 | t.Errorf("unexpected key.Name %#v", key.Name) 205 | } 206 | } 207 | 208 | func TestPublicPem(t *testing.T) { 209 | key := NewPrivateKey("foo", &rsaKey) 210 | 211 | der, err := x509.MarshalPKIXPublicKey(key.Public) 212 | if err != nil { 213 | t.Errorf("err %#v", err) 214 | } 215 | 216 | pemBytes := key.PublicPem() 217 | 218 | pem, _ := pem.Decode(pemBytes) 219 | 220 | if pem == nil { 221 | t.Errorf("couldn't decode pem: %#v", pemBytes) 222 | } 223 | 224 | if pem.Type != "PUBLIC KEY" { 225 | t.Errorf("pem unexpected type %#v", pem.Type) 226 | } 227 | if pem.Headers["Name"] != "foo" { 228 | t.Errorf("pem unexpected Header['Name'] %#v", pem.Headers["Name"]) 229 | } 230 | if !reflect.DeepEqual(pem.Bytes, der) { 231 | t.Errorf("pem unexpected bytes %#v\nbut: %#v", der, pem.Bytes) 232 | } 233 | } 234 | 235 | func TestPrivatePem(t *testing.T) { 236 | key := NewPrivateKey("foo", &rsaKey) 237 | der := x509.MarshalPKCS1PrivateKey(key.Private) 238 | 239 | pemBytes := key.PrivatePem() 240 | 241 | pem, _ := pem.Decode(pemBytes) 242 | 243 | if pem == nil { 244 | t.Errorf("couldn't decode pem: %#v", pemBytes) 245 | } 246 | 247 | if pem.Type != "PRIVATE KEY" { 248 | t.Errorf("pem unexpected type %#v", pem.Type) 249 | } 250 | if pem.Headers["Name"] != "foo" { 251 | t.Errorf("pem unexpected Header['Name'] %#v", pem.Headers["Name"]) 252 | } 253 | if !reflect.DeepEqual(pem.Bytes, der) { 254 | t.Errorf("pem unexpected bytes %#v\nbut: %#v", der, pem.Bytes) 255 | } 256 | } 257 | 258 | func TestPrivatePemOnPublicKey(t *testing.T) { 259 | pubKey := rsaKey.Public().(*rsa.PublicKey) 260 | key := NewPublicKey("foo", pubKey) 261 | 262 | pemBytes := key.PrivatePem() 263 | 264 | if pemBytes != nil { 265 | t.Errorf("unexpected value %#v", pemBytes) 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /keys/keychain.go: -------------------------------------------------------------------------------- 1 | package keys 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "os" 9 | "path" 10 | "path/filepath" 11 | "strings" 12 | ) 13 | 14 | var ErrKeyNotFound = errors.New("couldn't find specified key") 15 | var ErrKeyAlreadyExists = errors.New("another key already exists with same name") 16 | 17 | type Keychain struct { 18 | Path string 19 | Cache map[string]*Key 20 | } 21 | 22 | func NewKeychain(path string) *Keychain { 23 | return &Keychain{ 24 | Path: path, 25 | Cache: make(map[string]*Key), 26 | } 27 | } 28 | 29 | func (keychain *Keychain) Find(name string) (*Key, error) { 30 | if key, ok := keychain.Cache[name]; ok { 31 | return key, nil 32 | } 33 | 34 | privateKeyPath := path.Join(keychain.Path, name+".pem") 35 | publicKeyPath := path.Join(keychain.Path, name+".pub") 36 | 37 | if _, err := os.Stat(privateKeyPath); err == nil { 38 | key, err := LoadKeyFromFile(privateKeyPath) 39 | if err != nil { 40 | return nil, err 41 | } 42 | keychain.Cache[name] = key 43 | return key, nil 44 | } else if _, err := os.Stat(publicKeyPath); err == nil { 45 | key, err := LoadKeyFromFile(publicKeyPath) 46 | if err != nil { 47 | return nil, err 48 | } 49 | return key, nil 50 | } else { 51 | return nil, ErrKeyNotFound 52 | } 53 | } 54 | 55 | func (keychain *Keychain) Save(key *Key) error { 56 | if _, err := keychain.Find(key.Name); err == nil { 57 | return ErrKeyAlreadyExists 58 | } 59 | if key.Private == nil { 60 | publicKeyPath := path.Join(keychain.Path, key.Name+".pub") 61 | return ioutil.WriteFile(publicKeyPath, key.PublicPem(), 0644) 62 | } else { 63 | privateKeyPath := path.Join(keychain.Path, key.Name+".pem") 64 | return ioutil.WriteFile(privateKeyPath, key.PrivatePem(), 0600) 65 | } 66 | return nil 67 | } 68 | 69 | func (keychain *Keychain) List() []string { 70 | namesMap := make(map[string]bool) 71 | 72 | addNames := func(ext string) { 73 | matches, err := filepath.Glob(path.Join(keychain.Path, fmt.Sprintf("*.%s", ext))) 74 | if err != nil { 75 | log.Printf("error looking for key list (%s): %s", ext, err.Error()) 76 | return 77 | } 78 | for _, keyPath := range matches { 79 | name := strings.TrimSuffix(path.Base(keyPath), fmt.Sprintf(".%s", ext)) 80 | namesMap[name] = true 81 | } 82 | } 83 | 84 | addNames("pub") 85 | addNames("pem") 86 | 87 | names := make([]string, 0, len(namesMap)) 88 | for name, _ := range namesMap { 89 | names = append(names, name) 90 | } 91 | return names 92 | } 93 | 94 | func (keychain *Keychain) ListForEncryption() []string { 95 | matches, err := filepath.Glob(path.Join(keychain.Path, "*.pem")) 96 | if err != nil { 97 | log.Printf("error looking for key list (pem): %s", err.Error()) 98 | return []string{} 99 | } 100 | names := make([]string, 0, len(matches)) 101 | for _, keyPath := range matches { 102 | name := strings.TrimSuffix(path.Base(keyPath), ".pem") 103 | names = append(names, name) 104 | } 105 | return names 106 | } 107 | 108 | func (keychain *Keychain) ListForDecryption() []string { 109 | return keychain.List() 110 | } 111 | -------------------------------------------------------------------------------- /keys/keychain_test.go: -------------------------------------------------------------------------------- 1 | package keys 2 | 3 | import ( 4 | "crypto/rsa" 5 | "io/ioutil" 6 | "os" 7 | "path" 8 | "reflect" 9 | "sort" 10 | "testing" 11 | ) 12 | 13 | // helpers 14 | 15 | func GetKeychain() *Keychain { 16 | tmpDir, err := ioutil.TempDir("", "keychain_test") 17 | if err != nil { 18 | panic(err) 19 | } 20 | return NewKeychain(tmpDir) 21 | } 22 | 23 | func DestroyKeychain(kc *Keychain) { 24 | if err := os.RemoveAll(kc.Path); err != nil { 25 | panic(err) 26 | } 27 | } 28 | 29 | // test 30 | 31 | func TestKeychainFindBothPrivateAndPublicKey(t *testing.T) { 32 | keychain := GetKeychain() 33 | defer DestroyKeychain(keychain) 34 | 35 | if err := ioutil.WriteFile(path.Join(keychain.Path, "the-key.pem"), testRsaPrivateKey, 0600); err != nil { 36 | panic(err) 37 | } 38 | if err := ioutil.WriteFile(path.Join(keychain.Path, "the-key.pub"), testRsaPublicKey, 0644); err != nil { 39 | panic(err) 40 | } 41 | 42 | key, err := keychain.Find("the-key") 43 | if err != nil { 44 | t.Errorf("unexpected error %#v", err) 45 | } 46 | 47 | if key.Name != "the-key" { 48 | t.Errorf("unexpected key.Name %#v", key.Name) 49 | } 50 | 51 | if key.Private.E != rsaKey.E { 52 | t.Errorf("unexpected key.Private %#v", key.Private) 53 | } 54 | 55 | if key.Public.E != rsaKey.Public().(*rsa.PublicKey).E { 56 | t.Errorf("unexpected key.Public %#v", key.Public) 57 | } 58 | } 59 | 60 | func TestKeychainFindPrivateKey(t *testing.T) { 61 | keychain := GetKeychain() 62 | defer DestroyKeychain(keychain) 63 | 64 | if err := ioutil.WriteFile(path.Join(keychain.Path, "the-key.pem"), testRsaPrivateKey, 0600); err != nil { 65 | panic(err) 66 | } 67 | 68 | key, err := keychain.Find("the-key") 69 | if err != nil { 70 | t.Errorf("unexpected error %#v", err) 71 | } 72 | 73 | if key.Name != "the-key" { 74 | t.Errorf("unexpected key.Name %#v", key.Name) 75 | } 76 | 77 | if key.Private.E != rsaKey.E { 78 | t.Errorf("unexpected key.Private %#v", key.Private) 79 | } 80 | 81 | if key.Public.E != rsaKey.Public().(*rsa.PublicKey).E { 82 | t.Errorf("unexpected key.Public %#v", key.Public) 83 | } 84 | } 85 | 86 | func TestKeychainFindPublicKey(t *testing.T) { 87 | keychain := GetKeychain() 88 | defer DestroyKeychain(keychain) 89 | 90 | if err := ioutil.WriteFile(path.Join(keychain.Path, "the-key.pub"), testRsaPublicKey, 0644); err != nil { 91 | panic(err) 92 | } 93 | 94 | key, err := keychain.Find("the-key") 95 | if err != nil { 96 | t.Errorf("unexpected error %#v", err) 97 | } 98 | 99 | if key.Name != "the-key" { 100 | t.Errorf("unexpected key.Name %#v", key.Name) 101 | } 102 | 103 | if key.Private != nil { 104 | t.Errorf("unexpected key.Private %#v", key.Private) 105 | } 106 | 107 | if key.Public.E != rsaKey.Public().(*rsa.PublicKey).E { 108 | t.Errorf("unexpected key.Public %#v", key.Public) 109 | } 110 | } 111 | 112 | func TestKeychainFindUnexist(t *testing.T) { 113 | keychain := GetKeychain() 114 | defer DestroyKeychain(keychain) 115 | 116 | _, err := keychain.Find("the-key") 117 | if err != ErrKeyNotFound { 118 | t.Errorf("unexpected error %#v", err) 119 | } 120 | } 121 | 122 | func TestKeychainSavePrivateKey(t *testing.T) { 123 | keychain := GetKeychain() 124 | defer DestroyKeychain(keychain) 125 | 126 | key, err := LoadKey(testRsaPrivateKey) 127 | if err != nil { 128 | panic(err) 129 | } 130 | 131 | key.Name = "new-key" 132 | 133 | err = keychain.Save(key) 134 | 135 | if err != nil { 136 | t.Errorf("unexpected error %#v", err.Error()) 137 | } 138 | 139 | filepath := path.Join(keychain.Path, "new-key.pem") 140 | 141 | if fi, err := os.Stat(filepath); err == nil { 142 | if fi.Mode() != 0600 { 143 | t.Errorf("unexpected file mode %i", fi.Mode()) 144 | } 145 | } else { 146 | t.Errorf("expected file stat fail:", err.Error()) 147 | } 148 | 149 | bytes, err := ioutil.ReadFile(filepath) 150 | if err != nil { 151 | t.Errorf("failed to read file %s", err.Error()) 152 | } 153 | 154 | if !reflect.DeepEqual(key.PrivatePem(), bytes) { 155 | t.Errorf("key file content unexpected %#v", bytes) 156 | } 157 | } 158 | 159 | func TestKeychainSavePublicKey(t *testing.T) { 160 | keychain := GetKeychain() 161 | defer DestroyKeychain(keychain) 162 | 163 | key, err := LoadKey(testRsaPublicKey) 164 | if err != nil { 165 | panic(err) 166 | } 167 | 168 | key.Name = "new-key" 169 | 170 | err = keychain.Save(key) 171 | 172 | if err != nil { 173 | t.Errorf("unexpected error %#v", err.Error()) 174 | } 175 | 176 | filepath := path.Join(keychain.Path, "new-key.pub") 177 | 178 | if _, err := os.Stat(filepath); err != nil { 179 | t.Errorf("expected file stat fail:", err.Error()) 180 | } 181 | 182 | bytes, err := ioutil.ReadFile(filepath) 183 | if err != nil { 184 | t.Errorf("failed to read file %s", err.Error()) 185 | } 186 | 187 | if !reflect.DeepEqual(key.PublicPem(), bytes) { 188 | t.Errorf("key file content unexpected %#v", bytes) 189 | } 190 | } 191 | 192 | func TestKeychainSaveAlreadyExist(t *testing.T) { 193 | keychain := GetKeychain() 194 | defer DestroyKeychain(keychain) 195 | 196 | if err := ioutil.WriteFile(path.Join(keychain.Path, "the-key.pem"), testRsaPrivateKey, 0600); err != nil { 197 | panic(err) 198 | } 199 | if err := ioutil.WriteFile(path.Join(keychain.Path, "the-key.pub"), testRsaPublicKey, 0644); err != nil { 200 | panic(err) 201 | } 202 | 203 | key, err := LoadKey(testRsaPrivateKey) 204 | if err != nil { 205 | panic(err) 206 | } 207 | key.Name = "the-key" 208 | 209 | err = keychain.Save(key) 210 | 211 | if err != ErrKeyAlreadyExists { 212 | t.Errorf("unexpected error %#v", err.Error()) 213 | } 214 | } 215 | 216 | func TestKeychainListKeys(t *testing.T) { 217 | keychain := GetKeychain() 218 | defer DestroyKeychain(keychain) 219 | 220 | if err := ioutil.WriteFile(path.Join(keychain.Path, "the-key.pem"), testRsaPrivateKey, 0600); err != nil { 221 | panic(err) 222 | } 223 | if err := ioutil.WriteFile(path.Join(keychain.Path, "the-key.pub"), testRsaPublicKey, 0644); err != nil { 224 | panic(err) 225 | } 226 | 227 | if err := ioutil.WriteFile(path.Join(keychain.Path, "privonly.pem"), testRsaPrivateKey, 0600); err != nil { 228 | panic(err) 229 | } 230 | 231 | if err := ioutil.WriteFile(path.Join(keychain.Path, "pubonly.pub"), testRsaPublicKey, 0644); err != nil { 232 | panic(err) 233 | } 234 | 235 | var list []string 236 | 237 | list = keychain.List() 238 | sort.Strings(list) 239 | if !reflect.DeepEqual(list, []string{"privonly", "pubonly", "the-key"}) { 240 | t.Errorf("unexpected List result: %#v", list) 241 | } 242 | 243 | list = keychain.ListForEncryption() 244 | sort.Strings(list) 245 | if !reflect.DeepEqual(list, []string{"privonly", "the-key"}) { 246 | t.Errorf("unexpected ListForEncryption result: %#v", list) 247 | } 248 | 249 | list = keychain.ListForDecryption() 250 | sort.Strings(list) 251 | if !reflect.DeepEqual(list, []string{"privonly", "pubonly", "the-key"}) { 252 | t.Errorf("unexpected ListForDecryption result: %#v", list) 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /keys/main_test.go: -------------------------------------------------------------------------------- 1 | package keys 2 | 3 | import ( 4 | "crypto/rsa" 5 | "crypto/x509" 6 | "encoding/pem" 7 | "os" 8 | "testing" 9 | ) 10 | 11 | var testRsaPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- 12 | Name: the-key 13 | 14 | MIICXAIBAAKBgQDE0H3AjeUvlOA5ueZ1q6hukF+aRFbW2h8qW2OIw88+EN4qLani 15 | lTvTUO3V91hGhHe2CnnUOey1iAHnSPGx66XW3oN/Wuk+wK1tg1ivcCLHIOlRu22g 16 | 8DuS8TC92jhjkFVCgGasXNFGECiyF6J9WsYrF6F/OKvUVpEjWgyRMPMMuQIDAQAB 17 | AoGAMOlbhyH8ZhHKk64GfxHU/v00NSNsrWJxwlYJ63A2LceFXtgQUzYhMwf2w2j/ 18 | 8C51jbEWy85FbGvLhU4UetIEWW0OK5Y+J2juGD0ez1FX+EzmiO+khpGtYQ6OY56a 19 | 3g4FPsUuCj1gw2oBDDQ2e38RyqY9Nj3PWo4H5Y7ZbSWwSQ0CQQDSNABnC7AiM2K3 20 | 5uXqZiXx68RoLrYtGkXhgyZBIUZ+g6nbhBqpPEI9pql55yCjmx/zeY6VVipOffO2 21 | EEUpdnG/AkEA77G9SK8lqxMeH+GRL70jYNXBqdxYhKrWlFzom+VrHIyo//limocH 22 | dPJiEEIyPJQXeru2r2mWxVg98q+j3CUvhwJAIzebKaiHpfM+Atmog5EBonqBuYK5 23 | +ux/8LxsWFUe3mtoteJ4JQp3fqTBmC7lBQQkYkJnZRW+mM/5WPN44u15OQJBAJPO 24 | Wbehcav9vPzR3vK+QjurdKHnI5qjsnCInlPL8/IF9wzp3tkFXR7LfJckCtB6TcQ8 25 | Ttn6VaPZ11F456WQNK8CQETVQARcp/v4bWtVHfJKyBcx92FkclVNXae5aHpmvIjI 26 | LUu9LpYOrkcaL1d7SFPhWZUsI+crYKuLAb9tXG/AnJY= 27 | -----END RSA PRIVATE KEY-----`) 28 | 29 | var testRsaPublicKey = []byte(`-----BEGIN RSA PUBLIC KEY----- 30 | Name: the-key 31 | 32 | MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDE0H3AjeUvlOA5ueZ1q6hukF+a 33 | RFbW2h8qW2OIw88+EN4qLanilTvTUO3V91hGhHe2CnnUOey1iAHnSPGx66XW3oN/ 34 | Wuk+wK1tg1ivcCLHIOlRu22g8DuS8TC92jhjkFVCgGasXNFGECiyF6J9WsYrF6F/ 35 | OKvUVpEjWgyRMPMMuQIDAQAB 36 | -----END RSA PUBLIC KEY-----`) 37 | 38 | var testRsaPrivateKeyNoHeader = []byte(`-----BEGIN RSA PRIVATE KEY----- 39 | MIICXAIBAAKBgQDE0H3AjeUvlOA5ueZ1q6hukF+aRFbW2h8qW2OIw88+EN4qLani 40 | lTvTUO3V91hGhHe2CnnUOey1iAHnSPGx66XW3oN/Wuk+wK1tg1ivcCLHIOlRu22g 41 | 8DuS8TC92jhjkFVCgGasXNFGECiyF6J9WsYrF6F/OKvUVpEjWgyRMPMMuQIDAQAB 42 | AoGAMOlbhyH8ZhHKk64GfxHU/v00NSNsrWJxwlYJ63A2LceFXtgQUzYhMwf2w2j/ 43 | 8C51jbEWy85FbGvLhU4UetIEWW0OK5Y+J2juGD0ez1FX+EzmiO+khpGtYQ6OY56a 44 | 3g4FPsUuCj1gw2oBDDQ2e38RyqY9Nj3PWo4H5Y7ZbSWwSQ0CQQDSNABnC7AiM2K3 45 | 5uXqZiXx68RoLrYtGkXhgyZBIUZ+g6nbhBqpPEI9pql55yCjmx/zeY6VVipOffO2 46 | EEUpdnG/AkEA77G9SK8lqxMeH+GRL70jYNXBqdxYhKrWlFzom+VrHIyo//limocH 47 | dPJiEEIyPJQXeru2r2mWxVg98q+j3CUvhwJAIzebKaiHpfM+Atmog5EBonqBuYK5 48 | +ux/8LxsWFUe3mtoteJ4JQp3fqTBmC7lBQQkYkJnZRW+mM/5WPN44u15OQJBAJPO 49 | Wbehcav9vPzR3vK+QjurdKHnI5qjsnCInlPL8/IF9wzp3tkFXR7LfJckCtB6TcQ8 50 | Ttn6VaPZ11F456WQNK8CQETVQARcp/v4bWtVHfJKyBcx92FkclVNXae5aHpmvIjI 51 | LUu9LpYOrkcaL1d7SFPhWZUsI+crYKuLAb9tXG/AnJY= 52 | -----END RSA PRIVATE KEY-----`) 53 | 54 | var testRsaPublicKeyNoHeader = []byte(`-----BEGIN RSA PUBLIC KEY----- 55 | MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDE0H3AjeUvlOA5ueZ1q6hukF+a 56 | RFbW2h8qW2OIw88+EN4qLanilTvTUO3V91hGhHe2CnnUOey1iAHnSPGx66XW3oN/ 57 | Wuk+wK1tg1ivcCLHIOlRu22g8DuS8TC92jhjkFVCgGasXNFGECiyF6J9WsYrF6F/ 58 | OKvUVpEjWgyRMPMMuQIDAQAB 59 | -----END RSA PUBLIC KEY-----`) 60 | 61 | var testEcdsaPrivateKey = []byte(`-----BEGIN PRIVATE KEY----- 62 | MGgCAQEEHF+pP6QjO+LH97mzJlaiqZ1y5DynKEjUSXy7hVSgBwYFK4EEACGhPAM6 63 | AASwJR+5yutBOBaKlxjheM+VPm4kfeXoxnjN85OHAfYeyEPS95kZZKqbpvX8d8NF 64 | Z4+YLPZEMaBs7g== 65 | -----END PRIVATE KEY-----`) 66 | 67 | var testEcdsaPublicKey = []byte(`-----BEGIN PUBLIC KEY----- 68 | ME4wEAYHKoZIzj0CAQYFK4EEACEDOgAEsCUfucrrQTgWipcY4XjPlT5uJH3l6MZ4 69 | zfOThwH2HshD0veZGWSqm6b1/HfDRWePmCz2RDGgbO4= 70 | -----END PUBLIC KEY-----`) 71 | 72 | var rsaKey rsa.PrivateKey 73 | 74 | func TestMain(m *testing.M) { 75 | pem, _ := pem.Decode(testRsaPrivateKey) 76 | if pem == nil { 77 | panic("invalid pem") 78 | } 79 | 80 | priv, err := x509.ParsePKCS1PrivateKey(pem.Bytes) 81 | if err != nil { 82 | panic(err) 83 | } 84 | rsaKey = *priv 85 | 86 | os.Exit(m.Run()) 87 | } 88 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "github.com/codegangsta/cli" 7 | "github.com/sorah/etcvault/engine" 8 | "github.com/sorah/etcvault/keys" 9 | "io" 10 | "net/url" 11 | "os" 12 | "strings" 13 | "time" 14 | ) 15 | 16 | func main() { 17 | app := cli.NewApp() 18 | app.Name = "etcvault" 19 | app.Usage = "proxy for etcd, adding transparent encryption" 20 | app.Version = "0.3.1" 21 | 22 | app.Commands = []cli.Command{ 23 | { 24 | Name: "start", 25 | Usage: "start etcvault proxy", 26 | Action: actionStart, 27 | Flags: []cli.Flag{ 28 | cli.StringFlag{ 29 | Name: "keychain", 30 | Usage: "Path to directory for keys", 31 | }, 32 | cli.StringFlag{ 33 | Name: "listen", 34 | Value: "http://localhost:2381", 35 | Usage: "URL to listen. Specify https as scheme to listen HTTPS.", 36 | }, 37 | cli.StringFlag{ 38 | Name: "advertise-url", 39 | Value: "http://localhost:2381", 40 | Usage: "Client URL to advertise. Usually specify etcvault's URL", 41 | }, 42 | 43 | cli.StringFlag{ 44 | Name: "discovery-srv", 45 | Usage: "domain to fetch SRV records for backend etcd", 46 | }, 47 | cli.StringFlag{ 48 | Name: "initial-backends", 49 | Usage: "backend urls to fetch backend etcd members, separeted by comma", 50 | }, 51 | cli.StringFlag{ 52 | Name: "client-ca-file", 53 | Usage: "TLS CA file to verify certificate of etcd client ports (https://...:2379/)", 54 | }, 55 | cli.StringFlag{ 56 | Name: "client-cert-file", 57 | Usage: "TLS certficate file to send when communicating with etcd client ports (https://...:2379/)", 58 | }, 59 | cli.StringFlag{ 60 | Name: "client-key-file", 61 | Usage: "key for -client-cert-file", 62 | }, 63 | cli.StringFlag{ 64 | Name: "peer-ca-file", 65 | Usage: "TLS CA file to verify certificate of etcd peer ports (https://...:2380/)", 66 | }, 67 | cli.StringFlag{ 68 | Name: "peer-cert-file", 69 | Usage: "TLS certficate file to send when communicating with etcd peer ports (https://...:2380/)", 70 | }, 71 | cli.StringFlag{ 72 | Name: "peer-key-file", 73 | Usage: "key for -peer-cert-file", 74 | }, 75 | cli.StringFlag{ 76 | Name: "listen-ca-file", 77 | Usage: "When listening HTTPS and this is present, etcvault will validate its client with using this CA certificate. If not present, -client-ca-file will be used.", 78 | }, 79 | cli.StringFlag{ 80 | Name: "listen-cert-file", 81 | Usage: "When listening HTTPS and this is present, etcvault will use this certificate to listen. If not present, -client-cert-file will be used.", 82 | }, 83 | cli.StringFlag{ 84 | Name: "listen-key-file", 85 | Usage: "key for -listen-cert-file", 86 | }, 87 | cli.IntFlag{ 88 | Name: "discovery-interval", 89 | Value: 120, 90 | Usage: "Interval (in second) to refresh backends with specified discovery method", 91 | }, 92 | cli.BoolFlag{ 93 | Name: "readonly", 94 | Usage: "if set, etcvault will reject non GET requests", 95 | }, 96 | }, 97 | }, 98 | { 99 | Name: "keygen", 100 | Usage: "Generate new private key with specified name", 101 | Action: actionKeygen, 102 | Flags: []cli.Flag{ 103 | cli.StringFlag{ 104 | Name: "save", 105 | Usage: "Save generated key into specfied directory (keychain)", 106 | }, 107 | cli.IntFlag{ 108 | Name: "bits", 109 | Value: 2048, 110 | Usage: "RSA key bit length to generate", 111 | }, 112 | }, 113 | }, 114 | { 115 | Name: "transform", 116 | Usage: "transform ETCVAULT* strings (from argument or stdin) to appropriate strings", 117 | Action: actionTransform, 118 | Flags: []cli.Flag{ 119 | cli.StringFlag{ 120 | Name: "keychain", 121 | Usage: "Path to directory for keys", 122 | }, 123 | cli.BoolFlag{ 124 | Name: "stdin", 125 | Usage: "Read from stdin", 126 | }, 127 | }, 128 | }, 129 | } 130 | 131 | app.Run(os.Args) 132 | } 133 | 134 | func actionKeygen(ctx *cli.Context) { 135 | if len(ctx.Args()) < 1 { 136 | fmt.Fprintln(os.Stderr, "specify key name") 137 | os.Exit(1) 138 | } 139 | 140 | name := ctx.Args()[0] 141 | bits := ctx.Int("bits") 142 | 143 | key, err := keys.GenerateKey(name, bits) 144 | if err != nil { 145 | panic(err) 146 | } 147 | 148 | saveDir := ctx.String("save") 149 | 150 | if saveDir == "" { 151 | fmt.Printf("%s", key.PrivatePem()) 152 | } else { 153 | keychain := keys.NewKeychain(saveDir) 154 | keychain.Save(key) 155 | } 156 | } 157 | 158 | func actionTransform(ctx *cli.Context) { 159 | keychainDir := ctx.String("keychain") 160 | if keychainDir == "" { 161 | fmt.Fprintln(os.Stderr, "Specify -keychain option") 162 | os.Exit(1) 163 | } 164 | 165 | keychain := keys.NewKeychain(keychainDir) 166 | engine := engine.NewEngine(keychain) 167 | 168 | if ctx.Bool("stdin") { 169 | reader := bufio.NewReader(os.Stderr) 170 | for { 171 | line, err := reader.ReadString('\n') 172 | if err != nil { 173 | if err == io.EOF { 174 | break 175 | } else { 176 | panic(err) 177 | } 178 | } 179 | 180 | origStr := strings.TrimRight(line, "\n") 181 | 182 | str, err := engine.Transform(origStr) 183 | if err == nil { 184 | fmt.Println(str) 185 | } else { 186 | fmt.Println(origStr) 187 | fmt.Fprintf(os.Stderr, "ERR: %s\n", err.Error()) 188 | } 189 | } 190 | } else { 191 | for _, origStr := range ctx.Args() { 192 | str, err := engine.Transform(origStr) 193 | if err == nil { 194 | fmt.Println(str) 195 | } else { 196 | fmt.Println(origStr) 197 | fmt.Fprintf(os.Stderr, "ERR: %s", err.Error()) 198 | } 199 | } 200 | } 201 | } 202 | 203 | func actionStart(ctx *cli.Context) { 204 | keychainDir := ctx.String("keychain") 205 | if keychainDir == "" { 206 | fmt.Fprintln(os.Stderr, "Specify -keychain option") 207 | os.Exit(1) 208 | } 209 | 210 | discoverySrvDomain := ctx.String("discovery-srv") 211 | initialBackendUrlStrings := ctx.String("initial-backends") 212 | if discoverySrvDomain == "" && initialBackendUrlStrings == "" { 213 | fmt.Fprintln(os.Stderr, "Specify -discovery-srv or -initial-backends option") 214 | os.Exit(1) 215 | } 216 | if discoverySrvDomain != "" && initialBackendUrlStrings != "" { 217 | fmt.Fprintln(os.Stderr, "Only specifying only either -discovery-srv or -initial-backends is accepted.") 218 | os.Exit(1) 219 | } 220 | 221 | clientCaFilePath := ctx.String("client-ca-file") 222 | clientCertFilePath := ctx.String("client-cert-file") 223 | clientKeyFilePath := ctx.String("client-key-file") 224 | if (clientCertFilePath != "" || clientKeyFilePath != "") && !(clientCertFilePath != "" && clientKeyFilePath != "") { 225 | fmt.Fprintln(os.Stderr, "provide both -client-cert-file and -client-key-file") 226 | os.Exit(1) 227 | } 228 | 229 | peerCaFilePath := ctx.String("peer-ca-file") 230 | peerCertFilePath := ctx.String("peer-cert-file") 231 | peerKeyFilePath := ctx.String("peer-key-file") 232 | if (peerCertFilePath != "" || peerKeyFilePath != "") && !(peerCertFilePath != "" && peerKeyFilePath != "") { 233 | fmt.Fprintln(os.Stderr, "provide both -peer-cert-file and -peer-key-file") 234 | os.Exit(1) 235 | } 236 | 237 | listenCaFilePath := ctx.String("listen-ca-file") 238 | listenCertFilePath := ctx.String("listen-cert-file") 239 | listenKeyFilePath := ctx.String("listen-key-file") 240 | if (listenCertFilePath != "" || listenKeyFilePath != "") && !(listenCertFilePath != "" && listenKeyFilePath != "") { 241 | fmt.Fprintln(os.Stderr, "provide both -listen-cert-file and -listen-key-file") 242 | os.Exit(1) 243 | } 244 | 245 | discoveryInterval := ctx.Int("discovery-interval") 246 | 247 | readonly := ctx.Bool("readonly") 248 | 249 | listenUrl, err := url.Parse(ctx.String("listen")) 250 | if err != nil { 251 | fmt.Fprintf(os.Stderr, "couldn't parse -listen as URL: %s\n", err.Error()) 252 | os.Exit(1) 253 | } 254 | if listenUrl.Path != "" && listenUrl.Path != "/ " { 255 | fmt.Fprintf(os.Stderr, "-listen URL shouldn't include path: %s\n", listenUrl.Path) 256 | os.Exit(1) 257 | } 258 | if !(clientCertFilePath != "" && clientKeyFilePath != "") && listenUrl.Scheme == "https" { 259 | fmt.Fprintln(os.Stderr, "provide both -cert-file and -key-file when listen https") 260 | os.Exit(1) 261 | } 262 | 263 | advertiseUrl := ctx.String("advertise-url") 264 | 265 | starter := &ProxyStarter{ 266 | Listen: listenUrl, 267 | keychainDir: keychainDir, 268 | DiscoverySrvDomain: discoverySrvDomain, 269 | initialBackendUrlStrings: initialBackendUrlStrings, 270 | clientCaFilePath: clientCaFilePath, 271 | clientCertFilePath: clientCertFilePath, 272 | clientKeyFilePath: clientKeyFilePath, 273 | peerCaFilePath: peerCaFilePath, 274 | peerCertFilePath: peerCertFilePath, 275 | peerKeyFilePath: peerKeyFilePath, 276 | listenCaFilePath: listenCaFilePath, 277 | listenCertFilePath: listenCertFilePath, 278 | listenKeyFilePath: listenKeyFilePath, 279 | discoveryInterval: time.Duration(discoveryInterval) * time.Second, 280 | readonly: readonly, 281 | AdvertiseUrl: advertiseUrl, 282 | } 283 | 284 | starter.Start() 285 | } 286 | -------------------------------------------------------------------------------- /proxy/backend.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "log" 5 | "net/url" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | type Backend struct { 11 | sync.Mutex 12 | Url *url.URL 13 | Available bool 14 | nextCheckInterval time.Duration 15 | resumeTimer *time.Timer 16 | } 17 | 18 | func NewBackend(url *url.URL) *Backend { 19 | return &Backend{ 20 | Url: url, 21 | Available: true, 22 | nextCheckInterval: time.Duration(time.Second) * 15, 23 | resumeTimer: nil, 24 | } 25 | } 26 | 27 | func (backend *Backend) Fail() { 28 | backend.Lock() 29 | defer backend.Unlock() 30 | 31 | if !backend.Available { 32 | return 33 | } 34 | 35 | backend.Available = false 36 | checkInterval := backend.nextCheckInterval 37 | backend.nextCheckInterval = checkInterval * 2 38 | 39 | backend.resumeTimer = time.AfterFunc(backend.nextCheckInterval, func() { 40 | backend.Lock() 41 | if !backend.Available { 42 | backend.Available = true 43 | backend.Unlock() 44 | } else { 45 | backend.Unlock() 46 | return 47 | } 48 | 49 | log.Printf("Backend %s resumed (automatically)", backend.Url.String()) 50 | }) 51 | 52 | log.Printf("Backend %s marked as failure, will resume after %s", backend.Url.String(), checkInterval.String()) 53 | } 54 | 55 | func (backend *Backend) Ok() { 56 | backend.Lock() 57 | defer backend.Unlock() 58 | 59 | wasUnavailable := !backend.Available 60 | backend.Available = true 61 | backend.nextCheckInterval = time.Duration(time.Second) * 15 62 | 63 | if backend.resumeTimer != nil { 64 | backend.resumeTimer.Stop() 65 | backend.resumeTimer = nil 66 | } 67 | 68 | if wasUnavailable { 69 | log.Printf("Backend %s resumed", backend.Url.String()) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /proxy/discovery.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "net" 9 | "net/http" 10 | "net/url" 11 | ) 12 | 13 | // for testing... 14 | var lookupSRV = net.LookupSRV 15 | 16 | type etcdMember struct { 17 | ClientURLs []string 18 | PeerURLs []string 19 | Name string 20 | } 21 | 22 | type etcdMembers struct { 23 | Members []etcdMember 24 | } 25 | 26 | func DiscoverBackendsFromDns(transport *http.Transport, domain string) ([]*Backend, error) { 27 | _, records, errA := lookupSRV("etcd-server", "tcp", domain) 28 | if errA != nil { 29 | log.Printf("error when looking up _etcd-server._tcp.%s: %s", domain, errA.Error()) 30 | } 31 | 32 | _, ssl_records, errB := lookupSRV("etcd-server-ssl", "tcp", domain) 33 | if errB != nil { 34 | log.Printf("error when looking up _etcd-server-ssl._tcp.%s: %s", domain, errB.Error()) 35 | } 36 | 37 | if errA != nil && errB != nil { 38 | return nil, errA 39 | } 40 | 41 | urls := make([]*url.URL, 0, len(records)+len(ssl_records)) 42 | 43 | makeUrl := func(srv *net.SRV, scheme string) *url.URL { 44 | var target string 45 | if srv.Target[len(srv.Target)-1] == '.' { 46 | target = srv.Target[0 : len(srv.Target)-1] 47 | } else { 48 | target = srv.Target 49 | } 50 | 51 | hostPort := net.JoinHostPort(target, fmt.Sprintf("%d", srv.Port)) 52 | 53 | u := &url.URL{ 54 | Scheme: scheme, 55 | Host: hostPort, 56 | } 57 | return u 58 | } 59 | 60 | for _, srv := range ssl_records { 61 | urls = append(urls, makeUrl(srv, "https")) 62 | } 63 | 64 | for _, srv := range records { 65 | urls = append(urls, makeUrl(srv, "http")) 66 | } 67 | 68 | return DiscoverBackendsFromEtcdPeer(transport, urls), nil 69 | } 70 | 71 | func DiscoverBackendsFromEtcdPeer(transport *http.Transport, urls []*url.URL) []*Backend { 72 | return fetchBackendsFromEtcd(transport, urls, "/members", false) 73 | } 74 | 75 | func DiscoverBackendsFromEtcd(transport *http.Transport, urls []*url.URL) []*Backend { 76 | return fetchBackendsFromEtcd(transport, urls, "/v2/members", true) 77 | } 78 | 79 | func fetchBackendsFromEtcd(transport *http.Transport, urls []*url.URL, path string, wrapped bool) []*Backend { 80 | client := &http.Client{Transport: transport} 81 | 82 | for _, origUrl := range urls { 83 | u := new(url.URL) 84 | *u = *origUrl 85 | 86 | u.Path = path 87 | 88 | resp, err := client.Get(u.String()) 89 | if err != nil { 90 | log.Printf("error when retrieving %s: %s", u.String(), err.Error()) 91 | continue 92 | } 93 | 94 | respBody, err := ioutil.ReadAll(resp.Body) 95 | if err != nil { 96 | continue 97 | } 98 | err = resp.Body.Close() 99 | if err != nil { 100 | panic(err) 101 | } 102 | 103 | var members []etcdMember 104 | if wrapped { 105 | jsonData := &etcdMembers{} 106 | err = json.Unmarshal(respBody, jsonData) 107 | members = jsonData.Members 108 | } else { 109 | jsonData := []etcdMember{} 110 | err = json.Unmarshal(respBody, &jsonData) 111 | members = jsonData 112 | } 113 | 114 | if err != nil { 115 | log.Printf("error when parsing response from %s: %s", u.String(), err.Error()) 116 | continue 117 | } 118 | 119 | backends := make([]*Backend, 0, len(members)) 120 | 121 | for _, member := range members { 122 | if len(member.ClientURLs) < 1 { 123 | continue 124 | } 125 | clientUrl, err := url.Parse(member.ClientURLs[0]) 126 | if err != nil { 127 | continue 128 | } 129 | backend := NewBackend(clientUrl) 130 | 131 | backends = append(backends, backend) 132 | } 133 | 134 | return backends 135 | } 136 | 137 | return []*Backend{} 138 | } 139 | -------------------------------------------------------------------------------- /proxy/discovery_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "crypto/tls" 5 | "encoding/json" 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "strconv" 12 | "testing" 13 | ) 14 | 15 | func membersMock(count int, peer bool) (server *httptest.Server, tlsServer *httptest.Server, transport *http.Transport) { 16 | type memberT struct { 17 | ClientURLs []string 18 | PeerURLs []string 19 | Name string 20 | Id string 21 | } 22 | 23 | members := make([]memberT, 0, count) 24 | for i := 0; i < count; i++ { 25 | member := memberT{ 26 | ClientURLs: []string{fmt.Sprintf("http://member-%d:2379", i)}, 27 | PeerURLs: []string{fmt.Sprintf("http://member-%d:2380", i)}, 28 | Name: fmt.Sprintf("member-%d", i), 29 | Id: fmt.Sprintf("%x", i), 30 | } 31 | members = append(members, member) 32 | } 33 | 34 | var path, host string 35 | var membersJson []byte 36 | var err error 37 | 38 | if peer { 39 | host = "node:2380" 40 | path = "/members" 41 | membersJson, err = json.Marshal(members) 42 | if err != nil { 43 | panic(err) 44 | } 45 | } else { 46 | host = "node:2379" 47 | path = "/v2/members" 48 | membersJson, err = json.Marshal(struct { 49 | Members []memberT 50 | }{ 51 | Members: members, 52 | }) 53 | if err != nil { 54 | panic(err) 55 | } 56 | 57 | } 58 | 59 | server = httptest.NewServer(http.HandlerFunc(func(resp http.ResponseWriter, request *http.Request) { 60 | if request.URL.Path == path && request.Method == "GET" && request.Host == host { 61 | resp.Header().Add("Content-Type", "application/json") 62 | resp.WriteHeader(200) 63 | _, _ = resp.Write(membersJson) 64 | } else { 65 | http.Error(resp, "not found", 404) 66 | } 67 | })) 68 | 69 | tlsServer = httptest.NewTLSServer(http.HandlerFunc(func(resp http.ResponseWriter, request *http.Request) { 70 | if request.URL.Path == path && request.Method == "GET" { 71 | resp.Header().Add("Content-Type", "application/json") 72 | resp.WriteHeader(200) 73 | _, _ = resp.Write(membersJson) 74 | } else { 75 | http.Error(resp, "not found", 404) 76 | } 77 | })) 78 | 79 | transport = &http.Transport{ 80 | TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, 81 | Proxy: func(req *http.Request) (*url.URL, error) { 82 | if req.URL.Scheme == "https" { 83 | return nil, nil 84 | } 85 | u, _ := url.Parse(server.URL) 86 | return u, nil 87 | }, 88 | } 89 | 90 | return 91 | } 92 | 93 | // ---- 94 | 95 | func TestDiscoverBackendsFromEtcd(t *testing.T) { 96 | testServer, tlsServer, transport := membersMock(3, false) 97 | defer testServer.Close() 98 | defer tlsServer.Close() 99 | 100 | u, err := url.Parse("http://node:2379") 101 | if err != nil { 102 | panic(err) 103 | } 104 | 105 | backends := DiscoverBackendsFromEtcd(transport, []*url.URL{u}) 106 | 107 | if len(backends) != 3 { 108 | t.Errorf("unexpected backends size %d", len(backends)) 109 | return 110 | } 111 | 112 | if backends[0].Url.String() != "http://member-0:2379" { 113 | t.Errorf("unexpected backends[0] url %s", backends[0].Url.String()) 114 | } 115 | if backends[1].Url.String() != "http://member-1:2379" { 116 | t.Errorf("unexpected backends[1] url %s", backends[1].Url.String()) 117 | } 118 | if backends[2].Url.String() != "http://member-2:2379" { 119 | t.Errorf("unexpected backends[2] url %s", backends[2].Url.String()) 120 | } 121 | 122 | if u.String() != "http://node:2379" { 123 | t.Errorf("url changed %s", u.String()) 124 | } 125 | } 126 | 127 | func TestDiscoverBackendsFromEtcdPeer(t *testing.T) { 128 | testServer, tlsServer, transport := membersMock(3, true) 129 | defer testServer.Close() 130 | defer tlsServer.Close() 131 | 132 | u, err := url.Parse("http://node:2380") 133 | if err != nil { 134 | panic(err) 135 | } 136 | 137 | backends := DiscoverBackendsFromEtcdPeer(transport, []*url.URL{u}) 138 | 139 | if len(backends) != 3 { 140 | t.Errorf("unexpected backends size %d", len(backends)) 141 | return 142 | } 143 | 144 | if backends[0].Url.String() != "http://member-0:2379" { 145 | t.Errorf("unexpected backends[0] url %s", backends[0].Url.String()) 146 | } 147 | if backends[1].Url.String() != "http://member-1:2379" { 148 | t.Errorf("unexpected backends[1] url %s", backends[1].Url.String()) 149 | } 150 | if backends[2].Url.String() != "http://member-2:2379" { 151 | t.Errorf("unexpected backends[2] url %s", backends[2].Url.String()) 152 | } 153 | 154 | if u.String() != "http://node:2380" { 155 | t.Errorf("url changed %s", u.String()) 156 | } 157 | } 158 | 159 | func TestDiscoverBackendsFromDns(t *testing.T) { 160 | testServer, tlsServer, transport := membersMock(3, true) 161 | defer testServer.Close() 162 | defer tlsServer.Close() 163 | 164 | lookupSRV = func(service, proto, name string) (string, []*net.SRV, error) { 165 | if service == "etcd-server" && proto == "tcp" && name == "example.org" { 166 | return "", []*net.SRV{ 167 | { 168 | Target: "node.", 169 | Port: 2380, 170 | Priority: 0, 171 | Weight: 0, 172 | }, 173 | }, nil 174 | } 175 | return "", []*net.SRV{}, &net.DNSError{Err: "no such host", Name: "", Server: "", IsTimeout: false} 176 | } 177 | defer func() { lookupSRV = net.LookupSRV }() 178 | 179 | backends, err := DiscoverBackendsFromDns(transport, "example.org") 180 | 181 | if err != nil { 182 | t.Errorf("err %s", err.Error()) 183 | } 184 | 185 | if len(backends) != 3 { 186 | t.Errorf("unexpected backends size %d", len(backends)) 187 | return 188 | } 189 | 190 | if backends[0].Url.String() != "http://member-0:2379" { 191 | t.Errorf("unexpected backends[0] url %s", backends[0].Url.String()) 192 | } 193 | if backends[1].Url.String() != "http://member-1:2379" { 194 | t.Errorf("unexpected backends[1] url %s", backends[1].Url.String()) 195 | } 196 | if backends[2].Url.String() != "http://member-2:2379" { 197 | t.Errorf("unexpected backends[2] url %s", backends[2].Url.String()) 198 | } 199 | } 200 | 201 | func TestDiscoverBackendsFromDnsTls(t *testing.T) { 202 | testServer, tlsServer, transport := membersMock(3, true) 203 | defer testServer.Close() 204 | defer tlsServer.Close() 205 | 206 | tlsServerUrl, _ := url.Parse(tlsServer.URL) 207 | 208 | tlsServerHost, tlsServerPortStr, _ := net.SplitHostPort(tlsServerUrl.Host) 209 | tlsServerPort, _ := strconv.Atoi(tlsServerPortStr) 210 | 211 | lookupSRV = func(service, proto, name string) (string, []*net.SRV, error) { 212 | if service == "etcd-server-ssl" && proto == "tcp" && name == "example.org" { 213 | return "", []*net.SRV{ 214 | { 215 | Target: tlsServerHost, 216 | Port: uint16(tlsServerPort), 217 | Priority: 0, 218 | Weight: 0, 219 | }, 220 | }, nil 221 | } 222 | return "", []*net.SRV{}, &net.DNSError{Err: "no such host", Name: "", Server: "", IsTimeout: false} 223 | } 224 | defer func() { lookupSRV = net.LookupSRV }() 225 | 226 | backends, err := DiscoverBackendsFromDns(transport, "example.org") 227 | 228 | if err != nil { 229 | t.Errorf("err %s", err.Error()) 230 | } 231 | 232 | if len(backends) != 3 { 233 | t.Errorf("unexpected backends size %d", len(backends)) 234 | return 235 | } 236 | 237 | if backends[0].Url.String() != "http://member-0:2379" { 238 | t.Errorf("unexpected backends[0] url %s", backends[0].Url.String()) 239 | } 240 | if backends[1].Url.String() != "http://member-1:2379" { 241 | t.Errorf("unexpected backends[1] url %s", backends[1].Url.String()) 242 | } 243 | if backends[2].Url.String() != "http://member-2:2379" { 244 | t.Errorf("unexpected backends[2] url %s", backends[2].Url.String()) 245 | } 246 | } 247 | 248 | func TestDiscoverBackendsFromDnsError(t *testing.T) { 249 | testServer, tlsServer, transport := membersMock(3, true) 250 | defer testServer.Close() 251 | defer tlsServer.Close() 252 | 253 | lookupSRV = func(service, proto, name string) (string, []*net.SRV, error) { 254 | return "", []*net.SRV{}, &net.DNSError{Err: "no such host", Name: "", Server: "", IsTimeout: false} 255 | } 256 | defer func() { lookupSRV = net.LookupSRV }() 257 | 258 | backends, err := DiscoverBackendsFromDns(transport, "example.org") 259 | 260 | if _, ok := err.(*net.DNSError); !ok { 261 | t.Errorf("unexpected err %s", err.Error()) 262 | } 263 | 264 | if len(backends) != 0 { 265 | t.Errorf("unexpected backends size %d", len(backends)) 266 | return 267 | } 268 | } 269 | -------------------------------------------------------------------------------- /proxy/proxy.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "github.com/sorah/etcvault/engine" 8 | "io" 9 | "io/ioutil" 10 | "log" 11 | "net/http" 12 | ) 13 | 14 | type ClosableBuffer struct { 15 | *bytes.Buffer 16 | } 17 | 18 | func (buf ClosableBuffer) Close() error { 19 | return nil 20 | } 21 | 22 | // Hop-by-hop headers (borrowed from httputil.ReverseProxy) 23 | // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html 24 | var singleHopHeaders = []string{ 25 | "Connection", 26 | "Keep-Alive", 27 | "Proxy-Authenticate", 28 | "Proxy-Authorization", 29 | "Te", 30 | "Trailers", 31 | "Transfer-Encoding", 32 | "Upgrade", 33 | } 34 | 35 | type Proxy struct { 36 | Transport *http.Transport 37 | Router *Router 38 | Engine engine.Transformable 39 | AdvertiseUrl string 40 | } 41 | 42 | func NewProxy(transport *http.Transport, router *Router, e engine.Transformable, advertiseUrl string) http.Handler { 43 | return &Proxy{ 44 | Transport: transport, 45 | Router: router, 46 | Engine: e, 47 | AdvertiseUrl: advertiseUrl, 48 | } 49 | } 50 | 51 | func (proxy *Proxy) ServeHTTP(response http.ResponseWriter, request *http.Request) { 52 | if request.URL.Path == "/v2/members" { 53 | proxy.serveMembersRequest(response, request) 54 | } else if request.URL.Path == "/v2/machines" { 55 | proxy.serveMachinesRequest(response, request) 56 | } else if request.URL.Path == "/_etcvault/keys" { 57 | proxy.serveEtcvaultKeysRequest(response, request) 58 | } else { 59 | proxy.serveProxyRequest(response, request) 60 | } 61 | } 62 | 63 | func (proxy *Proxy) serveProxyRequest(response http.ResponseWriter, request *http.Request) { 64 | backendRequest := new(http.Request) 65 | // copy 66 | *backendRequest = *request 67 | backendRequest.Header = make(http.Header) 68 | 69 | backendRequest.Proto = "HTTP/1.1" 70 | backendRequest.ProtoMajor = 1 71 | backendRequest.ProtoMinor = 1 72 | backendRequest.Close = false 73 | 74 | copyHeader(request.Header, backendRequest.Header) 75 | removeSingleHopHeaders(&backendRequest.Header) 76 | 77 | if (backendRequest.Method == "POST" || backendRequest.Method == "PUT" || backendRequest.Method == "PATCH") && backendRequest.Body != nil { 78 | origBody := backendRequest.Body 79 | defer origBody.Close() 80 | 81 | if err := backendRequest.ParseForm(); err != nil { 82 | log.Printf("couldn't parse form: %s", err.Error()) 83 | http.Error(response, "couldn't parse form", 400) 84 | return 85 | } 86 | 87 | if backendRequest.PostForm != nil { 88 | origValue := backendRequest.PostForm.Get("value") 89 | value, err := proxy.Engine.Transform(origValue) 90 | if err == nil { 91 | backendRequest.PostForm.Set("value", value) 92 | } else { 93 | log.Printf("failed to transform value: %s", err.Error()) 94 | } 95 | newFormString := backendRequest.PostForm.Encode() 96 | backendRequest.Body = ClosableBuffer{bytes.NewBufferString(newFormString)} 97 | backendRequest.ContentLength = int64(len(newFormString)) 98 | } 99 | } 100 | 101 | var backendResponse *http.Response 102 | 103 | var closeNotifyCh <-chan bool 104 | closeNotifier, ok := response.(http.CloseNotifier) 105 | if ok { 106 | closeNotifyCh = closeNotifier.CloseNotify() 107 | } else { 108 | closeNotifyCh = make(<-chan bool) 109 | } 110 | 111 | completeCh := make(chan bool, 2) 112 | closed := false 113 | go func() { 114 | select { 115 | case <-closeNotifyCh: 116 | log.Printf("Request connection closed; cancelling ongoing backend request") 117 | closed = true 118 | proxy.Transport.CancelRequest(backendRequest) 119 | case <-completeCh: 120 | } 121 | if backendResponse != nil { 122 | backendResponse.Body.Close() 123 | } 124 | }() 125 | defer func() { 126 | completeCh <- true 127 | }() 128 | 129 | backends := proxy.Router.ShuffledAvailableBackends() 130 | for _, backend := range backends { 131 | backendRequest.URL.Scheme = backend.Url.Scheme 132 | backendRequest.URL.Host = backend.Url.Host 133 | 134 | var err error 135 | backendResponse, err = proxy.Transport.RoundTrip(backendRequest) 136 | if err != nil { 137 | log.Printf("backend %s response error: %s", backend.Url.String(), err.Error()) 138 | backend.Fail() 139 | continue 140 | } 141 | backend.Ok() 142 | break 143 | } 144 | 145 | if backendResponse == nil { 146 | log.Printf("all backends not available...") 147 | http.Error(response, "backends all unavailable", http.StatusBadGateway) 148 | return 149 | } 150 | 151 | defer backendResponse.Body.Close() 152 | 153 | removeSingleHopHeaders(&backendResponse.Header) 154 | copyHeader(backendResponse.Header, response.Header()) 155 | 156 | if backendResponse.Header.Get("Content-Type") == "application/json" { 157 | json, err := ioutil.ReadAll(backendResponse.Body) 158 | if closed { 159 | return 160 | } 161 | if err != nil { 162 | panic(err) 163 | } 164 | 165 | transformedJson, err := proxy.Engine.TransformEtcdJsonResponse(json) 166 | if err == nil { 167 | response.Header().Set("Content-Length", fmt.Sprintf("%d", len(transformedJson)+1)) 168 | response.WriteHeader(backendResponse.StatusCode) 169 | response.Write(transformedJson) 170 | response.Write([]byte("\n")) 171 | } else { 172 | fmt.Printf("transform error %s\n", err.Error()) 173 | response.WriteHeader(backendResponse.StatusCode) 174 | response.Write(json) 175 | } 176 | } else { 177 | response.WriteHeader(backendResponse.StatusCode) 178 | io.Copy(response, backendResponse.Body) 179 | } 180 | } 181 | 182 | func (proxy *Proxy) serveMembersRequest(response http.ResponseWriter, request *http.Request) { 183 | if request.Method != "GET" { 184 | http.Error(response, "not supported; communicate with etcd directly", http.StatusMethodNotAllowed) 185 | return 186 | } 187 | 188 | type memberT struct { 189 | ClientURLs []string 190 | PeerURLs []string 191 | Name string 192 | Id string 193 | } 194 | 195 | jsonBytes, err := json.Marshal(struct { 196 | Members []memberT 197 | }{ 198 | Members: []memberT{ 199 | { 200 | ClientURLs: []string{proxy.AdvertiseUrl}, 201 | Name: "etcvault", 202 | Id: "deadbeef", 203 | }, 204 | }, 205 | }) 206 | 207 | if err != nil { 208 | http.Error(response, "failed to marshal", 500) 209 | log.Printf("failed to marshal /v2/members: %s", err.Error()) 210 | return 211 | } 212 | 213 | response.Header().Add("Content-Type", "application/json") 214 | response.Header().Add("Server", "etcvault") 215 | response.WriteHeader(200) 216 | response.Write(jsonBytes) 217 | } 218 | 219 | func (proxy *Proxy) serveMachinesRequest(response http.ResponseWriter, request *http.Request) { 220 | if request.Method != "GET" { 221 | http.Error(response, "not supported; communicate with etcd directly", http.StatusMethodNotAllowed) 222 | return 223 | } 224 | 225 | response.Header().Add("Content-Type", "text/plain") 226 | response.Header().Add("Server", "etcvault") 227 | response.WriteHeader(200) 228 | response.Write([]byte(proxy.AdvertiseUrl)) 229 | } 230 | 231 | func (proxy *Proxy) serveEtcvaultKeysRequest(response http.ResponseWriter, request *http.Request) { 232 | if request.Method != "GET" { 233 | http.Error(response, "not found", http.StatusNotFound) 234 | return 235 | } 236 | 237 | request.ParseForm() 238 | var list []string 239 | if request.FormValue("encryption") != "" { 240 | list = proxy.Engine.GetKeychain().ListForEncryption() 241 | } else { 242 | list = proxy.Engine.GetKeychain().List() 243 | } 244 | 245 | response.Header().Add("Content-Type", "text/plain") 246 | response.Header().Add("Server", "etcvault") 247 | response.WriteHeader(200) 248 | 249 | for _, name := range list { 250 | response.Write([]byte(name)) 251 | response.Write([]byte("\n")) 252 | } 253 | } 254 | 255 | func copyHeader(source, destination http.Header) { 256 | for key, values := range source { 257 | for _, value := range values { 258 | destination.Add(key, value) 259 | } 260 | } 261 | } 262 | 263 | func removeSingleHopHeaders(header *http.Header) { 264 | for _, name := range singleHopHeaders { 265 | header.Del(name) 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /proxy/proxy_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "github.com/sorah/etcvault/engine" 8 | "math/rand" 9 | "net/http" 10 | "net/http/httptest" 11 | "net/url" 12 | "reflect" 13 | "strings" 14 | "testing" 15 | "time" 16 | ) 17 | 18 | type mockEngine struct { 19 | engine.Engine 20 | } 21 | 22 | func (e *mockEngine) Transform(str string) (string, error) { 23 | return fmt.Sprintf("<%s>", str), nil 24 | } 25 | 26 | func etcdMock(notify func(request *http.Request)) (cancel func(), serverUrl *url.URL, deadServerUrl *url.URL, deadServer *httptest.Server, transport *http.Transport) { 27 | server := httptest.NewServer(http.HandlerFunc(func(resp http.ResponseWriter, request *http.Request) { 28 | _ = request.ParseForm() 29 | notify(request) 30 | if request.URL.Path == "/v2/keys/greeting" && request.Method == "GET" { 31 | resp.Header().Add("Content-Type", "application/json") 32 | resp.WriteHeader(200) 33 | _, _ = resp.Write([]byte(`{"action":"get","node":{"key":"/greeting","value":"hello","modifiedIndex":1,"createdIndex":1}}`)) 34 | 35 | } else if request.URL.Path == "/v2/keys/greeting" && request.Method == "PUT" { 36 | resp.Header().Add("Content-Type", "application/json") 37 | resp.WriteHeader(200) 38 | _, _ = resp.Write([]byte(`{"action":"set","node":{"key":"/greeting","value":"hola","modifiedIndex":2,"createdIndex":2},"prevNode":{"key":"/greeting","value":"ETCVAULT::asis:hello::ETCVAULT","modifiedIndex":1,"createdIndex":1}}`)) 39 | 40 | } else if request.URL.Path == "/v2/keys/greeting" && request.Method == "POST" { 41 | resp.Header().Add("Content-Type", "application/json") 42 | resp.WriteHeader(200) 43 | _, _ = resp.Write([]byte(`{"action":"create","node":{"key":"/greeting/1","value:"hola","modifiedIndex":2,"createdIndex":2}}`)) 44 | 45 | } else if request.URL.Path == "/error" && request.Method == "GET" { 46 | resp.Header().Add("Content-Type", "application/json") 47 | resp.WriteHeader(200) 48 | _, _ = resp.Write([]byte(`{"action":"create","node":{"key":"`)) 49 | } else if request.URL.Path == "/text" && request.Method == "GET" { 50 | resp.Header().Add("Content-Type", "text/plain") 51 | resp.WriteHeader(200) 52 | _, _ = resp.Write([]byte(`it works!`)) 53 | 54 | } else if request.URL.Path == "/headers" && request.Method == "GET" { 55 | resp.Header().Set("Connection", "hello!") 56 | resp.Header().Set("Keep-Alive", "hello!") 57 | resp.Header().Set("Proxy-Authenticate", "hello!") 58 | resp.Header().Set("Proxy-Authorization", "hello!") 59 | resp.Header().Set("Te", "hello!") 60 | resp.Header().Set("Trailers", "hello!") 61 | resp.Header().Set("Upgrade", "hello!") 62 | resp.Header().Set("X-My-Original", "hello!") 63 | resp.Header().Set("Content-Type", "application/json") 64 | resp.WriteHeader(200) 65 | _, _ = resp.Write([]byte("{}")) 66 | } else { 67 | http.Error(resp, "not found", 404) 68 | } 69 | })) 70 | 71 | deadServer = httptest.NewServer(http.HandlerFunc(func(resp http.ResponseWriter, request *http.Request) { 72 | request.URL.Host = "dead" 73 | notify(request) 74 | http.Error(resp, "dead", 500) 75 | })) 76 | 77 | serverUrl, _ = url.Parse(server.URL) 78 | deadServerUrl, _ = url.Parse(deadServer.URL) 79 | 80 | transport = &http.Transport{} 81 | 82 | cancel = func() { 83 | server.Close() 84 | deadServer.Close() 85 | } 86 | 87 | return 88 | } 89 | 90 | func TestProxyGet(t *testing.T) { 91 | cancel, serverURL, _, _, transport := etcdMock(func(request *http.Request) { 92 | }) 93 | defer cancel() 94 | 95 | backends := []*Backend{ 96 | NewBackend(serverURL), 97 | } 98 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 99 | return backends, nil 100 | }) 101 | 102 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://localhost:2381") 103 | 104 | request, _ := http.NewRequest("GET", "http://localhost/v2/keys/greeting", nil) 105 | recorder := httptest.NewRecorder() 106 | proxyHandler.ServeHTTP(recorder, request) 107 | 108 | if recorder.Code != 200 { 109 | t.Errorf("unexpected response code: %d", recorder.Code) 110 | } 111 | if strings.Contains(recorder.Body.String(), "") { 112 | t.Errorf("unexpected response body: %s", recorder.Body.String()) 113 | } 114 | if header := recorder.Header().Get("Content-Type"); header != "application/json" { 115 | t.Errorf("unexpected Content-Type: %s", recorder.Header().Get("Content-Type")) 116 | } 117 | } 118 | 119 | func TestProxyPost(t *testing.T) { 120 | received := "" 121 | cancel, serverURL, _, _, transport := etcdMock(func(request *http.Request) { 122 | received = request.FormValue("value") 123 | }) 124 | defer cancel() 125 | 126 | backends := []*Backend{ 127 | NewBackend(serverURL), 128 | } 129 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 130 | return backends, nil 131 | }) 132 | 133 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://localhost:2381") 134 | 135 | recorder := httptest.NewRecorder() 136 | request, _ := http.NewRequest("POST", "http://localhost/v2/keys/greeting", bytes.NewBufferString("value=hola")) 137 | request.Header.Add("Content-Type", "application/x-www-form-urlencoded") 138 | 139 | proxyHandler.ServeHTTP(recorder, request) 140 | 141 | if recorder.Code != 200 { 142 | t.Errorf("unexpected response code: %s") 143 | } 144 | if received != "" { 145 | t.Errorf("unexpected request form value: %s", received) 146 | } 147 | if strings.Contains(recorder.Body.String(), "") { 148 | t.Errorf("unexpected response body: %s", recorder.Body.String()) 149 | } 150 | if header := recorder.Header().Get("Content-Type"); header != "application/json" { 151 | t.Errorf("unexpected Content-Type: %s", recorder.Header().Get("Content-Type")) 152 | } 153 | } 154 | 155 | func TestProxyPut(t *testing.T) { 156 | received := "" 157 | cancel, serverURL, _, _, transport := etcdMock(func(request *http.Request) { 158 | received = request.FormValue("value") 159 | }) 160 | defer cancel() 161 | 162 | backends := []*Backend{ 163 | NewBackend(serverURL), 164 | } 165 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 166 | return backends, nil 167 | }) 168 | 169 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://localhost:2381") 170 | 171 | recorder := httptest.NewRecorder() 172 | request, _ := http.NewRequest("PUT", "http://localhost/v2/keys/greeting", bytes.NewBufferString("value=hola")) 173 | request.Header.Add("Content-Type", "application/x-www-form-urlencoded") 174 | proxyHandler.ServeHTTP(recorder, request) 175 | 176 | if recorder.Code != 200 { 177 | t.Errorf("unexpected response code: %d", recorder.Code) 178 | } 179 | if received != "" { 180 | t.Errorf("unexpected request form value: %s", received) 181 | } 182 | if strings.Contains(recorder.Body.String(), "") { 183 | t.Errorf("unexpected response body: %s", recorder.Body.String()) 184 | } 185 | if header := recorder.Header().Get("Content-Type"); header != "application/json" { 186 | t.Errorf("unexpected Content-Type: %s", recorder.Header().Get("Content-Type")) 187 | } 188 | } 189 | 190 | func TestProxyBackendFailure(t *testing.T) { 191 | cancel, _, deadServerURL, _, transport := etcdMock(func(request *http.Request) { 192 | }) 193 | cancel() 194 | 195 | deadBackend := NewBackend(deadServerURL) 196 | backends := []*Backend{ 197 | deadBackend, 198 | } 199 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 200 | return backends, nil 201 | }) 202 | 203 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://localhost:2381") 204 | 205 | request, _ := http.NewRequest("GET", "http://localhost/v2/keys/greeting", nil) 206 | recorder := httptest.NewRecorder() 207 | proxyHandler.ServeHTTP(recorder, request) 208 | 209 | if recorder.Code != http.StatusBadGateway { 210 | t.Errorf("unexpected response code: %d", recorder.Code) 211 | } 212 | if deadBackend.Available { 213 | t.Errorf("unexpected deadBackend available") 214 | } 215 | } 216 | 217 | func TestProxyBackendRetry(t *testing.T) { 218 | cancel, serverURL, deadServerURL, deadServer, transport := etcdMock(func(request *http.Request) { 219 | }) 220 | defer cancel() 221 | deadServer.Close() 222 | rand.Seed(1) 223 | 224 | deadBackend := NewBackend(deadServerURL) 225 | backend := NewBackend(serverURL) 226 | backends := []*Backend{ 227 | deadBackend, 228 | backend, 229 | } 230 | 231 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 232 | return backends, nil 233 | }) 234 | 235 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://localhost:2381") 236 | 237 | request, _ := http.NewRequest("GET", "http://localhost/v2/keys/greeting", nil) 238 | recorder := httptest.NewRecorder() 239 | proxyHandler.ServeHTTP(recorder, request) 240 | 241 | if recorder.Code != 200 { 242 | t.Errorf("unexpected response code: %d", recorder.Code) 243 | } 244 | if strings.Contains(recorder.Body.String(), "") { 245 | t.Errorf("unexpected response body: %s", recorder.Body.String()) 246 | } 247 | if header := recorder.Header().Get("Content-Type"); header != "application/json" { 248 | t.Errorf("unexpected Content-Type: %s", recorder.Header().Get("Content-Type")) 249 | } 250 | if deadBackend.Available { 251 | t.Errorf("unexpected deadBackend available") 252 | } 253 | if !backend.Available { 254 | t.Errorf("unexpected backend unavailable") 255 | } 256 | } 257 | 258 | func TestProxyBackendFailureBackendNoRequest(t *testing.T) { 259 | cancel, serverURL, deadServerURL, _, transport := etcdMock(func(request *http.Request) { 260 | if request.URL.Host == "dead" { 261 | t.Errorf("unexpected request to dead") 262 | } 263 | }) 264 | defer cancel() 265 | 266 | deadBackend := NewBackend(deadServerURL) 267 | deadBackend.Available = false 268 | 269 | backend := NewBackend(serverURL) 270 | backends := []*Backend{ 271 | deadBackend, 272 | backend, 273 | } 274 | 275 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 276 | return backends, nil 277 | }) 278 | 279 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://localhost:2381") 280 | 281 | request, _ := http.NewRequest("GET", "http://localhost/v2/keys/greeting", nil) 282 | recorder := httptest.NewRecorder() 283 | proxyHandler.ServeHTTP(recorder, request) 284 | 285 | if recorder.Code != 200 { 286 | t.Errorf("unexpected response code: %d", recorder.Code) 287 | } 288 | if strings.Contains(recorder.Body.String(), "") { 289 | t.Errorf("unexpected response body: %s", recorder.Body.String()) 290 | } 291 | if header := recorder.Header().Get("Content-Type"); header != "application/json" { 292 | t.Errorf("unexpected Content-Type: %s", recorder.Header().Get("Content-Type")) 293 | } 294 | if deadBackend.Available { 295 | t.Errorf("unexpected deadBackend available") 296 | } 297 | if !backend.Available { 298 | t.Errorf("unexpected backend unavailable") 299 | } 300 | } 301 | 302 | func TestProxyInvalidJsonResponse(t *testing.T) { 303 | cancel, serverURL, _, _, transport := etcdMock(func(request *http.Request) { 304 | }) 305 | defer cancel() 306 | 307 | backends := []*Backend{ 308 | NewBackend(serverURL), 309 | } 310 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 311 | return backends, nil 312 | }) 313 | 314 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://localhost:2381") 315 | 316 | request, _ := http.NewRequest("GET", "http://localhost/error", nil) 317 | recorder := httptest.NewRecorder() 318 | proxyHandler.ServeHTTP(recorder, request) 319 | 320 | if recorder.Code != 200 { 321 | t.Errorf("unexpected response code: %d", recorder.Code) 322 | } 323 | if recorder.Body.String() != "{\"action\":\"create\",\"node\":{\"key\":\"\n" { 324 | t.Errorf("unexpected response body: %#v", recorder.Body.String()) 325 | } 326 | if header := recorder.Header().Get("Content-Type"); header != "application/json" { 327 | t.Errorf("unexpected Content-Type: %s", recorder.Header().Get("Content-Type")) 328 | } 329 | } 330 | 331 | func TestProxyNonJsonResponse(t *testing.T) { 332 | cancel, serverURL, _, _, transport := etcdMock(func(request *http.Request) { 333 | }) 334 | defer cancel() 335 | 336 | backends := []*Backend{ 337 | NewBackend(serverURL), 338 | } 339 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 340 | return backends, nil 341 | }) 342 | 343 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://localhost:2381") 344 | 345 | request, _ := http.NewRequest("GET", "http://localhost/text", nil) 346 | recorder := httptest.NewRecorder() 347 | proxyHandler.ServeHTTP(recorder, request) 348 | 349 | if recorder.Code != 200 { 350 | t.Errorf("unexpected response code: %d", recorder.Code) 351 | } 352 | if recorder.Body.String() != "it works!" { 353 | t.Errorf("unexpected response body: %s", recorder.Body.String()) 354 | } 355 | if header := recorder.Header().Get("Content-Type"); header != "text/plain" { 356 | t.Errorf("unexpected Content-Type: %s", recorder.Header().Get("Content-Type")) 357 | } 358 | } 359 | 360 | func TestProxyHeadersToBackend(t *testing.T) { 361 | receivedHeader := http.Header{} 362 | cancel, serverURL, _, _, transport := etcdMock(func(request *http.Request) { 363 | receivedHeader = request.Header 364 | }) 365 | defer cancel() 366 | 367 | backends := []*Backend{ 368 | NewBackend(serverURL), 369 | } 370 | 371 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 372 | return backends, nil 373 | }) 374 | 375 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://localhost:2381") 376 | 377 | request, _ := http.NewRequest("GET", "http://localhost/v2/keys/greeting", nil) 378 | request.Header.Set("Connection", "hello!") 379 | request.Header.Set("Keep-Alive", "hello!") 380 | request.Header.Set("Proxy-Authenticate", "hello!") 381 | request.Header.Set("Proxy-Authorization", "hello!") 382 | request.Header.Set("Te", "hello!") 383 | request.Header.Set("Trailers", "hello!") 384 | request.Header.Set("Transfer-Encoding", "hello!") 385 | request.Header.Set("Upgrade", "hello!") 386 | request.Header.Set("X-My-Original", "hello!") 387 | 388 | recorder := httptest.NewRecorder() 389 | proxyHandler.ServeHTTP(recorder, request) 390 | 391 | if recorder.Code != 200 { 392 | t.Errorf("unexpected response code: %d", recorder.Code) 393 | } 394 | if receivedHeader.Get("X-My-Original") != "hello!" { 395 | t.Errorf("unexpected request header %s to backend: %s", "Connection", receivedHeader.Get("Connection")) 396 | } 397 | if receivedHeader.Get("Connection") == "hello!" { 398 | t.Errorf("unexpected request header %s to backend: %s", "Connection", receivedHeader.Get("Connection")) 399 | } 400 | if receivedHeader.Get("Keep-Alive") == "hello!" { 401 | t.Errorf("unexpected request header %s to backend: %s", "Keep-Alive", receivedHeader.Get("Keep-Alive")) 402 | } 403 | if receivedHeader.Get("Proxy-Authenticate") == "hello!" { 404 | t.Errorf("unexpected request header %s to backend: %s", "Proxy-Authenticate", receivedHeader.Get("Proxy-Authenticate")) 405 | } 406 | if receivedHeader.Get("Proxy-Authorization") == "hello!" { 407 | t.Errorf("unexpected request header %s to backend: %s", "Proxy-Authorization", receivedHeader.Get("Proxy-Authorization")) 408 | } 409 | if receivedHeader.Get("Te") == "hello!" { 410 | t.Errorf("unexpected request header %s to backend: %s", "Te", receivedHeader.Get("Te")) 411 | } 412 | if receivedHeader.Get("Trailers") == "hello!" { 413 | t.Errorf("unexpected request header %s to backend: %s", "Trailers", receivedHeader.Get("Trailers")) 414 | } 415 | if receivedHeader.Get("Transfer-Encoding") == "hello!" { 416 | t.Errorf("unexpected request header %s to backend: %s", "Transfer-Encoding", receivedHeader.Get("Transfer-Encoding")) 417 | } 418 | if receivedHeader.Get("Upgrade") == "hello!" { 419 | t.Errorf("unexpected request header %s to backend: %s", "Upgrade", receivedHeader.Get("Upgrade")) 420 | } 421 | } 422 | 423 | func TestProxyHeadersFromBackend(t *testing.T) { 424 | cancel, serverURL, _, _, transport := etcdMock(func(request *http.Request) { 425 | }) 426 | defer cancel() 427 | 428 | backends := []*Backend{ 429 | NewBackend(serverURL), 430 | } 431 | 432 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 433 | return backends, nil 434 | }) 435 | 436 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://localhost:2381") 437 | 438 | request, _ := http.NewRequest("GET", "http://localhost/headers", nil) 439 | 440 | recorder := httptest.NewRecorder() 441 | proxyHandler.ServeHTTP(recorder, request) 442 | 443 | if recorder.Code != 200 { 444 | t.Errorf("unexpected response code: %d", recorder.Code) 445 | } 446 | receivedHeader := recorder.Header() 447 | if receivedHeader.Get("X-My-Original") != "hello!" { 448 | t.Errorf("unexpected response header %s from backend: %s", "Connection", receivedHeader.Get("Connection")) 449 | } 450 | if receivedHeader.Get("Connection") == "hello!" { 451 | t.Errorf("unexpected response header %s from backend: %s", "Connection", receivedHeader.Get("Connection")) 452 | } 453 | if receivedHeader.Get("Keep-Alive") == "hello!" { 454 | t.Errorf("unexpected response header %s from backend: %s", "Keep-Alive", receivedHeader.Get("Keep-Alive")) 455 | } 456 | if receivedHeader.Get("Proxy-Authenticate") == "hello!" { 457 | t.Errorf("unexpected response header %s from backend: %s", "Proxy-Authenticate", receivedHeader.Get("Proxy-Authenticate")) 458 | } 459 | if receivedHeader.Get("Proxy-Authorization") == "hello!" { 460 | t.Errorf("unexpected response header %s from backend: %s", "Proxy-Authorization", receivedHeader.Get("Proxy-Authorization")) 461 | } 462 | if receivedHeader.Get("Te") == "hello!" { 463 | t.Errorf("unexpected response header %s from backend: %s", "Te", receivedHeader.Get("Te")) 464 | } 465 | if receivedHeader.Get("Trailers") == "hello!" { 466 | t.Errorf("unexpected response header %s from backend: %s", "Trailers", receivedHeader.Get("Trailers")) 467 | } 468 | if receivedHeader.Get("Upgrade") == "hello!" { 469 | t.Errorf("unexpected response header %s from backend: %s", "Upgrade", receivedHeader.Get("Upgrade")) 470 | } 471 | } 472 | 473 | func TestProxyMembersRequest(t *testing.T) { 474 | cancel, _, _, _, transport := etcdMock(func(request *http.Request) { 475 | }) 476 | defer cancel() 477 | 478 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 479 | return []*Backend{}, nil 480 | }) 481 | 482 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://advertise-url") 483 | 484 | request, _ := http.NewRequest("GET", "http://localhost/v2/members", nil) 485 | 486 | recorder := httptest.NewRecorder() 487 | proxyHandler.ServeHTTP(recorder, request) 488 | 489 | if recorder.Code != 200 { 490 | t.Errorf("unexpected response code: %d", recorder.Code) 491 | } 492 | if header := recorder.Header().Get("Content-Type"); header != "application/json" { 493 | t.Errorf("unexpected Content-Type: %s", recorder.Header().Get("Content-Type")) 494 | } 495 | 496 | expectedJson := map[string]interface{}{} 497 | err := json.Unmarshal([]byte(`{"Members":[{"ClientURLs":["http://advertise-url"],"PeerURLs":null,"Name":"etcvault","Id":"deadbeef"}]}`), &expectedJson) 498 | if err != nil { 499 | panic(err) 500 | } 501 | 502 | responseJson := map[string]interface{}{} 503 | err = json.Unmarshal(recorder.Body.Bytes(), &responseJson) 504 | if err != nil { 505 | t.Errorf("response body couldn't parse as JSON: %s\n%s", err.Error(), recorder.Body.String()) 506 | } 507 | 508 | if !reflect.DeepEqual(responseJson, expectedJson) { 509 | t.Errorf("unexpected response body: %s", recorder.Body.String()) 510 | } 511 | 512 | } 513 | 514 | func TestProxyMachines(t *testing.T) { 515 | cancel, _, _, _, transport := etcdMock(func(request *http.Request) { 516 | }) 517 | defer cancel() 518 | 519 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 520 | return []*Backend{}, nil 521 | }) 522 | 523 | proxyHandler := NewProxy(transport, router, &mockEngine{}, "http://advertise-url") 524 | 525 | request, _ := http.NewRequest("GET", "http://localhost/v2/machines", nil) 526 | 527 | recorder := httptest.NewRecorder() 528 | proxyHandler.ServeHTTP(recorder, request) 529 | 530 | if recorder.Code != 200 { 531 | t.Errorf("unexpected response code: %d", recorder.Code) 532 | } 533 | if recorder.Body.String() != "http://advertise-url" { 534 | t.Errorf("unexpected response body: %s", recorder.Body.String()) 535 | } 536 | } 537 | -------------------------------------------------------------------------------- /proxy/readonly.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/sorah/etcvault/engine" 5 | "net/http" 6 | ) 7 | 8 | func NewReadonlyProxy(transport *http.Transport, router *Router, e engine.Transformable, advertiseUrl string) http.Handler { 9 | return readonlyHandler(NewProxy(transport, router, e, advertiseUrl)) 10 | } 11 | 12 | func readonlyHandler(handler http.Handler) http.Handler { 13 | return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { 14 | if request.Method != "GET" { 15 | // I prefer method not allowed, but following etcd's proxy mode behavior for compat 16 | response.WriteHeader(http.StatusNotImplemented) 17 | return 18 | } 19 | 20 | handler.ServeHTTP(response, request) 21 | }) 22 | } 23 | -------------------------------------------------------------------------------- /proxy/readonly_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestReadonlyProxyGet(t *testing.T) { 12 | cancel, serverURL, _, _, transport := etcdMock(func(request *http.Request) { 13 | if request.Method != "GET" { 14 | t.Errorf("Received non GET request") 15 | } 16 | }) 17 | defer cancel() 18 | 19 | backends := []*Backend{ 20 | NewBackend(serverURL), 21 | } 22 | 23 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 24 | return backends, nil 25 | }) 26 | 27 | proxyHandler := NewReadonlyProxy(transport, router, &mockEngine{}, "http://localhost") 28 | 29 | request, _ := http.NewRequest("GET", "http://localhost/v2/keys/greeting", nil) 30 | recorder := httptest.NewRecorder() 31 | proxyHandler.ServeHTTP(recorder, request) 32 | 33 | if recorder.Code != 200 { 34 | t.Errorf("unexpected response code: %d", recorder.Code) 35 | } 36 | if strings.Contains(recorder.Body.String(), "") { 37 | t.Errorf("unexpected response body: %s", recorder.Body.String()) 38 | } 39 | if header := recorder.Header().Get("Content-Type"); header != "application/json" { 40 | t.Errorf("unexpected Content-Type: %s", recorder.Header().Get("Content-Type")) 41 | } 42 | } 43 | 44 | func TestReadonlyProxyPost(t *testing.T) { 45 | cancel, serverURL, _, _, transport := etcdMock(func(request *http.Request) { 46 | if request.Method != "GET" { 47 | t.Errorf("Received non GET request") 48 | } 49 | }) 50 | defer cancel() 51 | 52 | backends := []*Backend{ 53 | NewBackend(serverURL), 54 | } 55 | 56 | router := NewRouter(time.Hour*24, func() ([]*Backend, error) { 57 | return backends, nil 58 | }) 59 | 60 | proxyHandler := NewReadonlyProxy(transport, router, &mockEngine{}, "http://localhost") 61 | 62 | request, _ := http.NewRequest("POST", "http://localhost/v2/keys/greeting", nil) 63 | recorder := httptest.NewRecorder() 64 | proxyHandler.ServeHTTP(recorder, request) 65 | 66 | if recorder.Code != 501 { 67 | t.Errorf("unexpected response code: %d", recorder.Code) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /proxy/router.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | "math/rand" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | const ( 12 | backendFilterAll = iota 13 | backendFilterFailed = iota 14 | backendFilterAvailable = iota 15 | ) 16 | 17 | var ErrAlreadyUpdateStarted = errors.New("Periodical updating is already running") 18 | 19 | type BackendUpdateFunc func() ([]*Backend, error) 20 | 21 | type Router struct { 22 | sync.RWMutex 23 | backends []*Backend 24 | UpdateFunc BackendUpdateFunc 25 | UpdateInterval time.Duration 26 | updateStopCh chan bool 27 | } 28 | 29 | func NewRouter(interval time.Duration, updateFunc BackendUpdateFunc) *Router { 30 | router := &Router{ 31 | backends: []*Backend{}, 32 | UpdateFunc: updateFunc, 33 | UpdateInterval: interval, 34 | updateStopCh: nil, 35 | } 36 | router.Update() 37 | return router 38 | } 39 | 40 | func (router *Router) StartUpdate() error { 41 | router.Lock() 42 | defer router.Unlock() 43 | 44 | if router.updateStopCh != nil { 45 | return ErrAlreadyUpdateStarted 46 | } 47 | 48 | router.updateStopCh = make(chan bool) 49 | 50 | go func() { 51 | for { 52 | select { 53 | case <-router.updateStopCh: 54 | return 55 | case <-time.After(router.UpdateInterval): 56 | router.Update() 57 | } 58 | } 59 | }() 60 | 61 | log.Println("Started periodical update of backends") 62 | 63 | return nil 64 | } 65 | 66 | func (router *Router) StopUpdate() { 67 | router.Lock() 68 | defer router.Unlock() 69 | 70 | if router.updateStopCh != nil { 71 | router.updateStopCh <- true 72 | log.Println("Stopped periodical update of backends") 73 | } 74 | } 75 | 76 | func (router *Router) Update() { 77 | router.Lock() 78 | defer router.Unlock() 79 | 80 | newBackends, err := router.UpdateFunc() 81 | if err == nil { 82 | router.backends = newBackends 83 | } else { 84 | log.Printf("Failed to update backends: %s", err.Error()) 85 | } 86 | } 87 | 88 | func (router *Router) getBackends(filter int) []*Backend { 89 | router.RLock() 90 | defer router.RUnlock() 91 | 92 | filteredBackends := make([]*Backend, 0, len(router.backends)) 93 | 94 | for _, backend := range router.backends { 95 | switch filter { 96 | case backendFilterAll: 97 | case backendFilterFailed: 98 | if backend.Available { 99 | continue 100 | } 101 | case backendFilterAvailable: 102 | if !backend.Available { 103 | continue 104 | } 105 | } 106 | 107 | filteredBackends = append(filteredBackends, backend) 108 | } 109 | 110 | return filteredBackends 111 | } 112 | 113 | func (router *Router) Backends() []*Backend { 114 | return router.getBackends(backendFilterAll) 115 | } 116 | 117 | func (router *Router) FailedBackends() []*Backend { 118 | return router.getBackends(backendFilterFailed) 119 | } 120 | 121 | func (router *Router) AvailableBackends() []*Backend { 122 | return router.getBackends(backendFilterAvailable) 123 | } 124 | 125 | func (router *Router) ShuffledAvailableBackends() []*Backend { 126 | backends := router.AvailableBackends() 127 | shuffledBackends := make([]*Backend, len(backends)) 128 | 129 | pattern := rand.Perm(len(backends)) 130 | for i, idx := range pattern { 131 | shuffledBackends[i] = backends[idx] 132 | } 133 | 134 | return shuffledBackends 135 | } 136 | -------------------------------------------------------------------------------- /proxy/router_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func generateBackendsForTest(count int) []*Backend { 11 | backends := make([]*Backend, 0, count) 12 | 13 | for i := 0; i < count; i++ { 14 | u, _ := url.Parse(fmt.Sprintf("http://backend-%d", i)) 15 | backends = append(backends, NewBackend(u)) 16 | } 17 | 18 | return backends 19 | } 20 | 21 | func TestBackends(t *testing.T) { 22 | router := NewRouter(time.Second*60, func() ([]*Backend, error) { 23 | return generateBackendsForTest(3), nil 24 | }) 25 | 26 | backends := router.Backends() 27 | 28 | if len(backends) != 3 { 29 | t.Errorf("Unexpected backends length %d", len(backends)) 30 | return 31 | } 32 | 33 | if backends[0].Url.Host != "backend-0" { 34 | t.Errorf("Unexpected backends[0] url %s", backends[0].Url.Host) 35 | } 36 | if backends[1].Url.Host != "backend-1" { 37 | t.Errorf("Unexpected backends[1] url %s", backends[1].Url.Host) 38 | } 39 | if backends[2].Url.Host != "backend-2" { 40 | t.Errorf("Unexpected backends[2] url %s", backends[2].Url.Host) 41 | } 42 | } 43 | 44 | func TestAvailableBackends(t *testing.T) { 45 | backendsSource := generateBackendsForTest(3) 46 | router := NewRouter(time.Second*60, func() ([]*Backend, error) { 47 | return backendsSource, nil 48 | }) 49 | 50 | backendsSource[0].Fail() 51 | 52 | backends := router.AvailableBackends() 53 | 54 | if len(backends) != 2 { 55 | t.Errorf("Unexpected backends length %d", len(backends)) 56 | return 57 | } 58 | 59 | if backends[0].Url.Host != "backend-1" { 60 | t.Errorf("Unexpected backends[0] url %s", backends[0].Url.Host) 61 | } 62 | if backends[1].Url.Host != "backend-2" { 63 | t.Errorf("Unexpected backends[1] url %s", backends[1].Url.Host) 64 | } 65 | } 66 | 67 | func TestFailedBackends(t *testing.T) { 68 | backendsSource := generateBackendsForTest(3) 69 | router := NewRouter(time.Second*60, func() ([]*Backend, error) { 70 | return backendsSource, nil 71 | }) 72 | 73 | backendsSource[1].Fail() 74 | backendsSource[2].Fail() 75 | 76 | backends := router.FailedBackends() 77 | 78 | if len(backends) != 2 { 79 | t.Errorf("Unexpected backends length %d", len(backends)) 80 | return 81 | } 82 | 83 | if backends[0].Url.Host != "backend-1" { 84 | t.Errorf("Unexpected backends[0] url %s", backends[0].Url.Host) 85 | } 86 | if backends[1].Url.Host != "backend-2" { 87 | t.Errorf("Unexpected backends[1] url %s", backends[1].Url.Host) 88 | } 89 | } 90 | 91 | func TestShuffledAvailableBackends(t *testing.T) { 92 | router := NewRouter(time.Second*60, func() ([]*Backend, error) { 93 | return generateBackendsForTest(3), nil 94 | }) 95 | 96 | backends := router.ShuffledAvailableBackends() 97 | 98 | if len(backends) != 3 { 99 | t.Errorf("Unexpected backends length %d", len(backends)) 100 | return 101 | } 102 | 103 | hosts := make(map[string]bool) 104 | hosts[backends[0].Url.Host] = true 105 | hosts[backends[1].Url.Host] = true 106 | hosts[backends[2].Url.Host] = true 107 | 108 | if exist, ok := hosts["backend-0"]; !(ok && exist) { 109 | t.Errorf("backend-0 not ok: %#v", backends) 110 | } 111 | if exist, ok := hosts["backend-1"]; !(ok && exist) { 112 | t.Errorf("backend-1 not ok: %#v", backends) 113 | } 114 | if exist, ok := hosts["backend-2"]; !(ok && exist) { 115 | t.Errorf("backend-2 not ok: %#v", backends) 116 | } 117 | } 118 | 119 | func TestUpdate(t *testing.T) { 120 | i := 0 121 | router := NewRouter(time.Second*60, func() ([]*Backend, error) { 122 | i++ 123 | return generateBackendsForTest(2 + i), nil 124 | }) 125 | 126 | backends := router.Backends() 127 | 128 | if len(backends) != 3 { 129 | t.Errorf("Unexpected backends length %d", len(backends)) 130 | return 131 | } 132 | 133 | if backends[0].Url.Host != "backend-0" { 134 | t.Errorf("Unexpected backends[0] url %s", backends[0].Url.Host) 135 | } 136 | if backends[1].Url.Host != "backend-1" { 137 | t.Errorf("Unexpected backends[1] url %s", backends[1].Url.Host) 138 | } 139 | if backends[2].Url.Host != "backend-2" { 140 | t.Errorf("Unexpected backends[2] url %s", backends[2].Url.Host) 141 | } 142 | 143 | router.Update() 144 | backends = router.Backends() 145 | 146 | if len(backends) != 4 { 147 | t.Errorf("Unexpected backends length %d", len(backends)) 148 | return 149 | } 150 | 151 | if backends[0].Url.Host != "backend-0" { 152 | t.Errorf("Unexpected backends[0] url %s", backends[0].Url.Host) 153 | } 154 | if backends[1].Url.Host != "backend-1" { 155 | t.Errorf("Unexpected backends[1] url %s", backends[1].Url.Host) 156 | } 157 | if backends[2].Url.Host != "backend-2" { 158 | t.Errorf("Unexpected backends[2] url %s", backends[2].Url.Host) 159 | } 160 | if backends[3].Url.Host != "backend-3" { 161 | t.Errorf("Unexpected backends[3] url %s", backends[3].Url.Host) 162 | } 163 | } 164 | 165 | func TestUpdateFail(t *testing.T) { 166 | i := -1 167 | router := NewRouter(time.Second*60, func() (backends []*Backend, err error) { 168 | i++ 169 | if i == 1 { 170 | backends = nil 171 | err = fmt.Errorf("hehe") 172 | return 173 | } 174 | backends = generateBackendsForTest(3 + i) 175 | return 176 | }) 177 | 178 | router.Update() 179 | backends := router.Backends() 180 | 181 | if len(backends) != 3 { 182 | t.Errorf("Unexpected backends length %d", len(backends)) 183 | return 184 | } 185 | 186 | if backends[0].Url.Host != "backend-0" { 187 | t.Errorf("Unexpected backends[0] url %s", backends[0].Url.Host) 188 | } 189 | if backends[1].Url.Host != "backend-1" { 190 | t.Errorf("Unexpected backends[1] url %s", backends[1].Url.Host) 191 | } 192 | if backends[2].Url.Host != "backend-2" { 193 | t.Errorf("Unexpected backends[2] url %s", backends[2].Url.Host) 194 | } 195 | 196 | router.Update() 197 | backends = router.Backends() 198 | 199 | if len(backends) != 5 { 200 | t.Errorf("Unexpected backends length %d", len(backends)) 201 | return 202 | } 203 | 204 | if backends[0].Url.Host != "backend-0" { 205 | t.Errorf("Unexpected backends[0] url %s", backends[0].Url.Host) 206 | } 207 | if backends[1].Url.Host != "backend-1" { 208 | t.Errorf("Unexpected backends[1] url %s", backends[1].Url.Host) 209 | } 210 | if backends[2].Url.Host != "backend-2" { 211 | t.Errorf("Unexpected backends[2] url %s", backends[2].Url.Host) 212 | } 213 | if backends[3].Url.Host != "backend-3" { 214 | t.Errorf("Unexpected backends[3] url %s", backends[3].Url.Host) 215 | } 216 | if backends[4].Url.Host != "backend-4" { 217 | t.Errorf("Unexpected backends[4] url %s", backends[4].Url.Host) 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /proxystarter.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "encoding/pem" 7 | "fmt" 8 | "github.com/sorah/etcvault/engine" 9 | "github.com/sorah/etcvault/keys" 10 | "github.com/sorah/etcvault/proxy" 11 | "io/ioutil" 12 | "net" 13 | "net/http" 14 | "net/url" 15 | "os" 16 | "strings" 17 | "time" 18 | ) 19 | 20 | func defaultHttpTransport() *http.Transport { 21 | return &http.Transport{ 22 | // DefaultTransport 23 | Dial: (&net.Dialer{ 24 | Timeout: 30 * time.Second, 25 | KeepAlive: 30 * time.Second, 26 | }).Dial, 27 | TLSHandshakeTimeout: 10 * time.Second, 28 | } 29 | } 30 | 31 | func caPool(caPath string) *x509.CertPool { 32 | pool := x509.NewCertPool() 33 | remainingPem, err := ioutil.ReadFile(caPath) 34 | if err != nil { 35 | fmt.Fprintf(os.Stderr, "error loading CA file %s: %s", caPath, err) 36 | os.Exit(1) 37 | } 38 | 39 | for { // load while file ends 40 | var block *pem.Block 41 | block, remainingPem = pem.Decode(remainingPem) 42 | if block == nil { 43 | return pool 44 | } 45 | cert, err := x509.ParseCertificate(block.Bytes) 46 | if err != nil { 47 | fmt.Fprintf(os.Stderr, "error while parsing CA PEM blocks: %s", err.Error()) 48 | os.Exit(1) 49 | } 50 | pool.AddCert(cert) 51 | } 52 | } 53 | 54 | func parseTlsKeypair(certPath, keyPath string) *tls.Config { 55 | certBytes, err := ioutil.ReadFile(certPath) 56 | if err != nil { 57 | fmt.Fprintf(os.Stderr, "error loading certificate %s: %s\n", certPath, err.Error()) 58 | os.Exit(1) 59 | } 60 | keyBytes, err := ioutil.ReadFile(keyPath) 61 | if err != nil { 62 | fmt.Fprintf(os.Stderr, "error loading certificate %s: %s\n", certPath, err.Error()) 63 | os.Exit(1) 64 | } 65 | 66 | keypair, err := tls.X509KeyPair(certBytes, keyBytes) 67 | if err != nil { 68 | fmt.Printf("error loading keypair: %s\n", err.Error()) 69 | } 70 | 71 | return &tls.Config{ 72 | Certificates: []tls.Certificate{keypair}, 73 | MinVersion: tls.VersionTLS10, 74 | } 75 | } 76 | 77 | func tlsConfigurationForClientUse(config *tls.Config, caPath string) *tls.Config { 78 | if config == nil { 79 | config = &tls.Config{} 80 | } 81 | 82 | if caPath != "" { 83 | config.RootCAs = caPool(caPath) 84 | } 85 | 86 | return config 87 | } 88 | 89 | func tlsConfigurationForServerUse(config *tls.Config, caPath string) *tls.Config { 90 | if config == nil { 91 | config = &tls.Config{} 92 | } 93 | 94 | if caPath != "" { 95 | config.ClientAuth = tls.RequireAndVerifyClientCert 96 | config.ClientCAs = caPool(caPath) 97 | } else { 98 | config.ClientAuth = tls.NoClientCert 99 | } 100 | 101 | return config 102 | } 103 | 104 | type ProxyStarter struct { 105 | // arguments 106 | Listen *url.URL 107 | AdvertiseUrl string 108 | 109 | keychainDir string 110 | DiscoverySrvDomain string 111 | initialBackendUrlStrings string 112 | 113 | clientCaFilePath string 114 | clientCertFilePath string 115 | clientKeyFilePath string 116 | 117 | peerCaFilePath string 118 | peerCertFilePath string 119 | peerKeyFilePath string 120 | 121 | listenCaFilePath string 122 | listenCertFilePath string 123 | listenKeyFilePath string 124 | 125 | readonly bool 126 | 127 | discoveryInterval time.Duration 128 | 129 | router *proxy.Router 130 | } 131 | 132 | func (starter *ProxyStarter) InitialBackendUrls() []*url.URL { 133 | urlStrings := strings.Split(starter.initialBackendUrlStrings, ",") 134 | urls := make([]*url.URL, len(urlStrings)) 135 | 136 | for i, urlString := range urlStrings { 137 | u, err := url.Parse(urlString) 138 | if err != nil { 139 | fmt.Fprintf(os.Stderr, "failed to parse url %s: %s\n", urlString, err.Error()) 140 | os.Exit(1) 141 | } 142 | urls[i] = u 143 | } 144 | 145 | return urls 146 | } 147 | 148 | func (starter *ProxyStarter) Keychain() *keys.Keychain { 149 | return keys.NewKeychain(starter.keychainDir) 150 | } 151 | 152 | func (starter *ProxyStarter) Engine() *engine.Engine { 153 | return engine.NewEngine(starter.Keychain()) 154 | } 155 | 156 | func (starter *ProxyStarter) ListenTlsConfig() *tls.Config { 157 | if starter.listenKeyFilePath != "" && starter.listenCertFilePath != "" { 158 | return parseTlsKeypair(starter.listenCertFilePath, starter.listenKeyFilePath) 159 | } else { 160 | return nil 161 | } 162 | } 163 | 164 | func (starter *ProxyStarter) PeerTlsConfig() *tls.Config { 165 | if starter.peerKeyFilePath != "" && starter.peerCertFilePath != "" { 166 | return parseTlsKeypair(starter.peerCertFilePath, starter.peerKeyFilePath) 167 | } else { 168 | return nil 169 | } 170 | } 171 | 172 | func (starter *ProxyStarter) ClientTlsConfig() *tls.Config { 173 | if starter.clientKeyFilePath != "" && starter.clientCertFilePath != "" { 174 | return parseTlsKeypair(starter.clientCertFilePath, starter.clientKeyFilePath) 175 | } else { 176 | return nil 177 | } 178 | } 179 | 180 | func (starter *ProxyStarter) TlsConfigForServerUse() *tls.Config { 181 | if starter.listenKeyFilePath != "" && starter.listenCertFilePath != "" { 182 | return tlsConfigurationForServerUse(starter.ListenTlsConfig(), starter.listenCaFilePath) 183 | } else if starter.clientKeyFilePath != "" && starter.clientCertFilePath != "" { 184 | return tlsConfigurationForServerUse(starter.ClientTlsConfig(), starter.clientCaFilePath) 185 | } else { 186 | return nil 187 | } 188 | } 189 | 190 | func (starter *ProxyStarter) ClientTlsConfigForClientUse() *tls.Config { 191 | return tlsConfigurationForClientUse(starter.ClientTlsConfig(), starter.clientCaFilePath) 192 | } 193 | 194 | func (starter *ProxyStarter) PeerTlsConfigForClientUse() *tls.Config { 195 | return tlsConfigurationForClientUse(starter.PeerTlsConfig(), starter.peerCaFilePath) 196 | } 197 | 198 | func (starter *ProxyStarter) PeerHttpTransport() *http.Transport { 199 | transport := defaultHttpTransport() 200 | transport.TLSClientConfig = starter.PeerTlsConfigForClientUse() 201 | return transport 202 | } 203 | 204 | func (starter *ProxyStarter) ClientHttpTransport() *http.Transport { 205 | transport := defaultHttpTransport() 206 | transport.TLSClientConfig = starter.ClientTlsConfigForClientUse() 207 | return transport 208 | } 209 | 210 | func (starter *ProxyStarter) Listener() net.Listener { 211 | listener, err := net.Listen("tcp", starter.Listen.Host) 212 | if err != nil { 213 | fmt.Fprintf(os.Stderr, "failed to listen %s: %s", starter.Listen.String(), err.Error()) 214 | os.Exit(1) 215 | } 216 | 217 | if starter.Listen.Scheme == "https" { 218 | tlsConfig := starter.TlsConfigForServerUse() 219 | listener = tls.NewListener(listener, tlsConfig) 220 | } 221 | 222 | return listener 223 | } 224 | 225 | func (starter *ProxyStarter) BackendUpdateFunc() proxy.BackendUpdateFunc { 226 | if starter.DiscoverySrvDomain != "" { 227 | transport := starter.PeerHttpTransport() 228 | return func() ([]*proxy.Backend, error) { 229 | return proxy.DiscoverBackendsFromDns(transport, starter.DiscoverySrvDomain) 230 | } 231 | } else { 232 | transport := starter.ClientHttpTransport() 233 | return func() ([]*proxy.Backend, error) { 234 | return proxy.DiscoverBackendsFromEtcd(transport, starter.InitialBackendUrls()), nil 235 | } 236 | } 237 | } 238 | 239 | func (starter *ProxyStarter) Router() *proxy.Router { 240 | if starter.router != nil { 241 | return starter.router 242 | } 243 | 244 | starter.router = proxy.NewRouter(starter.discoveryInterval, starter.BackendUpdateFunc()) 245 | err := starter.router.StartUpdate() 246 | if err != nil { 247 | fmt.Fprintf(os.Stderr, "error starting backend discovery: %s", err.Error()) 248 | } 249 | 250 | return starter.router 251 | } 252 | 253 | func (starter *ProxyStarter) Proxy() http.Handler { 254 | if starter.readonly { 255 | return proxy.NewReadonlyProxy(starter.ClientHttpTransport(), starter.Router(), starter.Engine(), starter.AdvertiseUrl) 256 | } else { 257 | return proxy.NewProxy(starter.ClientHttpTransport(), starter.Router(), starter.Engine(), starter.AdvertiseUrl) 258 | } 259 | } 260 | 261 | func (starter *ProxyStarter) HttpServer() *http.Server { 262 | return &http.Server{ 263 | Handler: starter.Proxy(), 264 | ReadTimeout: 5 * time.Minute, 265 | } 266 | } 267 | 268 | func (starter *ProxyStarter) Start() { 269 | fmt.Printf("Serving at %s\n", starter.Listen.String()) 270 | starter.HttpServer().Serve(starter.Listener()) 271 | } 272 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | PKGS="./keys ./container ./engine ./proxy" 5 | FORMATS="$PKGS *.go" 6 | 7 | for pkg in $PKGS; do 8 | go test -cover $pkg 9 | done 10 | 11 | fmt_result="$(gofmt -l $FORMATS)" 12 | if [ -n "${fmt_result}" ]; then 13 | echo -e "gofmt checking failed:\n${fmt_result}" 14 | exit 1 15 | fi 16 | 17 | --------------------------------------------------------------------------------